| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- """
- Locule Model Module
- YOLO-based segmentation model for durian locule counting.
- Detects and counts locules with colored mask overlays.
- """
- from pathlib import Path
- from typing import Dict, Any, Optional, Tuple
- import logging
- import cv2
- import numpy as np
- from ultralytics import YOLO
- from PyQt5.QtGui import QImage
- from models.base_model import BaseModel
- from utils.config import (
- LOCULE_MODEL_PATH,
- LOCULE_COLORS,
- YOLO_CONFIDENCE_THRESHOLD,
- )
- logger = logging.getLogger(__name__)
- class LoculeModel(BaseModel):
- """
- YOLO-based locule segmentation and counting model.
-
- Detects individual locules in durian cross-section images
- and applies colored masks using ROYGBIV color scheme.
-
- Attributes:
- model: YOLO segmentation model instance
- colors: ROYGBIV color scheme for masks (BGR format)
- confidence_threshold: Minimum confidence for detections
- """
-
- def __init__(
- self,
- model_path: Optional[str] = None,
- device: str = "cuda",
- confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD
- ):
- """
- Initialize the locule counting model.
-
- Args:
- model_path: Path to YOLO .pt segmentation model (optional)
- device: Device to use ('cuda' or 'cpu')
- confidence_threshold: Minimum confidence for detections (0.0-1.0)
- """
- if model_path is None:
- model_path = str(LOCULE_MODEL_PATH)
-
- super().__init__(model_path, device)
- self.colors = LOCULE_COLORS
- self.confidence_threshold = confidence_threshold
-
- def load(self) -> bool:
- """
- Load the YOLO segmentation model.
-
- Returns:
- bool: True if loaded successfully, False otherwise
- """
- try:
- model_path = Path(self.model_path)
-
- if not model_path.exists():
- logger.error(f"Model file does not exist: {model_path}")
- return False
-
- logger.info(f"Loading locule model from {model_path}")
- self.model = YOLO(str(model_path))
-
- # Move model to specified device
- self.model.to(self.device)
-
- self._is_loaded = True
- logger.info(f"Locule model loaded on {self.device}")
- return True
-
- except Exception as e:
- logger.error(f"Failed to load locule model: {e}")
- self._is_loaded = False
- return False
-
- def _apply_colored_masks(
- self,
- image: np.ndarray,
- masks: Optional[np.ndarray],
- boxes: np.ndarray,
- confidences: np.ndarray,
- class_names: list
- ) -> Tuple[np.ndarray, int]:
- """
- Apply colored masks and bounding boxes to the image.
-
- Args:
- image: Input image (BGR format)
- masks: Segmentation masks [N, H, W] or None
- boxes: Bounding boxes [N, 4]
- confidences: Confidence scores [N]
- class_names: Class names for each detection
-
- Returns:
- Tuple[np.ndarray, int]: (masked image, valid detection count)
- """
- masked_image = image.copy()
- valid_count = 0
-
- for i, (box, confidence, name) in enumerate(zip(boxes, confidences, class_names)):
- # Skip low confidence detections
- if confidence < self.confidence_threshold:
- continue
-
- valid_count += 1
-
- xmin, ymin, xmax, ymax = map(int, box)
-
- # Get color from ROYGBIV (cycle if more than 7)
- color = self.colors[i % len(self.colors)]
-
- # Draw bounding box
- cv2.rectangle(
- masked_image,
- (xmin, ymin),
- (xmax, ymax),
- color,
- 2
- )
-
- # Draw label with confidence
- label = f"{name}: {confidence:.2f}"
- cv2.putText(
- masked_image,
- label,
- (xmin, ymin - 5),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.8,
- color,
- 2,
- lineType=cv2.LINE_AA
- )
-
- # Apply mask if available
- if masks is not None and i < len(masks):
- mask = masks[i]
-
- # Resize mask to match image dimensions if needed
- if mask.shape[:2] != masked_image.shape[:2]:
- mask = cv2.resize(
- mask,
- (masked_image.shape[1], masked_image.shape[0])
- )
-
- # Convert mask to binary
- mask_binary = (mask * 255).astype(np.uint8)
-
- # Create colored overlay
- colored_mask = np.zeros_like(masked_image, dtype=np.uint8)
- for c in range(3):
- colored_mask[:, :, c] = mask_binary * (color[c] / 255)
-
- # Blend mask with image
- masked_image = cv2.addWeighted(
- masked_image,
- 1.0,
- colored_mask,
- 0.5,
- 0
- )
-
- return masked_image, valid_count
-
- def predict(self, image_path: str) -> Dict[str, Any]:
- """
- Count locules in a durian cross-section image.
-
- Args:
- image_path: Path to input image
-
- Returns:
- Dict containing:
- - 'success': Whether prediction succeeded
- - 'annotated_image': QImage with colored masks
- - 'locule_count': Number of detected locules
- - 'detections': List of detection details
- - 'error': Error message if failed
-
- Raises:
- RuntimeError: If model is not loaded
- """
- if not self._is_loaded:
- raise RuntimeError("Model not loaded. Call load() first.")
-
- try:
- # Load image
- image = cv2.imread(image_path)
- if image is None:
- raise ValueError(f"Could not load image: {image_path}")
-
- # Get original dimensions and preserve aspect ratio
- orig_height, orig_width = image.shape[:2]
-
- # Resize to a standard size while maintaining aspect ratio
- # Use 640 as max dimension (common YOLO input size)
- max_dim = 640
- if orig_width > orig_height:
- new_width = max_dim
- new_height = int((max_dim / orig_width) * orig_height)
- else:
- new_height = max_dim
- new_width = int((max_dim / orig_height) * orig_width)
-
- image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
-
- # Run YOLO segmentation
- results = self.model.predict(image_resized)
-
- detections = []
- all_boxes = []
- all_confidences = []
- all_class_names = []
- all_masks = None
-
- # Process results
- for result in results:
- if result.boxes is None or len(result.boxes) == 0:
- continue
-
- boxes = result.boxes.cpu().numpy()
- masks = result.masks.data.cpu().numpy() if result.masks is not None else None
-
- if masks is not None:
- # Ensure mask count matches box count
- masks = masks[:len(boxes)]
- all_masks = masks
-
- for idx, box in enumerate(boxes):
- confidence = float(box.conf[0])
-
- if confidence < self.confidence_threshold:
- continue
-
- xmin, ymin, xmax, ymax = map(float, box.xyxy[0])
- class_id = int(box.cls[0])
- class_name = self.model.names.get(class_id, f"Locule {idx + 1}")
-
- detections.append({
- 'bbox': [xmin, ymin, xmax, ymax],
- 'confidence': confidence,
- 'class_id': class_id,
- 'class_name': class_name,
- 'index': idx
- })
-
- all_boxes.append([xmin, ymin, xmax, ymax])
- all_confidences.append(confidence)
- all_class_names.append(class_name)
-
- # Apply colored masks
- if len(all_boxes) > 0:
- masked_image, locule_count = self._apply_colored_masks(
- image_resized,
- all_masks,
- np.array(all_boxes),
- np.array(all_confidences),
- all_class_names
- )
- else:
- masked_image = image_resized
- locule_count = 0
-
- # Convert to QImage
- rgb_image = cv2.cvtColor(masked_image, cv2.COLOR_BGR2RGB)
- h, w, ch = rgb_image.shape
- bytes_per_line = ch * w
- q_image = QImage(
- rgb_image.data,
- w,
- h,
- bytes_per_line,
- QImage.Format_RGB888
- )
-
- logger.info(f"Detected {locule_count} locules")
-
- return {
- 'success': True,
- 'annotated_image': q_image,
- 'locule_count': locule_count,
- 'detections': detections,
- 'error': None
- }
-
- except Exception as e:
- logger.error(f"Prediction failed: {e}")
- return {
- 'success': False,
- 'annotated_image': None,
- 'locule_count': 0,
- 'detections': [],
- 'error': str(e)
- }
-
- def predict_batch(self, image_paths: list) -> list:
- """
- Count locules in multiple images.
-
- Args:
- image_paths: List of paths to images
-
- Returns:
- List[Dict]: List of prediction results
- """
- results = []
- for image_path in image_paths:
- result = self.predict(image_path)
- results.append(result)
- return results
|