Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alejandro Gaston Alvarez Franceschi committed Jan 3, 2024
1 parent 1de46b2 commit 44a98ab
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9559,13 +9559,13 @@ def forward(self, x):
[None, 1, 3], # channels
[16, 32], # n_fft
[5, 9], # num_frames
[None, 4, 5], # hop_length
[None, 5], # hop_length
[None, 10, 8], # win_length
[None, torch.hann_window], # window
[False, True], # center
[False, True], # normalized
[None, False, True], # onesided
[None, 30, 40], # length
[None, "shorter", "larger"], # length
[False, True], # return_complex
)
)
Expand All @@ -9576,9 +9576,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
if hop_length is None and win_length is not None:
pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length")

# Compute input_shape to generate test case
freq = n_fft//2+1 if onesided else n_fft
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)

# If not set,c ompute hop_length for capturing errors
if hop_length is None:
hop_length = n_fft // 4

if length == "shorter":
length = n_fft//2 + hop_length * (num_frames - 1)
elif length == "larger":
length = n_fft*3//2 + hop_length * (num_frames - 1)

class ISTFTModel(torch.nn.Module):
def forward(self, x):
applied_window = window(win_length) if window and win_length else None
Expand All @@ -9598,7 +9608,7 @@ def forward(self, x):
else:
return torch.real(x)

if win_length and center is False:
if (center is False and win_length) or (center and win_length and length):
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"):
TorchBaseTest.run_compare_torch(
Expand All @@ -9607,7 +9617,7 @@ def forward(self, x):
backend=backend,
compute_unit=compute_unit
)
elif length is not None and return_complex is True:
elif length and return_complex:
with pytest.raises(ValueError, match="New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`"):
TorchBaseTest.run_compare_torch(
input_shape,
Expand Down

0 comments on commit 44a98ab

Please sign in to comment.