shape_worker.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """
  2. Shape Worker Module
  3. Worker thread for async shape classification.
  4. """
  5. from typing import Optional
  6. import logging
  7. from PyQt5.QtCore import pyqtSignal
  8. from PyQt5.QtGui import QImage
  9. from workers.base_worker import BaseWorker, WorkerSignals
  10. from models.shape_model import ShapeModel
  11. from utils.config import get_device
  12. logger = logging.getLogger(__name__)
  13. class ShapeWorkerSignals(WorkerSignals):
  14. """
  15. Signals specific to shape classification.
  16. Signals:
  17. result_ready: Emitted when classification is complete
  18. (annotated_image: QImage, shape_class: str, class_id: int, confidence: float)
  19. """
  20. result_ready = pyqtSignal(QImage, str, int, float)
  21. class ShapeWorker(BaseWorker):
  22. """
  23. Worker for processing images and classifying durian shape.
  24. Runs ShapeModel inference in a background thread without blocking UI.
  25. Attributes:
  26. image_path: Path to image file
  27. model: ShapeModel instance
  28. signals: ShapeWorkerSignals for emitting results
  29. """
  30. def __init__(self, image_path: str, model: Optional[ShapeModel] = None):
  31. """
  32. Initialize the shape worker.
  33. Args:
  34. image_path: Path to image file to process
  35. model: ShapeModel instance (if None, creates new one)
  36. """
  37. super().__init__()
  38. self.image_path = image_path
  39. self.model = model
  40. # Replace base signals with shape-specific signals
  41. self.signals = ShapeWorkerSignals()
  42. # If no model provided, create and load one
  43. if self.model is None:
  44. device = get_device()
  45. # Shape model must be provided with explicit path
  46. # This will fail if model doesn't exist - that's expected
  47. from utils.config import PROJECT_ROOT
  48. shape_model_path = PROJECT_ROOT / "model_files" / "shape.pt"
  49. self.model = ShapeModel(str(shape_model_path), device=device)
  50. self.model.load()
  51. logger.info(f"ShapeWorker created for: {image_path}")
  52. def process(self):
  53. """
  54. Process the image and classify shape.
  55. Emits result_ready signal with classification results.
  56. """
  57. if self.is_cancelled():
  58. logger.info("ShapeWorker cancelled before processing")
  59. return
  60. # Update progress
  61. self.emit_progress(10, "Loading image...")
  62. if not self.model.is_loaded:
  63. logger.warning("Model not loaded, loading now...")
  64. self.emit_progress(30, "Loading model...")
  65. if not self.model.load():
  66. raise RuntimeError("Failed to load shape model")
  67. # Process image
  68. self.emit_progress(50, "Classifying shape...")
  69. result = self.model.predict(self.image_path)
  70. if self.is_cancelled():
  71. logger.info("ShapeWorker cancelled during processing")
  72. return
  73. if not result['success']:
  74. raise RuntimeError(result['error'])
  75. self.emit_progress(90, "Finalizing results...")
  76. # Make a copy of the QImage to ensure thread safety
  77. annotated_image_copy = result['annotated_image'].copy() if result['annotated_image'] else None
  78. # Emit results with annotated image if available
  79. self.signals.result_ready.emit(
  80. annotated_image_copy,
  81. result['shape_class'],
  82. result['class_id'],
  83. result['confidence']
  84. )
  85. self.emit_progress(100, "Complete!")
  86. logger.info(
  87. f"ShapeWorker completed: {result['shape_class']} "
  88. f"(confidence: {result['confidence']:.3f})"
  89. )