""" Defect Model Module YOLO-based defect detection model for durian quality assessment. Detects and classifies defects: minor defects, no defects, reject. """ from pathlib import Path from typing import Dict, Any, Optional, Tuple import logging import cv2 import numpy as np from ultralytics import YOLO from PyQt5.QtGui import QImage from models.base_model import BaseModel from utils.config import ( DEFECT_MODEL_PATH, DEFECT_CLASS_COLORS, DEFECT_CLASS_NAMES, YOLO_CONFIDENCE_THRESHOLD, ) logger = logging.getLogger(__name__) class DefectModel(BaseModel): """ YOLO-based defect detection model. Detects defects in durian images and classifies them into: - Minor defects (Pink) - No defects (Cyan) - Reject (Purple) Attributes: model: YOLO model instance class_colors: BGR color mapping for each class class_names: Name mapping for each class confidence_threshold: Minimum confidence for detections """ def __init__( self, model_path: Optional[str] = None, device: str = "cuda", confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD ): """ Initialize the defect detection model. Args: model_path: Path to YOLO .pt model file (optional) device: Device to use ('cuda' or 'cpu') confidence_threshold: Minimum confidence for detections (0.0-1.0) """ if model_path is None: model_path = str(DEFECT_MODEL_PATH) super().__init__(model_path, device) self.class_colors = DEFECT_CLASS_COLORS self.class_names = DEFECT_CLASS_NAMES self.confidence_threshold = confidence_threshold def load(self) -> bool: """ Load the YOLO model. Returns: bool: True if loaded successfully, False otherwise """ try: model_path = Path(self.model_path) if not model_path.exists(): logger.error(f"Model file does not exist: {model_path}") return False logger.info(f"Loading defect model from {model_path}") self.model = YOLO(str(model_path)) # Move model to specified device self.model.to(self.device) self._is_loaded = True logger.info(f"Defect model loaded on {self.device}") return True except Exception as e: logger.error(f"Failed to load defect model: {e}") self._is_loaded = False return False def _draw_detections( self, image: np.ndarray, boxes: np.ndarray, confidences: np.ndarray, class_ids: np.ndarray ) -> Tuple[np.ndarray, list]: """ Draw bounding boxes and labels on the image. Args: image: Input image (BGR format) boxes: Bounding boxes [N, 4] (xmin, ymin, xmax, ymax) confidences: Confidence scores [N] class_ids: Class IDs [N] Returns: Tuple[np.ndarray, list]: (annotated image, detected class names) """ annotated_image = image.copy() detected_classes = [] for box, confidence, class_id in zip(boxes, confidences, class_ids): # Skip low confidence detections if confidence < self.confidence_threshold: continue xmin, ymin, xmax, ymax = map(int, box) class_id = int(class_id) # Get class information class_name = self.class_names.get(class_id, f"Class {class_id}") color = self.class_colors.get(class_id, (255, 255, 255)) detected_classes.append(class_name) # Draw bounding box cv2.rectangle( annotated_image, (xmin, ymin), (xmax, ymax), color, 2 ) # Draw label with confidence label = f"{class_name}: {confidence:.2f}" cv2.putText( annotated_image, label, (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, lineType=cv2.LINE_AA ) return annotated_image, detected_classes def predict(self, image_path: str) -> Dict[str, Any]: """ Detect defects in an image. Args: image_path: Path to input image Returns: Dict containing: - 'success': Whether prediction succeeded - 'annotated_image': QImage with bounding boxes - 'detections': List of detection dictionaries - 'class_counts': Count of each detected class - 'primary_class': Most prevalent class - 'error': Error message if failed Raises: RuntimeError: If model is not loaded """ if not self._is_loaded: raise RuntimeError("Model not loaded. Call load() first.") try: # Load image image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image: {image_path}") # Run YOLO inference results = self.model.predict(image) detections = [] all_boxes = [] all_confidences = [] all_class_ids = [] # Process results for result in results: if result.boxes is None or len(result.boxes) == 0: continue boxes = result.boxes.cpu().numpy() for box in boxes: confidence = float(box.conf[0]) if confidence < self.confidence_threshold: continue xmin, ymin, xmax, ymax = map(float, box.xyxy[0]) class_id = int(box.cls[0]) class_name = self.class_names.get(class_id, f"Class {class_id}") detections.append({ 'bbox': [xmin, ymin, xmax, ymax], 'confidence': confidence, 'class_id': class_id, 'class_name': class_name }) all_boxes.append([xmin, ymin, xmax, ymax]) all_confidences.append(confidence) all_class_ids.append(class_id) # Draw detections if len(all_boxes) > 0: annotated_image, detected_class_names = self._draw_detections( image, np.array(all_boxes), np.array(all_confidences), np.array(all_class_ids) ) else: annotated_image = image detected_class_names = [] # Count classes class_counts = {} for class_name in detected_class_names: class_counts[class_name] = class_counts.get(class_name, 0) + 1 # Determine primary class if class_counts: primary_class = max(class_counts.items(), key=lambda x: x[1])[0] else: primary_class = "No detections" # Convert to QImage rgb_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) h, w, ch = rgb_image.shape bytes_per_line = ch * w q_image = QImage( rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888 ) logger.info(f"Detected {len(detections)} objects. Primary class: {primary_class}") return { 'success': True, 'annotated_image': q_image, 'detections': detections, 'class_counts': class_counts, 'primary_class': primary_class, 'total_detections': len(detections), 'error': None } except Exception as e: logger.error(f"Prediction failed: {e}") return { 'success': False, 'annotated_image': None, 'detections': [], 'class_counts': {}, 'primary_class': None, 'total_detections': 0, 'error': str(e) } def predict_batch(self, image_paths: list) -> list: """ Detect defects in multiple images. Args: image_paths: List of paths to images Returns: List[Dict]: List of prediction results """ results = [] for image_path in image_paths: result = self.predict(image_path) results.append(result) return results