diff --git a/scripts/demo.py b/scripts/demo.py new file mode 100644 index 0000000..0344329 --- /dev/null +++ b/scripts/demo.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*-coding:utf-8 -*- +""" +@File : test.py +@Time : 2024/05/13 00:59:00 +@Author : Oriol Nieto +@Contact : onieto@adobe.com +@License : (C)Copyright 2024, Adobe Research +@Desc : Demo script comparing torchaudio's spectrogram with DFT version + of the ConvertibleSpectrogram. +""" + +import librosa +import numpy as np +import torch +import torchaudio +from scipy import signal as sig + +from convmelspec.stft import ConvertibleSpectrogram as Spectrogram + +FFT_SIZE = 1024 +SR = 16_000 +HOP_SIZE = 512 +MEL_BANDS = None + + +def get_hann_torch(win_size, sym=True): + wn = sig.windows.hann(win_size, sym=sym).astype(np.float32) + return torch.from_numpy(wn) + + +def get_audio(): + example_audio_path = librosa.example("nutcracker") + y, sr = librosa.load(example_audio_path, sr=SR) + total_sec = 1 + y = y[int(sr) : (total_sec * sr + sr)].astype(np.float32) + return torch.from_numpy(y).unsqueeze(0) + + +def main(): + + x = get_audio() + wn = sig.windows.hann(FFT_SIZE, sym=True) + + stft = Spectrogram( + sr=SR, + n_fft=FFT_SIZE, + hop_size=HOP_SIZE, + padding=0, + window=wn, + spec_mode="DFT", + dtype=torch.float32, + ) + + stft_ta = torchaudio.transforms.Spectrogram( + n_fft=FFT_SIZE, + hop_length=HOP_SIZE, + window_fn=get_hann_torch, + power=2.0, + center=False, + ) + + # AssertionError + assert torch.allclose(stft(x), stft_ta(x), atol=1e-5) + print("DFT and Torchaudio's Spectrogram all close!") + + +if __name__ == "__main__": + main()