locule_model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. """
  2. Locule Model Module
  3. YOLO-based segmentation model for durian locule counting.
  4. Detects and counts locules with colored mask overlays.
  5. """
  6. from pathlib import Path
  7. from typing import Dict, Any, Optional, Tuple
  8. import logging
  9. import cv2
  10. import numpy as np
  11. from ultralytics import YOLO
  12. from PyQt5.QtGui import QImage
  13. from models.base_model import BaseModel
  14. from utils.config import (
  15. LOCULE_MODEL_PATH,
  16. LOCULE_COLORS,
  17. YOLO_CONFIDENCE_THRESHOLD,
  18. )
  19. logger = logging.getLogger(__name__)
  20. class LoculeModel(BaseModel):
  21. """
  22. YOLO-based locule segmentation and counting model.
  23. Detects individual locules in durian cross-section images
  24. and applies colored masks using ROYGBIV color scheme.
  25. Attributes:
  26. model: YOLO segmentation model instance
  27. colors: ROYGBIV color scheme for masks (BGR format)
  28. confidence_threshold: Minimum confidence for detections
  29. """
  30. def __init__(
  31. self,
  32. model_path: Optional[str] = None,
  33. device: str = "cuda",
  34. confidence_threshold: float = YOLO_CONFIDENCE_THRESHOLD
  35. ):
  36. """
  37. Initialize the locule counting model.
  38. Args:
  39. model_path: Path to YOLO .pt segmentation model (optional)
  40. device: Device to use ('cuda' or 'cpu')
  41. confidence_threshold: Minimum confidence for detections (0.0-1.0)
  42. """
  43. if model_path is None:
  44. model_path = str(LOCULE_MODEL_PATH)
  45. super().__init__(model_path, device)
  46. self.colors = LOCULE_COLORS
  47. self.confidence_threshold = confidence_threshold
  48. def load(self) -> bool:
  49. """
  50. Load the YOLO segmentation model.
  51. Returns:
  52. bool: True if loaded successfully, False otherwise
  53. """
  54. try:
  55. model_path = Path(self.model_path)
  56. if not model_path.exists():
  57. logger.error(f"Model file does not exist: {model_path}")
  58. return False
  59. logger.info(f"Loading locule model from {model_path}")
  60. self.model = YOLO(str(model_path))
  61. # Move model to specified device
  62. self.model.to(self.device)
  63. self._is_loaded = True
  64. logger.info(f"Locule model loaded on {self.device}")
  65. return True
  66. except Exception as e:
  67. logger.error(f"Failed to load locule model: {e}")
  68. self._is_loaded = False
  69. return False
  70. def _apply_colored_masks(
  71. self,
  72. image: np.ndarray,
  73. masks: Optional[np.ndarray],
  74. boxes: np.ndarray,
  75. confidences: np.ndarray,
  76. class_names: list
  77. ) -> Tuple[np.ndarray, int]:
  78. """
  79. Apply colored masks and bounding boxes to the image.
  80. Args:
  81. image: Input image (BGR format)
  82. masks: Segmentation masks [N, H, W] or None
  83. boxes: Bounding boxes [N, 4]
  84. confidences: Confidence scores [N]
  85. class_names: Class names for each detection
  86. Returns:
  87. Tuple[np.ndarray, int]: (masked image, valid detection count)
  88. """
  89. masked_image = image.copy()
  90. valid_count = 0
  91. for i, (box, confidence, name) in enumerate(zip(boxes, confidences, class_names)):
  92. # Skip low confidence detections
  93. if confidence < self.confidence_threshold:
  94. continue
  95. valid_count += 1
  96. xmin, ymin, xmax, ymax = map(int, box)
  97. # Get color from ROYGBIV (cycle if more than 7)
  98. color = self.colors[i % len(self.colors)]
  99. # Draw bounding box
  100. cv2.rectangle(
  101. masked_image,
  102. (xmin, ymin),
  103. (xmax, ymax),
  104. color,
  105. 2
  106. )
  107. # Draw label with confidence
  108. label = f"{name}: {confidence:.2f}"
  109. cv2.putText(
  110. masked_image,
  111. label,
  112. (xmin, ymin - 5),
  113. cv2.FONT_HERSHEY_SIMPLEX,
  114. 0.8,
  115. color,
  116. 2,
  117. lineType=cv2.LINE_AA
  118. )
  119. # Apply mask if available
  120. if masks is not None and i < len(masks):
  121. mask = masks[i]
  122. # Resize mask to match image dimensions if needed
  123. if mask.shape[:2] != masked_image.shape[:2]:
  124. mask = cv2.resize(
  125. mask,
  126. (masked_image.shape[1], masked_image.shape[0])
  127. )
  128. # Convert mask to binary
  129. mask_binary = (mask * 255).astype(np.uint8)
  130. # Create colored overlay
  131. colored_mask = np.zeros_like(masked_image, dtype=np.uint8)
  132. for c in range(3):
  133. colored_mask[:, :, c] = mask_binary * (color[c] / 255)
  134. # Blend mask with image
  135. masked_image = cv2.addWeighted(
  136. masked_image,
  137. 1.0,
  138. colored_mask,
  139. 0.5,
  140. 0
  141. )
  142. return masked_image, valid_count
  143. def predict(self, image_path: str) -> Dict[str, Any]:
  144. """
  145. Count locules in a durian cross-section image.
  146. Args:
  147. image_path: Path to input image
  148. Returns:
  149. Dict containing:
  150. - 'success': Whether prediction succeeded
  151. - 'annotated_image': QImage with colored masks
  152. - 'locule_count': Number of detected locules
  153. - 'detections': List of detection details
  154. - 'error': Error message if failed
  155. Raises:
  156. RuntimeError: If model is not loaded
  157. """
  158. if not self._is_loaded:
  159. raise RuntimeError("Model not loaded. Call load() first.")
  160. try:
  161. # Load image
  162. image = cv2.imread(image_path)
  163. if image is None:
  164. raise ValueError(f"Could not load image: {image_path}")
  165. # Get original dimensions and preserve aspect ratio
  166. orig_height, orig_width = image.shape[:2]
  167. # Resize to a standard size while maintaining aspect ratio
  168. # Use 640 as max dimension (common YOLO input size)
  169. max_dim = 640
  170. if orig_width > orig_height:
  171. new_width = max_dim
  172. new_height = int((max_dim / orig_width) * orig_height)
  173. else:
  174. new_height = max_dim
  175. new_width = int((max_dim / orig_height) * orig_width)
  176. image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
  177. # Run YOLO segmentation
  178. results = self.model.predict(image_resized)
  179. detections = []
  180. all_boxes = []
  181. all_confidences = []
  182. all_class_names = []
  183. all_masks = None
  184. # Process results
  185. for result in results:
  186. if result.boxes is None or len(result.boxes) == 0:
  187. continue
  188. boxes = result.boxes.cpu().numpy()
  189. masks = result.masks.data.cpu().numpy() if result.masks is not None else None
  190. if masks is not None:
  191. # Ensure mask count matches box count
  192. masks = masks[:len(boxes)]
  193. all_masks = masks
  194. for idx, box in enumerate(boxes):
  195. confidence = float(box.conf[0])
  196. if confidence < self.confidence_threshold:
  197. continue
  198. xmin, ymin, xmax, ymax = map(float, box.xyxy[0])
  199. class_id = int(box.cls[0])
  200. class_name = self.model.names.get(class_id, f"Locule {idx + 1}")
  201. detections.append({
  202. 'bbox': [xmin, ymin, xmax, ymax],
  203. 'confidence': confidence,
  204. 'class_id': class_id,
  205. 'class_name': class_name,
  206. 'index': idx
  207. })
  208. all_boxes.append([xmin, ymin, xmax, ymax])
  209. all_confidences.append(confidence)
  210. all_class_names.append(class_name)
  211. # Apply colored masks
  212. if len(all_boxes) > 0:
  213. masked_image, locule_count = self._apply_colored_masks(
  214. image_resized,
  215. all_masks,
  216. np.array(all_boxes),
  217. np.array(all_confidences),
  218. all_class_names
  219. )
  220. else:
  221. masked_image = image_resized
  222. locule_count = 0
  223. # Convert to QImage
  224. rgb_image = cv2.cvtColor(masked_image, cv2.COLOR_BGR2RGB)
  225. h, w, ch = rgb_image.shape
  226. bytes_per_line = ch * w
  227. q_image = QImage(
  228. rgb_image.data,
  229. w,
  230. h,
  231. bytes_per_line,
  232. QImage.Format_RGB888
  233. )
  234. logger.info(f"Detected {locule_count} locules")
  235. return {
  236. 'success': True,
  237. 'annotated_image': q_image,
  238. 'locule_count': locule_count,
  239. 'detections': detections,
  240. 'error': None
  241. }
  242. except Exception as e:
  243. logger.error(f"Prediction failed: {e}")
  244. return {
  245. 'success': False,
  246. 'annotated_image': None,
  247. 'locule_count': 0,
  248. 'detections': [],
  249. 'error': str(e)
  250. }
  251. def predict_batch(self, image_paths: list) -> list:
  252. """
  253. Count locules in multiple images.
  254. Args:
  255. image_paths: List of paths to images
  256. Returns:
  257. List[Dict]: List of prediction results
  258. """
  259. results = []
  260. for image_path in image_paths:
  261. result = self.predict(image_path)
  262. results.append(result)
  263. return results