""" Locule Model Module YOLO-based segmentation model for durian locule counting. Detects and counts locules with colored mask overlays. """ from pathlib import Path from typing import Dict, Any, Optional, Tuple import logging import cv2 import numpy as np from ultralytics import YOLO from PyQt5.QtGui import QImage from models.base_model import BaseModel from utils.config import ( LOCULE_MODEL_PATH, LOCULE_COLORS, YOLO_CONFIDENCE_THRESHOLD, ) logger = logging.getLogger(__name__) class LoculeModel(BaseModel): """ YOLO-based locule segmentation and counting model. Detects individual locules in durian cross-section images and applies colored masks using ROYGBIV color scheme. Attributes: model: YOLO segmentation model instance colors: ROYGBIV color scheme for masks (BGR format) confidence_threshold: Minimum confidence for detections """ def __init__( self, model_path: Optional[str] = None, device: str = "cuda", confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD ): """ Initialize the locule counting model. Args: model_path: Path to YOLO .pt segmentation model (optional) device: Device to use ('cuda' or 'cpu') confidence_threshold: Minimum confidence for detections (0.0-1.0) """ if model_path is None: model_path = str(LOCULE_MODEL_PATH) super().__init__(model_path, device) self.colors = LOCULE_COLORS self.confidence_threshold = confidence_threshold def load(self) -> bool: """ Load the YOLO segmentation model. Returns: bool: True if loaded successfully, False otherwise """ try: model_path = Path(self.model_path) if not model_path.exists(): logger.error(f"Model file does not exist: {model_path}") return False logger.info(f"Loading locule model from {model_path}") self.model = YOLO(str(model_path)) # Move model to specified device self.model.to(self.device) self._is_loaded = True logger.info(f"Locule model loaded on {self.device}") return True except Exception as e: logger.error(f"Failed to load locule model: {e}") self._is_loaded = False return False def _apply_colored_masks( self, image: np.ndarray, masks: Optional[np.ndarray], boxes: np.ndarray, confidences: np.ndarray, class_names: list ) -> Tuple[np.ndarray, int]: """ Apply colored masks and bounding boxes to the image. Args: image: Input image (BGR format) masks: Segmentation masks [N, H, W] or None boxes: Bounding boxes [N, 4] confidences: Confidence scores [N] class_names: Class names for each detection Returns: Tuple[np.ndarray, int]: (masked image, valid detection count) """ masked_image = image.copy() valid_count = 0 for i, (box, confidence, name) in enumerate(zip(boxes, confidences, class_names)): # Skip low confidence detections if confidence < self.confidence_threshold: continue valid_count += 1 xmin, ymin, xmax, ymax = map(int, box) # Get color from ROYGBIV (cycle if more than 7) color = self.colors[i % len(self.colors)] # Draw bounding box cv2.rectangle( masked_image, (xmin, ymin), (xmax, ymax), color, 2 ) # Draw label with confidence label = f"{name}: {confidence:.2f}" cv2.putText( masked_image, label, (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, lineType=cv2.LINE_AA ) # Apply mask if available if masks is not None and i < len(masks): mask = masks[i] # Resize mask to match image dimensions if needed if mask.shape[:2] != masked_image.shape[:2]: mask = cv2.resize( mask, (masked_image.shape[1], masked_image.shape[0]) ) # Convert mask to binary mask_binary = (mask * 255).astype(np.uint8) # Create colored overlay colored_mask = np.zeros_like(masked_image, dtype=np.uint8) for c in range(3): colored_mask[:, :, c] = mask_binary * (color[c] / 255) # Blend mask with image masked_image = cv2.addWeighted( masked_image, 1.0, colored_mask, 0.5, 0 ) return masked_image, valid_count def predict(self, image_path: str) -> Dict[str, Any]: """ Count locules in a durian cross-section image. Args: image_path: Path to input image Returns: Dict containing: - 'success': Whether prediction succeeded - 'annotated_image': QImage with colored masks - 'locule_count': Number of detected locules - 'detections': List of detection details - 'error': Error message if failed Raises: RuntimeError: If model is not loaded """ if not self._is_loaded: raise RuntimeError("Model not loaded. Call load() first.") try: # Load image image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image: {image_path}") # Get original dimensions and preserve aspect ratio orig_height, orig_width = image.shape[:2] # Resize to a standard size while maintaining aspect ratio # Use 640 as max dimension (common YOLO input size) max_dim = 640 if orig_width > orig_height: new_width = max_dim new_height = int((max_dim / orig_width) * orig_height) else: new_height = max_dim new_width = int((max_dim / orig_height) * orig_width) image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR) # Run YOLO segmentation results = self.model.predict(image_resized) detections = [] all_boxes = [] all_confidences = [] all_class_names = [] all_masks = None # Process results for result in results: if result.boxes is None or len(result.boxes) == 0: continue boxes = result.boxes.cpu().numpy() masks = result.masks.data.cpu().numpy() if result.masks is not None else None if masks is not None: # Ensure mask count matches box count masks = masks[:len(boxes)] all_masks = masks for idx, box in enumerate(boxes): confidence = float(box.conf[0]) if confidence < self.confidence_threshold: continue xmin, ymin, xmax, ymax = map(float, box.xyxy[0]) class_id = int(box.cls[0]) class_name = self.model.names.get(class_id, f"Locule {idx + 1}") detections.append({ 'bbox': [xmin, ymin, xmax, ymax], 'confidence': confidence, 'class_id': class_id, 'class_name': class_name, 'index': idx }) all_boxes.append([xmin, ymin, xmax, ymax]) all_confidences.append(confidence) all_class_names.append(class_name) # Apply colored masks if len(all_boxes) > 0: masked_image, locule_count = self._apply_colored_masks( image_resized, all_masks, np.array(all_boxes), np.array(all_confidences), all_class_names ) else: masked_image = image_resized locule_count = 0 # Convert to QImage rgb_image = cv2.cvtColor(masked_image, cv2.COLOR_BGR2RGB) h, w, ch = rgb_image.shape bytes_per_line = ch * w q_image = QImage( rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888 ) logger.info(f"Detected {locule_count} locules") return { 'success': True, 'annotated_image': q_image, 'locule_count': locule_count, 'detections': detections, 'error': None } except Exception as e: logger.error(f"Prediction failed: {e}") return { 'success': False, 'annotated_image': None, 'locule_count': 0, 'detections': [], 'error': str(e) } def predict_batch(self, image_paths: list) -> list: """ Count locules in multiple images. Args: image_paths: List of paths to images Returns: List[Dict]: List of prediction results """ results = [] for image_path in image_paths: result = self.predict(image_path) results.append(result) return results