Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update for near perfect correlation with pystoi #9

Merged
merged 8 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ loss_batch.mean().backward()
```

### Comparing NumPy and PyTorch versions : the static test
Values obtained with the NumPy version are compared to
Values obtained with the NumPy version (commit [84b1bd8](https://github.com/mpariente/pystoi/commit/84b1bd8f894c76eb5ddc3425946a4e2052e825fe)) are compared to
the PyTorch version in the following graphs.
##### 8kHz
Classic STOI measure
Expand Down
Binary file modified plots/16kHzExtendedwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/16kHzExtendedwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/16kHzwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/16kHzwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzExtendedwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzExtendedwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
222 changes: 149 additions & 73 deletions torch_stoi/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
class NegSTOILoss(nn.Module):
""" Negated Short Term Objective Intelligibility (STOI) metric, to be used
as a loss function.
Inspired from [1, 2, 3] but not exactly the same : cannot be used as
the STOI metric directly (use pystoi instead). See Notes.
Inspired from [1, 2, 3] but not exactly the same due to a different
resampling technique. Use pystoi when evaluating your system.
Args:
sample_rate (int): sample rate of audio input
Expand All @@ -31,18 +31,15 @@ class NegSTOILoss(nn.Module):
been reduced.
Warnings:
This function cannot be used to compute the "real" STOI metric as
we applied some changes to speed-up loss computation. See Notes section.
This function does not exactly match the "real" STOI metric due to a
different resampling technique. Use pystoi when evaluating your system.
Notes:
In the NumPy version, some kind of simple VAD was used to remove the
silent frames before chunking the signal into short-term envelope
vectors. We don't do the same here because removing frames in a
batch is cumbersome and inefficient.
If `use_vad` is set to True, instead we detect the silent frames and
keep a mask tensor. At the end, the normalized correlation of
short-term envelope vectors is masked using this mask (unfolded) and
the mean is computed taking the mask values into account.
`use_vad` can be set to `False` to skip the VAD for efficiency. However
results can become substantially different compared to the "real" STOI.
When `True` (default), results are very close but still slightly
different due to a different resampling technique.
Compared against mpariente/pystoi@84b1bd8.
References
[1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
Expand Down Expand Up @@ -116,70 +113,111 @@ def forward(self, est_targets: torch.Tensor,
est_targets.view(-1, wav_len),
targets.view(-1, wav_len),
).view(inner)
batch_size = targets.shape[0]

if self.do_resample and self.sample_rate != FS:
targets = self.resample(targets)
est_targets = self.resample(est_targets)

if self.use_vad:
targets, est_targets, mask = self.remove_silent_frames(
targets, est_targets, self.dyn_range, self.win, self.win_len,
self.win_len//2
)

# Here comes the real computation, take STFT
x_spec = self.stft(targets, self.win, self.nfft, overlap=2)
y_spec = self.stft(est_targets, self.win, self.nfft, overlap=2)
# Reapply the mask because of overlap in STFT
if self.use_vad:
x_spec *= mask.unsqueeze(1)
y_spec *= mask.unsqueeze(1)

"""Uncommenting the following lines and the last block at the end
allows to replicate the pystoi behavior when less than N speech frames
are detected
"""
# # Ensure at least 30 frames for intermediate intelligibility
# if self.use_vad:
# not_enough = mask.sum(-1) < self.intel_frames
# if not_enough.any():
# import warnings
# warnings.warn('Not enough STFT frames to compute intermediate '
# 'intelligibility measure after removing silent '
# 'frames. Returning 1e-5. Please check you wav '
# 'files', RuntimeWarning)
# if not_enough.all():
# return torch.full(batch_size, 1e-5)
# x_spec = x_spec[~not_enough]
# y_spec = y_spec[~not_enough]
# mask = mask[~not_enough]
mpariente marked this conversation as resolved.
Show resolved Hide resolved

# Apply OB matrix to the spectrograms as in Eq. (1)
x_tob = torch.matmul(self.OBM, torch.norm(x_spec, 2, -1) ** 2 + EPS).pow(0.5)
y_tob = torch.matmul(self.OBM, torch.norm(y_spec, 2, -1) ** 2 + EPS).pow(0.5)
x_tob = torch.matmul(self.OBM, x_spec.abs().pow(2) + EPS).sqrt()
y_tob = torch.matmul(self.OBM, y_spec.abs().pow(2) + EPS).sqrt()

# Perform N-frame segmentation --> (batch, 15, N, n_chunks)
batch = targets.shape[0]
x_seg = unfold(x_tob.unsqueeze(2),
kernel_size=(1, self.intel_frames),
stride=(1, 1)).view(batch, x_tob.shape[1], N, -1)
stride=(1, 1)).view(batch_size, x_tob.shape[1], self.intel_frames, -1)
y_seg = unfold(y_tob.unsqueeze(2),
kernel_size=(1, self.intel_frames),
stride=(1, 1)).view(batch, y_tob.shape[1], N, -1)
# Compute mask if use_vad
stride=(1, 1)).view(batch_size, y_tob.shape[1], self.intel_frames, -1)
# Reapply the mask because of overlap in N-frame segmentation
if self.use_vad:
# Detech silent frames (boolean mask of shape (batch, 1, frame_idx)
mask = self.detect_silent_frames(targets, self.dyn_range,
self.win_len, self.win_len // 2)
mask = pad(mask, [0, x_tob.shape[-1] - mask.shape[-1]])
# Unfold on the mask, to float and mean per frame.
mask_f = unfold(mask.unsqueeze(2).float(),
kernel_size=(1, self.intel_frames),
stride=(1, 1)).view(batch, 1, N, -1)
else:
mask_f = None
mask = mask[..., self.intel_frames-1:]
x_seg *= mask.unsqueeze(1).unsqueeze(2)
y_seg *= mask.unsqueeze(1).unsqueeze(2)

if self.extended:
# Normalize rows and columns of intermediate intelligibility frames
x_n = self.rowcol_norm(x_seg, mask=mask_f)
y_n = self.rowcol_norm(y_seg, mask=mask_f)
# No need to pass the mask because zeros do not affect statistics
x_n = self.rowcol_norm(x_seg)
y_n = self.rowcol_norm(y_seg)
corr_comp = x_n * y_n
correction = self.intel_frames * x_n.shape[-1]
corr_comp = corr_comp.sum(1)

else:
# Find normalization constants and normalize
norm_const = (masked_norm(x_seg, dim=2, keepdim=True, mask=mask_f) /
(masked_norm(y_seg, dim=2, keepdim=True, mask=mask_f)
+ EPS))
# No need to pass the mask because zeros do not affect statistics
norm_const = (
x_seg.norm(p=2, dim=2, keepdim=True) /
(y_seg.norm(p=2, dim=2, keepdim=True) + EPS)
)
y_seg_normed = y_seg * norm_const
# Clip as described in [1]
clip_val = 10 ** (-self.beta / 20)
y_prim = torch.min(y_seg_normed, x_seg * (1 + clip_val))
# Mean/var normalize vectors
y_prim = meanvar_norm(y_prim, dim=2, mask=mask_f)
x_seg = meanvar_norm(x_seg, dim=2, mask=mask_f)
# No need to pass the mask because zeros do not affect statistics
y_prim = y_prim - y_prim.mean(2, keepdim=True)
x_seg = x_seg - x_seg.mean(2, keepdim=True)
y_prim = y_prim / (y_prim.norm(p=2, dim=2, keepdim=True) + EPS)
x_seg = x_seg / (x_seg.norm(p=2, dim=2, keepdim=True) + EPS)
# Matrix with entries summing to sum of correlations of vectors
corr_comp = y_prim * x_seg
# J, M as in [1], eq.6
correction = x_seg.shape[1] * x_seg.shape[-1]
corr_comp = corr_comp.sum(2)

# Compute average (E)STOI w. or w/o VAD.
sum_over = list(range(1, x_seg.ndim)) # Keep batch dim
output = corr_comp.mean(1)
if self.use_vad:
corr_comp = corr_comp * mask_f
correction = correction * mask_f.mean() + EPS
# Return -(E)STOI to optimize for
return - torch.sum(corr_comp, dim=sum_over) / correction
output = output.sum(-1)/mask.sum(-1)
else:
output = output.mean(-1)

"""Uncomment this to replicate the pystoi behavior when less than N
speech frames are detected
"""
# if np.any(not_enough):
# output_ = torch.empty(batch_size)
# output_[not_enough] = 1e-5
# output_[~not_enough] = output
# output = output_
philgzl marked this conversation as resolved.
Show resolved Hide resolved

return - output

@staticmethod
def detect_silent_frames(x, dyn_range, framelen, hop):
def remove_silent_frames(x, y, dyn_range, window, framelen, hop):
""" Detects silent frames on input tensor.
A frame is excluded if its energy is lower than max(energy) - dyn_range
Expand All @@ -193,54 +231,92 @@ def detect_silent_frames(x, dyn_range, framelen, hop):
torch.BoolTensor, framewise mask.
"""
x_frames = unfold(x[:, None, None, :], kernel_size=(1, framelen),
stride=(1, hop))[..., :-1]
stride=(1, hop))
y_frames = unfold(y[:, None, None, :], kernel_size=(1, framelen),
stride=(1, hop))
x_frames *= window[:, None]
y_frames *= window[:, None]

# Compute energies in dB
x_energies = 20 * torch.log10(torch.norm(x_frames, dim=1,
keepdim=True) + EPS)
# Find boolean mask of energies lower than dynamic_range dB
# with respect to maximum clean speech energy frame
mask = (torch.max(x_energies, dim=2, keepdim=True)[0] - dyn_range -
x_energies) < 0
return mask
mask = (x_energies.amax(2, keepdim=True) - dyn_range - x_energies) < 0
mask = mask.squeeze(1)

# Remove silent frames and pad with zeroes
x_frames = x_frames.permute(0, 2, 1)
y_frames = y_frames.permute(0, 2, 1)
x_frames = _mask_audio(x_frames, mask)
y_frames = _mask_audio(y_frames, mask)

x_sil = _overlap_and_add(x_frames, hop)
y_sil = _overlap_and_add(y_frames, hop)
x_frames = x_frames.permute(0, 2, 1)
y_frames = y_frames.permute(0, 2, 1)

mask, _ = mask.long().sort(-1, descending=True)

return x_sil, y_sil, mask

@staticmethod
def stft(x, win, fft_size, overlap=4):
"""We can't use torch.stft:
- It's buggy with center=False as it discards the last frame
- It pads the frame left and right before taking the fft instead
of padding right
Instead we unfold and take rfft. This gives the same result as
pystoi.utils.stft.
"""
win_len = win.shape[0]
hop = int(win_len / overlap)
# Last frame not taken because NFFT size is larger, torch bug IMO.
x_padded = torch.nn.functional.pad(x, pad=[0, hop])
# From torch 1.8.0
try:
return torch.stft(x_padded, fft_size, hop_length=hop, window=win,
center=False, win_length=win_len, return_complex=False)
# Under 1.8.0
except TypeError:
return torch.stft(x_padded, fft_size, hop_length=hop, window=win,
center=False, win_length=win_len)
frames = unfold(x[:, None, None, :], kernel_size=(1, win_len),
stride=(1, hop))
return torch.fft.rfft(frames*win[:, None], n=fft_size, dim=1)

@staticmethod
def rowcol_norm(x, mask=None):
def rowcol_norm(x):
""" Mean/variance normalize axis 2 and 1 of input vector"""
for dim in [2, 1]:
x = meanvar_norm(x, mask=mask, dim=dim)
x = x - x.mean(dim, keepdim=True)
x = x / (x.norm(p=2, dim=dim, keepdim=True) + EPS)
return x


def meanvar_norm(x, mask=None, dim=-1):
x = x - masked_mean(x, dim=dim, mask=mask, keepdim=True)
x = x / (masked_norm(x, p=2, dim=dim, keepdim=True, mask=mask) + EPS)
return x
def _overlap_and_add(x_frames, hop):
batch_size, num_frames, framelen = x_frames.shape
# Compute the number of segments, per frame.
segments = -(-framelen // hop) # Divide and round up.

# Pad the framelen dimension to segments * hop and add n=segments frames
signal = pad(
x_frames, (0, segments * hop - framelen, 0, segments)
)

# Reshape to a 4D tensor, splitting the framelen dimension in two
signal = signal.reshape((batch_size, num_frames + segments, segments, hop))
# Transpose dimensions so shape = (batch, segments, frame+segments, hop)
signal = signal.permute(0, 2, 1, 3)
# Reshape so that signal.shape = (batch, segments * (frame+segments), hop)
signal = signal.reshape((batch_size, -1, hop))

def masked_mean(x, dim=-1, mask=None, keepdim=False):
if mask is None:
return x.mean(dim=dim, keepdim=keepdim)
return (x * mask).sum(dim=dim, keepdim=keepdim) / (
mask.sum(dim=dim, keepdim=keepdim) + EPS
# Now behold the magic!! Remove last n=segments elements from second axis
philgzl marked this conversation as resolved.
Show resolved Hide resolved
signal = signal[:, :-segments]
# Reshape to (batch, segments, frame+segments-1, hop)
signal = signal.reshape(
(batch_size, segments, num_frames + segments - 1, hop)
)
# This has introduced a shift by one in all rows

# Now, reduce over the columns and flatten the array to achieve the result
signal = signal.sum(axis=1)
end = (num_frames - 1) * hop + framelen
signal = signal.reshape((batch_size, -1))[:end]
return signal


def masked_norm(x, p=2, dim=-1, mask=None, keepdim=False):
if mask is None:
return torch.norm(x, p=p, dim=dim, keepdim=keepdim)
return torch.norm(x * mask, p=p, dim=dim, keepdim=keepdim)
def _mask_audio(x, mask):
return torch.stack([
pad(xi[mi], (0, 0, 0, len(xi) - mi.sum())) for xi, mi in zip(x, mask)
])
Comment on lines +320 to +322
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the list comprehension which takes the most time, and it makes sense that it does.