-
Notifications
You must be signed in to change notification settings - Fork 8
/
whisper_to_cml.py
55 lines (42 loc) · 1.24 KB
/
whisper_to_cml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import whisper
import numpy as np
import torch
import coremltools as ct
def load_models():
model = whisper.load_model("small").cpu()
return model.encoder, model.decoder
def convert_encoder_to_tvm(model):
model.eval()
input_shape = (1, 80, 3000)
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)
model = ct.convert(
traced_model,
convert_to="mlprogram",
inputs=[ct.TensorType(shape=input_shape)]
)
return model
def convert_decoder_to_tvm(model):
model.eval()
tokens_shape = (1, 1)
audio_shape = (1, 1500, 768)
token_data = torch.randn(tokens_shape).long()
audio_data = torch.randn(audio_shape)
traced_model = torch.jit.trace(model, (token_data, audio_data))
model = ct.convert(
traced_model,
convert_to="mlprogram",
inputs=[
ct.TensorType(shape=tokens_shape),
ct.TensorType(shape=audio_shape)
]
)
return model
def main():
encoder, decoder = load_models()
decoder = convert_decoder_to_tvm(decoder)
decoder.save("decoder.mlpackage")
encoder = convert_encoder_to_tvm(encoder)
encoder.save("encoder.mlpackage")
if __name__ == "__main__":
main()