Skip to content

Commit

Permalink
Modify offline_inference_tt.py to include max_tokens arg
Browse files Browse the repository at this point in the history
  • Loading branch information
milank94 committed Oct 21, 2024
1 parent 088f40a commit f8e8324
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def run_inference(
prompts_json,
default_max_tokens=128,
max_tokens=128,
max_seqs_in_batch=32,
num_repeat_prompts=2,
measure_perf=False,
Expand All @@ -23,11 +23,11 @@ def run_inference(
):
# Generation args
ignore_eos = True if measure_perf else False

if greedy_sampling:
sampling_params = SamplingParams(max_tokens=default_max_tokens, ignore_eos=ignore_eos, temperature=0.0)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=ignore_eos, temperature=0.0)
else:
sampling_params = SamplingParams(max_tokens=default_max_tokens, ignore_eos=ignore_eos, top_k=10, top_p=0.9, temperature=1.0)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=ignore_eos, top_k=10, top_p=0.9, temperature=1.0)

# Create an LLM.
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration)
Expand All @@ -41,19 +41,20 @@ def run_inference(
if num_repeat_prompts is not None:
prompts = prompts * num_repeat_prompts
print("Number of prompts:", len(prompts))

generate_tokens(llm, prompts, sampling_params, print_output=True)
else:
print("Note: Ignoring prompts for performance measurement")
run_inference_perf(llm, sampling_params, max_seqs_in_batch, input_prompt_len=perf_prompt_len)
run_inference_perf(llm, sampling_params, max_seqs_in_batch, max_tokens, input_prompt_len=perf_prompt_len)


def run_inference_perf(
llm : LLM,
sampling_params,
max_seqs_in_batch,
max_tokens,
prompts=None,
input_prompt_len=None # Used to generate dummy prompts if prompts is None
input_prompt_len=None, # Used to generate dummy prompts if prompts is None
):
assert llm.llm_engine.log_stats, "disable_log_stats=False is required for llm to use stat loggers"
if prompts is not None:
Expand All @@ -64,10 +65,11 @@ def run_inference_perf(
print("Measuring performance with dummy prompts of length", input_prompt_len)
prompt_token_ids = [[0]*input_prompt_len]*max_seqs_in_batch # dummy prompts
sampling_params = sampling_params[:max_seqs_in_batch] if isinstance(sampling_params, list) else sampling_params

# Set an arbitrary max_tokens to simulate generating multiple tokens consecutively
sampling_params.max_tokens = 33 # 1 prefill output token + 32 decode output tokens

print("Generating prompts with output length", max_tokens)
sampling_params.max_tokens = max_tokens

assert_str = f"prompt length ({input_prompt_len}) + num generated tokens ({sampling_params.max_tokens}) will exceed max_model_len ({llm.llm_engine.model_config.max_model_len})"
assert input_prompt_len + sampling_params.max_tokens <= llm.llm_engine.model_config.max_model_len, assert_str

Expand All @@ -88,11 +90,11 @@ def run_inference_perf(
print("Finished inference runs")

# Collect stats
ttft = llm.llm_engine.stat_loggers['global'].time_to_first_token.avg
ttft = llm.llm_engine.stat_loggers['global'].time_to_first_token.avg / max_seqs_in_batch
tpot = llm.llm_engine.stat_loggers['global'].time_per_output_token.avg
print(f"Average time to first token (batch): {ttft} s")
print(f"Average time to first token per user: {ttft} s")
print(f"Average decode throughput: {1/tpot} t/s/u")


def generate_tokens(llm : LLM, prompts, sampling_params, prompt_token_ids=None, print_output=True):
# Generate texts from the prompts. The output is a list of RequestOutput objects
Expand All @@ -104,13 +106,14 @@ def generate_tokens(llm : LLM, prompts, sampling_params, prompt_token_ids=None,
generated_text = output.outputs[0].text
if print_output:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompts_json", type=str, default="tt_metal/prompts.json", help="Path to JSON file containing prompts")
parser.add_argument("--measure_perf", action="store_true", help="Measure performance")
parser.add_argument("--perf_prompt_len", type=int, default=128, help="Length of dummy prompts for performance measurement")
parser.add_argument("--max_tokens", type=int, default=128, help="Length of outputs")
parser.add_argument("--greedy_sampling", action="store_true", help="Use greedy decoding instead of top-k/p")
parser.add_argument("--max_seqs_in_batch", type=int, default=32, help="Maximum batch size for inference")
args = parser.parse_args()
Expand All @@ -119,6 +122,7 @@ def generate_tokens(llm : LLM, prompts, sampling_params, prompt_token_ids=None,
args.prompts_json,
measure_perf=args.measure_perf,
perf_prompt_len=args.perf_prompt_len,
max_tokens=args.max_tokens,
greedy_sampling=args.greedy_sampling,
max_seqs_in_batch=args.max_seqs_in_batch
)

0 comments on commit f8e8324

Please sign in to comment.