Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alejandro Gaston Alvarez Franceschi committed Nov 15, 2023
1 parent 4966daf commit 1aaed1c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 44 deletions.
20 changes: 9 additions & 11 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9504,9 +9504,8 @@ def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_leng
class STFTModel(torch.nn.Module):
def forward(self, x):
applied_window = window(win_length) if window and win_length else None
x = torch.complex(x, x) if complex else x
x = torch.stft(
x,
torch.complex(x, x) if complex else x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
Expand Down Expand Up @@ -9534,28 +9533,26 @@ class TestISTFT(TorchBaseTest):
compute_units,
backends,
[(1, 32, 9), (32, 9), (3, 32, 9)], # input shape
[False, True], # complex
[16], # n_fft
[None, 4, 5], # hop_length
[None, 16, 9], # win_length
[None, torch.hann_window], # window
[None, False, True], # center
["constant", "reflect", "replicate"], # pad mode
[False, True], # normalized
[None, False, True], # onesided
[None, 60], # length
[False, True], # return_complex
)
)
def test_istft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
if complex and onesided:
pytest.skip("Onesided stft not possible for complex inputs")
def test_istft(self, compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex):
if return_complex and onesided:
pytest.skip("Complex output is incompatible with onesided")

class ISTFTModel(torch.nn.Module):
def forward(self, x):
applied_window = window(win_length) if window and win_length else None
x = torch.complex(x, x)
x = torch.istft(
x,
torch.complex(x, x),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
Expand All @@ -9564,8 +9561,9 @@ def forward(self, x):
normalized=normalized,
onesided=onesided,
length=length,
return_complex=True)
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
return_complex=return_complex)
if return_complex:
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
return x

TorchBaseTest.run_compare_torch(
Expand Down
2 changes: 1 addition & 1 deletion coremltools/converters/mil/mil/input_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class TensorInputType(_InputType):
class conv(Operation):
input_spec = InputSpec(
x=TensorInputType(type_domain="T"),
weight=TensorInputType(type_domain="U"),
weight=TensorInputType(type_domain="T"),
)
type_domains = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ class complex_istft(Operation):
Attributes
----------
V: complex64
T: fp32, complex64
References
Expand All @@ -901,7 +902,7 @@ class complex_istft(Operation):
"""

input_spec = InputSpec(
input=TensorInputType(type_domain="T"),
input=TensorInputType(type_domain="V"),
n_fft=TensorInputType(const=True, type_domain=types.int32),
hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
win_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
Expand All @@ -912,7 +913,7 @@ class complex_istft(Operation):
)

type_domains = {
"T": (types.fp32, types.complex64),
"V": types.complex64,
}

def default_inputs(self):
Expand All @@ -937,7 +938,6 @@ def type_inference(self):
output_shape += [self.length]
return types.tensor(output_type, tuple(output_shape))


n_frames = self.input.shape[-1]
output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _stft(
We can write STFT in terms of convolutions with a DFT kernel.
At the end:
* The real part output is: cos_base * input_real + sin_base * input_imag
* The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
* The imaginary part output is: cos_base * input_imag - sin_base * input_real
Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
"""
hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op)
Expand All @@ -342,7 +342,7 @@ def _stft(

# create a window of centered 1s of the requested size
if win_length:
window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op)
window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op)

# apply time window
if window:
Expand All @@ -358,12 +358,13 @@ def _stft(
if input_imaginary:
signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op)

# conv with DFT kernel across the input signal
# The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
# DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
# If x is complex then x[n]=(a+i*b)
# So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
# So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
# Convolve the DFT kernel with the input signal
# DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n])
# real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k))
# imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k))
# But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k):
# real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k))
# imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k))
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
if input_imaginary:
Expand All @@ -372,11 +373,11 @@ def _stft(

# add everything together
if input_imaginary:
real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
else:
real_result = cos_windows_real
imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op)
imag_result = sin_windows_real

# reduce the rank of the output
if should_increase_rank:
Expand Down Expand Up @@ -417,17 +418,18 @@ def _istft(
# By default, use the entire frame
win_length = win_length or n_fft

input_shape = mb.shape(x=x, before_op=before_op)
n_frames = input_shape.val[-1]
fft_size = input_shape.val[-2]
# expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
input_shape = mb.shape(x=input_real, before_op=before_op)
channels = input_shape.val[0]
fft_size = input_shape.val[1]
n_frames = input_shape.val[2]
expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)

is_onesided = onesided.val if onesided else fft_size != n_fft
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)

# create a window of centered 1s of the requested size
if win_length:
window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op)
window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op)

# apply time window
if window:
Expand All @@ -447,14 +449,13 @@ def _istft(
signal_real = mb.mul(x=signal_real, y=multiplier, before_op=before_op)
signal_imaginary = mb.mul(x=signal_imaginary, y=multiplier, before_op=before_op)

# Conv with DFT kernel across the input signal
# We can describe the IDFT in terms of DFT just by swapping the input and output
# Convolve the DFT kernel with the input signal
# We can describe the IDFT in terms of DFT just by swapping the input and output.
# ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
# So IDFT(x) = (1/N) * swap(DFT(swap(x)))
# and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i)
# If x is complex then x[n]=(a+i*b)
# then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
# then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
# IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N
# So using the definition in stft function, we get:
# real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
# imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
Expand Down Expand Up @@ -519,6 +520,7 @@ def _overlap_add(
def _get_window(
win_length: Var,
n_fft: Var,
window: Optional[Var],
before_op: Operation,
) -> Var:
n_left = (n_fft.val - win_length.val) // 2
Expand Down Expand Up @@ -750,17 +752,21 @@ def _lower_complex_istft(op: Operation):
is_complex = types.is_complex(op.input.dtype)

# check parameters for validity
if is_complex:
raise ValueError("Only complex inputs are allowed")
if op.win_length and op.win_length.val > op.n_fft.val:
raise ValueError("Window length must be less than or equal to n_fft")
if is_complex and op.onesided and op.onesided.val:
raise ValueError("Onesided is only valid for real inputs")
if op.return_complex and op.onesided and op.onesided.val:
raise ValueError("Complex output is not compatible with onesided")

real, imag = _istft(
op.input.real if is_complex else op.input,
op.input.imag if is_complex else None,
op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, before_op=op)
op.input.real, op.input.imag,
op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, op.length, before_op=op)

return _wrap_complex_output(op.outputs[0], real, imag)
if op.return_complex:
return _wrap_complex_output(op.outputs[0], real, imag)
else
return real


@LowerComplex.register_lower_func(op_type="complex_shape")
Expand Down

0 comments on commit 1aaed1c

Please sign in to comment.