| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696 |
- """
- Audio Model Module
- Audio classification model for durian ripeness detection using knock detection and mel-spectrogram features.
- Detects knocks in audio using librosa, extracts mel-spectrograms, and averages predictions across knocks.
- """
- import os
- import tempfile
- import pickle
- import json
- from pathlib import Path
- from typing import Dict, Any, Tuple, Optional, List
- import logging
- import numpy as np
- import librosa
- import librosa.display
- import matplotlib
- matplotlib.use('agg') # Use non-interactive backend
- import matplotlib.pyplot as plt
- from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
- from PyQt5.QtGui import QImage, QPixmap
- from models.base_model import BaseModel
- from utils.config import (
- AUDIO_MODEL_PATH,
- SPECTROGRAM_FIG_SIZE,
- RIPENESS_CLASSES,
- )
- # Import TensorFlow
- try:
- import tensorflow as tf
- except ImportError:
- tf = None
- logger = logging.getLogger(__name__)
- class AudioModel(BaseModel):
- """
- Audio-based ripeness classification model.
- Detects knocks in durian audio using onset detection, extracts mel-spectrogram features,
- and averages predictions across all detected knocks for robust ripeness classification.
- Attributes:
- model: Keras model for mel-spectrogram classification
- label_encoder: Scikit-learn label encoder for class names
- preprocessing_stats: JSON statistics (max_length, normalization params)
- class_names: List of class names (unripe, ripe, overripe)
- """
- # Mel-spectrogram parameters (must match training)
- MEL_PARAMS = {
- 'n_mels': 64,
- 'hop_length': 512,
- 'n_fft': 2048,
- 'sr': 22050
- }
-
- # Knock detection parameters
- KNOCK_DETECTION = {
- 'delta': 0.3, # Onset detection delta
- 'wait': 10, # Onset detection wait frames
- 'onset_shift': 0.05, # Shift onsets back by 50ms
- 'knock_duration': 0.2 # Extract 200ms per knock
- }
- def __init__(self, model_path: Optional[str] = None, device: str = "cpu"):
- """
- Initialize the audio model.
- Args:
- model_path: Path to model directory (optional)
- device: Device to use (cpu/gpu - not used for TensorFlow)
- """
- if model_path is None:
- # AUDIO_MODEL_PATH points to models/audio/ which contains our files
- model_path = str(AUDIO_MODEL_PATH)
- super().__init__(model_path, device)
- self.class_names = RIPENESS_CLASSES
- self.model = None
- self.label_encoder = None
- self.preprocessing_stats = None
-
- logger.info(f"AudioModel initialized with model_path: {model_path}")
-
- def load(self) -> bool:
- """
- Load the model, label encoder, and preprocessing statistics.
- Returns:
- bool: True if loaded successfully, False otherwise
- """
- try:
- base_dir = Path(self.model_path)
-
- # Try two possible paths: direct path or voice_memos_ripeness subdirectory
- possible_dirs = [
- base_dir, # Files directly in model_path
- base_dir / "voice_memos_ripeness" # Files in voice_memos_ripeness subdir
- ]
-
- model_dir = None
- for possible_dir in possible_dirs:
- if possible_dir.exists():
- # Check if this directory has the required files
- model_file = possible_dir / "best_model_mel_spec_grouped.keras"
- if model_file.exists():
- model_dir = possible_dir
- break
-
- if model_dir is None:
- logger.error(f"Could not find model files in: {base_dir} or {base_dir / 'voice_memos_ripeness'}")
- return False
- logger.info(f"Loading audio model from {model_dir}")
- # Load Keras model
- model_path = model_dir / "best_model_mel_spec_grouped.keras"
- if not model_path.exists():
- logger.error(f"Model file not found: {model_path}")
- return False
-
- logger.info(f"Loading TensorFlow model from {model_path}...")
- self.model = tf.keras.models.load_model(str(model_path))
- logger.info(f"✓ TensorFlow model loaded successfully")
- # Load label encoder
- encoder_path = model_dir / "label_encoder.pkl"
- if not encoder_path.exists():
- logger.error(f"Label encoder not found: {encoder_path}")
- return False
-
- logger.info(f"Loading label encoder from {encoder_path}...")
- with open(encoder_path, 'rb') as f:
- self.label_encoder = pickle.load(f)
- logger.info(f"✓ Label encoder loaded with classes: {list(self.label_encoder.classes_)}")
- # Load preprocessing stats
- stats_path = model_dir / "preprocessing_stats.json"
- if not stats_path.exists():
- logger.error(f"Preprocessing stats not found: {stats_path}")
- return False
-
- logger.info(f"Loading preprocessing stats from {stats_path}...")
- with open(stats_path, 'r') as f:
- self.preprocessing_stats = json.load(f)
- logger.info(f"✓ Preprocessing stats loaded, max_length: {self.preprocessing_stats.get('max_length')}")
- self._is_loaded = True
- logger.info("✓ Audio model loaded successfully")
- return True
- except Exception as e:
- logger.error(f"Failed to load audio model: {e}", exc_info=True)
- self._is_loaded = False
- return False
- def predict(self, audio_path: str) -> Dict[str, Any]:
- """
- Predict ripeness from an audio file using knock detection and mel-spectrogram analysis.
- Args:
- audio_path: Path to audio file (supports WAV and other formats via librosa)
- Returns:
- Dict containing:
- - 'class_name': Predicted class name (Ripe/Unripe/Overripe)
- - 'class_index': Predicted class index
- - 'probabilities': Dictionary of class probabilities (0-1 range)
- - 'confidence': Confidence score (0-1 range, averaged across knocks)
- - 'spectrogram_image': QPixmap of mel-spectrogram with knocks marked
- - 'waveform_image': QPixmap of waveform with knocks marked
- - 'knock_count': Number of knocks detected
- - 'knock_times': List of knock onset times in seconds
- - 'success': Whether prediction succeeded
- - 'error': Error message if failed
- """
- if not self._is_loaded or self.model is None:
- raise RuntimeError("Model not loaded. Call load() first.")
- try:
- # Ensure audio is in WAV format
- wav_path = self._ensure_wav_format(audio_path)
-
- # Load audio
- logger.info(f"Loading audio from {audio_path}")
- y, sr = librosa.load(wav_path, sr=self.MEL_PARAMS['sr'], mono=True)
-
- # Trim silence from beginning and end
- cut_samples = int(0.5 * sr)
- if len(y) > 2 * cut_samples:
- y = y[cut_samples:-cut_samples]
- elif len(y) > cut_samples:
- y = y[cut_samples:]
-
- # Detect knocks and extract features
- logger.info("Detecting knocks in audio...")
- features, knock_times = self._extract_knock_features(y, sr)
-
- logger.info(f"DEBUG: Detected {len(features)} knocks")
-
- if len(features) == 0:
- logger.error("❌ No knocks detected in audio file - returning error")
- return {
- 'success': False,
- 'class_name': None,
- 'class_index': None,
- 'probabilities': {},
- 'confidence': 0.0,
- 'spectrogram_image': None,
- 'waveform_image': None,
- 'knock_count': 0,
- 'knock_times': [],
- 'error': 'No knocks detected in audio file'
- }
-
- logger.info(f"Detected {len(features)} knocks at times: {knock_times}")
-
- # Prepare features for model
- max_length = self.preprocessing_stats['max_length']
- X = np.array([
- np.pad(f, ((0, max_length - f.shape[0]), (0, 0)), mode='constant')
- for f in features
- ])
-
- if len(X.shape) == 3:
- X = np.expand_dims(X, -1)
-
- # Run inference
- logger.info(f"Running model inference on {len(features)} knocks...")
- probs = self.model.predict(X, verbose=0)
-
- # Get per-knock predictions
- per_knock_preds = []
- for i, knock_probs in enumerate(probs):
- knock_pred_idx = np.argmax(knock_probs)
- knock_pred_class = self.label_encoder.classes_[knock_pred_idx]
- knock_confidence = float(knock_probs[knock_pred_idx])
- per_knock_preds.append({
- 'class': knock_pred_class,
- 'class_idx': knock_pred_idx,
- 'confidence': knock_confidence,
- 'probabilities': {self.label_encoder.classes_[j]: float(knock_probs[j]) for j in range(len(self.label_encoder.classes_))}
- })
-
- logger.info(f"Per-knock predictions: {[p['class'] for p in per_knock_preds]}")
-
- # Average predictions across all knocks (CONFIDENCE LOGIC)
- avg_probs = np.mean(probs, axis=0)
- predicted_idx = np.argmax(avg_probs)
- predicted_class = self.label_encoder.classes_[predicted_idx]
- confidence = float(avg_probs[predicted_idx])
-
- logger.info(f"Average probabilities: {dict(zip(self.label_encoder.classes_, avg_probs))}")
- logger.info(f"Final prediction: {predicted_class} ({confidence:.2%})")
-
- # Create probability dictionary
- prob_dict = {
- self.label_encoder.classes_[i]: float(avg_probs[i])
- for i in range(len(self.label_encoder.classes_))
- }
-
- # Capitalize class name for display
- predicted_class_display = predicted_class.capitalize() if isinstance(predicted_class, str) else predicted_class
-
- # Generate visualizations with knock annotations
- spectrogram_image = self._generate_mel_spectrogram_with_knocks(y, sr, knock_times)
- waveform_image = self._generate_waveform_with_knocks(y, sr, knock_times)
-
- logger.info(f"Prediction: {predicted_class_display} ({confidence:.2%}) from {len(features)} knocks")
-
- return {
- 'success': True,
- 'class_name': predicted_class_display,
- 'class_index': predicted_idx,
- 'probabilities': prob_dict,
- 'confidence': confidence,
- 'spectrogram_image': spectrogram_image,
- 'waveform_image': waveform_image,
- 'knock_count': len(features),
- 'knock_times': knock_times,
- 'per_knock_predictions': per_knock_preds,
- 'error': None
- }
- except Exception as e:
- error_msg = str(e)
- logger.error(f"Prediction failed: {error_msg}", exc_info=True)
-
- # Provide helpful error message for audio format issues
- if 'convert' in error_msg.lower() or 'format' in error_msg.lower():
- error_msg += " - Please ensure ffmpeg is installed: conda install -c conda-forge ffmpeg"
-
- return {
- 'success': False,
- 'class_name': None,
- 'class_index': None,
- 'probabilities': {},
- 'confidence': 0.0,
- 'spectrogram_image': None,
- 'waveform_image': None,
- 'knock_count': 0,
- 'knock_times': [],
- 'per_knock_predictions': [],
- 'error': error_msg
- }
-
- def _ensure_wav_format(self, audio_path: str) -> str:
- """
- Ensure audio file is in WAV format, converting if necessary.
- Supports M4A, MP3, OGG, FLAC, WMA and other formats.
- Args:
- audio_path: Path to audio file
- Returns:
- Path to WAV file (original or converted)
- """
- ext = os.path.splitext(audio_path)[1].lower()
- if ext == '.wav':
- return audio_path
-
- logger.info(f"Converting {ext} to WAV format...")
-
- # Try ffmpeg first (most reliable for various formats)
- try:
- import subprocess
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
- tmp_path = tmp_file.name
- tmp_file.close()
-
- # Use ffmpeg to convert
- cmd = [
- 'ffmpeg',
- '-i', audio_path,
- '-acodec', 'pcm_s16le',
- '-ar', str(self.MEL_PARAMS['sr']),
- '-ac', '1', # mono
- '-y', # overwrite
- tmp_path
- ]
-
- logger.info(f"Using ffmpeg to convert {ext}")
- subprocess.run(cmd, capture_output=True, check=True, timeout=30)
- logger.info(f"Converted to temporary WAV: {tmp_path}")
- return tmp_path
- except Exception as e:
- logger.warning(f"ffmpeg conversion failed: {e}, trying pydub...")
-
- # Try pydub second (handles most formats if installed)
- try:
- from pydub import AudioSegment
- logger.info(f"Using pydub to convert {ext}")
- audio = AudioSegment.from_file(audio_path)
-
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
- audio.export(tmp_file.name, format='wav')
- logger.info(f"Converted to temporary WAV: {tmp_file.name}")
- return tmp_file.name
- except Exception as e:
- logger.warning(f"pydub conversion failed: {e}, trying librosa...")
-
- # Try librosa as final fallback
- try:
- import soundfile as sf
- except ImportError:
- logger.warning("soundfile not available, using scipy for conversion")
- sf = None
-
- try:
- logger.info("Using librosa to load and convert audio")
- # Load with librosa (requires ffmpeg backend for non-WAV)
- y, sr = librosa.load(audio_path, sr=self.MEL_PARAMS['sr'], mono=True)
-
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
- if sf is not None:
- sf.write(tmp_file.name, y, sr)
- else:
- # Fallback: use scipy
- import scipy.io.wavfile as wavfile
- # Normalize to 16-bit range
- y_int16 = np.clip(y * 32767, -32768, 32767).astype(np.int16)
- wavfile.write(tmp_file.name, sr, y_int16)
-
- logger.info(f"Converted to temporary WAV: {tmp_file.name}")
- return tmp_file.name
- except Exception as e:
- logger.error(f"Audio conversion failed with all methods: {e}")
- logger.error(f"Please ensure ffmpeg is installed or install pydub: pip install pydub")
- raise RuntimeError(
- f"Failed to convert {ext} audio file. "
- "Install ffmpeg or pydub: 'pip install pydub' and 'pip install ffmpeg'"
- ) from e
-
- def _extract_knock_features(self, audio: np.ndarray, sr: int) -> Tuple[List[np.ndarray], List[float]]:
- """
- Detect knocks in audio and extract mel-spectrogram features.
- Args:
- audio: Audio time series
- sr: Sample rate
- Returns:
- Tuple of:
- - List of mel-spectrogram arrays (one per knock)
- - List of knock onset times in seconds
- """
- # Detect onsets (knock starts)
- logger.info("Detecting onset times...")
- onset_frames = librosa.onset.onset_detect(
- y=audio, sr=sr,
- delta=self.KNOCK_DETECTION['delta'],
- wait=self.KNOCK_DETECTION['wait'],
- units='frames'
- )
- onset_times = librosa.frames_to_time(onset_frames, sr=sr)
-
- if len(onset_times) == 0:
- logger.warning("No onsets detected")
- return [], []
-
- # Shift onsets slightly back
- shifted_times = [max(0, t - self.KNOCK_DETECTION['onset_shift']) for t in onset_times]
- logger.info(f"Detected {len(shifted_times)} onset times")
-
- # Extract knock segments
- knock_samples = int(round(self.KNOCK_DETECTION['knock_duration'] * sr))
- knocks = []
- valid_times = []
-
- for onset in shifted_times:
- start = int(round(onset * sr))
- end = start + knock_samples
-
- if start >= len(audio):
- continue
-
- if end <= len(audio):
- knock = audio[start:end]
- else:
- # Pad with zeros if at end
- knock = np.zeros(knock_samples, dtype=audio.dtype)
- available = len(audio) - start
- if available > 0:
- knock[:available] = audio[start:]
- else:
- continue
-
- knocks.append(knock)
- valid_times.append(onset)
-
- # Extract mel-spectrograms
- logger.info(f"Extracting mel-spectrograms from {len(knocks)} knocks...")
- features = []
- for knock in knocks:
- mel_spec = self._extract_mel_spectrogram(knock, sr)
- features.append(mel_spec)
-
- return features, valid_times
-
- def _extract_mel_spectrogram(self, audio: np.ndarray, sr: int) -> np.ndarray:
- """
- Extract normalized mel-spectrogram from audio.
- Args:
- audio: Audio segment
- sr: Sample rate
- Returns:
- Normalized mel-spectrogram (time, n_mels)
- """
- # Compute mel-spectrogram
- S = librosa.feature.melspectrogram(
- y=audio, sr=sr,
- n_mels=self.MEL_PARAMS['n_mels'],
- hop_length=self.MEL_PARAMS['hop_length'],
- n_fft=self.MEL_PARAMS['n_fft']
- )
-
- # Convert to dB scale
- S_db = librosa.power_to_db(S, ref=np.max)
-
- # Normalize
- std = np.std(S_db)
- if std != 0:
- S_db = (S_db - np.mean(S_db)) / std
- else:
- S_db = S_db - np.mean(S_db)
-
- return S_db.T # (time, n_mels)
-
- def predict_batch(self, audio_paths: list) -> list:
- """
- Predict ripeness for multiple audio files.
- Args:
- audio_paths: List of paths to audio files
- Returns:
- List[Dict]: List of prediction results
- """
- results = []
- for audio_path in audio_paths:
- result = self.predict(audio_path)
- results.append(result)
- return results
-
- def _generate_waveform_with_knocks(self, audio: np.ndarray, sr: int, knock_times: List[float]) -> Optional[QPixmap]:
- """
- Generate waveform visualization with knock locations marked.
- Args:
- audio: Audio time series
- sr: Sample rate
- knock_times: List of knock onset times in seconds
- Returns:
- QPixmap: Waveform plot with knock markers
- """
- try:
- fig, ax = plt.subplots(figsize=SPECTROGRAM_FIG_SIZE)
-
- # Plot waveform
- librosa.display.waveshow(audio, sr=sr, alpha=0.6, ax=ax)
-
- # Mark knock locations
- knock_duration = self.KNOCK_DETECTION['knock_duration']
- for knock_time in knock_times:
- # Vertical line at onset
- ax.axvline(knock_time, color='red', linestyle='--', alpha=0.8, linewidth=1.5)
- # Span showing knock duration
- ax.axvspan(knock_time, knock_time + knock_duration, color='orange', alpha=0.2)
-
- ax.set_title(f'Waveform with {len(knock_times)} Detected Knocks')
- ax.set_xlabel('Time (s)')
- ax.set_ylabel('Amplitude')
- ax.grid(True, alpha=0.3)
-
- # Convert to QPixmap
- canvas = FigureCanvas(fig)
- canvas.draw()
-
- width_px, height_px = fig.get_size_inches() * fig.get_dpi()
- width_px, height_px = int(width_px), int(height_px)
-
- img = QImage(canvas.buffer_rgba(), width_px, height_px, QImage.Format_ARGB32)
- img = img.rgbSwapped()
- pixmap = QPixmap(img)
-
- plt.close(fig)
-
- return pixmap
-
- except Exception as e:
- logger.error(f"Failed to generate waveform: {e}")
- return None
-
- def _generate_mel_spectrogram_with_knocks(self, audio: np.ndarray, sr: int, knock_times: List[float]) -> Optional[QPixmap]:
- """
- Generate mel-spectrogram visualization with knock locations marked.
- Args:
- audio: Audio time series
- sr: Sample rate
- knock_times: List of knock onset times in seconds
- Returns:
- QPixmap: Mel-spectrogram plot with knock markers
- """
- try:
- # Compute mel-spectrogram
- S = librosa.feature.melspectrogram(
- y=audio, sr=sr,
- n_mels=self.MEL_PARAMS['n_mels'],
- hop_length=self.MEL_PARAMS['hop_length'],
- n_fft=self.MEL_PARAMS['n_fft']
- )
-
- # Convert to dB scale
- S_db = librosa.power_to_db(S, ref=np.max)
-
- # Create figure with tight layout
- fig = plt.figure(figsize=SPECTROGRAM_FIG_SIZE)
- ax = fig.add_subplot(111)
-
- # Display mel-spectrogram
- img = librosa.display.specshow(
- S_db,
- x_axis='time',
- y_axis='mel',
- sr=sr,
- hop_length=self.MEL_PARAMS['hop_length'],
- cmap='magma',
- ax=ax
- )
-
- # Mark knock locations
- knock_duration = self.KNOCK_DETECTION['knock_duration']
- for knock_time in knock_times:
- # Vertical line at onset
- ax.axvline(knock_time, color='cyan', linestyle='--', alpha=0.8, linewidth=1.5)
- # Span showing knock duration
- ax.axvspan(knock_time, knock_time + knock_duration, color='cyan', alpha=0.15)
-
- ax.set_title(f'Mel Spectrogram with {len(knock_times)} Detected Knocks (64 Coefficients)')
- ax.set_xlabel('Time (s)')
- ax.set_ylabel('Mel Frequency')
-
- # Add colorbar properly
- plt.colorbar(img, ax=ax, format='%+2.0f dB', label='Power (dB)')
-
- plt.tight_layout()
-
- # Convert to QPixmap
- canvas = FigureCanvas(fig)
- canvas.draw()
-
- width_px, height_px = fig.get_size_inches() * fig.get_dpi()
- width_px, height_px = int(width_px), int(height_px)
-
- img_qimage = QImage(canvas.buffer_rgba(), width_px, height_px, QImage.Format_ARGB32)
- img_qimage = img_qimage.rgbSwapped()
- pixmap = QPixmap(img_qimage)
-
- plt.close(fig)
-
- return pixmap
-
- except Exception as e:
- logger.error(f"Failed to generate mel-spectrogram: {e}", exc_info=True)
- return None
-
- def _generate_spectrogram_image(self, audio: np.ndarray, sr: int) -> Optional[QPixmap]:
- """
- Generate a mel-spectrogram visualization from audio.
- Args:
- audio: Audio time series
- sr: Sample rate
- Returns:
- QPixmap: Rendered mel-spectrogram image or None if failed
- """
- try:
- # Compute mel-spectrogram
- S = librosa.feature.melspectrogram(
- y=audio, sr=sr,
- n_mels=self.MEL_PARAMS['n_mels'],
- hop_length=self.MEL_PARAMS['hop_length'],
- n_fft=self.MEL_PARAMS['n_fft']
- )
-
- # Convert to dB scale
- S_db = librosa.power_to_db(S, ref=np.max)
-
- # Create figure
- fig, ax = plt.subplots(figsize=SPECTROGRAM_FIG_SIZE)
-
- # Display mel-spectrogram
- librosa.display.specshow(
- S_db,
- x_axis='time',
- y_axis='mel',
- sr=sr,
- hop_length=self.MEL_PARAMS['hop_length'],
- cmap='magma',
- ax=ax
- )
-
- ax.set_title('Mel Spectrogram (64 coefficients)')
- ax.set_xlabel('Time (s)')
- ax.set_ylabel('Mel Frequency')
-
- # Convert to QPixmap
- canvas = FigureCanvas(fig)
- canvas.draw()
-
- width_px, height_px = fig.get_size_inches() * fig.get_dpi()
- width_px, height_px = int(width_px), int(height_px)
-
- img = QImage(canvas.buffer_rgba(), width_px, height_px, QImage.Format_ARGB32)
- img = img.rgbSwapped()
- pixmap = QPixmap(img)
-
- plt.close(fig)
-
- return pixmap
-
- except Exception as e:
- logger.error(f"Failed to generate spectrogram image: {e}")
- return None
|