| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- """
- Defect Model Module
- YOLO-based defect detection model for durian quality assessment.
- Detects and classifies defects: minor defects, no defects, reject.
- """
- from pathlib import Path
- from typing import Dict, Any, Optional, Tuple
- import logging
- import cv2
- import numpy as np
- from ultralytics import YOLO
- from PyQt5.QtGui import QImage
- from models.base_model import BaseModel
- from utils.config import (
- DEFECT_MODEL_PATH,
- DEFECT_CLASS_COLORS,
- DEFECT_CLASS_NAMES,
- YOLO_CONFIDENCE_THRESHOLD,
- )
- logger = logging.getLogger(__name__)
- class DefectModel(BaseModel):
- """
- YOLO-based defect detection model.
-
- Detects defects in durian images and classifies them into:
- - Minor defects (Pink)
- - No defects (Cyan)
- - Reject (Purple)
-
- Attributes:
- model: YOLO model instance
- class_colors: BGR color mapping for each class
- class_names: Name mapping for each class
- confidence_threshold: Minimum confidence for detections
- """
-
- def __init__(
- self,
- model_path: Optional[str] = None,
- device: str = "cuda",
- confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD
- ):
- """
- Initialize the defect detection model.
-
- Args:
- model_path: Path to YOLO .pt model file (optional)
- device: Device to use ('cuda' or 'cpu')
- confidence_threshold: Minimum confidence for detections (0.0-1.0)
- """
- if model_path is None:
- model_path = str(DEFECT_MODEL_PATH)
-
- super().__init__(model_path, device)
- self.class_colors = DEFECT_CLASS_COLORS
- self.class_names = DEFECT_CLASS_NAMES
- self.confidence_threshold = confidence_threshold
-
- def load(self) -> bool:
- """
- Load the YOLO model.
-
- Returns:
- bool: True if loaded successfully, False otherwise
- """
- try:
- model_path = Path(self.model_path)
-
- if not model_path.exists():
- logger.error(f"Model file does not exist: {model_path}")
- return False
-
- logger.info(f"Loading defect model from {model_path}")
- self.model = YOLO(str(model_path))
-
- # Move model to specified device
- self.model.to(self.device)
-
- self._is_loaded = True
- logger.info(f"Defect model loaded on {self.device}")
- return True
-
- except Exception as e:
- logger.error(f"Failed to load defect model: {e}")
- self._is_loaded = False
- return False
-
- def _draw_detections(
- self,
- image: np.ndarray,
- boxes: np.ndarray,
- confidences: np.ndarray,
- class_ids: np.ndarray
- ) -> Tuple[np.ndarray, list]:
- """
- Draw bounding boxes and labels on the image.
-
- Args:
- image: Input image (BGR format)
- boxes: Bounding boxes [N, 4] (xmin, ymin, xmax, ymax)
- confidences: Confidence scores [N]
- class_ids: Class IDs [N]
-
- Returns:
- Tuple[np.ndarray, list]: (annotated image, detected class names)
- """
- annotated_image = image.copy()
- detected_classes = []
-
- for box, confidence, class_id in zip(boxes, confidences, class_ids):
- # Skip low confidence detections
- if confidence < self.confidence_threshold:
- continue
-
- xmin, ymin, xmax, ymax = map(int, box)
- class_id = int(class_id)
-
- # Get class information
- class_name = self.class_names.get(class_id, f"Class {class_id}")
- color = self.class_colors.get(class_id, (255, 255, 255))
-
- detected_classes.append(class_name)
-
- # Draw bounding box
- cv2.rectangle(
- annotated_image,
- (xmin, ymin),
- (xmax, ymax),
- color,
- 2
- )
-
- # Draw label with confidence
- label = f"{class_name}: {confidence:.2f}"
- cv2.putText(
- annotated_image,
- label,
- (xmin, ymin - 5),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.8,
- color,
- 2,
- lineType=cv2.LINE_AA
- )
-
- return annotated_image, detected_classes
-
- def predict(self, image_path: str) -> Dict[str, Any]:
- """
- Detect defects in an image.
-
- Args:
- image_path: Path to input image
-
- Returns:
- Dict containing:
- - 'success': Whether prediction succeeded
- - 'annotated_image': QImage with bounding boxes
- - 'detections': List of detection dictionaries
- - 'class_counts': Count of each detected class
- - 'primary_class': Most prevalent class
- - 'error': Error message if failed
-
- Raises:
- RuntimeError: If model is not loaded
- """
- if not self._is_loaded:
- raise RuntimeError("Model not loaded. Call load() first.")
-
- try:
- # Load image
- image = cv2.imread(image_path)
- if image is None:
- raise ValueError(f"Could not load image: {image_path}")
-
- # Run YOLO inference
- results = self.model.predict(image)
-
- detections = []
- all_boxes = []
- all_confidences = []
- all_class_ids = []
-
- # Process results
- for result in results:
- if result.boxes is None or len(result.boxes) == 0:
- continue
-
- boxes = result.boxes.cpu().numpy()
-
- for box in boxes:
- confidence = float(box.conf[0])
-
- if confidence < self.confidence_threshold:
- continue
-
- xmin, ymin, xmax, ymax = map(float, box.xyxy[0])
- class_id = int(box.cls[0])
- class_name = self.class_names.get(class_id, f"Class {class_id}")
-
- detections.append({
- 'bbox': [xmin, ymin, xmax, ymax],
- 'confidence': confidence,
- 'class_id': class_id,
- 'class_name': class_name
- })
-
- all_boxes.append([xmin, ymin, xmax, ymax])
- all_confidences.append(confidence)
- all_class_ids.append(class_id)
-
- # Draw detections
- if len(all_boxes) > 0:
- annotated_image, detected_class_names = self._draw_detections(
- image,
- np.array(all_boxes),
- np.array(all_confidences),
- np.array(all_class_ids)
- )
- else:
- annotated_image = image
- detected_class_names = []
-
- # Count classes
- class_counts = {}
- for class_name in detected_class_names:
- class_counts[class_name] = class_counts.get(class_name, 0) + 1
-
- # Determine primary class
- if class_counts:
- primary_class = max(class_counts.items(), key=lambda x: x[1])[0]
- else:
- primary_class = "No detections"
-
- # Convert to QImage
- rgb_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
- h, w, ch = rgb_image.shape
- bytes_per_line = ch * w
- q_image = QImage(
- rgb_image.data,
- w,
- h,
- bytes_per_line,
- QImage.Format_RGB888
- )
-
- logger.info(f"Detected {len(detections)} objects. Primary class: {primary_class}")
-
- return {
- 'success': True,
- 'annotated_image': q_image,
- 'detections': detections,
- 'class_counts': class_counts,
- 'primary_class': primary_class,
- 'total_detections': len(detections),
- 'error': None
- }
-
- except Exception as e:
- logger.error(f"Prediction failed: {e}")
- return {
- 'success': False,
- 'annotated_image': None,
- 'detections': [],
- 'class_counts': {},
- 'primary_class': None,
- 'total_detections': 0,
- 'error': str(e)
- }
-
- def predict_batch(self, image_paths: list) -> list:
- """
- Detect defects in multiple images.
-
- Args:
- image_paths: List of paths to images
-
- Returns:
- List[Dict]: List of prediction results
- """
- results = []
- for image_path in image_paths:
- result = self.predict(image_path)
- results.append(result)
- return results
|