-
Notifications
You must be signed in to change notification settings - Fork 311
/
quant_infer.py
34 lines (27 loc) · 1.12 KB
/
quant_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import hiq, time
from hiq.memory import total_gpu_memory_mb, get_memory_mb
import platform
def main():
try:
wsl = 'microsoft' in platform.uname()[2].lower()
except:
wsl = False
driver = hiq.HiQLatency(
hiq_table_or_path=[
["llama.llama_infer", "", "run", "run_quant"],
["llama.llama_infer", "LLaMATokenizer", "from_pretrained", "from_pretrained"],
["llama.hf", "LLaMATokenizer", "encode", "encode"],
["llama.llama_infer", "", "load_quant", "load_quant"],
["llama.hf.modeling_llama", "LLaMAForCausalLM", "generate", "generate"]
],
metric_funcs=[time.time, get_memory_mb] + ([total_gpu_memory_mb] if not wsl else []), # WSL does not contain nvidia-smi
# extra_metrics={hiq.ExtraMetrics.ARGS},
)
args = hiq.mod("llama.llama_infer").get_args()
hiq.mod("llama.llama_infer").run(args)
print("*" * 30, ("GPU/" if not wsl else "") + "CPU/Latency Profiling", "*" * 30)
if wsl:
print('(WSL does not contain nvidia-smi, GPU profiling is disabled)')
driver.show()
if __name__ == "__main__":
main()