forked from ml-explore/mlx-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark.py
93 lines (75 loc) · 1.86 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright © 2023 Apple Inc.
import time
import mlx.core as mx
from whisper import load_models
from whisper import audio
from whisper import decoding
from whisper import transcribe
audio_file = "whisper/assets/ls_test.flac"
def timer(fn, *args):
for _ in range(5):
fn(*args)
num_its = 10
tic = time.perf_counter()
for _ in range(num_its):
fn(*args)
toc = time.perf_counter()
return (toc - tic) / num_its
def feats():
data = audio.load_audio(audio_file)
data = audio.pad_or_trim(data)
mels = audio.log_mel_spectrogram(data)
mx.eval(mels)
return mels
def model_forward(model, mels, tokens):
logits = model(mels, tokens)
mx.eval(logits)
return logits
def decode(model, mels):
return decoding.decode(model, mels)
def everything():
return transcribe(audio_file)
if __name__ == "__main__":
feat_time = timer(feats)
print(f"Feature time {feat_time:.3f}")
mels = feats()[None]
tokens = mx.array(
[
50364,
1396,
264,
665,
5133,
23109,
25462,
264,
6582,
293,
750,
632,
42841,
292,
370,
938,
294,
4054,
293,
12653,
356,
50620,
50620,
23563,
322,
3312,
13,
50680,
],
mx.int32,
)[None]
model = load_models.load_model("tiny")
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)
print(f"Decode time {decode_time:.3f}")
everything_time = timer(everything)
print(f"Everything time {everything_time:.3f}")