Source code for drumscript.audio_processor.feature_extractor
# DrumScript/audio_processor/feature_extractor.py
"""
This module will extract relevant features from audio segments for drum classification.
"""
from typing import Any
import librosa
import numpy as np
from drumscript.notation_generator.constants import HOP_LENGTH, N_FFT, ONSET_SLICE_DURATION_MS, SAMPLE_RATE, SEGMENT_LENGTH_SECONDS
# Calculate the expected number of frames (timesteps) per segment
# This calculation needs to be robust to ensure consistency with librosa's output.
# librosa.stft typically returns 1 + (len(y) - n_fft) // hop_length frames.
# So, we first determine the expected audio length that corresponds to SEGMENT_LENGTH_SECONDS
EXPECTED_AUDIO_LEN_SAMPLES = int(SEGMENT_LENGTH_SECONDS * SAMPLE_RATE)
EXPECTED_N_FRAMES = 1 + (EXPECTED_AUDIO_LEN_SAMPLES - N_FFT) // HOP_LENGTH
if EXPECTED_N_FRAMES < 1:
EXPECTED_N_FRAMES = 1 # Ensure at least one frame, even for very short segments
# previous Define TOTAL_FEATURES_PER_FRAME globally so it's accessible in __main__ block
# previous Number of MFCCs + Spectral Centroid + Spectral Rolloff + ZCR + RMS = 20 + 1 + 1 + 1 + 1 = 24
# previous TOTAL_FEATURES_PER_FRAME = 20 + 1 + 1 + 1 + 1
# current TOTAL_FEATURES_PER_FRAME = 40 (MFCCs) + 1 (Zero-Crossing Rate) = 41
N_MFCC = 41
# UPDATED: We are adding band energy features (Low, Mid, High), so +3 more features
TOTAL_FEATURES_PER_FRAME = N_MFCC + 3 + 3 # TOTAL_FEATURES_PER_FRAME = 47
# THE SCRIP WILL ADD FEATURES [1 (Centroid) + 1 (Rolloff) + 1 (RMS) + 3 (Band Energies)]
# --- Main Feature Extraction Functions ---
[docs]
def extract_features(audio_path: np.ndarray, sr: int) -> dict[str, Any]:
"""
Extracts a dictionary of features from a single audio segment.
Features are returned as mean values over the segment's duration.
:param audio_path: The audio data array.
:type audio_path: np.ndarray
:param sr: The sampling rate.
:type sr: int
:return: Dictionary of features (spectral_centroid, rms, mfccs, etc.).
:rtype: Dict[str, Any]
"""
if audio_path.size == 0:
return None
try:
# --- Standard Features ---
# spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_path, sr=sr))
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_path, sr=SAMPLE_RATE))
spectral_rolloff = np.mean(librosa.feature.spectral_rolloff(y=audio_path, sr=SAMPLE_RATE))
# spectral_rolloff = np.mean(librosa.feature.spectral_rolloff(y=audio_path, sr=sr))
rms = np.mean(librosa.feature.rms(y=audio_path))
zcr = np.mean(librosa.feature.zero_crossing_rate(y=audio_path))
# mfccs = np.mean(librosa.feature.mfcc(y=audio_path, sr=sr, n_mfcc=N_MFCC), axis=1)
mfccs = np.mean(librosa.feature.mfcc(y=audio_path, sr=SAMPLE_RATE, n_mfcc=N_MFCC), axis=1)
# --- Sustain Feature Calculation ---
# Split the segment into two halves to measure energy decay.
half_point = len(audio_path) // 2
first_half_rms = np.mean(librosa.feature.rms(y=audio_path[:half_point]))
second_half_rms = np.mean(librosa.feature.rms(y=audio_path[half_point:]))
# Calculate the ratio. Add a small epsilon to avoid division by zero.
sustain_level = second_half_rms / (first_half_rms + 1e-6)
# --- Band Energy Calculation (Based on Cheatsheet Logic) ---
# Calculate Spectrogram magnitude
S = np.abs(librosa.stft(audio_path, n_fft=N_FFT, hop_length=HOP_LENGTH))
# Get frequency bins
# fft_freqs = librosa.fft_frequencies(sr=sr, n_fft=N_FFT)
fft_freqs = librosa.fft_frequencies(sr=SAMPLE_RATE, n_fft=N_FFT)
# Define masks for bands based on constants.py (Low < 300Hz, Mid 300-5000Hz, High > 5000Hz)
# Note: 300Hz chosen to separate Kick/Floor Tom from Snare/Rack Tom body
low_band_mask = fft_freqs <= 300
mid_band_mask = (fft_freqs > 300) & (fft_freqs <= 5000)
high_band_mask = fft_freqs > 5000
# Sum energy in these bands (averaging over time)
energy_low = np.mean(np.sum(S[low_band_mask, :], axis=0))
energy_mid = np.mean(np.sum(S[mid_band_mask, :], axis=0))
energy_high = np.mean(np.sum(S[high_band_mask, :], axis=0))
return {
"spectral_centroid": spectral_centroid,
"spectral_rolloff": spectral_rolloff,
"rms": rms,
"zero_crossing_rate": zcr,
"mfccs": mfccs.tolist(),
"sustain_level": sustain_level,
"energy_low": energy_low, # Kick/Floor Tom indicator
"energy_mid": energy_mid, # Snare/Rack Tom indicator
"energy_high": energy_high, # Hi-Hat/Cymbal indicator
}
except Exception as e:
print(f"Warning: Error extracting features from a segment: {e}")
return None
[docs]
def extract_features_for_onsets(audio_path: np.ndarray, sr: int, onset_times: list[float]) -> list[dict[str, Any]]:
"""
Slices an audio array around each onset time and extracts features for each slice.
:param audio_path: Full audio array.
:type audio_path: np.ndarray
:param sr: Sampling rate.
:type sr: int
:param onset_times: List of onset timestamps.
:type onset_times: List[float]
:return: List of feature dictionaries.
:rtype: List[Dict[str, Any]]
"""
all_features = []
sr = SAMPLE_RATE
# Calculate *half* the slice duration in samples
half_slice_samples = int((ONSET_SLICE_DURATION_MS / 1000.0) * sr) // 2
for time_sec in onset_times:
# center_sample = librosa.time_to_samples(time_sec, sr=sr)
center_sample = librosa.time_to_samples(time_sec, sr=SAMPLE_RATE)
# Define start and end points, centered around the onset
start_sample = center_sample - half_slice_samples
end_sample = center_sample + half_slice_samples
# Boundary checks
start_sample = max(0, start_sample)
end_sample = min(len(audio_path), end_sample)
audio_slice = audio_path[start_sample:end_sample]
# Extract features for the slice
features = extract_features(audio_slice, sr)
if features:
# Add the onset time to the dictionary of features
features["onset_time"] = time_sec
all_features.append(features)
return all_features
# --------------------------------------------------------------------------uncomment during testing
# from datetime import datetime
# print("\n# ------------------------------------------------------------------------------------")
# datetimestamp = datetime.now()
# print(f'\ndate/time: {datetimestamp}')
# --------------------------------------------------------------------------------------------------