| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- """
- Shape Model Module
- YOLO-based shape classification model for durian quality assessment.
- Detects and classifies durian shape: Regular vs Irregular.
- """
- from pathlib import Path
- from typing import Dict, Any, Optional
- import logging
- import cv2
- import numpy as np
- from ultralytics import YOLO
- from PyQt5.QtGui import QImage
- from models.base_model import BaseModel
- from utils.config import (
- YOLO_CONFIDENCE_THRESHOLD,
- )
- logger = logging.getLogger(__name__)
- # Shape class constants
- SHAPE_CLASS_NAMES = {
- 0: "Irregular",
- 1: "Regular",
- }
- SHAPE_CLASS_COLORS = {
- 0: (86, 0, 254), # Irregular - Purple
- 1: (0, 252, 199), # Regular - Cyan
- }
- class ShapeModel(BaseModel):
- """
- YOLO-based shape classification model.
-
- Classifies durian shape into:
- - Regular (Class 1)
- - Irregular (Class 0)
-
- Attributes:
- model: YOLO model instance
- class_names: Name mapping for each class
- class_colors: BGR color mapping for visualization
- confidence_threshold: Minimum confidence for classifications
- """
-
- def __init__(
- self,
- model_path: str,
- device: str = "cuda",
- confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD
- ):
- """
- Initialize the shape classification model.
-
- Args:
- model_path: Path to YOLO .pt model file (shape.pt)
- device: Device to use ('cuda' or 'cpu')
- confidence_threshold: Minimum confidence threshold (0.0-1.0)
- """
- super().__init__(model_path, device)
- self.class_names = SHAPE_CLASS_NAMES
- self.class_colors = SHAPE_CLASS_COLORS
- self.confidence_threshold = confidence_threshold
-
- def load(self) -> bool:
- """
- Load the YOLO model.
-
- Returns:
- bool: True if loaded successfully, False otherwise
- """
- try:
- model_path = Path(self.model_path)
-
- if not model_path.exists():
- logger.error(f"Shape model file does not exist: {model_path}")
- return False
-
- logger.info(f"Loading shape model from {model_path}")
- self.model = YOLO(str(model_path))
-
- # Move model to specified device
- self.model.to(self.device)
-
- self._is_loaded = True
- logger.info(f"Shape model loaded on {self.device}")
- return True
-
- except Exception as e:
- logger.error(f"Failed to load shape model: {e}")
- self._is_loaded = False
- return False
-
- def _draw_bounding_box(
- self,
- image: np.ndarray,
- box: Any,
- class_id: int,
- confidence: float,
- shape_class: str
- ) -> np.ndarray:
- """
- Draw bounding box and label on the image.
-
- Args:
- image: Input image (BGR format)
- box: YOLO box object with coordinates
- class_id: Class ID (0=Irregular, 1=Regular)
- confidence: Confidence score
- shape_class: Shape class name
-
- Returns:
- Annotated image with bounding box
- """
- annotated = image.copy()
-
- # Get bounding box coordinates
- xmin, ymin, xmax, ymax = map(int, box.xyxy[0])
-
- # Get color based on class
- color = self.class_colors.get(class_id, (255, 255, 255))
-
- # Draw bounding box
- cv2.rectangle(
- annotated,
- (xmin, ymin),
- (xmax, ymax),
- color,
- 2
- )
-
- # Draw label with confidence
- label = f"{shape_class}: {confidence:.2f}"
- cv2.putText(
- annotated,
- label,
- (xmin, ymin - 5),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.8,
- color,
- 2,
- lineType=cv2.LINE_AA
- )
-
- return annotated
-
- def predict(self, image_path: str) -> Dict[str, Any]:
- """
- Classify the shape of a durian in an image.
-
- Args:
- image_path: Path to input image
-
- Returns:
- Dict containing:
- - 'success': Whether prediction succeeded
- - 'shape_class': Detected shape (Regular/Irregular)
- - 'class_id': Numeric class ID (0=Irregular, 1=Regular)
- - 'confidence': Confidence score (0.0-1.0)
- - 'annotated_image': QImage with bounding box (if detection model)
- - 'error': Error message if failed
-
- Raises:
- RuntimeError: If model is not loaded
- """
- if not self._is_loaded:
- raise RuntimeError("Model not loaded. Call load() first.")
-
- try:
- # Load image
- image = cv2.imread(image_path)
- if image is None:
- raise ValueError(f"Could not load image: {image_path}")
-
- # Run YOLO inference
- results = self.model.predict(image)
-
- shape_class = None
- confidence = 0.0
- class_id = None
- annotated_image = None
-
- # Process results - shape.pt is a classification model with probs
- for result in results:
- # Check if results have classification probabilities
- if result.probs is not None:
- logger.info(f"Shape model returned probs: {result.probs}")
-
- # Get top class
- class_id = int(result.probs.top1) # Index of highest probability
- confidence = float(result.probs.top1conf.cpu().item())
-
- # Get class name
- shape_class = self.class_names.get(
- class_id,
- f"Unknown({class_id})"
- )
-
- logger.info(
- f"Shape classification (via probs): {shape_class} "
- f"(confidence: {confidence:.3f})"
- )
-
- return {
- 'success': True,
- 'shape_class': shape_class,
- 'class_id': class_id,
- 'confidence': confidence,
- 'annotated_image': None,
- 'error': None
- }
-
- # Fallback: Check for detection results with class names
- # (in case shape.pt is detection model instead of classification)
- if result.boxes is not None and len(result.boxes) > 0:
- logger.info(f"Shape model returned detection boxes")
-
- boxes = result.boxes.cpu().numpy()
- if len(boxes) > 0:
- box = boxes[0]
- class_id = int(box.cls[0])
- confidence = float(box.conf[0])
-
- # Get class name
- shape_class = self.class_names.get(
- class_id,
- f"Unknown({class_id})"
- )
-
- # Draw bounding box on image
- annotated_image_np = self._draw_bounding_box(
- image,
- box,
- class_id,
- confidence,
- shape_class
- )
-
- # Convert to QImage
- rgb_image = cv2.cvtColor(annotated_image_np, cv2.COLOR_BGR2RGB)
- h, w, ch = rgb_image.shape
- bytes_per_line = ch * w
- annotated_image = QImage(
- rgb_image.data,
- w,
- h,
- bytes_per_line,
- QImage.Format_RGB888
- )
-
- logger.info(
- f"Shape classification (via boxes): {shape_class} "
- f"(confidence: {confidence:.3f})"
- )
-
- return {
- 'success': True,
- 'shape_class': shape_class,
- 'class_id': class_id,
- 'confidence': confidence,
- 'annotated_image': annotated_image,
- 'error': None
- }
-
- # No results found
- logger.warning(f"No shape classification results from model. Results: {results}")
- return {
- 'success': False,
- 'shape_class': None,
- 'class_id': None,
- 'confidence': 0.0,
- 'annotated_image': None,
- 'error': 'No classification result from model'
- }
-
- except Exception as e:
- logger.error(f"Shape prediction failed: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return {
- 'success': False,
- 'shape_class': None,
- 'class_id': None,
- 'confidence': 0.0,
- 'annotated_image': None,
- 'error': str(e)
- }
-
- def predict_batch(self, image_paths: list) -> list:
- """
- Classify shapes in multiple images.
-
- Args:
- image_paths: List of paths to images
-
- Returns:
- List[Dict]: List of prediction results
- """
- results = []
- for image_path in image_paths:
- result = self.predict(image_path)
- results.append(result)
- return results
|