| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- """
- Shape Worker Module
- Worker thread for async shape classification.
- """
- from typing import Optional
- import logging
- from PyQt5.QtCore import pyqtSignal
- from PyQt5.QtGui import QImage
- from workers.base_worker import BaseWorker, WorkerSignals
- from models.shape_model import ShapeModel
- from utils.config import get_device
- logger = logging.getLogger(__name__)
- class ShapeWorkerSignals(WorkerSignals):
- """
- Signals specific to shape classification.
-
- Signals:
- result_ready: Emitted when classification is complete
- (annotated_image: QImage, shape_class: str, class_id: int, confidence: float)
- """
- result_ready = pyqtSignal(QImage, str, int, float)
- class ShapeWorker(BaseWorker):
- """
- Worker for processing images and classifying durian shape.
-
- Runs ShapeModel inference in a background thread without blocking UI.
-
- Attributes:
- image_path: Path to image file
- model: ShapeModel instance
- signals: ShapeWorkerSignals for emitting results
- """
-
- def __init__(self, image_path: str, model: Optional[ShapeModel] = None):
- """
- Initialize the shape worker.
-
- Args:
- image_path: Path to image file to process
- model: ShapeModel instance (if None, creates new one)
- """
- super().__init__()
- self.image_path = image_path
- self.model = model
-
- # Replace base signals with shape-specific signals
- self.signals = ShapeWorkerSignals()
-
- # If no model provided, create and load one
- if self.model is None:
- device = get_device()
- # Shape model must be provided with explicit path
- # This will fail if model doesn't exist - that's expected
- from utils.config import PROJECT_ROOT
- shape_model_path = PROJECT_ROOT / "model_files" / "shape.pt"
- self.model = ShapeModel(str(shape_model_path), device=device)
- self.model.load()
-
- logger.info(f"ShapeWorker created for: {image_path}")
-
- def process(self):
- """
- Process the image and classify shape.
-
- Emits result_ready signal with classification results.
- """
- if self.is_cancelled():
- logger.info("ShapeWorker cancelled before processing")
- return
-
- # Update progress
- self.emit_progress(10, "Loading image...")
-
- if not self.model.is_loaded:
- logger.warning("Model not loaded, loading now...")
- self.emit_progress(30, "Loading model...")
- if not self.model.load():
- raise RuntimeError("Failed to load shape model")
-
- # Process image
- self.emit_progress(50, "Classifying shape...")
-
- result = self.model.predict(self.image_path)
-
- if self.is_cancelled():
- logger.info("ShapeWorker cancelled during processing")
- return
-
- if not result['success']:
- raise RuntimeError(result['error'])
-
- self.emit_progress(90, "Finalizing results...")
-
- # Make a copy of the QImage to ensure thread safety
- annotated_image_copy = result['annotated_image'].copy() if result['annotated_image'] else None
-
- # Emit results with annotated image if available
- self.signals.result_ready.emit(
- annotated_image_copy,
- result['shape_class'],
- result['class_id'],
- result['confidence']
- )
-
- self.emit_progress(100, "Complete!")
-
- logger.info(
- f"ShapeWorker completed: {result['shape_class']} "
- f"(confidence: {result['confidence']:.3f})"
- )
|