| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- """
- Maturity Model Module
- Multispectral maturity classification model for durian ripeness detection.
- Uses ResNet18 with 8-channel input for multispectral TIFF image analysis.
- """
- from pathlib import Path
- from typing import Dict, Any, Optional, Tuple
- import logging
- import torch
- import torch.nn as nn
- import torchvision.models as M
- import numpy as np
- import cv2
- import tifffile
- from PyQt5.QtGui import QImage, QPixmap
- from models.base_model import BaseModel
- from utils.config import get_device
- logger = logging.getLogger(__name__)
- # ==================== MODEL ARCHITECTURE ====================
- def make_resnet18_8ch(num_classes: int):
- """Initializes a ResNet18 model with an 8-channel input Conv2d layer."""
- try:
- m = M.resnet18(weights=None)
- except TypeError:
- m = M.resnet18(pretrained=False)
- # Modify input convolution to accept 8 channels
- m.conv1 = nn.Conv2d(8, 64, kernel_size=7, stride=2, padding=3, bias=False)
- # Modify final fully-connected layer for the number of classes
- m.fc = nn.Linear(m.fc.in_features, num_classes)
- return m
- # ==================== TIFF LOADING HELPERS ====================
- def _read_first_page(path):
- """Reads the first (or only) page from a TIFF file."""
- with tifffile.TiffFile(path) as tif:
- if len(tif.pages) > 1:
- arr = tif.pages[0].asarray()
- else:
- arr = tif.asarray()
- return arr
- def _split_2x4_mosaic_to_cube(img2d):
- """Splits a 2D mosaic (H, W) into an 8-channel cube (h, w, 8)."""
- if img2d.ndim != 2:
- raise ValueError(f"Expected 2D mosaic, got {img2d.shape}.")
- H, W = img2d.shape
- if (H % 2 != 0) or (W % 4 != 0):
- raise ValueError(f"Image shape {img2d.shape} not divisible by (2,4).")
- h2, w4 = H // 2, W // 4
- # Tiles are ordered row-by-row
- tiles = [img2d[r * h2:(r + 1) * h2, c * w4:(c + 1) * w4] for r in range(2) for c in range(4)]
- cube = np.stack(tiles, axis=-1)
- return cube
- def load_and_split_mosaic(path, *, as_float=True, normalize="uint16", eps=1e-6):
- """
- Loads an 8-band TIFF. If it's a 2D mosaic, it splits it.
- Applies 'uint16' normalization (divide by 65535.0) as used in training.
- """
- arr = _read_first_page(path)
- if arr.ndim == 3 and arr.shape[-1] == 8:
- cube = arr # Already a cube
- elif arr.ndim == 2:
- cube = _split_2x4_mosaic_to_cube(arr) # Mosaic, needs splitting
- else:
- raise ValueError(f"Unsupported TIFF shape {arr.shape}. Expect 2D mosaic or (H,W,8).")
- cube = np.ascontiguousarray(cube)
- if normalize == "uint16":
- if cube.dtype == np.uint16:
- cube = cube.astype(np.float32) / 65535.0
- else:
- # Fallback for non-uint16
- cmin, cmax = float(cube.min()), float(cube.max())
- denom = (cmax - cmin) if (cmax - cmin) > eps else 1.0
- cube = (cube.astype(np.float32) - cmin) / denom
- if as_float and cube.dtype != np.float32:
- cube = cube.astype(np.float32)
- return cube
- # ==================== MASKING PIPELINE ====================
- def _odd(k):
- k = int(k)
- return k if k % 2 == 1 else k + 1
- def _robust_u8(x, p_lo=1, p_hi=99, eps=1e-6):
- lo, hi = np.percentile(x, [p_lo, p_hi])
- if hi - lo < eps:
- lo, hi = float(x.min()), float(x.max()) if float(x.max()) > float(x.min()) else (0.0, 1.0)
- y = (x - lo) / (hi - lo + eps)
- return (np.clip(y, 0, 1) * 255).astype(np.uint8)
- def _clear_border(mask255):
- lab_n, lab, _, _ = cv2.connectedComponentsWithStats((mask255 > 0).astype(np.uint8), connectivity=4)
- if lab_n <= 1:
- return mask255
- H, W = mask255.shape
- edge = set(np.concatenate([lab[0, :], lab[H-1, :], lab[:, 0], lab[:, W-1]]).tolist())
- edge.discard(0)
- out = mask255.copy()
- if edge:
- out[np.isin(lab, list(edge))] = 0
- return out
- def _largest_cc(mask255):
- num, lab, stats, _ = cv2.connectedComponentsWithStats((mask255 > 0).astype(np.uint8), connectivity=4)
- if num <= 1:
- return mask255
- largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA]))
- return ((lab == largest).astype(np.uint8) * 255)
- def _fill_holes(mask255):
- inv = (mask255 == 0).astype(np.uint8)
- num, lab, _, _ = cv2.connectedComponentsWithStats(inv, connectivity=4)
- if num <= 1:
- return mask255
- H, W = mask255.shape
- border_labels = set(np.concatenate([lab[0, :], lab[H-1, :], lab[:, 0], lab[:, W-1]]).tolist())
- out = mask255.copy()
- for k in range(1, num):
- if k not in border_labels:
- out[lab == k] = 255
- return out
- def otsu_filled_robust_mask(
- cube, *, band_index=4, blur_ksize=81, p_lo=1, p_hi=99,
- close_ksize=5, close_iters=1, clear_border=False, auto_invert=True,
- ):
- """The full masking function from Cell 4."""
- band = cube[..., band_index].astype(np.float32)
- k = _odd(blur_ksize) if blur_ksize and blur_ksize > 1 else 0
- bg = cv2.GaussianBlur(band, (k, k), 0) if k else np.median(band)
- diff = band - bg
- diff8 = _robust_u8(diff, p_lo=p_lo, p_hi=p_hi)
- T, _ = cv2.threshold(diff8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
- mask = (diff8 >= T).astype(np.uint8) * 255
- if auto_invert:
- H, W = mask.shape
- b = max(1, min(H, W) // 10)
- border_pixels = np.concatenate([
- mask[:b, :].ravel(), mask[-b:, :].ravel(),
- mask[:, :b].ravel(), mask[:, -b:].ravel()
- ])
- if border_pixels.size > 0 and float(border_pixels.mean()) > 127:
- mask = 255 - mask
- if close_ksize and close_ksize > 1:
- se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (_odd(close_ksize), _odd(close_ksize)))
- for _ in range(max(1, int(close_iters))):
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, se)
- mask = _largest_cc(mask)
- if clear_border:
- mask = _clear_border(mask)
- mask = _fill_holes(mask)
- return mask.astype(np.uint8)
- def make_mask_filled(cube, *, band_index=4, blur_ksize=81, close_ksize=4, **kwargs):
- """The 'SOT mask maker' from Cell 4b, used in your dataset."""
- mask = otsu_filled_robust_mask(
- cube,
- band_index=band_index,
- blur_ksize=blur_ksize,
- close_ksize=close_ksize,
- **kwargs
- )
- return mask
- # ==================== CROPPING HELPERS ====================
- def crop_square_from_mask(cube, mask, pad=8):
- """Finds mask bbox, adds padding, and makes it square. Returns (crop, (y0,y1,x0,x1)))."""
- H, W = mask.shape
- ys, xs = np.where(mask > 0)
- if len(ys) == 0:
- side = int(0.8 * min(H, W))
- y0 = (H - side) // 2
- x0 = (W - side) // 2
- y1, x1 = y0 + side, x0 + side
- else:
- y0, y1 = ys.min(), ys.max()
- x0, x1 = xs.min(), xs.max()
- side = max(y1 - y0 + 1, x1 - x0 + 1) + 2 * pad
- cy, cx = (y0 + y1) // 2, (x0 + x1) // 2
- y0 = max(0, cy - side // 2)
- x0 = max(0, cx - side // 2)
- y1 = min(H, y0 + side)
- x1 = min(W, x0 + side)
- y0 = max(0, y1 - side)
- x0 = max(0, x1 - side)
- return cube[y0:y1, x0:x1, :], (y0, y1, x0, x1)
- def crop_and_resize_square(cube, mask, size=256, pad=8):
- """Crops and resizes to the final (size, size) square."""
- crop, (y0, y1, x0, x1) = crop_square_from_mask(cube, mask, pad=pad)
- crop = cv2.resize(crop, (size, size), interpolation=cv2.INTER_AREA)
- return crop, (y0, y1, x0, x1)
- # ==================== GRAD-CAM ====================
- def gradcam(model, img_tensor, target_class=None, layer_name='layer4'):
- """Computes Grad-CAM heatmap for the model."""
- model.eval()
- activations, gradients = {}, {}
- def f_hook(_, __, out):
- activations['v'] = out
- def b_hook(_, grad_in, grad_out):
- gradients['v'] = grad_out[0]
- try:
- layer = dict([*model.named_modules()])[layer_name]
- except KeyError:
- raise ValueError(f"Layer '{layer_name}' not found in model.")
- fh = layer.register_forward_hook(f_hook)
- bh = layer.register_backward_hook(b_hook)
- # Forward pass
- out = model(img_tensor)
- if target_class is None:
- target_class = int(out.argmax(1).item())
- # Backward pass
- loss = out[0, target_class]
- model.zero_grad()
- loss.backward()
- # Get activations and gradients
- acts = activations['v'][0] # (C, H, W)
- grads = gradients['v'][0] # (C, H, W)
- # Compute weights and CAM
- weights = grads.mean(dim=(1, 2)) # (C,)
- cam = (weights[:, None, None] * acts).sum(0).detach().cpu().numpy()
- cam = np.maximum(cam, 0)
- cam /= (cam.max() + 1e-6)
- fh.remove()
- bh.remove()
- return cam, target_class
- # ==================== MATURITY MODEL CLASS ====================
- class MaturityModel(BaseModel):
- """
- Multispectral maturity classification model.
-
- Uses ResNet18 with 8-channel input for multispectral TIFF image analysis.
- Processes 8-band multispectral images and predicts maturity/ripeness classes.
-
- Attributes:
- model: ResNet18 model instance
- class_names: List of class names
- mean: Mean values for normalization (1, 1, 8)
- std: Std values for normalization (1, 1, 8)
- mask_band_index: Band index used for masking (default: 4, 860nm)
- img_size: Target image size after preprocessing (default: 256)
- img_pad: Padding for cropping (default: 8)
- """
-
- def __init__(
- self,
- model_path: Optional[str] = None,
- device: Optional[str] = None,
- mask_band_index: int = 4,
- img_size: int = 256,
- img_pad: int = 8
- ):
- """
- Initialize the maturity model.
-
- Args:
- model_path: Path to the .pt model file (optional)
- device: Device to use ('cuda' or 'cpu', optional)
- mask_band_index: Band index for masking (default: 4, 860nm)
- img_size: Target image size (default: 256)
- img_pad: Padding for cropping (default: 8)
- """
- from utils.config import MATURITY_MODEL_PATH
-
- if model_path is None:
- model_path = str(MATURITY_MODEL_PATH)
-
- if device is None:
- device = get_device()
-
- super().__init__(model_path, device)
- self.mask_band_index = mask_band_index
- self.img_size = img_size
- self.img_pad = img_pad
- self.class_names = []
- self.mean = None
- self.std = None
-
- def load(self) -> bool:
- """
- Load the maturity model from checkpoint.
-
- Returns:
- bool: True if loaded successfully, False otherwise
- """
- try:
- model_path = Path(self.model_path)
-
- if not model_path.exists():
- logger.error(f"Model file does not exist: {model_path}")
- return False
-
- logger.info(f"Loading maturity model from {model_path}")
-
- # Load checkpoint
- ckpt = torch.load(model_path, map_location=self.device)
-
- # Extract class names, mean, and std
- self.class_names = ckpt['class_names']
- # Reshape mean/std for broadcasting (1, 1, 8)
- self.mean = ckpt['mean'].reshape(1, 1, 8)
- self.std = ckpt['std'].reshape(1, 1, 8)
-
- # Re-create model architecture
- self.model = make_resnet18_8ch(num_classes=len(self.class_names))
- self.model.load_state_dict(ckpt['state_dict'])
- self.model.to(self.device)
- self.model.eval()
-
- self._is_loaded = True
- logger.info(f"Maturity model loaded on {self.device}. Classes: {self.class_names}")
- return True
-
- except Exception as e:
- logger.error(f"Failed to load maturity model: {e}")
- import traceback
- logger.error(traceback.format_exc())
- self._is_loaded = False
- return False
-
- def _preprocess(self, tif_path):
- """
- Runs the full preprocessing pipeline.
-
- 1. Load 8-band cube
- 2. Get mask
- 3. Crop and resize
- 4. Apply pixel mask
- 5. Normalize
- 6. Convert to Tensor
-
- Returns:
- Tuple[torch.Tensor, np.ndarray]: (tensor for model, visual crop for display)
- """
- # 1. Load 8-band cube
- cube = load_and_split_mosaic(str(tif_path), as_float=True, normalize="uint16")
-
- # 2. Get mask (using the exact params from training)
- mask_full = make_mask_filled(
- cube,
- band_index=self.mask_band_index,
- blur_ksize=81,
- close_ksize=4
- )
-
- # 3. Crop and resize
- crop, (y0, y1, x0, x1) = crop_and_resize_square(
- cube,
- mask_full,
- size=self.img_size,
- pad=self.img_pad
- )
-
- # 4. Apply pixel mask (critical step from training)
- mask_crop = mask_full[y0:y1, x0:x1]
- mask_resz = cv2.resize(mask_crop, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
- # Create a {0, 1} mask and broadcast it
- m = (mask_resz > 0).astype(crop.dtype)
- # Zero out the background pixels in all 8 channels
- crop_masked = crop * m[..., None]
-
- # 5. Normalize
- norm_crop = (crop_masked - self.mean) / self.std
-
- # 6. To Tensor
- # (H, W, C) -> (C, H, W) -> (B, C, H, W)
- x = torch.from_numpy(norm_crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
-
- # Return tensor for model and the visible (un-normalized, masked) crop for visualization
- return x, crop_masked
-
- def predict(self, tif_path: str) -> Dict[str, Any]:
- """
- Run prediction on a multispectral TIFF file.
-
- Args:
- tif_path: Path to .tif file (8-band multispectral)
-
- Returns:
- Dict containing:
- - 'success': Whether prediction succeeded
- - 'prediction': Predicted class name
- - 'confidence': Confidence score (0-1)
- - 'probabilities': Dictionary of class probabilities
- - 'error': Error message if failed
-
- Raises:
- RuntimeError: If model is not loaded
- """
- if not self._is_loaded:
- raise RuntimeError("Model not loaded. Call load() first.")
-
- try:
- # Preprocess
- img_tensor, _ = self._preprocess(tif_path)
-
- # Run inference
- with torch.no_grad():
- logits = self.model(img_tensor)
- probs = torch.softmax(logits, dim=1)[0]
- pred_idx = logits.argmax(1).item()
-
- probs_cpu = probs.cpu().numpy()
-
- return {
- 'success': True,
- 'prediction': self.class_names[pred_idx],
- 'confidence': float(probs_cpu[pred_idx]),
- 'probabilities': {name: float(prob) for name, prob in zip(self.class_names, probs_cpu)},
- 'error': None
- }
-
- except Exception as e:
- logger.error(f"Prediction failed: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return {
- 'success': False,
- 'prediction': None,
- 'confidence': 0.0,
- 'probabilities': {},
- 'error': str(e)
- }
-
- def run_gradcam(self, tif_path: str, band_to_show: Optional[int] = None) -> Tuple[Optional[QImage], Optional[str]]:
- """
- Run Grad-CAM visualization on a .tif file.
-
- Args:
- tif_path: Path to input .tif file
- band_to_show: Which of the 8 bands to use as background (default: mask_band_index)
-
- Returns:
- Tuple[QImage, str]: (overlay_image, predicted_class_name) or (None, None) if failed
- """
- if not self._is_loaded:
- raise RuntimeError("Model not loaded. Call load() first.")
-
- if band_to_show is None:
- band_to_show = self.mask_band_index
-
- try:
- img_tensor, crop_img = self._preprocess(tif_path)
-
- # Run Grad-CAM
- heatmap, pred_idx = gradcam(
- self.model,
- img_tensor,
- target_class=None,
- layer_name='layer4'
- )
-
- pred_name = self.class_names[pred_idx]
-
- # Create visualization
- # Get the specific band to show (it's already float [0,1] and masked)
- band_img = crop_img[..., band_to_show]
-
- # Normalize band to [0,255] uint8 for display
- band_u8 = (np.clip(band_img, 0, 1) * 255).astype(np.uint8)
- vis_img = cv2.cvtColor(band_u8, cv2.COLOR_GRAY2BGR)
-
- # Resize heatmap and apply colormap
- hm_resized = (cv2.resize(heatmap, (self.img_size, self.img_size)) * 255).astype(np.uint8)
- hm_color = cv2.applyColorMap(hm_resized, cv2.COLORMAP_JET)
-
- # Create overlay
- overlay = cv2.addWeighted(vis_img, 0.6, hm_color, 0.4, 0)
-
- # Apply overlay only to fruit pixels
- mask = (band_u8 > 0)
- final_vis = np.zeros_like(vis_img)
- final_vis[mask] = overlay[mask]
-
- # Convert to QImage
- rgb_image = cv2.cvtColor(final_vis, cv2.COLOR_BGR2RGB)
- rgb_image = np.ascontiguousarray(rgb_image) # Ensure contiguous memory
- h, w, ch = rgb_image.shape
- bytes_per_line = ch * w
-
- # Create QImage with copied data to prevent memory issues
- q_image = QImage(rgb_image.tobytes(), w, h, bytes_per_line, QImage.Format_RGB888)
- q_image = q_image.copy() # Make a deep copy to own the memory
-
- return q_image, pred_name
-
- except Exception as e:
- logger.error(f"Grad-CAM failed: {e}")
- import traceback
- logger.error(traceback.format_exc())
- return None, None
-
- def predict_batch(self, tif_paths: list) -> list:
- """
- Predict maturity for multiple TIFF files.
-
- Args:
- tif_paths: List of paths to .tif files
-
- Returns:
- List[Dict]: List of prediction results
- """
- results = []
- for tif_path in tif_paths:
- result = self.predict(tif_path)
- results.append(result)
- return results
|