""" Maturity Model Module Multispectral maturity classification model for durian ripeness detection. Uses ResNet18 with 8-channel input for multispectral TIFF image analysis. """ from pathlib import Path from typing import Dict, Any, Optional, Tuple import logging import torch import torch.nn as nn import torchvision.models as M import numpy as np import cv2 import tifffile from PyQt5.QtGui import QImage, QPixmap from models.base_model import BaseModel from utils.config import get_device logger = logging.getLogger(__name__) # ==================== MODEL ARCHITECTURE ==================== def make_resnet18_8ch(num_classes: int): """Initializes a ResNet18 model with an 8-channel input Conv2d layer.""" try: m = M.resnet18(weights=None) except TypeError: m = M.resnet18(pretrained=False) # Modify input convolution to accept 8 channels m.conv1 = nn.Conv2d(8, 64, kernel_size=7, stride=2, padding=3, bias=False) # Modify final fully-connected layer for the number of classes m.fc = nn.Linear(m.fc.in_features, num_classes) return m # ==================== TIFF LOADING HELPERS ==================== def _read_first_page(path): """Reads the first (or only) page from a TIFF file.""" with tifffile.TiffFile(path) as tif: if len(tif.pages) > 1: arr = tif.pages[0].asarray() else: arr = tif.asarray() return arr def _split_2x4_mosaic_to_cube(img2d): """Splits a 2D mosaic (H, W) into an 8-channel cube (h, w, 8).""" if img2d.ndim != 2: raise ValueError(f"Expected 2D mosaic, got {img2d.shape}.") H, W = img2d.shape if (H % 2 != 0) or (W % 4 != 0): raise ValueError(f"Image shape {img2d.shape} not divisible by (2,4).") h2, w4 = H // 2, W // 4 # Tiles are ordered row-by-row tiles = [img2d[r * h2:(r + 1) * h2, c * w4:(c + 1) * w4] for r in range(2) for c in range(4)] cube = np.stack(tiles, axis=-1) return cube def load_and_split_mosaic(path, *, as_float=True, normalize="uint16", eps=1e-6): """ Loads an 8-band TIFF. If it's a 2D mosaic, it splits it. Applies 'uint16' normalization (divide by 65535.0) as used in training. """ arr = _read_first_page(path) if arr.ndim == 3 and arr.shape[-1] == 8: cube = arr # Already a cube elif arr.ndim == 2: cube = _split_2x4_mosaic_to_cube(arr) # Mosaic, needs splitting else: raise ValueError(f"Unsupported TIFF shape {arr.shape}. Expect 2D mosaic or (H,W,8).") cube = np.ascontiguousarray(cube) if normalize == "uint16": if cube.dtype == np.uint16: cube = cube.astype(np.float32) / 65535.0 else: # Fallback for non-uint16 cmin, cmax = float(cube.min()), float(cube.max()) denom = (cmax - cmin) if (cmax - cmin) > eps else 1.0 cube = (cube.astype(np.float32) - cmin) / denom if as_float and cube.dtype != np.float32: cube = cube.astype(np.float32) return cube # ==================== MASKING PIPELINE ==================== def _odd(k): k = int(k) return k if k % 2 == 1 else k + 1 def _robust_u8(x, p_lo=1, p_hi=99, eps=1e-6): lo, hi = np.percentile(x, [p_lo, p_hi]) if hi - lo < eps: lo, hi = float(x.min()), float(x.max()) if float(x.max()) > float(x.min()) else (0.0, 1.0) y = (x - lo) / (hi - lo + eps) return (np.clip(y, 0, 1) * 255).astype(np.uint8) def _clear_border(mask255): lab_n, lab, _, _ = cv2.connectedComponentsWithStats((mask255 > 0).astype(np.uint8), connectivity=4) if lab_n <= 1: return mask255 H, W = mask255.shape edge = set(np.concatenate([lab[0, :], lab[H-1, :], lab[:, 0], lab[:, W-1]]).tolist()) edge.discard(0) out = mask255.copy() if edge: out[np.isin(lab, list(edge))] = 0 return out def _largest_cc(mask255): num, lab, stats, _ = cv2.connectedComponentsWithStats((mask255 > 0).astype(np.uint8), connectivity=4) if num <= 1: return mask255 largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA])) return ((lab == largest).astype(np.uint8) * 255) def _fill_holes(mask255): inv = (mask255 == 0).astype(np.uint8) num, lab, _, _ = cv2.connectedComponentsWithStats(inv, connectivity=4) if num <= 1: return mask255 H, W = mask255.shape border_labels = set(np.concatenate([lab[0, :], lab[H-1, :], lab[:, 0], lab[:, W-1]]).tolist()) out = mask255.copy() for k in range(1, num): if k not in border_labels: out[lab == k] = 255 return out def otsu_filled_robust_mask( cube, *, band_index=4, blur_ksize=81, p_lo=1, p_hi=99, close_ksize=5, close_iters=1, clear_border=False, auto_invert=True, ): """The full masking function from Cell 4.""" band = cube[..., band_index].astype(np.float32) k = _odd(blur_ksize) if blur_ksize and blur_ksize > 1 else 0 bg = cv2.GaussianBlur(band, (k, k), 0) if k else np.median(band) diff = band - bg diff8 = _robust_u8(diff, p_lo=p_lo, p_hi=p_hi) T, _ = cv2.threshold(diff8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) mask = (diff8 >= T).astype(np.uint8) * 255 if auto_invert: H, W = mask.shape b = max(1, min(H, W) // 10) border_pixels = np.concatenate([ mask[:b, :].ravel(), mask[-b:, :].ravel(), mask[:, :b].ravel(), mask[:, -b:].ravel() ]) if border_pixels.size > 0 and float(border_pixels.mean()) > 127: mask = 255 - mask if close_ksize and close_ksize > 1: se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (_odd(close_ksize), _odd(close_ksize))) for _ in range(max(1, int(close_iters))): mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, se) mask = _largest_cc(mask) if clear_border: mask = _clear_border(mask) mask = _fill_holes(mask) return mask.astype(np.uint8) def make_mask_filled(cube, *, band_index=4, blur_ksize=81, close_ksize=4, **kwargs): """The 'SOT mask maker' from Cell 4b, used in your dataset.""" mask = otsu_filled_robust_mask( cube, band_index=band_index, blur_ksize=blur_ksize, close_ksize=close_ksize, **kwargs ) return mask # ==================== CROPPING HELPERS ==================== def crop_square_from_mask(cube, mask, pad=8): """Finds mask bbox, adds padding, and makes it square. Returns (crop, (y0,y1,x0,x1))).""" H, W = mask.shape ys, xs = np.where(mask > 0) if len(ys) == 0: side = int(0.8 * min(H, W)) y0 = (H - side) // 2 x0 = (W - side) // 2 y1, x1 = y0 + side, x0 + side else: y0, y1 = ys.min(), ys.max() x0, x1 = xs.min(), xs.max() side = max(y1 - y0 + 1, x1 - x0 + 1) + 2 * pad cy, cx = (y0 + y1) // 2, (x0 + x1) // 2 y0 = max(0, cy - side // 2) x0 = max(0, cx - side // 2) y1 = min(H, y0 + side) x1 = min(W, x0 + side) y0 = max(0, y1 - side) x0 = max(0, x1 - side) return cube[y0:y1, x0:x1, :], (y0, y1, x0, x1) def crop_and_resize_square(cube, mask, size=256, pad=8): """Crops and resizes to the final (size, size) square.""" crop, (y0, y1, x0, x1) = crop_square_from_mask(cube, mask, pad=pad) crop = cv2.resize(crop, (size, size), interpolation=cv2.INTER_AREA) return crop, (y0, y1, x0, x1) # ==================== GRAD-CAM ==================== def gradcam(model, img_tensor, target_class=None, layer_name='layer4'): """Computes Grad-CAM heatmap for the model.""" model.eval() activations, gradients = {}, {} def f_hook(_, __, out): activations['v'] = out def b_hook(_, grad_in, grad_out): gradients['v'] = grad_out[0] try: layer = dict([*model.named_modules()])[layer_name] except KeyError: raise ValueError(f"Layer '{layer_name}' not found in model.") fh = layer.register_forward_hook(f_hook) bh = layer.register_backward_hook(b_hook) # Forward pass out = model(img_tensor) if target_class is None: target_class = int(out.argmax(1).item()) # Backward pass loss = out[0, target_class] model.zero_grad() loss.backward() # Get activations and gradients acts = activations['v'][0] # (C, H, W) grads = gradients['v'][0] # (C, H, W) # Compute weights and CAM weights = grads.mean(dim=(1, 2)) # (C,) cam = (weights[:, None, None] * acts).sum(0).detach().cpu().numpy() cam = np.maximum(cam, 0) cam /= (cam.max() + 1e-6) fh.remove() bh.remove() return cam, target_class # ==================== MATURITY MODEL CLASS ==================== class MaturityModel(BaseModel): """ Multispectral maturity classification model. Uses ResNet18 with 8-channel input for multispectral TIFF image analysis. Processes 8-band multispectral images and predicts maturity/ripeness classes. Attributes: model: ResNet18 model instance class_names: List of class names mean: Mean values for normalization (1, 1, 8) std: Std values for normalization (1, 1, 8) mask_band_index: Band index used for masking (default: 4, 860nm) img_size: Target image size after preprocessing (default: 256) img_pad: Padding for cropping (default: 8) """ def __init__( self, model_path: Optional[str] = None, device: Optional[str] = None, mask_band_index: int = 4, img_size: int = 256, img_pad: int = 8 ): """ Initialize the maturity model. Args: model_path: Path to the .pt model file (optional) device: Device to use ('cuda' or 'cpu', optional) mask_band_index: Band index for masking (default: 4, 860nm) img_size: Target image size (default: 256) img_pad: Padding for cropping (default: 8) """ from utils.config import MATURITY_MODEL_PATH if model_path is None: model_path = str(MATURITY_MODEL_PATH) if device is None: device = get_device() super().__init__(model_path, device) self.mask_band_index = mask_band_index self.img_size = img_size self.img_pad = img_pad self.class_names = [] self.mean = None self.std = None def load(self) -> bool: """ Load the maturity model from checkpoint. Returns: bool: True if loaded successfully, False otherwise """ try: model_path = Path(self.model_path) if not model_path.exists(): logger.error(f"Model file does not exist: {model_path}") return False logger.info(f"Loading maturity model from {model_path}") # Load checkpoint ckpt = torch.load(model_path, map_location=self.device) # Extract class names, mean, and std self.class_names = ckpt['class_names'] # Reshape mean/std for broadcasting (1, 1, 8) self.mean = ckpt['mean'].reshape(1, 1, 8) self.std = ckpt['std'].reshape(1, 1, 8) # Re-create model architecture self.model = make_resnet18_8ch(num_classes=len(self.class_names)) self.model.load_state_dict(ckpt['state_dict']) self.model.to(self.device) self.model.eval() self._is_loaded = True logger.info(f"Maturity model loaded on {self.device}. Classes: {self.class_names}") return True except Exception as e: logger.error(f"Failed to load maturity model: {e}") import traceback logger.error(traceback.format_exc()) self._is_loaded = False return False def _preprocess(self, tif_path): """ Runs the full preprocessing pipeline. 1. Load 8-band cube 2. Get mask 3. Crop and resize 4. Apply pixel mask 5. Normalize 6. Convert to Tensor Returns: Tuple[torch.Tensor, np.ndarray]: (tensor for model, visual crop for display) """ # 1. Load 8-band cube cube = load_and_split_mosaic(str(tif_path), as_float=True, normalize="uint16") # 2. Get mask (using the exact params from training) mask_full = make_mask_filled( cube, band_index=self.mask_band_index, blur_ksize=81, close_ksize=4 ) # 3. Crop and resize crop, (y0, y1, x0, x1) = crop_and_resize_square( cube, mask_full, size=self.img_size, pad=self.img_pad ) # 4. Apply pixel mask (critical step from training) mask_crop = mask_full[y0:y1, x0:x1] mask_resz = cv2.resize(mask_crop, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST) # Create a {0, 1} mask and broadcast it m = (mask_resz > 0).astype(crop.dtype) # Zero out the background pixels in all 8 channels crop_masked = crop * m[..., None] # 5. Normalize norm_crop = (crop_masked - self.mean) / self.std # 6. To Tensor # (H, W, C) -> (C, H, W) -> (B, C, H, W) x = torch.from_numpy(norm_crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device) # Return tensor for model and the visible (un-normalized, masked) crop for visualization return x, crop_masked def predict(self, tif_path: str) -> Dict[str, Any]: """ Run prediction on a multispectral TIFF file. Args: tif_path: Path to .tif file (8-band multispectral) Returns: Dict containing: - 'success': Whether prediction succeeded - 'prediction': Predicted class name - 'confidence': Confidence score (0-1) - 'probabilities': Dictionary of class probabilities - 'error': Error message if failed Raises: RuntimeError: If model is not loaded """ if not self._is_loaded: raise RuntimeError("Model not loaded. Call load() first.") try: # Preprocess img_tensor, _ = self._preprocess(tif_path) # Run inference with torch.no_grad(): logits = self.model(img_tensor) probs = torch.softmax(logits, dim=1)[0] pred_idx = logits.argmax(1).item() probs_cpu = probs.cpu().numpy() return { 'success': True, 'prediction': self.class_names[pred_idx], 'confidence': float(probs_cpu[pred_idx]), 'probabilities': {name: float(prob) for name, prob in zip(self.class_names, probs_cpu)}, 'error': None } except Exception as e: logger.error(f"Prediction failed: {e}") import traceback logger.error(traceback.format_exc()) return { 'success': False, 'prediction': None, 'confidence': 0.0, 'probabilities': {}, 'error': str(e) } def run_gradcam(self, tif_path: str, band_to_show: Optional[int] = None) -> Tuple[Optional[QImage], Optional[str]]: """ Run Grad-CAM visualization on a .tif file. Args: tif_path: Path to input .tif file band_to_show: Which of the 8 bands to use as background (default: mask_band_index) Returns: Tuple[QImage, str]: (overlay_image, predicted_class_name) or (None, None) if failed """ if not self._is_loaded: raise RuntimeError("Model not loaded. Call load() first.") if band_to_show is None: band_to_show = self.mask_band_index try: img_tensor, crop_img = self._preprocess(tif_path) # Run Grad-CAM heatmap, pred_idx = gradcam( self.model, img_tensor, target_class=None, layer_name='layer4' ) pred_name = self.class_names[pred_idx] # Create visualization # Get the specific band to show (it's already float [0,1] and masked) band_img = crop_img[..., band_to_show] # Normalize band to [0,255] uint8 for display band_u8 = (np.clip(band_img, 0, 1) * 255).astype(np.uint8) vis_img = cv2.cvtColor(band_u8, cv2.COLOR_GRAY2BGR) # Resize heatmap and apply colormap hm_resized = (cv2.resize(heatmap, (self.img_size, self.img_size)) * 255).astype(np.uint8) hm_color = cv2.applyColorMap(hm_resized, cv2.COLORMAP_JET) # Create overlay overlay = cv2.addWeighted(vis_img, 0.6, hm_color, 0.4, 0) # Apply overlay only to fruit pixels mask = (band_u8 > 0) final_vis = np.zeros_like(vis_img) final_vis[mask] = overlay[mask] # Convert to QImage rgb_image = cv2.cvtColor(final_vis, cv2.COLOR_BGR2RGB) rgb_image = np.ascontiguousarray(rgb_image) # Ensure contiguous memory h, w, ch = rgb_image.shape bytes_per_line = ch * w # Create QImage with copied data to prevent memory issues q_image = QImage(rgb_image.tobytes(), w, h, bytes_per_line, QImage.Format_RGB888) q_image = q_image.copy() # Make a deep copy to own the memory return q_image, pred_name except Exception as e: logger.error(f"Grad-CAM failed: {e}") import traceback logger.error(traceback.format_exc()) return None, None def predict_batch(self, tif_paths: list) -> list: """ Predict maturity for multiple TIFF files. Args: tif_paths: List of paths to .tif files Returns: List[Dict]: List of prediction results """ results = [] for tif_path in tif_paths: result = self.predict(tif_path) results.append(result) return results