Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add b2ai speaker verification functions #87

Merged
merged 18 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions audio_48khz_mono_16bits.wav
43 changes: 29 additions & 14 deletions src/senselab/audio/tasks/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,49 @@
"""This module implements some utilities for the preprocessing task."""

from typing import List, Tuple
from typing import List, Optional, Tuple

import pydra
import torchaudio.functional as F
import torch
from scipy import signal
from speechbrain.augment.time_domain import Resample

from senselab.audio.data_structures.audio import Audio


def resample_audios(audios: List[Audio], resample_rate: int, rolloff: float = 0.99) -> List[Audio]:
"""Resamples all Audios to a given sampling rate.

Takes a list of audios and resamples each into the new sampling rate. Notably does not assume any
specific structure of the audios (can vary in stereo vs. mono as well as their original sampling rate)
def resample_audios(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks you @ibevers . this is helpful. i did some further research and i think we should go with an alternative implementation that is not yours or mine, but that from transforms.Resample. transforms.resample is not that different from functionals.resample, but it precomputes and reuses the resampling kernel, so using it will result in more efficient computation if resampling multiple waveforms with the same resampling parameters. they both internally do the butterworth filtering for anti-aliasing - which is why your method and my method are redundant in the same wrapping function - and then resample the signal. we can pass order and lowcut as param and compute roll off by ourselves. I would appreciate if you could refactor the code.
reference: https://pytorch.org/audio/main/generated/torchaudio.transforms.Resample.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an fyi that transforms.resample does not use butterworth filtering. it uses sinc interpolation with a hamming or kaiser window. in at least my initial antiliasing tests with fixed sinuisoids it was not great at creating a good filter. it still passed some amount of the signal through. i can't find the notebook right this minute, but the general idea of the test is:

create sinusoid at 14 KHz sampled at 48K, then filter down to samping rate of 16K. you should not see on an FFT any signal peak at 2K (the aliased signal). if you do, that means the anti-aliasing filter is not doing a good job. anything that far from nyquist (8K) should be completely filtered out.

hence it may make sense to have multiple resamplers still.

Copy link
Collaborator Author

@ibevers ibevers Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hash this out in #90 and leave the resampling the way it is for this PR? @fabiocat93 @satra

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. i think that at this point we could just remove the torchaudio implementation and use yours from b2aiprep @ibevers . this will solve 2 issues at once

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

audios: List[Audio],
resample_rate: int,
lowcut: Optional[float] = None,
order: int = 4,
) -> List[Audio]:
"""Resamples a list of audio signals to a given sampling rate.

Args:
audios: List of Audios to resample
resample_rate: Rate at which to resample the Audio
rolloff: The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies
audios (List[Audio]): List of audio objects to resample.
resample_rate (int): Target sampling rate.
lowcut (float, optional): Low cut frequency for IIR filter.
order (int, optional): Order of the IIR filter. Defaults to 4.

Returns:
List of Audios that have all been resampled to the given resampling rate
List[Audio]: Resampled audio objects.
"""
resampled_audios = []
for audio in audios:
resampled = F.resample(audio.waveform, audio.sampling_rate, resample_rate, rolloff=rolloff)
if lowcut is None:
lowcut = resample_rate / 2 - 100
sos = signal.butter(order, lowcut, btype="low", output="sos", fs=resample_rate)

channels = []
for channel in audio.waveform:
filtered_channel = torch.from_numpy(signal.sosfiltfilt(sos, channel.numpy()).copy()).float()
resampler = Resample(orig_freq=audio.sampling_rate, new_freq=resample_rate)
resampled_channel = resampler(filtered_channel.unsqueeze(0)).squeeze(0)
channels.append(resampled_channel)

resampled_waveform = torch.stack(channels)
resampled_audios.append(
Audio(
waveform=resampled,
waveform=resampled_waveform,
sampling_rate=resample_rate,
metadata=audio.metadata.copy(),
orig_path_or_id=audio.orig_path_or_id,
Expand Down
1 change: 1 addition & 0 deletions src/senselab/audio/tasks/speaker_verification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Verifies whether two audio segments belong to the same speaker."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Audio Processing and Speaker Verification Module.

This module provides functions for resampling audio using an IIR filter and
verifying if two audio samples or files are from the same speaker using a
specified model.
"""

from typing import List, Optional, Tuple

from torch.nn.functional import cosine_similarity

from senselab.audio.data_structures.audio import Audio
from senselab.audio.tasks.speaker_embeddings.speechbrain import SpeechBrainEmbeddings
from senselab.utils.data_structures.device import DeviceType, _select_device_and_dtype
from senselab.utils.data_structures.model import SpeechBrainModel

TRAINING_SAMPLE_RATE = 16000 # spkrec-ecapa-voxceleb trained on 16kHz audio


def verify_speaker(
audios: List[Tuple[Audio, Audio]],
model: SpeechBrainModel = SpeechBrainModel(path_or_uri="speechbrain/spkrec-ecapa-voxceleb", revision="main"),
device: Optional[DeviceType] = None,
threshold: float = 0.25,
) -> List[Tuple[float, bool]]:
"""Verifies if two audio samples are from the same speaker.

Args:
audios (List[Tuple[Audio, Audio]]): A list of tuples, where each tuple contains
two audio samples to be compared.
model (SpeechBrainModel, optional): The model for speaker verification.
device (DeviceType, optional): The device to run the model on. Defaults to CPU.
threshold (float, optional): The threshold to determine same speaker.

Returns:
List[Tuple[float, bool]]: A list of tuples containing the verification score and
the prediction for each pair of audio samples. The
verification score is a float indicating the similarity
between the two samples, and the prediction is a boolean
indicating if the two samples are from the same speaker.
"""
device = _select_device_and_dtype(compatible_devices=[DeviceType.CPU, DeviceType.CUDA])[0]

scores_and_predictions = []
for audio1, audio2 in audios:
if audio1.sampling_rate != TRAINING_SAMPLE_RATE:
raise ValueError(f"{model.path_or_uri} trained on {TRAINING_SAMPLE_RATE} \
sample audio, but audio1 has sample rate {audio1.sampling_rate}.")
if audio2.sampling_rate != TRAINING_SAMPLE_RATE:
raise ValueError(f"{model.path_or_uri} trained on {TRAINING_SAMPLE_RATE} \
sample audio, but audio2 has sample rate {audio2.sampling_rate}.")

embeddings = SpeechBrainEmbeddings.extract_speechbrain_speaker_embeddings_from_audios(
audios=[audio1, audio2], model=model, device=device
)
embedding1, embedding2 = embeddings
similarity = cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0))
score = similarity.mean().item()
prediction = score > threshold
scores_and_predictions.append((score, prediction))
return scores_and_predictions
40 changes: 40 additions & 0 deletions src/tests/audio/tasks/speaker_verification_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Test Module for Audio Processing and Speaker Verification.

This module contains minimal tests to ensure the audio processing and speaker verification functions do not fail.

Tests:
- test_resample_iir: Tests the resample_iir function.
- test_verify_speaker: Tests the verify_speaker function.
- test_verify_speaker_from_files: Tests the verify_speaker_from_files function.
"""

import os

import pytest

from senselab.audio.data_structures.audio import Audio
from senselab.audio.tasks.preprocessing.preprocessing import resample_audios
from senselab.audio.tasks.speaker_verification.speaker_verification import (
verify_speaker,
)

if os.getenv("GITHUB_ACTIONS") != "true":

@pytest.mark.large_model
def test_verify_speaker(mono_audio_sample: Audio) -> None:
"""Tests the verify_speaker function to ensure it does not fail.

Args:
mono_audio_sample (Audio): The mono audio sample to use for testing.

Returns:
None
"""
mono_audio_sample = resample_audios([mono_audio_sample], 16000)[0]
assert mono_audio_sample.sampling_rate == 16000
mono_audio_samples = [(mono_audio_sample, mono_audio_sample)] * 3
scores_and_predictions = verify_speaker(mono_audio_samples)
assert scores_and_predictions
assert len(scores_and_predictions[0]) == 2
assert isinstance(scores_and_predictions[0][0], float)
assert isinstance(scores_and_predictions[0][1], bool)