Skip to content

Commit

Permalink
add better validation
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 13, 2024
1 parent ebcb4ab commit 506e026
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
23 changes: 17 additions & 6 deletions examples/offline_inference_fakehpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"Berlin is the capital city of ",
"Louvre is located in the city called ",
"Barack Obama was the 44th president of ",
"Warsaw is the capital city of ",
"Gniezno is a city in ",
"Hebrew is an official state language of ",
"San Francisco is located in the state of ",
"Llanfairpwllgwyngyll is located in country of ",
]
ref_answers = [
"Germany", "Paris", "United States", "Poland", "Poland", "Israel",
"California", "Wales"
]
# Create a sampling params object.
sampling_params = SamplingParams()
sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
for output, answer in zip(outputs, ref_answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert answer in generated_text, (
f"The generated text does not contain the correct answer: {answer}")
print('PASSED')
3 changes: 2 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def is_hpu() -> bool:

@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return not _is_habana_frameworks_installed() and _is_built_for_hpu()
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' or (
not _is_habana_frameworks_installed() and _is_built_for_hpu())


@lru_cache(maxsize=None)
Expand Down

0 comments on commit 506e026

Please sign in to comment.