Skip to content

Commit

Permalink
Update for near perfect correlation with pystoi (#9)
Browse files Browse the repository at this point in the history
* Update for near perfect correlation with pystoi

* Revert resampling_method update

* Fix comments

* Fix CUDA processing

* Update plots

* Update to reflect mpariente/pystoi#33

* Update plots

* Update docstring and README.md
  • Loading branch information
philgzl authored May 10, 2023
1 parent 0242f25 commit eb58aae
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 74 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ loss_batch.mean().backward()
```

### Comparing NumPy and PyTorch versions : the static test
Values obtained with the NumPy version are compared to
Values obtained with the NumPy version (commit [84b1bd8](https://github.com/mpariente/pystoi/commit/84b1bd8f894c76eb5ddc3425946a4e2052e825fe)) are compared to
the PyTorch version in the following graphs.
##### 8kHz
Classic STOI measure
Expand Down
Binary file modified plots/16kHzExtendedwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/16kHzExtendedwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/16kHzwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/16kHzwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzExtendedwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzExtendedwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzwithVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/8kHzwoVAD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
222 changes: 149 additions & 73 deletions torch_stoi/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
class NegSTOILoss(nn.Module):
""" Negated Short Term Objective Intelligibility (STOI) metric, to be used
as a loss function.
Inspired from [1, 2, 3] but not exactly the same : cannot be used as
the STOI metric directly (use pystoi instead). See Notes.
Inspired from [1, 2, 3] but not exactly the same due to a different
resampling technique. Use pystoi when evaluating your system.
Args:
sample_rate (int): sample rate of audio input
Expand All @@ -31,18 +31,15 @@ class NegSTOILoss(nn.Module):
been reduced.
Warnings:
This function cannot be used to compute the "real" STOI metric as
we applied some changes to speed-up loss computation. See Notes section.
This function does not exactly match the "real" STOI metric due to a
different resampling technique. Use pystoi when evaluating your system.
Notes:
In the NumPy version, some kind of simple VAD was used to remove the
silent frames before chunking the signal into short-term envelope
vectors. We don't do the same here because removing frames in a
batch is cumbersome and inefficient.
If `use_vad` is set to True, instead we detect the silent frames and
keep a mask tensor. At the end, the normalized correlation of
short-term envelope vectors is masked using this mask (unfolded) and
the mean is computed taking the mask values into account.
`use_vad` can be set to `False` to skip the VAD for efficiency. However
results can become substantially different compared to the "real" STOI.
When `True` (default), results are very close but still slightly
different due to a different resampling technique.
Compared against mpariente/pystoi@84b1bd8.
References
[1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
Expand Down Expand Up @@ -116,70 +113,111 @@ def forward(self, est_targets: torch.Tensor,
est_targets.view(-1, wav_len),
targets.view(-1, wav_len),
).view(inner)
batch_size = targets.shape[0]

if self.do_resample and self.sample_rate != FS:
targets = self.resample(targets)
est_targets = self.resample(est_targets)

if self.use_vad:
targets, est_targets, mask = self.remove_silent_frames(
targets, est_targets, self.dyn_range, self.win, self.win_len,
self.win_len//2
)

# Here comes the real computation, take STFT
x_spec = self.stft(targets, self.win, self.nfft, overlap=2)
y_spec = self.stft(est_targets, self.win, self.nfft, overlap=2)
# Reapply the mask because of overlap in STFT
if self.use_vad:
x_spec *= mask.unsqueeze(1)
y_spec *= mask.unsqueeze(1)

"""Uncommenting the following lines and the last block at the end
allows to replicate the pystoi behavior when less than N speech frames
are detected
"""
# # Ensure at least 30 frames for intermediate intelligibility
# if self.use_vad:
# not_enough = mask.sum(-1) < self.intel_frames
# if not_enough.any():
# import warnings
# warnings.warn('Not enough STFT frames to compute intermediate '
# 'intelligibility measure after removing silent '
# 'frames. Returning 1e-5. Please check you wav '
# 'files', RuntimeWarning)
# if not_enough.all():
# return torch.full(batch_size, 1e-5)
# x_spec = x_spec[~not_enough]
# y_spec = y_spec[~not_enough]
# mask = mask[~not_enough]

# Apply OB matrix to the spectrograms as in Eq. (1)
x_tob = torch.matmul(self.OBM, torch.norm(x_spec, 2, -1) ** 2 + EPS).pow(0.5)
y_tob = torch.matmul(self.OBM, torch.norm(y_spec, 2, -1) ** 2 + EPS).pow(0.5)
x_tob = torch.matmul(self.OBM, x_spec.abs().pow(2) + EPS).sqrt()
y_tob = torch.matmul(self.OBM, y_spec.abs().pow(2) + EPS).sqrt()

# Perform N-frame segmentation --> (batch, 15, N, n_chunks)
batch = targets.shape[0]
x_seg = unfold(x_tob.unsqueeze(2),
kernel_size=(1, self.intel_frames),
stride=(1, 1)).view(batch, x_tob.shape[1], N, -1)
stride=(1, 1)).view(batch_size, x_tob.shape[1], self.intel_frames, -1)
y_seg = unfold(y_tob.unsqueeze(2),
kernel_size=(1, self.intel_frames),
stride=(1, 1)).view(batch, y_tob.shape[1], N, -1)
# Compute mask if use_vad
stride=(1, 1)).view(batch_size, y_tob.shape[1], self.intel_frames, -1)
# Reapply the mask because of overlap in N-frame segmentation
if self.use_vad:
# Detech silent frames (boolean mask of shape (batch, 1, frame_idx)
mask = self.detect_silent_frames(targets, self.dyn_range,
self.win_len, self.win_len // 2)
mask = pad(mask, [0, x_tob.shape[-1] - mask.shape[-1]])
# Unfold on the mask, to float and mean per frame.
mask_f = unfold(mask.unsqueeze(2).float(),
kernel_size=(1, self.intel_frames),
stride=(1, 1)).view(batch, 1, N, -1)
else:
mask_f = None
mask = mask[..., self.intel_frames-1:]
x_seg *= mask.unsqueeze(1).unsqueeze(2)
y_seg *= mask.unsqueeze(1).unsqueeze(2)

if self.extended:
# Normalize rows and columns of intermediate intelligibility frames
x_n = self.rowcol_norm(x_seg, mask=mask_f)
y_n = self.rowcol_norm(y_seg, mask=mask_f)
# No need to pass the mask because zeros do not affect statistics
x_n = self.rowcol_norm(x_seg)
y_n = self.rowcol_norm(y_seg)
corr_comp = x_n * y_n
correction = self.intel_frames * x_n.shape[-1]
corr_comp = corr_comp.sum(1)

else:
# Find normalization constants and normalize
norm_const = (masked_norm(x_seg, dim=2, keepdim=True, mask=mask_f) /
(masked_norm(y_seg, dim=2, keepdim=True, mask=mask_f)
+ EPS))
# No need to pass the mask because zeros do not affect statistics
norm_const = (
x_seg.norm(p=2, dim=2, keepdim=True) /
(y_seg.norm(p=2, dim=2, keepdim=True) + EPS)
)
y_seg_normed = y_seg * norm_const
# Clip as described in [1]
clip_val = 10 ** (-self.beta / 20)
y_prim = torch.min(y_seg_normed, x_seg * (1 + clip_val))
# Mean/var normalize vectors
y_prim = meanvar_norm(y_prim, dim=2, mask=mask_f)
x_seg = meanvar_norm(x_seg, dim=2, mask=mask_f)
# No need to pass the mask because zeros do not affect statistics
y_prim = y_prim - y_prim.mean(2, keepdim=True)
x_seg = x_seg - x_seg.mean(2, keepdim=True)
y_prim = y_prim / (y_prim.norm(p=2, dim=2, keepdim=True) + EPS)
x_seg = x_seg / (x_seg.norm(p=2, dim=2, keepdim=True) + EPS)
# Matrix with entries summing to sum of correlations of vectors
corr_comp = y_prim * x_seg
# J, M as in [1], eq.6
correction = x_seg.shape[1] * x_seg.shape[-1]
corr_comp = corr_comp.sum(2)

# Compute average (E)STOI w. or w/o VAD.
sum_over = list(range(1, x_seg.ndim)) # Keep batch dim
output = corr_comp.mean(1)
if self.use_vad:
corr_comp = corr_comp * mask_f
correction = correction * mask_f.mean() + EPS
# Return -(E)STOI to optimize for
return - torch.sum(corr_comp, dim=sum_over) / correction
output = output.sum(-1)/mask.sum(-1)
else:
output = output.mean(-1)

"""Uncomment this to replicate the pystoi behavior when less than N
speech frames are detected
"""
# if np.any(not_enough):
# output_ = torch.empty(batch_size)
# output_[not_enough] = 1e-5
# output_[~not_enough] = output
# output = output_

return - output

@staticmethod
def detect_silent_frames(x, dyn_range, framelen, hop):
def remove_silent_frames(x, y, dyn_range, window, framelen, hop):
""" Detects silent frames on input tensor.
A frame is excluded if its energy is lower than max(energy) - dyn_range
Expand All @@ -193,54 +231,92 @@ def detect_silent_frames(x, dyn_range, framelen, hop):
torch.BoolTensor, framewise mask.
"""
x_frames = unfold(x[:, None, None, :], kernel_size=(1, framelen),
stride=(1, hop))[..., :-1]
stride=(1, hop))
y_frames = unfold(y[:, None, None, :], kernel_size=(1, framelen),
stride=(1, hop))
x_frames *= window[:, None]
y_frames *= window[:, None]

# Compute energies in dB
x_energies = 20 * torch.log10(torch.norm(x_frames, dim=1,
keepdim=True) + EPS)
# Find boolean mask of energies lower than dynamic_range dB
# with respect to maximum clean speech energy frame
mask = (torch.max(x_energies, dim=2, keepdim=True)[0] - dyn_range -
x_energies) < 0
return mask
mask = (x_energies.amax(2, keepdim=True) - dyn_range - x_energies) < 0
mask = mask.squeeze(1)

# Remove silent frames and pad with zeroes
x_frames = x_frames.permute(0, 2, 1)
y_frames = y_frames.permute(0, 2, 1)
x_frames = _mask_audio(x_frames, mask)
y_frames = _mask_audio(y_frames, mask)

x_sil = _overlap_and_add(x_frames, hop)
y_sil = _overlap_and_add(y_frames, hop)
x_frames = x_frames.permute(0, 2, 1)
y_frames = y_frames.permute(0, 2, 1)

mask, _ = mask.long().sort(-1, descending=True)

return x_sil, y_sil, mask

@staticmethod
def stft(x, win, fft_size, overlap=4):
"""We can't use torch.stft:
- It's buggy with center=False as it discards the last frame
- It pads the frame left and right before taking the fft instead
of padding right
Instead we unfold and take rfft. This gives the same result as
pystoi.utils.stft.
"""
win_len = win.shape[0]
hop = int(win_len / overlap)
# Last frame not taken because NFFT size is larger, torch bug IMO.
x_padded = torch.nn.functional.pad(x, pad=[0, hop])
# From torch 1.8.0
try:
return torch.stft(x_padded, fft_size, hop_length=hop, window=win,
center=False, win_length=win_len, return_complex=False)
# Under 1.8.0
except TypeError:
return torch.stft(x_padded, fft_size, hop_length=hop, window=win,
center=False, win_length=win_len)
frames = unfold(x[:, None, None, :], kernel_size=(1, win_len),
stride=(1, hop))
return torch.fft.rfft(frames*win[:, None], n=fft_size, dim=1)

@staticmethod
def rowcol_norm(x, mask=None):
def rowcol_norm(x):
""" Mean/variance normalize axis 2 and 1 of input vector"""
for dim in [2, 1]:
x = meanvar_norm(x, mask=mask, dim=dim)
x = x - x.mean(dim, keepdim=True)
x = x / (x.norm(p=2, dim=dim, keepdim=True) + EPS)
return x


def meanvar_norm(x, mask=None, dim=-1):
x = x - masked_mean(x, dim=dim, mask=mask, keepdim=True)
x = x / (masked_norm(x, p=2, dim=dim, keepdim=True, mask=mask) + EPS)
return x
def _overlap_and_add(x_frames, hop):
batch_size, num_frames, framelen = x_frames.shape
# Compute the number of segments, per frame.
segments = -(-framelen // hop) # Divide and round up.

# Pad the framelen dimension to segments * hop and add n=segments frames
signal = pad(
x_frames, (0, segments * hop - framelen, 0, segments)
)

# Reshape to a 4D tensor, splitting the framelen dimension in two
signal = signal.reshape((batch_size, num_frames + segments, segments, hop))
# Transpose dimensions so shape = (batch, segments, frame+segments, hop)
signal = signal.permute(0, 2, 1, 3)
# Reshape so that signal.shape = (batch, segments * (frame+segments), hop)
signal = signal.reshape((batch_size, -1, hop))

def masked_mean(x, dim=-1, mask=None, keepdim=False):
if mask is None:
return x.mean(dim=dim, keepdim=keepdim)
return (x * mask).sum(dim=dim, keepdim=keepdim) / (
mask.sum(dim=dim, keepdim=keepdim) + EPS
# Now behold the magic!! Remove last n=segments elements from second axis
signal = signal[:, :-segments]
# Reshape to (batch, segments, frame+segments-1, hop)
signal = signal.reshape(
(batch_size, segments, num_frames + segments - 1, hop)
)
# This has introduced a shift by one in all rows

# Now, reduce over the columns and flatten the array to achieve the result
signal = signal.sum(axis=1)
end = (num_frames - 1) * hop + framelen
signal = signal.reshape((batch_size, -1))[:end]
return signal


def masked_norm(x, p=2, dim=-1, mask=None, keepdim=False):
if mask is None:
return torch.norm(x, p=p, dim=dim, keepdim=keepdim)
return torch.norm(x * mask, p=p, dim=dim, keepdim=keepdim)
def _mask_audio(x, mask):
return torch.stack([
pad(xi[mi], (0, 0, 0, len(xi) - mi.sum())) for xi, mi in zip(x, mask)
])

0 comments on commit eb58aae

Please sign in to comment.