""" Shape Model Module YOLO-based shape classification model for durian quality assessment. Detects and classifies durian shape: Regular vs Irregular. """ from pathlib import Path from typing import Dict, Any, Optional import logging import cv2 import numpy as np from ultralytics import YOLO from PyQt5.QtGui import QImage from models.base_model import BaseModel from utils.config import ( YOLO_CONFIDENCE_THRESHOLD, ) logger = logging.getLogger(__name__) # Shape class constants SHAPE_CLASS_NAMES = { 0: "Irregular", 1: "Regular", } SHAPE_CLASS_COLORS = { 0: (86, 0, 254), # Irregular - Purple 1: (0, 252, 199), # Regular - Cyan } class ShapeModel(BaseModel): """ YOLO-based shape classification model. Classifies durian shape into: - Regular (Class 1) - Irregular (Class 0) Attributes: model: YOLO model instance class_names: Name mapping for each class class_colors: BGR color mapping for visualization confidence_threshold: Minimum confidence for classifications """ def __init__( self, model_path: str, device: str = "cuda", confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD ): """ Initialize the shape classification model. Args: model_path: Path to YOLO .pt model file (shape.pt) device: Device to use ('cuda' or 'cpu') confidence_threshold: Minimum confidence threshold (0.0-1.0) """ super().__init__(model_path, device) self.class_names = SHAPE_CLASS_NAMES self.class_colors = SHAPE_CLASS_COLORS self.confidence_threshold = confidence_threshold def load(self) -> bool: """ Load the YOLO model. Returns: bool: True if loaded successfully, False otherwise """ try: model_path = Path(self.model_path) if not model_path.exists(): logger.error(f"Shape model file does not exist: {model_path}") return False logger.info(f"Loading shape model from {model_path}") self.model = YOLO(str(model_path)) # Move model to specified device self.model.to(self.device) self._is_loaded = True logger.info(f"Shape model loaded on {self.device}") return True except Exception as e: logger.error(f"Failed to load shape model: {e}") self._is_loaded = False return False def _draw_bounding_box( self, image: np.ndarray, box: Any, class_id: int, confidence: float, shape_class: str ) -> np.ndarray: """ Draw bounding box and label on the image. Args: image: Input image (BGR format) box: YOLO box object with coordinates class_id: Class ID (0=Irregular, 1=Regular) confidence: Confidence score shape_class: Shape class name Returns: Annotated image with bounding box """ annotated = image.copy() # Get bounding box coordinates xmin, ymin, xmax, ymax = map(int, box.xyxy[0]) # Get color based on class color = self.class_colors.get(class_id, (255, 255, 255)) # Draw bounding box cv2.rectangle( annotated, (xmin, ymin), (xmax, ymax), color, 2 ) # Draw label with confidence label = f"{shape_class}: {confidence:.2f}" cv2.putText( annotated, label, (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, lineType=cv2.LINE_AA ) return annotated def predict(self, image_path: str) -> Dict[str, Any]: """ Classify the shape of a durian in an image. Args: image_path: Path to input image Returns: Dict containing: - 'success': Whether prediction succeeded - 'shape_class': Detected shape (Regular/Irregular) - 'class_id': Numeric class ID (0=Irregular, 1=Regular) - 'confidence': Confidence score (0.0-1.0) - 'annotated_image': QImage with bounding box (if detection model) - 'error': Error message if failed Raises: RuntimeError: If model is not loaded """ if not self._is_loaded: raise RuntimeError("Model not loaded. Call load() first.") try: # Load image image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image: {image_path}") # Run YOLO inference results = self.model.predict(image) shape_class = None confidence = 0.0 class_id = None annotated_image = None # Process results - shape.pt is a classification model with probs for result in results: # Check if results have classification probabilities if result.probs is not None: logger.info(f"Shape model returned probs: {result.probs}") # Get top class class_id = int(result.probs.top1) # Index of highest probability confidence = float(result.probs.top1conf.cpu().item()) # Get class name shape_class = self.class_names.get( class_id, f"Unknown({class_id})" ) logger.info( f"Shape classification (via probs): {shape_class} " f"(confidence: {confidence:.3f})" ) return { 'success': True, 'shape_class': shape_class, 'class_id': class_id, 'confidence': confidence, 'annotated_image': None, 'error': None } # Fallback: Check for detection results with class names # (in case shape.pt is detection model instead of classification) if result.boxes is not None and len(result.boxes) > 0: logger.info(f"Shape model returned detection boxes") boxes = result.boxes.cpu().numpy() if len(boxes) > 0: box = boxes[0] class_id = int(box.cls[0]) confidence = float(box.conf[0]) # Get class name shape_class = self.class_names.get( class_id, f"Unknown({class_id})" ) # Draw bounding box on image annotated_image_np = self._draw_bounding_box( image, box, class_id, confidence, shape_class ) # Convert to QImage rgb_image = cv2.cvtColor(annotated_image_np, cv2.COLOR_BGR2RGB) h, w, ch = rgb_image.shape bytes_per_line = ch * w annotated_image = QImage( rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888 ) logger.info( f"Shape classification (via boxes): {shape_class} " f"(confidence: {confidence:.3f})" ) return { 'success': True, 'shape_class': shape_class, 'class_id': class_id, 'confidence': confidence, 'annotated_image': annotated_image, 'error': None } # No results found logger.warning(f"No shape classification results from model. Results: {results}") return { 'success': False, 'shape_class': None, 'class_id': None, 'confidence': 0.0, 'annotated_image': None, 'error': 'No classification result from model' } except Exception as e: logger.error(f"Shape prediction failed: {e}") import traceback logger.error(traceback.format_exc()) return { 'success': False, 'shape_class': None, 'class_id': None, 'confidence': 0.0, 'annotated_image': None, 'error': str(e) } def predict_batch(self, image_paths: list) -> list: """ Classify shapes in multiple images. Args: image_paths: List of paths to images Returns: List[Dict]: List of prediction results """ results = [] for image_path in image_paths: result = self.predict(image_path) results.append(result) return results