diff --git a/README.md b/README.md index 0c4c757..85f37bd 100755 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ loss_batch.mean().backward() ``` ### Comparing NumPy and PyTorch versions : the static test -Values obtained with the NumPy version are compared to +Values obtained with the NumPy version (commit [84b1bd8](https://github.com/mpariente/pystoi/commit/84b1bd8f894c76eb5ddc3425946a4e2052e825fe)) are compared to the PyTorch version in the following graphs. ##### 8kHz Classic STOI measure diff --git a/plots/16kHzExtendedwithVAD.png b/plots/16kHzExtendedwithVAD.png index 83ff77a..f3d4218 100755 Binary files a/plots/16kHzExtendedwithVAD.png and b/plots/16kHzExtendedwithVAD.png differ diff --git a/plots/16kHzExtendedwoVAD.png b/plots/16kHzExtendedwoVAD.png index 57c137c..e554283 100755 Binary files a/plots/16kHzExtendedwoVAD.png and b/plots/16kHzExtendedwoVAD.png differ diff --git a/plots/16kHzwithVAD.png b/plots/16kHzwithVAD.png index 5a037c5..91da9d8 100755 Binary files a/plots/16kHzwithVAD.png and b/plots/16kHzwithVAD.png differ diff --git a/plots/16kHzwoVAD.png b/plots/16kHzwoVAD.png index 0624f6e..1e3d96c 100755 Binary files a/plots/16kHzwoVAD.png and b/plots/16kHzwoVAD.png differ diff --git a/plots/8kHzExtendedwithVAD.png b/plots/8kHzExtendedwithVAD.png index f1f9877..2420d80 100755 Binary files a/plots/8kHzExtendedwithVAD.png and b/plots/8kHzExtendedwithVAD.png differ diff --git a/plots/8kHzExtendedwoVAD.png b/plots/8kHzExtendedwoVAD.png index 4dedcd0..1be46c7 100755 Binary files a/plots/8kHzExtendedwoVAD.png and b/plots/8kHzExtendedwoVAD.png differ diff --git a/plots/8kHzwithVAD.png b/plots/8kHzwithVAD.png index 834dd3d..7fc2d65 100755 Binary files a/plots/8kHzwithVAD.png and b/plots/8kHzwithVAD.png differ diff --git a/plots/8kHzwoVAD.png b/plots/8kHzwoVAD.png index 94cb792..c54218e 100755 Binary files a/plots/8kHzwoVAD.png and b/plots/8kHzwoVAD.png differ diff --git a/torch_stoi/stoi.py b/torch_stoi/stoi.py index 7181c00..25a7156 100755 --- a/torch_stoi/stoi.py +++ b/torch_stoi/stoi.py @@ -12,8 +12,8 @@ class NegSTOILoss(nn.Module): """ Negated Short Term Objective Intelligibility (STOI) metric, to be used as a loss function. - Inspired from [1, 2, 3] but not exactly the same : cannot be used as - the STOI metric directly (use pystoi instead). See Notes. + Inspired from [1, 2, 3] but not exactly the same due to a different + resampling technique. Use pystoi when evaluating your system. Args: sample_rate (int): sample rate of audio input @@ -31,18 +31,15 @@ class NegSTOILoss(nn.Module): been reduced. Warnings: - This function cannot be used to compute the "real" STOI metric as - we applied some changes to speed-up loss computation. See Notes section. + This function does not exactly match the "real" STOI metric due to a + different resampling technique. Use pystoi when evaluating your system. Notes: - In the NumPy version, some kind of simple VAD was used to remove the - silent frames before chunking the signal into short-term envelope - vectors. We don't do the same here because removing frames in a - batch is cumbersome and inefficient. - If `use_vad` is set to True, instead we detect the silent frames and - keep a mask tensor. At the end, the normalized correlation of - short-term envelope vectors is masked using this mask (unfolded) and - the mean is computed taking the mask values into account. + `use_vad` can be set to `False` to skip the VAD for efficiency. However + results can become substantially different compared to the "real" STOI. + When `True` (default), results are very close but still slightly + different due to a different resampling technique. + Compared against mpariente/pystoi@84b1bd8. References [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time @@ -116,70 +113,111 @@ def forward(self, est_targets: torch.Tensor, est_targets.view(-1, wav_len), targets.view(-1, wav_len), ).view(inner) + batch_size = targets.shape[0] + if self.do_resample and self.sample_rate != FS: targets = self.resample(targets) est_targets = self.resample(est_targets) + if self.use_vad: + targets, est_targets, mask = self.remove_silent_frames( + targets, est_targets, self.dyn_range, self.win, self.win_len, + self.win_len//2 + ) + # Here comes the real computation, take STFT x_spec = self.stft(targets, self.win, self.nfft, overlap=2) y_spec = self.stft(est_targets, self.win, self.nfft, overlap=2) + # Reapply the mask because of overlap in STFT + if self.use_vad: + x_spec *= mask.unsqueeze(1) + y_spec *= mask.unsqueeze(1) + + """Uncommenting the following lines and the last block at the end + allows to replicate the pystoi behavior when less than N speech frames + are detected + """ + # # Ensure at least 30 frames for intermediate intelligibility + # if self.use_vad: + # not_enough = mask.sum(-1) < self.intel_frames + # if not_enough.any(): + # import warnings + # warnings.warn('Not enough STFT frames to compute intermediate ' + # 'intelligibility measure after removing silent ' + # 'frames. Returning 1e-5. Please check you wav ' + # 'files', RuntimeWarning) + # if not_enough.all(): + # return torch.full(batch_size, 1e-5) + # x_spec = x_spec[~not_enough] + # y_spec = y_spec[~not_enough] + # mask = mask[~not_enough] + # Apply OB matrix to the spectrograms as in Eq. (1) - x_tob = torch.matmul(self.OBM, torch.norm(x_spec, 2, -1) ** 2 + EPS).pow(0.5) - y_tob = torch.matmul(self.OBM, torch.norm(y_spec, 2, -1) ** 2 + EPS).pow(0.5) + x_tob = torch.matmul(self.OBM, x_spec.abs().pow(2) + EPS).sqrt() + y_tob = torch.matmul(self.OBM, y_spec.abs().pow(2) + EPS).sqrt() + # Perform N-frame segmentation --> (batch, 15, N, n_chunks) - batch = targets.shape[0] x_seg = unfold(x_tob.unsqueeze(2), kernel_size=(1, self.intel_frames), - stride=(1, 1)).view(batch, x_tob.shape[1], N, -1) + stride=(1, 1)).view(batch_size, x_tob.shape[1], self.intel_frames, -1) y_seg = unfold(y_tob.unsqueeze(2), kernel_size=(1, self.intel_frames), - stride=(1, 1)).view(batch, y_tob.shape[1], N, -1) - # Compute mask if use_vad + stride=(1, 1)).view(batch_size, y_tob.shape[1], self.intel_frames, -1) + # Reapply the mask because of overlap in N-frame segmentation if self.use_vad: - # Detech silent frames (boolean mask of shape (batch, 1, frame_idx) - mask = self.detect_silent_frames(targets, self.dyn_range, - self.win_len, self.win_len // 2) - mask = pad(mask, [0, x_tob.shape[-1] - mask.shape[-1]]) - # Unfold on the mask, to float and mean per frame. - mask_f = unfold(mask.unsqueeze(2).float(), - kernel_size=(1, self.intel_frames), - stride=(1, 1)).view(batch, 1, N, -1) - else: - mask_f = None + mask = mask[..., self.intel_frames-1:] + x_seg *= mask.unsqueeze(1).unsqueeze(2) + y_seg *= mask.unsqueeze(1).unsqueeze(2) if self.extended: # Normalize rows and columns of intermediate intelligibility frames - x_n = self.rowcol_norm(x_seg, mask=mask_f) - y_n = self.rowcol_norm(y_seg, mask=mask_f) + # No need to pass the mask because zeros do not affect statistics + x_n = self.rowcol_norm(x_seg) + y_n = self.rowcol_norm(y_seg) corr_comp = x_n * y_n - correction = self.intel_frames * x_n.shape[-1] + corr_comp = corr_comp.sum(1) + else: # Find normalization constants and normalize - norm_const = (masked_norm(x_seg, dim=2, keepdim=True, mask=mask_f) / - (masked_norm(y_seg, dim=2, keepdim=True, mask=mask_f) - + EPS)) + # No need to pass the mask because zeros do not affect statistics + norm_const = ( + x_seg.norm(p=2, dim=2, keepdim=True) / + (y_seg.norm(p=2, dim=2, keepdim=True) + EPS) + ) y_seg_normed = y_seg * norm_const # Clip as described in [1] clip_val = 10 ** (-self.beta / 20) y_prim = torch.min(y_seg_normed, x_seg * (1 + clip_val)) # Mean/var normalize vectors - y_prim = meanvar_norm(y_prim, dim=2, mask=mask_f) - x_seg = meanvar_norm(x_seg, dim=2, mask=mask_f) + # No need to pass the mask because zeros do not affect statistics + y_prim = y_prim - y_prim.mean(2, keepdim=True) + x_seg = x_seg - x_seg.mean(2, keepdim=True) + y_prim = y_prim / (y_prim.norm(p=2, dim=2, keepdim=True) + EPS) + x_seg = x_seg / (x_seg.norm(p=2, dim=2, keepdim=True) + EPS) # Matrix with entries summing to sum of correlations of vectors corr_comp = y_prim * x_seg - # J, M as in [1], eq.6 - correction = x_seg.shape[1] * x_seg.shape[-1] + corr_comp = corr_comp.sum(2) # Compute average (E)STOI w. or w/o VAD. - sum_over = list(range(1, x_seg.ndim)) # Keep batch dim + output = corr_comp.mean(1) if self.use_vad: - corr_comp = corr_comp * mask_f - correction = correction * mask_f.mean() + EPS - # Return -(E)STOI to optimize for - return - torch.sum(corr_comp, dim=sum_over) / correction + output = output.sum(-1)/mask.sum(-1) + else: + output = output.mean(-1) + + """Uncomment this to replicate the pystoi behavior when less than N + speech frames are detected + """ + # if np.any(not_enough): + # output_ = torch.empty(batch_size) + # output_[not_enough] = 1e-5 + # output_[~not_enough] = output + # output = output_ + + return - output @staticmethod - def detect_silent_frames(x, dyn_range, framelen, hop): + def remove_silent_frames(x, y, dyn_range, window, framelen, hop): """ Detects silent frames on input tensor. A frame is excluded if its energy is lower than max(energy) - dyn_range @@ -193,54 +231,92 @@ def detect_silent_frames(x, dyn_range, framelen, hop): torch.BoolTensor, framewise mask. """ x_frames = unfold(x[:, None, None, :], kernel_size=(1, framelen), - stride=(1, hop))[..., :-1] + stride=(1, hop)) + y_frames = unfold(y[:, None, None, :], kernel_size=(1, framelen), + stride=(1, hop)) + x_frames *= window[:, None] + y_frames *= window[:, None] + # Compute energies in dB x_energies = 20 * torch.log10(torch.norm(x_frames, dim=1, keepdim=True) + EPS) # Find boolean mask of energies lower than dynamic_range dB # with respect to maximum clean speech energy frame - mask = (torch.max(x_energies, dim=2, keepdim=True)[0] - dyn_range - - x_energies) < 0 - return mask + mask = (x_energies.amax(2, keepdim=True) - dyn_range - x_energies) < 0 + mask = mask.squeeze(1) + + # Remove silent frames and pad with zeroes + x_frames = x_frames.permute(0, 2, 1) + y_frames = y_frames.permute(0, 2, 1) + x_frames = _mask_audio(x_frames, mask) + y_frames = _mask_audio(y_frames, mask) + + x_sil = _overlap_and_add(x_frames, hop) + y_sil = _overlap_and_add(y_frames, hop) + x_frames = x_frames.permute(0, 2, 1) + y_frames = y_frames.permute(0, 2, 1) + + mask, _ = mask.long().sort(-1, descending=True) + + return x_sil, y_sil, mask @staticmethod def stft(x, win, fft_size, overlap=4): + """We can't use torch.stft: + - It's buggy with center=False as it discards the last frame + - It pads the frame left and right before taking the fft instead + of padding right + Instead we unfold and take rfft. This gives the same result as + pystoi.utils.stft. + """ win_len = win.shape[0] hop = int(win_len / overlap) - # Last frame not taken because NFFT size is larger, torch bug IMO. - x_padded = torch.nn.functional.pad(x, pad=[0, hop]) - # From torch 1.8.0 - try: - return torch.stft(x_padded, fft_size, hop_length=hop, window=win, - center=False, win_length=win_len, return_complex=False) - # Under 1.8.0 - except TypeError: - return torch.stft(x_padded, fft_size, hop_length=hop, window=win, - center=False, win_length=win_len) + frames = unfold(x[:, None, None, :], kernel_size=(1, win_len), + stride=(1, hop)) + return torch.fft.rfft(frames*win[:, None], n=fft_size, dim=1) @staticmethod - def rowcol_norm(x, mask=None): + def rowcol_norm(x): """ Mean/variance normalize axis 2 and 1 of input vector""" for dim in [2, 1]: - x = meanvar_norm(x, mask=mask, dim=dim) + x = x - x.mean(dim, keepdim=True) + x = x / (x.norm(p=2, dim=dim, keepdim=True) + EPS) return x -def meanvar_norm(x, mask=None, dim=-1): - x = x - masked_mean(x, dim=dim, mask=mask, keepdim=True) - x = x / (masked_norm(x, p=2, dim=dim, keepdim=True, mask=mask) + EPS) - return x +def _overlap_and_add(x_frames, hop): + batch_size, num_frames, framelen = x_frames.shape + # Compute the number of segments, per frame. + segments = -(-framelen // hop) # Divide and round up. + # Pad the framelen dimension to segments * hop and add n=segments frames + signal = pad( + x_frames, (0, segments * hop - framelen, 0, segments) + ) + + # Reshape to a 4D tensor, splitting the framelen dimension in two + signal = signal.reshape((batch_size, num_frames + segments, segments, hop)) + # Transpose dimensions so shape = (batch, segments, frame+segments, hop) + signal = signal.permute(0, 2, 1, 3) + # Reshape so that signal.shape = (batch, segments * (frame+segments), hop) + signal = signal.reshape((batch_size, -1, hop)) -def masked_mean(x, dim=-1, mask=None, keepdim=False): - if mask is None: - return x.mean(dim=dim, keepdim=keepdim) - return (x * mask).sum(dim=dim, keepdim=keepdim) / ( - mask.sum(dim=dim, keepdim=keepdim) + EPS + # Now behold the magic!! Remove last n=segments elements from second axis + signal = signal[:, :-segments] + # Reshape to (batch, segments, frame+segments-1, hop) + signal = signal.reshape( + (batch_size, segments, num_frames + segments - 1, hop) ) + # This has introduced a shift by one in all rows + + # Now, reduce over the columns and flatten the array to achieve the result + signal = signal.sum(axis=1) + end = (num_frames - 1) * hop + framelen + signal = signal.reshape((batch_size, -1))[:end] + return signal -def masked_norm(x, p=2, dim=-1, mask=None, keepdim=False): - if mask is None: - return torch.norm(x, p=p, dim=dim, keepdim=keepdim) - return torch.norm(x * mask, p=p, dim=dim, keepdim=keepdim) +def _mask_audio(x, mask): + return torch.stack([ + pad(xi[mi], (0, 0, 0, len(xi) - mi.sum())) for xi, mi in zip(x, mask) + ])