defect_model.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """
  2. Defect Model Module
  3. YOLO-based defect detection model for durian quality assessment.
  4. Detects and classifies defects: minor defects, no defects, reject.
  5. """
  6. from pathlib import Path
  7. from typing import Dict, Any, Optional, Tuple
  8. import logging
  9. import cv2
  10. import numpy as np
  11. from ultralytics import YOLO
  12. from PyQt5.QtGui import QImage
  13. from models.base_model import BaseModel
  14. from utils.config import (
  15. DEFECT_MODEL_PATH,
  16. DEFECT_CLASS_COLORS,
  17. DEFECT_CLASS_NAMES,
  18. YOLO_CONFIDENCE_THRESHOLD,
  19. )
  20. logger = logging.getLogger(__name__)
  21. class DefectModel(BaseModel):
  22. """
  23. YOLO-based defect detection model.
  24. Detects defects in durian images and classifies them into:
  25. - Minor defects (Pink)
  26. - No defects (Cyan)
  27. - Reject (Purple)
  28. Attributes:
  29. model: YOLO model instance
  30. class_colors: BGR color mapping for each class
  31. class_names: Name mapping for each class
  32. confidence_threshold: Minimum confidence for detections
  33. """
  34. def __init__(
  35. self,
  36. model_path: Optional[str] = None,
  37. device: str = "cuda",
  38. confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD
  39. ):
  40. """
  41. Initialize the defect detection model.
  42. Args:
  43. model_path: Path to YOLO .pt model file (optional)
  44. device: Device to use ('cuda' or 'cpu')
  45. confidence_threshold: Minimum confidence for detections (0.0-1.0)
  46. """
  47. if model_path is None:
  48. model_path = str(DEFECT_MODEL_PATH)
  49. super().__init__(model_path, device)
  50. self.class_colors = DEFECT_CLASS_COLORS
  51. self.class_names = DEFECT_CLASS_NAMES
  52. self.confidence_threshold = confidence_threshold
  53. def load(self) -> bool:
  54. """
  55. Load the YOLO model.
  56. Returns:
  57. bool: True if loaded successfully, False otherwise
  58. """
  59. try:
  60. model_path = Path(self.model_path)
  61. if not model_path.exists():
  62. logger.error(f"Model file does not exist: {model_path}")
  63. return False
  64. logger.info(f"Loading defect model from {model_path}")
  65. self.model = YOLO(str(model_path))
  66. # Move model to specified device
  67. self.model.to(self.device)
  68. self._is_loaded = True
  69. logger.info(f"Defect model loaded on {self.device}")
  70. return True
  71. except Exception as e:
  72. logger.error(f"Failed to load defect model: {e}")
  73. self._is_loaded = False
  74. return False
  75. def _draw_detections(
  76. self,
  77. image: np.ndarray,
  78. boxes: np.ndarray,
  79. confidences: np.ndarray,
  80. class_ids: np.ndarray
  81. ) -> Tuple[np.ndarray, list]:
  82. """
  83. Draw bounding boxes and labels on the image.
  84. Args:
  85. image: Input image (BGR format)
  86. boxes: Bounding boxes [N, 4] (xmin, ymin, xmax, ymax)
  87. confidences: Confidence scores [N]
  88. class_ids: Class IDs [N]
  89. Returns:
  90. Tuple[np.ndarray, list]: (annotated image, detected class names)
  91. """
  92. annotated_image = image.copy()
  93. detected_classes = []
  94. for box, confidence, class_id in zip(boxes, confidences, class_ids):
  95. # Skip low confidence detections
  96. if confidence < self.confidence_threshold:
  97. continue
  98. xmin, ymin, xmax, ymax = map(int, box)
  99. class_id = int(class_id)
  100. # Get class information
  101. class_name = self.class_names.get(class_id, f"Class {class_id}")
  102. color = self.class_colors.get(class_id, (255, 255, 255))
  103. detected_classes.append(class_name)
  104. # Draw bounding box
  105. cv2.rectangle(
  106. annotated_image,
  107. (xmin, ymin),
  108. (xmax, ymax),
  109. color,
  110. 2
  111. )
  112. # Draw label with confidence
  113. label = f"{class_name}: {confidence:.2f}"
  114. cv2.putText(
  115. annotated_image,
  116. label,
  117. (xmin, ymin - 5),
  118. cv2.FONT_HERSHEY_SIMPLEX,
  119. 0.8,
  120. color,
  121. 2,
  122. lineType=cv2.LINE_AA
  123. )
  124. return annotated_image, detected_classes
  125. def predict(self, image_path: str) -> Dict[str, Any]:
  126. """
  127. Detect defects in an image.
  128. Args:
  129. image_path: Path to input image
  130. Returns:
  131. Dict containing:
  132. - 'success': Whether prediction succeeded
  133. - 'annotated_image': QImage with bounding boxes
  134. - 'detections': List of detection dictionaries
  135. - 'class_counts': Count of each detected class
  136. - 'primary_class': Most prevalent class
  137. - 'error': Error message if failed
  138. Raises:
  139. RuntimeError: If model is not loaded
  140. """
  141. if not self._is_loaded:
  142. raise RuntimeError("Model not loaded. Call load() first.")
  143. try:
  144. # Load image
  145. image = cv2.imread(image_path)
  146. if image is None:
  147. raise ValueError(f"Could not load image: {image_path}")
  148. # Run YOLO inference
  149. results = self.model.predict(image)
  150. detections = []
  151. all_boxes = []
  152. all_confidences = []
  153. all_class_ids = []
  154. # Process results
  155. for result in results:
  156. if result.boxes is None or len(result.boxes) == 0:
  157. continue
  158. boxes = result.boxes.cpu().numpy()
  159. for box in boxes:
  160. confidence = float(box.conf[0])
  161. if confidence < self.confidence_threshold:
  162. continue
  163. xmin, ymin, xmax, ymax = map(float, box.xyxy[0])
  164. class_id = int(box.cls[0])
  165. class_name = self.class_names.get(class_id, f"Class {class_id}")
  166. detections.append({
  167. 'bbox': [xmin, ymin, xmax, ymax],
  168. 'confidence': confidence,
  169. 'class_id': class_id,
  170. 'class_name': class_name
  171. })
  172. all_boxes.append([xmin, ymin, xmax, ymax])
  173. all_confidences.append(confidence)
  174. all_class_ids.append(class_id)
  175. # Draw detections
  176. if len(all_boxes) > 0:
  177. annotated_image, detected_class_names = self._draw_detections(
  178. image,
  179. np.array(all_boxes),
  180. np.array(all_confidences),
  181. np.array(all_class_ids)
  182. )
  183. else:
  184. annotated_image = image
  185. detected_class_names = []
  186. # Count classes
  187. class_counts = {}
  188. for class_name in detected_class_names:
  189. class_counts[class_name] = class_counts.get(class_name, 0) + 1
  190. # Determine primary class
  191. if class_counts:
  192. primary_class = max(class_counts.items(), key=lambda x: x[1])[0]
  193. else:
  194. primary_class = "No detections"
  195. # Convert to QImage
  196. rgb_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
  197. h, w, ch = rgb_image.shape
  198. bytes_per_line = ch * w
  199. q_image = QImage(
  200. rgb_image.data,
  201. w,
  202. h,
  203. bytes_per_line,
  204. QImage.Format_RGB888
  205. )
  206. logger.info(f"Detected {len(detections)} objects. Primary class: {primary_class}")
  207. return {
  208. 'success': True,
  209. 'annotated_image': q_image,
  210. 'detections': detections,
  211. 'class_counts': class_counts,
  212. 'primary_class': primary_class,
  213. 'total_detections': len(detections),
  214. 'error': None
  215. }
  216. except Exception as e:
  217. logger.error(f"Prediction failed: {e}")
  218. return {
  219. 'success': False,
  220. 'annotated_image': None,
  221. 'detections': [],
  222. 'class_counts': {},
  223. 'primary_class': None,
  224. 'total_detections': 0,
  225. 'error': str(e)
  226. }
  227. def predict_batch(self, image_paths: list) -> list:
  228. """
  229. Detect defects in multiple images.
  230. Args:
  231. image_paths: List of paths to images
  232. Returns:
  233. List[Dict]: List of prediction results
  234. """
  235. results = []
  236. for image_path in image_paths:
  237. result = self.predict(image_path)
  238. results.append(result)
  239. return results