maturity_model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. """
  2. Maturity Model Module
  3. Multispectral maturity classification model for durian ripeness detection.
  4. Uses ResNet18 with 8-channel input for multispectral TIFF image analysis.
  5. """
  6. from pathlib import Path
  7. from typing import Dict, Any, Optional, Tuple
  8. import logging
  9. import torch
  10. import torch.nn as nn
  11. import torchvision.models as M
  12. import numpy as np
  13. import cv2
  14. import tifffile
  15. from PyQt5.QtGui import QImage, QPixmap
  16. from models.base_model import BaseModel
  17. from utils.config import get_device
  18. logger = logging.getLogger(__name__)
  19. # ==================== MODEL ARCHITECTURE ====================
  20. def make_resnet18_8ch(num_classes: int):
  21. """Initializes a ResNet18 model with an 8-channel input Conv2d layer."""
  22. try:
  23. m = M.resnet18(weights=None)
  24. except TypeError:
  25. m = M.resnet18(pretrained=False)
  26. # Modify input convolution to accept 8 channels
  27. m.conv1 = nn.Conv2d(8, 64, kernel_size=7, stride=2, padding=3, bias=False)
  28. # Modify final fully-connected layer for the number of classes
  29. m.fc = nn.Linear(m.fc.in_features, num_classes)
  30. return m
  31. # ==================== TIFF LOADING HELPERS ====================
  32. def _read_first_page(path):
  33. """Reads the first (or only) page from a TIFF file."""
  34. with tifffile.TiffFile(path) as tif:
  35. if len(tif.pages) > 1:
  36. arr = tif.pages[0].asarray()
  37. else:
  38. arr = tif.asarray()
  39. return arr
  40. def _split_2x4_mosaic_to_cube(img2d):
  41. """Splits a 2D mosaic (H, W) into an 8-channel cube (h, w, 8)."""
  42. if img2d.ndim != 2:
  43. raise ValueError(f"Expected 2D mosaic, got {img2d.shape}.")
  44. H, W = img2d.shape
  45. if (H % 2 != 0) or (W % 4 != 0):
  46. raise ValueError(f"Image shape {img2d.shape} not divisible by (2,4).")
  47. h2, w4 = H // 2, W // 4
  48. # Tiles are ordered row-by-row
  49. tiles = [img2d[r * h2:(r + 1) * h2, c * w4:(c + 1) * w4] for r in range(2) for c in range(4)]
  50. cube = np.stack(tiles, axis=-1)
  51. return cube
  52. def load_and_split_mosaic(path, *, as_float=True, normalize="uint16", eps=1e-6):
  53. """
  54. Loads an 8-band TIFF. If it's a 2D mosaic, it splits it.
  55. Applies 'uint16' normalization (divide by 65535.0) as used in training.
  56. """
  57. arr = _read_first_page(path)
  58. if arr.ndim == 3 and arr.shape[-1] == 8:
  59. cube = arr # Already a cube
  60. elif arr.ndim == 2:
  61. cube = _split_2x4_mosaic_to_cube(arr) # Mosaic, needs splitting
  62. else:
  63. raise ValueError(f"Unsupported TIFF shape {arr.shape}. Expect 2D mosaic or (H,W,8).")
  64. cube = np.ascontiguousarray(cube)
  65. if normalize == "uint16":
  66. if cube.dtype == np.uint16:
  67. cube = cube.astype(np.float32) / 65535.0
  68. else:
  69. # Fallback for non-uint16
  70. cmin, cmax = float(cube.min()), float(cube.max())
  71. denom = (cmax - cmin) if (cmax - cmin) > eps else 1.0
  72. cube = (cube.astype(np.float32) - cmin) / denom
  73. if as_float and cube.dtype != np.float32:
  74. cube = cube.astype(np.float32)
  75. return cube
  76. # ==================== MASKING PIPELINE ====================
  77. def _odd(k):
  78. k = int(k)
  79. return k if k % 2 == 1 else k + 1
  80. def _robust_u8(x, p_lo=1, p_hi=99, eps=1e-6):
  81. lo, hi = np.percentile(x, [p_lo, p_hi])
  82. if hi - lo < eps:
  83. lo, hi = float(x.min()), float(x.max()) if float(x.max()) > float(x.min()) else (0.0, 1.0)
  84. y = (x - lo) / (hi - lo + eps)
  85. return (np.clip(y, 0, 1) * 255).astype(np.uint8)
  86. def _clear_border(mask255):
  87. lab_n, lab, _, _ = cv2.connectedComponentsWithStats((mask255 > 0).astype(np.uint8), connectivity=4)
  88. if lab_n <= 1:
  89. return mask255
  90. H, W = mask255.shape
  91. edge = set(np.concatenate([lab[0, :], lab[H-1, :], lab[:, 0], lab[:, W-1]]).tolist())
  92. edge.discard(0)
  93. out = mask255.copy()
  94. if edge:
  95. out[np.isin(lab, list(edge))] = 0
  96. return out
  97. def _largest_cc(mask255):
  98. num, lab, stats, _ = cv2.connectedComponentsWithStats((mask255 > 0).astype(np.uint8), connectivity=4)
  99. if num <= 1:
  100. return mask255
  101. largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA]))
  102. return ((lab == largest).astype(np.uint8) * 255)
  103. def _fill_holes(mask255):
  104. inv = (mask255 == 0).astype(np.uint8)
  105. num, lab, _, _ = cv2.connectedComponentsWithStats(inv, connectivity=4)
  106. if num <= 1:
  107. return mask255
  108. H, W = mask255.shape
  109. border_labels = set(np.concatenate([lab[0, :], lab[H-1, :], lab[:, 0], lab[:, W-1]]).tolist())
  110. out = mask255.copy()
  111. for k in range(1, num):
  112. if k not in border_labels:
  113. out[lab == k] = 255
  114. return out
  115. def otsu_filled_robust_mask(
  116. cube, *, band_index=4, blur_ksize=81, p_lo=1, p_hi=99,
  117. close_ksize=5, close_iters=1, clear_border=False, auto_invert=True,
  118. ):
  119. """The full masking function from Cell 4."""
  120. band = cube[..., band_index].astype(np.float32)
  121. k = _odd(blur_ksize) if blur_ksize and blur_ksize > 1 else 0
  122. bg = cv2.GaussianBlur(band, (k, k), 0) if k else np.median(band)
  123. diff = band - bg
  124. diff8 = _robust_u8(diff, p_lo=p_lo, p_hi=p_hi)
  125. T, _ = cv2.threshold(diff8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  126. mask = (diff8 >= T).astype(np.uint8) * 255
  127. if auto_invert:
  128. H, W = mask.shape
  129. b = max(1, min(H, W) // 10)
  130. border_pixels = np.concatenate([
  131. mask[:b, :].ravel(), mask[-b:, :].ravel(),
  132. mask[:, :b].ravel(), mask[:, -b:].ravel()
  133. ])
  134. if border_pixels.size > 0 and float(border_pixels.mean()) > 127:
  135. mask = 255 - mask
  136. if close_ksize and close_ksize > 1:
  137. se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (_odd(close_ksize), _odd(close_ksize)))
  138. for _ in range(max(1, int(close_iters))):
  139. mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, se)
  140. mask = _largest_cc(mask)
  141. if clear_border:
  142. mask = _clear_border(mask)
  143. mask = _fill_holes(mask)
  144. return mask.astype(np.uint8)
  145. def make_mask_filled(cube, *, band_index=4, blur_ksize=81, close_ksize=4, **kwargs):
  146. """The 'SOT mask maker' from Cell 4b, used in your dataset."""
  147. mask = otsu_filled_robust_mask(
  148. cube,
  149. band_index=band_index,
  150. blur_ksize=blur_ksize,
  151. close_ksize=close_ksize,
  152. **kwargs
  153. )
  154. return mask
  155. # ==================== CROPPING HELPERS ====================
  156. def crop_square_from_mask(cube, mask, pad=8):
  157. """Finds mask bbox, adds padding, and makes it square. Returns (crop, (y0,y1,x0,x1)))."""
  158. H, W = mask.shape
  159. ys, xs = np.where(mask > 0)
  160. if len(ys) == 0:
  161. side = int(0.8 * min(H, W))
  162. y0 = (H - side) // 2
  163. x0 = (W - side) // 2
  164. y1, x1 = y0 + side, x0 + side
  165. else:
  166. y0, y1 = ys.min(), ys.max()
  167. x0, x1 = xs.min(), xs.max()
  168. side = max(y1 - y0 + 1, x1 - x0 + 1) + 2 * pad
  169. cy, cx = (y0 + y1) // 2, (x0 + x1) // 2
  170. y0 = max(0, cy - side // 2)
  171. x0 = max(0, cx - side // 2)
  172. y1 = min(H, y0 + side)
  173. x1 = min(W, x0 + side)
  174. y0 = max(0, y1 - side)
  175. x0 = max(0, x1 - side)
  176. return cube[y0:y1, x0:x1, :], (y0, y1, x0, x1)
  177. def crop_and_resize_square(cube, mask, size=256, pad=8):
  178. """Crops and resizes to the final (size, size) square."""
  179. crop, (y0, y1, x0, x1) = crop_square_from_mask(cube, mask, pad=pad)
  180. crop = cv2.resize(crop, (size, size), interpolation=cv2.INTER_AREA)
  181. return crop, (y0, y1, x0, x1)
  182. # ==================== GRAD-CAM ====================
  183. def gradcam(model, img_tensor, target_class=None, layer_name='layer4'):
  184. """Computes Grad-CAM heatmap for the model."""
  185. model.eval()
  186. activations, gradients = {}, {}
  187. def f_hook(_, __, out):
  188. activations['v'] = out
  189. def b_hook(_, grad_in, grad_out):
  190. gradients['v'] = grad_out[0]
  191. try:
  192. layer = dict([*model.named_modules()])[layer_name]
  193. except KeyError:
  194. raise ValueError(f"Layer '{layer_name}' not found in model.")
  195. fh = layer.register_forward_hook(f_hook)
  196. bh = layer.register_backward_hook(b_hook)
  197. # Forward pass
  198. out = model(img_tensor)
  199. if target_class is None:
  200. target_class = int(out.argmax(1).item())
  201. # Backward pass
  202. loss = out[0, target_class]
  203. model.zero_grad()
  204. loss.backward()
  205. # Get activations and gradients
  206. acts = activations['v'][0] # (C, H, W)
  207. grads = gradients['v'][0] # (C, H, W)
  208. # Compute weights and CAM
  209. weights = grads.mean(dim=(1, 2)) # (C,)
  210. cam = (weights[:, None, None] * acts).sum(0).detach().cpu().numpy()
  211. cam = np.maximum(cam, 0)
  212. cam /= (cam.max() + 1e-6)
  213. fh.remove()
  214. bh.remove()
  215. return cam, target_class
  216. # ==================== MATURITY MODEL CLASS ====================
  217. class MaturityModel(BaseModel):
  218. """
  219. Multispectral maturity classification model.
  220. Uses ResNet18 with 8-channel input for multispectral TIFF image analysis.
  221. Processes 8-band multispectral images and predicts maturity/ripeness classes.
  222. Attributes:
  223. model: ResNet18 model instance
  224. class_names: List of class names
  225. mean: Mean values for normalization (1, 1, 8)
  226. std: Std values for normalization (1, 1, 8)
  227. mask_band_index: Band index used for masking (default: 4, 860nm)
  228. img_size: Target image size after preprocessing (default: 256)
  229. img_pad: Padding for cropping (default: 8)
  230. """
  231. def __init__(
  232. self,
  233. model_path: Optional[str] = None,
  234. device: Optional[str] = None,
  235. mask_band_index: int = 4,
  236. img_size: int = 256,
  237. img_pad: int = 8
  238. ):
  239. """
  240. Initialize the maturity model.
  241. Args:
  242. model_path: Path to the .pt model file (optional)
  243. device: Device to use ('cuda' or 'cpu', optional)
  244. mask_band_index: Band index for masking (default: 4, 860nm)
  245. img_size: Target image size (default: 256)
  246. img_pad: Padding for cropping (default: 8)
  247. """
  248. from utils.config import MATURITY_MODEL_PATH
  249. if model_path is None:
  250. model_path = str(MATURITY_MODEL_PATH)
  251. if device is None:
  252. device = get_device()
  253. super().__init__(model_path, device)
  254. self.mask_band_index = mask_band_index
  255. self.img_size = img_size
  256. self.img_pad = img_pad
  257. self.class_names = []
  258. self.mean = None
  259. self.std = None
  260. def load(self) -> bool:
  261. """
  262. Load the maturity model from checkpoint.
  263. Returns:
  264. bool: True if loaded successfully, False otherwise
  265. """
  266. try:
  267. model_path = Path(self.model_path)
  268. if not model_path.exists():
  269. logger.error(f"Model file does not exist: {model_path}")
  270. return False
  271. logger.info(f"Loading maturity model from {model_path}")
  272. # Load checkpoint
  273. ckpt = torch.load(model_path, map_location=self.device)
  274. # Extract class names, mean, and std
  275. self.class_names = ckpt['class_names']
  276. # Reshape mean/std for broadcasting (1, 1, 8)
  277. self.mean = ckpt['mean'].reshape(1, 1, 8)
  278. self.std = ckpt['std'].reshape(1, 1, 8)
  279. # Re-create model architecture
  280. self.model = make_resnet18_8ch(num_classes=len(self.class_names))
  281. self.model.load_state_dict(ckpt['state_dict'])
  282. self.model.to(self.device)
  283. self.model.eval()
  284. self._is_loaded = True
  285. logger.info(f"Maturity model loaded on {self.device}. Classes: {self.class_names}")
  286. return True
  287. except Exception as e:
  288. logger.error(f"Failed to load maturity model: {e}")
  289. import traceback
  290. logger.error(traceback.format_exc())
  291. self._is_loaded = False
  292. return False
  293. def _preprocess(self, tif_path):
  294. """
  295. Runs the full preprocessing pipeline.
  296. 1. Load 8-band cube
  297. 2. Get mask
  298. 3. Crop and resize
  299. 4. Apply pixel mask
  300. 5. Normalize
  301. 6. Convert to Tensor
  302. Returns:
  303. Tuple[torch.Tensor, np.ndarray]: (tensor for model, visual crop for display)
  304. """
  305. # 1. Load 8-band cube
  306. cube = load_and_split_mosaic(str(tif_path), as_float=True, normalize="uint16")
  307. # 2. Get mask (using the exact params from training)
  308. mask_full = make_mask_filled(
  309. cube,
  310. band_index=self.mask_band_index,
  311. blur_ksize=81,
  312. close_ksize=4
  313. )
  314. # 3. Crop and resize
  315. crop, (y0, y1, x0, x1) = crop_and_resize_square(
  316. cube,
  317. mask_full,
  318. size=self.img_size,
  319. pad=self.img_pad
  320. )
  321. # 4. Apply pixel mask (critical step from training)
  322. mask_crop = mask_full[y0:y1, x0:x1]
  323. mask_resz = cv2.resize(mask_crop, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
  324. # Create a {0, 1} mask and broadcast it
  325. m = (mask_resz > 0).astype(crop.dtype)
  326. # Zero out the background pixels in all 8 channels
  327. crop_masked = crop * m[..., None]
  328. # 5. Normalize
  329. norm_crop = (crop_masked - self.mean) / self.std
  330. # 6. To Tensor
  331. # (H, W, C) -> (C, H, W) -> (B, C, H, W)
  332. x = torch.from_numpy(norm_crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
  333. # Return tensor for model and the visible (un-normalized, masked) crop for visualization
  334. return x, crop_masked
  335. def predict(self, tif_path: str) -> Dict[str, Any]:
  336. """
  337. Run prediction on a multispectral TIFF file.
  338. Args:
  339. tif_path: Path to .tif file (8-band multispectral)
  340. Returns:
  341. Dict containing:
  342. - 'success': Whether prediction succeeded
  343. - 'prediction': Predicted class name
  344. - 'confidence': Confidence score (0-1)
  345. - 'probabilities': Dictionary of class probabilities
  346. - 'error': Error message if failed
  347. Raises:
  348. RuntimeError: If model is not loaded
  349. """
  350. if not self._is_loaded:
  351. raise RuntimeError("Model not loaded. Call load() first.")
  352. try:
  353. # Preprocess
  354. img_tensor, _ = self._preprocess(tif_path)
  355. # Run inference
  356. with torch.no_grad():
  357. logits = self.model(img_tensor)
  358. probs = torch.softmax(logits, dim=1)[0]
  359. pred_idx = logits.argmax(1).item()
  360. probs_cpu = probs.cpu().numpy()
  361. return {
  362. 'success': True,
  363. 'prediction': self.class_names[pred_idx],
  364. 'confidence': float(probs_cpu[pred_idx]),
  365. 'probabilities': {name: float(prob) for name, prob in zip(self.class_names, probs_cpu)},
  366. 'error': None
  367. }
  368. except Exception as e:
  369. logger.error(f"Prediction failed: {e}")
  370. import traceback
  371. logger.error(traceback.format_exc())
  372. return {
  373. 'success': False,
  374. 'prediction': None,
  375. 'confidence': 0.0,
  376. 'probabilities': {},
  377. 'error': str(e)
  378. }
  379. def run_gradcam(self, tif_path: str, band_to_show: Optional[int] = None) -> Tuple[Optional[QImage], Optional[str]]:
  380. """
  381. Run Grad-CAM visualization on a .tif file.
  382. Args:
  383. tif_path: Path to input .tif file
  384. band_to_show: Which of the 8 bands to use as background (default: mask_band_index)
  385. Returns:
  386. Tuple[QImage, str]: (overlay_image, predicted_class_name) or (None, None) if failed
  387. """
  388. if not self._is_loaded:
  389. raise RuntimeError("Model not loaded. Call load() first.")
  390. if band_to_show is None:
  391. band_to_show = self.mask_band_index
  392. try:
  393. img_tensor, crop_img = self._preprocess(tif_path)
  394. # Run Grad-CAM
  395. heatmap, pred_idx = gradcam(
  396. self.model,
  397. img_tensor,
  398. target_class=None,
  399. layer_name='layer4'
  400. )
  401. pred_name = self.class_names[pred_idx]
  402. # Create visualization
  403. # Get the specific band to show (it's already float [0,1] and masked)
  404. band_img = crop_img[..., band_to_show]
  405. # Normalize band to [0,255] uint8 for display
  406. band_u8 = (np.clip(band_img, 0, 1) * 255).astype(np.uint8)
  407. vis_img = cv2.cvtColor(band_u8, cv2.COLOR_GRAY2BGR)
  408. # Resize heatmap and apply colormap
  409. hm_resized = (cv2.resize(heatmap, (self.img_size, self.img_size)) * 255).astype(np.uint8)
  410. hm_color = cv2.applyColorMap(hm_resized, cv2.COLORMAP_JET)
  411. # Create overlay
  412. overlay = cv2.addWeighted(vis_img, 0.6, hm_color, 0.4, 0)
  413. # Apply overlay only to fruit pixels
  414. mask = (band_u8 > 0)
  415. final_vis = np.zeros_like(vis_img)
  416. final_vis[mask] = overlay[mask]
  417. # Convert to QImage
  418. rgb_image = cv2.cvtColor(final_vis, cv2.COLOR_BGR2RGB)
  419. rgb_image = np.ascontiguousarray(rgb_image) # Ensure contiguous memory
  420. h, w, ch = rgb_image.shape
  421. bytes_per_line = ch * w
  422. # Create QImage with copied data to prevent memory issues
  423. q_image = QImage(rgb_image.tobytes(), w, h, bytes_per_line, QImage.Format_RGB888)
  424. q_image = q_image.copy() # Make a deep copy to own the memory
  425. return q_image, pred_name
  426. except Exception as e:
  427. logger.error(f"Grad-CAM failed: {e}")
  428. import traceback
  429. logger.error(traceback.format_exc())
  430. return None, None
  431. def predict_batch(self, tif_paths: list) -> list:
  432. """
  433. Predict maturity for multiple TIFF files.
  434. Args:
  435. tif_paths: List of paths to .tif files
  436. Returns:
  437. List[Dict]: List of prediction results
  438. """
  439. results = []
  440. for tif_path in tif_paths:
  441. result = self.predict(tif_path)
  442. results.append(result)
  443. return results