""" Audio Model Module Audio classification model for durian ripeness detection using knock detection and mel-spectrogram features. Detects knocks in audio using librosa, extracts mel-spectrograms, and averages predictions across knocks. """ import os import tempfile import pickle import json from pathlib import Path from typing import Dict, Any, Tuple, Optional, List import logging import numpy as np import librosa import librosa.display import matplotlib matplotlib.use('agg') # Use non-interactive backend import matplotlib.pyplot as plt from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from PyQt5.QtGui import QImage, QPixmap from models.base_model import BaseModel from utils.config import ( AUDIO_MODEL_PATH, SPECTROGRAM_FIG_SIZE, RIPENESS_CLASSES, ) # Import TensorFlow try: import tensorflow as tf except ImportError: tf = None logger = logging.getLogger(__name__) class AudioModel(BaseModel): """ Audio-based ripeness classification model. Detects knocks in durian audio using onset detection, extracts mel-spectrogram features, and averages predictions across all detected knocks for robust ripeness classification. Attributes: model: Keras model for mel-spectrogram classification label_encoder: Scikit-learn label encoder for class names preprocessing_stats: JSON statistics (max_length, normalization params) class_names: List of class names (unripe, ripe, overripe) """ # Mel-spectrogram parameters (must match training) MEL_PARAMS = { 'n_mels': 64, 'hop_length': 512, 'n_fft': 2048, 'sr': 22050 } # Knock detection parameters KNOCK_DETECTION = { 'delta': 0.3, # Onset detection delta 'wait': 10, # Onset detection wait frames 'onset_shift': 0.05, # Shift onsets back by 50ms 'knock_duration': 0.2 # Extract 200ms per knock } def __init__(self, model_path: Optional[str] = None, device: str = "cpu"): """ Initialize the audio model. Args: model_path: Path to model directory (optional) device: Device to use (cpu/gpu - not used for TensorFlow) """ if model_path is None: # AUDIO_MODEL_PATH points to models/audio/ which contains our files model_path = str(AUDIO_MODEL_PATH) super().__init__(model_path, device) self.class_names = RIPENESS_CLASSES self.model = None self.label_encoder = None self.preprocessing_stats = None logger.info(f"AudioModel initialized with model_path: {model_path}") def load(self) -> bool: """ Load the model, label encoder, and preprocessing statistics. Returns: bool: True if loaded successfully, False otherwise """ try: base_dir = Path(self.model_path) # Try two possible paths: direct path or voice_memos_ripeness subdirectory possible_dirs = [ base_dir, # Files directly in model_path base_dir / "voice_memos_ripeness" # Files in voice_memos_ripeness subdir ] model_dir = None for possible_dir in possible_dirs: if possible_dir.exists(): # Check if this directory has the required files model_file = possible_dir / "best_model_mel_spec_grouped.keras" if model_file.exists(): model_dir = possible_dir break if model_dir is None: logger.error(f"Could not find model files in: {base_dir} or {base_dir / 'voice_memos_ripeness'}") return False logger.info(f"Loading audio model from {model_dir}") # Load Keras model model_path = model_dir / "best_model_mel_spec_grouped.keras" if not model_path.exists(): logger.error(f"Model file not found: {model_path}") return False logger.info(f"Loading TensorFlow model from {model_path}...") self.model = tf.keras.models.load_model(str(model_path)) logger.info(f"✓ TensorFlow model loaded successfully") # Load label encoder encoder_path = model_dir / "label_encoder.pkl" if not encoder_path.exists(): logger.error(f"Label encoder not found: {encoder_path}") return False logger.info(f"Loading label encoder from {encoder_path}...") with open(encoder_path, 'rb') as f: self.label_encoder = pickle.load(f) logger.info(f"✓ Label encoder loaded with classes: {list(self.label_encoder.classes_)}") # Load preprocessing stats stats_path = model_dir / "preprocessing_stats.json" if not stats_path.exists(): logger.error(f"Preprocessing stats not found: {stats_path}") return False logger.info(f"Loading preprocessing stats from {stats_path}...") with open(stats_path, 'r') as f: self.preprocessing_stats = json.load(f) logger.info(f"✓ Preprocessing stats loaded, max_length: {self.preprocessing_stats.get('max_length')}") self._is_loaded = True logger.info("✓ Audio model loaded successfully") return True except Exception as e: logger.error(f"Failed to load audio model: {e}", exc_info=True) self._is_loaded = False return False def predict(self, audio_path: str) -> Dict[str, Any]: """ Predict ripeness from an audio file using knock detection and mel-spectrogram analysis. Args: audio_path: Path to audio file (supports WAV and other formats via librosa) Returns: Dict containing: - 'class_name': Predicted class name (Ripe/Unripe/Overripe) - 'class_index': Predicted class index - 'probabilities': Dictionary of class probabilities (0-1 range) - 'confidence': Confidence score (0-1 range, averaged across knocks) - 'spectrogram_image': QPixmap of mel-spectrogram with knocks marked - 'waveform_image': QPixmap of waveform with knocks marked - 'knock_count': Number of knocks detected - 'knock_times': List of knock onset times in seconds - 'success': Whether prediction succeeded - 'error': Error message if failed """ if not self._is_loaded or self.model is None: raise RuntimeError("Model not loaded. Call load() first.") try: # Ensure audio is in WAV format wav_path = self._ensure_wav_format(audio_path) # Load audio logger.info(f"Loading audio from {audio_path}") y, sr = librosa.load(wav_path, sr=self.MEL_PARAMS['sr'], mono=True) # Trim silence from beginning and end cut_samples = int(0.5 * sr) if len(y) > 2 * cut_samples: y = y[cut_samples:-cut_samples] elif len(y) > cut_samples: y = y[cut_samples:] # Detect knocks and extract features logger.info("Detecting knocks in audio...") features, knock_times = self._extract_knock_features(y, sr) logger.info(f"DEBUG: Detected {len(features)} knocks") if len(features) == 0: logger.error("❌ No knocks detected in audio file - returning error") return { 'success': False, 'class_name': None, 'class_index': None, 'probabilities': {}, 'confidence': 0.0, 'spectrogram_image': None, 'waveform_image': None, 'knock_count': 0, 'knock_times': [], 'error': 'No knocks detected in audio file' } logger.info(f"Detected {len(features)} knocks at times: {knock_times}") # Prepare features for model max_length = self.preprocessing_stats['max_length'] X = np.array([ np.pad(f, ((0, max_length - f.shape[0]), (0, 0)), mode='constant') for f in features ]) if len(X.shape) == 3: X = np.expand_dims(X, -1) # Run inference logger.info(f"Running model inference on {len(features)} knocks...") probs = self.model.predict(X, verbose=0) # Get per-knock predictions per_knock_preds = [] for i, knock_probs in enumerate(probs): knock_pred_idx = np.argmax(knock_probs) knock_pred_class = self.label_encoder.classes_[knock_pred_idx] knock_confidence = float(knock_probs[knock_pred_idx]) per_knock_preds.append({ 'class': knock_pred_class, 'class_idx': knock_pred_idx, 'confidence': knock_confidence, 'probabilities': {self.label_encoder.classes_[j]: float(knock_probs[j]) for j in range(len(self.label_encoder.classes_))} }) logger.info(f"Per-knock predictions: {[p['class'] for p in per_knock_preds]}") # Average predictions across all knocks (CONFIDENCE LOGIC) avg_probs = np.mean(probs, axis=0) predicted_idx = np.argmax(avg_probs) predicted_class = self.label_encoder.classes_[predicted_idx] confidence = float(avg_probs[predicted_idx]) logger.info(f"Average probabilities: {dict(zip(self.label_encoder.classes_, avg_probs))}") logger.info(f"Final prediction: {predicted_class} ({confidence:.2%})") # Create probability dictionary prob_dict = { self.label_encoder.classes_[i]: float(avg_probs[i]) for i in range(len(self.label_encoder.classes_)) } # Capitalize class name for display predicted_class_display = predicted_class.capitalize() if isinstance(predicted_class, str) else predicted_class # Generate visualizations with knock annotations spectrogram_image = self._generate_mel_spectrogram_with_knocks(y, sr, knock_times) waveform_image = self._generate_waveform_with_knocks(y, sr, knock_times) logger.info(f"Prediction: {predicted_class_display} ({confidence:.2%}) from {len(features)} knocks") return { 'success': True, 'class_name': predicted_class_display, 'class_index': predicted_idx, 'probabilities': prob_dict, 'confidence': confidence, 'spectrogram_image': spectrogram_image, 'waveform_image': waveform_image, 'knock_count': len(features), 'knock_times': knock_times, 'per_knock_predictions': per_knock_preds, 'error': None } except Exception as e: error_msg = str(e) logger.error(f"Prediction failed: {error_msg}", exc_info=True) # Provide helpful error message for audio format issues if 'convert' in error_msg.lower() or 'format' in error_msg.lower(): error_msg += " - Please ensure ffmpeg is installed: conda install -c conda-forge ffmpeg" return { 'success': False, 'class_name': None, 'class_index': None, 'probabilities': {}, 'confidence': 0.0, 'spectrogram_image': None, 'waveform_image': None, 'knock_count': 0, 'knock_times': [], 'per_knock_predictions': [], 'error': error_msg } def _ensure_wav_format(self, audio_path: str) -> str: """ Ensure audio file is in WAV format, converting if necessary. Supports M4A, MP3, OGG, FLAC, WMA and other formats. Args: audio_path: Path to audio file Returns: Path to WAV file (original or converted) """ ext = os.path.splitext(audio_path)[1].lower() if ext == '.wav': return audio_path logger.info(f"Converting {ext} to WAV format...") # Try ffmpeg first (most reliable for various formats) try: import subprocess tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') tmp_path = tmp_file.name tmp_file.close() # Use ffmpeg to convert cmd = [ 'ffmpeg', '-i', audio_path, '-acodec', 'pcm_s16le', '-ar', str(self.MEL_PARAMS['sr']), '-ac', '1', # mono '-y', # overwrite tmp_path ] logger.info(f"Using ffmpeg to convert {ext}") subprocess.run(cmd, capture_output=True, check=True, timeout=30) logger.info(f"Converted to temporary WAV: {tmp_path}") return tmp_path except Exception as e: logger.warning(f"ffmpeg conversion failed: {e}, trying pydub...") # Try pydub second (handles most formats if installed) try: from pydub import AudioSegment logger.info(f"Using pydub to convert {ext}") audio = AudioSegment.from_file(audio_path) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') audio.export(tmp_file.name, format='wav') logger.info(f"Converted to temporary WAV: {tmp_file.name}") return tmp_file.name except Exception as e: logger.warning(f"pydub conversion failed: {e}, trying librosa...") # Try librosa as final fallback try: import soundfile as sf except ImportError: logger.warning("soundfile not available, using scipy for conversion") sf = None try: logger.info("Using librosa to load and convert audio") # Load with librosa (requires ffmpeg backend for non-WAV) y, sr = librosa.load(audio_path, sr=self.MEL_PARAMS['sr'], mono=True) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') if sf is not None: sf.write(tmp_file.name, y, sr) else: # Fallback: use scipy import scipy.io.wavfile as wavfile # Normalize to 16-bit range y_int16 = np.clip(y * 32767, -32768, 32767).astype(np.int16) wavfile.write(tmp_file.name, sr, y_int16) logger.info(f"Converted to temporary WAV: {tmp_file.name}") return tmp_file.name except Exception as e: logger.error(f"Audio conversion failed with all methods: {e}") logger.error(f"Please ensure ffmpeg is installed or install pydub: pip install pydub") raise RuntimeError( f"Failed to convert {ext} audio file. " "Install ffmpeg or pydub: 'pip install pydub' and 'pip install ffmpeg'" ) from e def _extract_knock_features(self, audio: np.ndarray, sr: int) -> Tuple[List[np.ndarray], List[float]]: """ Detect knocks in audio and extract mel-spectrogram features. Args: audio: Audio time series sr: Sample rate Returns: Tuple of: - List of mel-spectrogram arrays (one per knock) - List of knock onset times in seconds """ # Detect onsets (knock starts) logger.info("Detecting onset times...") onset_frames = librosa.onset.onset_detect( y=audio, sr=sr, delta=self.KNOCK_DETECTION['delta'], wait=self.KNOCK_DETECTION['wait'], units='frames' ) onset_times = librosa.frames_to_time(onset_frames, sr=sr) if len(onset_times) == 0: logger.warning("No onsets detected") return [], [] # Shift onsets slightly back shifted_times = [max(0, t - self.KNOCK_DETECTION['onset_shift']) for t in onset_times] logger.info(f"Detected {len(shifted_times)} onset times") # Extract knock segments knock_samples = int(round(self.KNOCK_DETECTION['knock_duration'] * sr)) knocks = [] valid_times = [] for onset in shifted_times: start = int(round(onset * sr)) end = start + knock_samples if start >= len(audio): continue if end <= len(audio): knock = audio[start:end] else: # Pad with zeros if at end knock = np.zeros(knock_samples, dtype=audio.dtype) available = len(audio) - start if available > 0: knock[:available] = audio[start:] else: continue knocks.append(knock) valid_times.append(onset) # Extract mel-spectrograms logger.info(f"Extracting mel-spectrograms from {len(knocks)} knocks...") features = [] for knock in knocks: mel_spec = self._extract_mel_spectrogram(knock, sr) features.append(mel_spec) return features, valid_times def _extract_mel_spectrogram(self, audio: np.ndarray, sr: int) -> np.ndarray: """ Extract normalized mel-spectrogram from audio. Args: audio: Audio segment sr: Sample rate Returns: Normalized mel-spectrogram (time, n_mels) """ # Compute mel-spectrogram S = librosa.feature.melspectrogram( y=audio, sr=sr, n_mels=self.MEL_PARAMS['n_mels'], hop_length=self.MEL_PARAMS['hop_length'], n_fft=self.MEL_PARAMS['n_fft'] ) # Convert to dB scale S_db = librosa.power_to_db(S, ref=np.max) # Normalize std = np.std(S_db) if std != 0: S_db = (S_db - np.mean(S_db)) / std else: S_db = S_db - np.mean(S_db) return S_db.T # (time, n_mels) def predict_batch(self, audio_paths: list) -> list: """ Predict ripeness for multiple audio files. Args: audio_paths: List of paths to audio files Returns: List[Dict]: List of prediction results """ results = [] for audio_path in audio_paths: result = self.predict(audio_path) results.append(result) return results def _generate_waveform_with_knocks(self, audio: np.ndarray, sr: int, knock_times: List[float]) -> Optional[QPixmap]: """ Generate waveform visualization with knock locations marked. Args: audio: Audio time series sr: Sample rate knock_times: List of knock onset times in seconds Returns: QPixmap: Waveform plot with knock markers """ try: fig, ax = plt.subplots(figsize=SPECTROGRAM_FIG_SIZE) # Plot waveform librosa.display.waveshow(audio, sr=sr, alpha=0.6, ax=ax) # Mark knock locations knock_duration = self.KNOCK_DETECTION['knock_duration'] for knock_time in knock_times: # Vertical line at onset ax.axvline(knock_time, color='red', linestyle='--', alpha=0.8, linewidth=1.5) # Span showing knock duration ax.axvspan(knock_time, knock_time + knock_duration, color='orange', alpha=0.2) ax.set_title(f'Waveform with {len(knock_times)} Detected Knocks') ax.set_xlabel('Time (s)') ax.set_ylabel('Amplitude') ax.grid(True, alpha=0.3) # Convert to QPixmap canvas = FigureCanvas(fig) canvas.draw() width_px, height_px = fig.get_size_inches() * fig.get_dpi() width_px, height_px = int(width_px), int(height_px) img = QImage(canvas.buffer_rgba(), width_px, height_px, QImage.Format_ARGB32) img = img.rgbSwapped() pixmap = QPixmap(img) plt.close(fig) return pixmap except Exception as e: logger.error(f"Failed to generate waveform: {e}") return None def _generate_mel_spectrogram_with_knocks(self, audio: np.ndarray, sr: int, knock_times: List[float]) -> Optional[QPixmap]: """ Generate mel-spectrogram visualization with knock locations marked. Args: audio: Audio time series sr: Sample rate knock_times: List of knock onset times in seconds Returns: QPixmap: Mel-spectrogram plot with knock markers """ try: # Compute mel-spectrogram S = librosa.feature.melspectrogram( y=audio, sr=sr, n_mels=self.MEL_PARAMS['n_mels'], hop_length=self.MEL_PARAMS['hop_length'], n_fft=self.MEL_PARAMS['n_fft'] ) # Convert to dB scale S_db = librosa.power_to_db(S, ref=np.max) # Create figure with tight layout fig = plt.figure(figsize=SPECTROGRAM_FIG_SIZE) ax = fig.add_subplot(111) # Display mel-spectrogram img = librosa.display.specshow( S_db, x_axis='time', y_axis='mel', sr=sr, hop_length=self.MEL_PARAMS['hop_length'], cmap='magma', ax=ax ) # Mark knock locations knock_duration = self.KNOCK_DETECTION['knock_duration'] for knock_time in knock_times: # Vertical line at onset ax.axvline(knock_time, color='cyan', linestyle='--', alpha=0.8, linewidth=1.5) # Span showing knock duration ax.axvspan(knock_time, knock_time + knock_duration, color='cyan', alpha=0.15) ax.set_title(f'Mel Spectrogram with {len(knock_times)} Detected Knocks (64 Coefficients)') ax.set_xlabel('Time (s)') ax.set_ylabel('Mel Frequency') # Add colorbar properly plt.colorbar(img, ax=ax, format='%+2.0f dB', label='Power (dB)') plt.tight_layout() # Convert to QPixmap canvas = FigureCanvas(fig) canvas.draw() width_px, height_px = fig.get_size_inches() * fig.get_dpi() width_px, height_px = int(width_px), int(height_px) img_qimage = QImage(canvas.buffer_rgba(), width_px, height_px, QImage.Format_ARGB32) img_qimage = img_qimage.rgbSwapped() pixmap = QPixmap(img_qimage) plt.close(fig) return pixmap except Exception as e: logger.error(f"Failed to generate mel-spectrogram: {e}", exc_info=True) return None def _generate_spectrogram_image(self, audio: np.ndarray, sr: int) -> Optional[QPixmap]: """ Generate a mel-spectrogram visualization from audio. Args: audio: Audio time series sr: Sample rate Returns: QPixmap: Rendered mel-spectrogram image or None if failed """ try: # Compute mel-spectrogram S = librosa.feature.melspectrogram( y=audio, sr=sr, n_mels=self.MEL_PARAMS['n_mels'], hop_length=self.MEL_PARAMS['hop_length'], n_fft=self.MEL_PARAMS['n_fft'] ) # Convert to dB scale S_db = librosa.power_to_db(S, ref=np.max) # Create figure fig, ax = plt.subplots(figsize=SPECTROGRAM_FIG_SIZE) # Display mel-spectrogram librosa.display.specshow( S_db, x_axis='time', y_axis='mel', sr=sr, hop_length=self.MEL_PARAMS['hop_length'], cmap='magma', ax=ax ) ax.set_title('Mel Spectrogram (64 coefficients)') ax.set_xlabel('Time (s)') ax.set_ylabel('Mel Frequency') # Convert to QPixmap canvas = FigureCanvas(fig) canvas.draw() width_px, height_px = fig.get_size_inches() * fig.get_dpi() width_px, height_px = int(width_px), int(height_px) img = QImage(canvas.buffer_rgba(), width_px, height_px, QImage.Format_ARGB32) img = img.rgbSwapped() pixmap = QPixmap(img) plt.close(fig) return pixmap except Exception as e: logger.error(f"Failed to generate spectrogram image: {e}") return None