""" Shape Worker Module Worker thread for async shape classification. """ from typing import Optional import logging from PyQt5.QtCore import pyqtSignal from PyQt5.QtGui import QImage from workers.base_worker import BaseWorker, WorkerSignals from models.shape_model import ShapeModel from utils.config import get_device logger = logging.getLogger(__name__) class ShapeWorkerSignals(WorkerSignals): """ Signals specific to shape classification. Signals: result_ready: Emitted when classification is complete (annotated_image: QImage, shape_class: str, class_id: int, confidence: float) """ result_ready = pyqtSignal(QImage, str, int, float) class ShapeWorker(BaseWorker): """ Worker for processing images and classifying durian shape. Runs ShapeModel inference in a background thread without blocking UI. Attributes: image_path: Path to image file model: ShapeModel instance signals: ShapeWorkerSignals for emitting results """ def __init__(self, image_path: str, model: Optional[ShapeModel] = None): """ Initialize the shape worker. Args: image_path: Path to image file to process model: ShapeModel instance (if None, creates new one) """ super().__init__() self.image_path = image_path self.model = model # Replace base signals with shape-specific signals self.signals = ShapeWorkerSignals() # If no model provided, create and load one if self.model is None: device = get_device() # Shape model must be provided with explicit path # This will fail if model doesn't exist - that's expected from utils.config import PROJECT_ROOT shape_model_path = PROJECT_ROOT / "model_files" / "shape.pt" self.model = ShapeModel(str(shape_model_path), device=device) self.model.load() logger.info(f"ShapeWorker created for: {image_path}") def process(self): """ Process the image and classify shape. Emits result_ready signal with classification results. """ if self.is_cancelled(): logger.info("ShapeWorker cancelled before processing") return # Update progress self.emit_progress(10, "Loading image...") if not self.model.is_loaded: logger.warning("Model not loaded, loading now...") self.emit_progress(30, "Loading model...") if not self.model.load(): raise RuntimeError("Failed to load shape model") # Process image self.emit_progress(50, "Classifying shape...") result = self.model.predict(self.image_path) if self.is_cancelled(): logger.info("ShapeWorker cancelled during processing") return if not result['success']: raise RuntimeError(result['error']) self.emit_progress(90, "Finalizing results...") # Make a copy of the QImage to ensure thread safety annotated_image_copy = result['annotated_image'].copy() if result['annotated_image'] else None # Emit results with annotated image if available self.signals.result_ready.emit( annotated_image_copy, result['shape_class'], result['class_id'], result['confidence'] ) self.emit_progress(100, "Complete!") logger.info( f"ShapeWorker completed: {result['shape_class']} " f"(confidence: {result['confidence']:.3f})" )