| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- """
- Base Model Module
- Abstract base class for all AI models in the DuDONG system.
- Provides common interface and functionality for model loading and prediction.
- """
- from abc import ABC, abstractmethod
- from typing import Any, Dict, Optional
- import logging
- # Setup logging
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- class BaseModel(ABC):
- """
- Abstract base class for AI models.
-
- All model wrappers should inherit from this class and implement
- the required abstract methods.
-
- Attributes:
- model_path (str): Path to the model file/directory
- device (str): Device to run model on ('cuda' or 'cpu')
- model (Any): The loaded model object
- _is_loaded (bool): Flag indicating if model is loaded
- """
-
- def __init__(self, model_path: str, device: str = "cpu"):
- """
- Initialize the base model.
-
- Args:
- model_path: Path to the model file or directory
- device: Device to use for inference ('cuda' or 'cpu')
- """
- self.model_path = model_path
- self.device = device
- self.model: Optional[Any] = None
- self._is_loaded = False
-
- logger.info(f"Initializing {self.__class__.__name__} with device: {device}")
-
- @abstractmethod
- def load(self) -> bool:
- """
- Load the model from disk.
-
- This method must be implemented by all subclasses.
-
- Returns:
- bool: True if model loaded successfully, False otherwise
- """
- pass
-
- @abstractmethod
- def predict(self, input_data: Any) -> Dict[str, Any]:
- """
- Run inference on the input data.
-
- This method must be implemented by all subclasses.
-
- Args:
- input_data: Input data for prediction (format varies by model)
-
- Returns:
- Dict[str, Any]: Dictionary containing prediction results
- """
- pass
-
- @property
- def is_loaded(self) -> bool:
- """
- Check if the model is loaded (property accessor).
-
- Returns:
- bool: True if model is loaded, False otherwise
- """
- return self._is_loaded
-
- def unload(self) -> None:
- """
- Unload the model from memory.
-
- This can be overridden by subclasses if special cleanup is needed.
- """
- if self.model is not None:
- del self.model
- self.model = None
- self._is_loaded = False
- logger.info(f"{self.__class__.__name__} unloaded")
-
- def get_device(self) -> str:
- """
- Get the device the model is running on.
-
- Returns:
- str: Device name ('cuda' or 'cpu')
- """
- return self.device
-
- def __repr__(self) -> str:
- """
- String representation of the model.
-
- Returns:
- str: Model information
- """
- status = "loaded" if self._is_loaded else "not loaded"
- return f"{self.__class__.__name__}(device={self.device}, status={status})"
|