Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alejandro Gaston Alvarez Franceschi committed Jan 2, 2024
1 parent 48c34e3 commit c66d1aa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9540,7 +9540,7 @@ def forward(self, x):
[False, True], # center
[False, True], # normalized
[None, False, True], # onesided
[None, 60], # length
[None, 30, 40], # length
[False, True], # return_complex
)
)
Expand Down Expand Up @@ -9582,6 +9582,14 @@ def forward(self, x):
backend=backend,
compute_unit=compute_unit
)
elif length is not None and return_complex is True:
with pytest.raises(ValueError, match="New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`"):
TorchBaseTest.run_compare_torch(
input_shape,
ISTFTModel(),
backend=backend,
compute_unit=compute_unit
)
else:
TorchBaseTest.run_compare_torch(
input_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def _istft(
window_mtx = mb.stack(values=[window_square] * n_frames, axis=0, before_op=before_op)
window_mtx = mb.expand_dims(x=window_mtx, axes=(0,), before_op=before_op)
window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)

# After this operation if it didn't have any channels dimention it adds one
real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op)
imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op)
# We need to adapt last dimension
Expand All @@ -487,12 +489,8 @@ def _istft(
real_result = mb.pad(x=real_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
imag_result = mb.pad(x=imag_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
elif length.val < expected_output_signal_len:
if channels:
real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
else:
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)
real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op)

return real_result, imag_result

Expand Down

0 comments on commit c66d1aa

Please sign in to comment.