shape_model.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. """
  2. Shape Model Module
  3. YOLO-based shape classification model for durian quality assessment.
  4. Detects and classifies durian shape: Regular vs Irregular.
  5. """
  6. from pathlib import Path
  7. from typing import Dict, Any, Optional
  8. import logging
  9. import cv2
  10. import numpy as np
  11. from ultralytics import YOLO
  12. from PyQt5.QtGui import QImage
  13. from models.base_model import BaseModel
  14. from utils.config import (
  15. YOLO_CONFIDENCE_THRESHOLD,
  16. )
  17. logger = logging.getLogger(__name__)
  18. # Shape class constants
  19. SHAPE_CLASS_NAMES = {
  20. 0: "Irregular",
  21. 1: "Regular",
  22. }
  23. SHAPE_CLASS_COLORS = {
  24. 0: (86, 0, 254), # Irregular - Purple
  25. 1: (0, 252, 199), # Regular - Cyan
  26. }
  27. class ShapeModel(BaseModel):
  28. """
  29. YOLO-based shape classification model.
  30. Classifies durian shape into:
  31. - Regular (Class 1)
  32. - Irregular (Class 0)
  33. Attributes:
  34. model: YOLO model instance
  35. class_names: Name mapping for each class
  36. class_colors: BGR color mapping for visualization
  37. confidence_threshold: Minimum confidence for classifications
  38. """
  39. def __init__(
  40. self,
  41. model_path: str,
  42. device: str = "cuda",
  43. confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD
  44. ):
  45. """
  46. Initialize the shape classification model.
  47. Args:
  48. model_path: Path to YOLO .pt model file (shape.pt)
  49. device: Device to use ('cuda' or 'cpu')
  50. confidence_threshold: Minimum confidence threshold (0.0-1.0)
  51. """
  52. super().__init__(model_path, device)
  53. self.class_names = SHAPE_CLASS_NAMES
  54. self.class_colors = SHAPE_CLASS_COLORS
  55. self.confidence_threshold = confidence_threshold
  56. def load(self) -> bool:
  57. """
  58. Load the YOLO model.
  59. Returns:
  60. bool: True if loaded successfully, False otherwise
  61. """
  62. try:
  63. model_path = Path(self.model_path)
  64. if not model_path.exists():
  65. logger.error(f"Shape model file does not exist: {model_path}")
  66. return False
  67. logger.info(f"Loading shape model from {model_path}")
  68. self.model = YOLO(str(model_path))
  69. # Move model to specified device
  70. self.model.to(self.device)
  71. self._is_loaded = True
  72. logger.info(f"Shape model loaded on {self.device}")
  73. return True
  74. except Exception as e:
  75. logger.error(f"Failed to load shape model: {e}")
  76. self._is_loaded = False
  77. return False
  78. def _draw_bounding_box(
  79. self,
  80. image: np.ndarray,
  81. box: Any,
  82. class_id: int,
  83. confidence: float,
  84. shape_class: str
  85. ) -> np.ndarray:
  86. """
  87. Draw bounding box and label on the image.
  88. Args:
  89. image: Input image (BGR format)
  90. box: YOLO box object with coordinates
  91. class_id: Class ID (0=Irregular, 1=Regular)
  92. confidence: Confidence score
  93. shape_class: Shape class name
  94. Returns:
  95. Annotated image with bounding box
  96. """
  97. annotated = image.copy()
  98. # Get bounding box coordinates
  99. xmin, ymin, xmax, ymax = map(int, box.xyxy[0])
  100. # Get color based on class
  101. color = self.class_colors.get(class_id, (255, 255, 255))
  102. # Draw bounding box
  103. cv2.rectangle(
  104. annotated,
  105. (xmin, ymin),
  106. (xmax, ymax),
  107. color,
  108. 2
  109. )
  110. # Draw label with confidence
  111. label = f"{shape_class}: {confidence:.2f}"
  112. cv2.putText(
  113. annotated,
  114. label,
  115. (xmin, ymin - 5),
  116. cv2.FONT_HERSHEY_SIMPLEX,
  117. 0.8,
  118. color,
  119. 2,
  120. lineType=cv2.LINE_AA
  121. )
  122. return annotated
  123. def predict(self, image_path: str) -> Dict[str, Any]:
  124. """
  125. Classify the shape of a durian in an image.
  126. Args:
  127. image_path: Path to input image
  128. Returns:
  129. Dict containing:
  130. - 'success': Whether prediction succeeded
  131. - 'shape_class': Detected shape (Regular/Irregular)
  132. - 'class_id': Numeric class ID (0=Irregular, 1=Regular)
  133. - 'confidence': Confidence score (0.0-1.0)
  134. - 'annotated_image': QImage with bounding box (if detection model)
  135. - 'error': Error message if failed
  136. Raises:
  137. RuntimeError: If model is not loaded
  138. """
  139. if not self._is_loaded:
  140. raise RuntimeError("Model not loaded. Call load() first.")
  141. try:
  142. # Load image
  143. image = cv2.imread(image_path)
  144. if image is None:
  145. raise ValueError(f"Could not load image: {image_path}")
  146. # Run YOLO inference
  147. results = self.model.predict(image)
  148. shape_class = None
  149. confidence = 0.0
  150. class_id = None
  151. annotated_image = None
  152. # Process results - shape.pt is a classification model with probs
  153. for result in results:
  154. # Check if results have classification probabilities
  155. if result.probs is not None:
  156. logger.info(f"Shape model returned probs: {result.probs}")
  157. # Get top class
  158. class_id = int(result.probs.top1) # Index of highest probability
  159. confidence = float(result.probs.top1conf.cpu().item())
  160. # Get class name
  161. shape_class = self.class_names.get(
  162. class_id,
  163. f"Unknown({class_id})"
  164. )
  165. logger.info(
  166. f"Shape classification (via probs): {shape_class} "
  167. f"(confidence: {confidence:.3f})"
  168. )
  169. return {
  170. 'success': True,
  171. 'shape_class': shape_class,
  172. 'class_id': class_id,
  173. 'confidence': confidence,
  174. 'annotated_image': None,
  175. 'error': None
  176. }
  177. # Fallback: Check for detection results with class names
  178. # (in case shape.pt is detection model instead of classification)
  179. if result.boxes is not None and len(result.boxes) > 0:
  180. logger.info(f"Shape model returned detection boxes")
  181. boxes = result.boxes.cpu().numpy()
  182. if len(boxes) > 0:
  183. box = boxes[0]
  184. class_id = int(box.cls[0])
  185. confidence = float(box.conf[0])
  186. # Get class name
  187. shape_class = self.class_names.get(
  188. class_id,
  189. f"Unknown({class_id})"
  190. )
  191. # Draw bounding box on image
  192. annotated_image_np = self._draw_bounding_box(
  193. image,
  194. box,
  195. class_id,
  196. confidence,
  197. shape_class
  198. )
  199. # Convert to QImage
  200. rgb_image = cv2.cvtColor(annotated_image_np, cv2.COLOR_BGR2RGB)
  201. h, w, ch = rgb_image.shape
  202. bytes_per_line = ch * w
  203. annotated_image = QImage(
  204. rgb_image.data,
  205. w,
  206. h,
  207. bytes_per_line,
  208. QImage.Format_RGB888
  209. )
  210. logger.info(
  211. f"Shape classification (via boxes): {shape_class} "
  212. f"(confidence: {confidence:.3f})"
  213. )
  214. return {
  215. 'success': True,
  216. 'shape_class': shape_class,
  217. 'class_id': class_id,
  218. 'confidence': confidence,
  219. 'annotated_image': annotated_image,
  220. 'error': None
  221. }
  222. # No results found
  223. logger.warning(f"No shape classification results from model. Results: {results}")
  224. return {
  225. 'success': False,
  226. 'shape_class': None,
  227. 'class_id': None,
  228. 'confidence': 0.0,
  229. 'annotated_image': None,
  230. 'error': 'No classification result from model'
  231. }
  232. except Exception as e:
  233. logger.error(f"Shape prediction failed: {e}")
  234. import traceback
  235. logger.error(traceback.format_exc())
  236. return {
  237. 'success': False,
  238. 'shape_class': None,
  239. 'class_id': None,
  240. 'confidence': 0.0,
  241. 'annotated_image': None,
  242. 'error': str(e)
  243. }
  244. def predict_batch(self, image_paths: list) -> list:
  245. """
  246. Classify shapes in multiple images.
  247. Args:
  248. image_paths: List of paths to images
  249. Returns:
  250. List[Dict]: List of prediction results
  251. """
  252. results = []
  253. for image_path in image_paths:
  254. result = self.predict(image_path)
  255. results.append(result)
  256. return results