base_model.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. """
  2. Base Model Module
  3. Abstract base class for all AI models in the DuDONG system.
  4. Provides common interface and functionality for model loading and prediction.
  5. """
  6. from abc import ABC, abstractmethod
  7. from typing import Any, Dict, Optional
  8. import logging
  9. # Setup logging
  10. logging.basicConfig(level=logging.INFO)
  11. logger = logging.getLogger(__name__)
  12. class BaseModel(ABC):
  13. """
  14. Abstract base class for AI models.
  15. All model wrappers should inherit from this class and implement
  16. the required abstract methods.
  17. Attributes:
  18. model_path (str): Path to the model file/directory
  19. device (str): Device to run model on ('cuda' or 'cpu')
  20. model (Any): The loaded model object
  21. _is_loaded (bool): Flag indicating if model is loaded
  22. """
  23. def __init__(self, model_path: str, device: str = "cpu"):
  24. """
  25. Initialize the base model.
  26. Args:
  27. model_path: Path to the model file or directory
  28. device: Device to use for inference ('cuda' or 'cpu')
  29. """
  30. self.model_path = model_path
  31. self.device = device
  32. self.model: Optional[Any] = None
  33. self._is_loaded = False
  34. logger.info(f"Initializing {self.__class__.__name__} with device: {device}")
  35. @abstractmethod
  36. def load(self) -> bool:
  37. """
  38. Load the model from disk.
  39. This method must be implemented by all subclasses.
  40. Returns:
  41. bool: True if model loaded successfully, False otherwise
  42. """
  43. pass
  44. @abstractmethod
  45. def predict(self, input_data: Any) -> Dict[str, Any]:
  46. """
  47. Run inference on the input data.
  48. This method must be implemented by all subclasses.
  49. Args:
  50. input_data: Input data for prediction (format varies by model)
  51. Returns:
  52. Dict[str, Any]: Dictionary containing prediction results
  53. """
  54. pass
  55. @property
  56. def is_loaded(self) -> bool:
  57. """
  58. Check if the model is loaded (property accessor).
  59. Returns:
  60. bool: True if model is loaded, False otherwise
  61. """
  62. return self._is_loaded
  63. def unload(self) -> None:
  64. """
  65. Unload the model from memory.
  66. This can be overridden by subclasses if special cleanup is needed.
  67. """
  68. if self.model is not None:
  69. del self.model
  70. self.model = None
  71. self._is_loaded = False
  72. logger.info(f"{self.__class__.__name__} unloaded")
  73. def get_device(self) -> str:
  74. """
  75. Get the device the model is running on.
  76. Returns:
  77. str: Device name ('cuda' or 'cpu')
  78. """
  79. return self.device
  80. def __repr__(self) -> str:
  81. """
  82. String representation of the model.
  83. Returns:
  84. str: Model information
  85. """
  86. status = "loaded" if self._is_loaded else "not loaded"
  87. return f"{self.__class__.__name__}(device={self.device}, status={status})"