Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: block_softmax accuracy issue in flat_pa kernel, qwen2-7B model #275

Open
jikunshang opened this issue Sep 12, 2024 · 44 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@jikunshang
Copy link

Your current environment

The output of `python collect_env.py`

🐛 Describe the bug

This issue is introduced by block_softmax kernel(part of flat_pa, see #169 )
For some models(like Qwen2-7B), value of Q * K.t() * scale may be much greater (like 2000+), block_softmax just subtrct a dummy max_value(hardcoded 10), the exp kernel may overflow, lead to follow calculation not correct.
you can use examples/offline_inference.py with Qwen/Qwen2-7B model to reproduce, result is like below

python3 examples/offline_inference.py

... vllm log ...

Processed prompts: 100%|????????????????????????????????| 5/5 [00:10<00:00,  2.05s/it, est. speed input: 2.44 toks/s, output: 62.53 toks/s]
Prompt: 'Hello, my name is', Generated text: ' Dr!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
Prompt: 'The president of the United States is', Generated text: ' the!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
Prompt: 'The capital of France is', Generated text: ' Paris!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
Prompt: 'The future of AI is', Generated text: ' here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
Prompt: 'who are you', Generated text: '?!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

if you print attn in block_softmax, you will see the value are abnormal, compare to llama model.
image

@jikunshang jikunshang added the bug Something isn't working label Sep 12, 2024
@szutenberg
Copy link

Hi @jikunshang,

Thanks for the analysis.

We had a look at the issue and the clean fix is not trivial because block2batch performs also summation which obviously breaks max operation.

Could you check if our new WA ( #280 ) fixes the issue?

Thanks.

PS I'll close #268 since it's anyway marked as "DNM". Let's move the discussion to this issue.

@xuechendi
Copy link

xuechendi commented Sep 12, 2024

Hi, @szutenberg ,

Thanks for the quick fixing through #280.

I tested with #280, however, generated output is not very good, even with incorrect characters on some queries.

Query:
"给我介绍一下中国" => "Introduce me to China"

image

output is not finished and not making sense.

image

@jikunshang
Copy link
Author

Thanks @xuechendi for verify this.
@szutenberg I also tried this fix from #280 . the problem is same, for attention value I showed in picture, it will get a max_val = 2768. For other values, sub_(max_val) will be extramelly small, exp() will get zero. the correct max value should calculate on dim =-1 and dim = 1 partially(blocks of same sequence)

@madamczykhabana
Copy link

Hi @jikunshang and @xuechendi . Thanks for the detailed analysis!
max_val = 2768 is definitely going to be problematic.

Just to let you know I've started working on a proper fix. The tricky part is to implement it in a way that won't dramatically hurt performance in other models. Once I have a working prototype I'll send it to you so that you can verify on your end.

@madamczykhabana
Copy link

I've looked at attention values between heads and there's no other option here then calculating local maximum per bs/qheads/kvheads . Even for a single attention head the values can span from -2k to 2k. Which is in-line with your observations.

The good news is that once I've changed the code to properly calculate the local maximum the outputs started to look ok:

Prompt: 'The capital of France is', Generated text: ' Paris. It is the largest city in France and the second largest city in the European Union. Paris is a major global city and is considered one of the most important cultural, artistic, and scientific centers in the world. It is home to many famous landmarks such as the Eiffel Tower, the Louvre Museum,'

Prompt: 'The capital of France is', Generated text: ' Paris. It is the largest city in France and the second largest city in the European Union. Paris is a major global city and is considered one of the most important cultural, artistic, and scientific centers in the world. It is home to many famous landmarks such as the Eiffel Tower, the Louvre Museum,'

Prompt: 'The president of the United States is', Generated text: " the head of state and head of government of the United States. The president directs the executive branch of the federal government and is the commander-in-chief of the United States Armed Forces. The president is also the head of the nation's foreign policy and is responsible for appointing most of the officials who run the federal government."

Prompt: 'Hello, my name is', Generated text: ' Dr. David M. Berman. I am a board-certified plastic surgeon in the Washington, DC area. I am a member of the American Society of Plastic Surgeons, the American Society for Aesthetic Plastic Surgery, and the American Society of Maxillofacial Surgeons. I am also a member of the'

Next step is to remove hardcoded values and figure out how to do calculate the local maximums efficiently on Gaudi.

@jikunshang
Copy link
Author

Thanks for your investigation!

@Yanli2190
Copy link

Yanli2190 commented Sep 19, 2024

@jikunshang @madamczykhabana , does this issue only impact Qwen? we tried meta-llama/Llama-2-7b-chat-hf, meta-llama/Llama-2-70b-chat-hf(tp-4), and sbintuitions/sarashina2-70b(tp=4) based on latest habana_main commit: b62fba8, the model outupt quality isn't good
is it same root cause?
meta-llama/Llama-2-7b-chat-hf

Prompt: 'Hello, my name is', Generated text: ' Ethan,'
Prompt: 'The president of the United States is', Generated text: ' a great and'
Prompt: 'The capital of France is', Generated text: ' Paris, which'
Prompt: 'The future of AI is', Generated text: ' in the hands'

meta-llama/Llama-2-70b-chat-hf

Prompt: 'Hello, my name is', Generated text: ' Margo'
Prompt: 'The president of the United States is', Generated text: ' a powerful'
Prompt: 'The capital of France is', Generated text: ' Paris,'
Prompt: 'The future of AI is', Generated text: ' in the'

sbintuitions/sarashina2-70b

Prompt: 'Hello, my name is', Generated text: ' Noemi. I have been a Spanish teacher for 14 years. I'
Prompt: 'The president of the United States is', Generated text: ' mandated to hold the title of Commander in Chief of the U.S'
Prompt: 'The capital of France is', Generated text: ' Paris. The capital of Great Britain is London.'
Prompt: 'The future of AI is', Generated text: ' ultimately unclear, and its development is a complex process that is often'

@jikunshang
Copy link
Author

@jikunshang @madamczykhabana , does this issue only impact Qwen? we tried meta-llama/Llama-2-7b-chat-hf, meta-llama/Llama-2-70b-chat-hf(tp-4), and sbintuitions/sarashina2-70b(tp=4) based on latest habana_main commit: b62fba8, the model outupt quality isn't good is it same root cause? meta-llama/Llama-2-7b-chat-hf

Prompt: 'Hello, my name is', Generated text: ' Ethan,' Prompt: 'The president of the United States is', Generated text: ' a great and' Prompt: 'The capital of France is', Generated text: ' Paris, which' Prompt: 'The future of AI is', Generated text: ' in the hands'

meta-llama/Llama-2-70b-chat-hf

Prompt: 'Hello, my name is', Generated text: ' Margo' Prompt: 'The president of the United States is', Generated text: ' a powerful' Prompt: 'The capital of France is', Generated text: ' Paris,' Prompt: 'The future of AI is', Generated text: ' in the'

sbintuitions/sarashina2-70b

Prompt: 'Hello, my name is', Generated text: ' Noemi. I have been a Spanish teacher for 14 years. I' Prompt: 'The president of the United States is', Generated text: ' mandated to hold the title of Commander in Chief of the U.S' Prompt: 'The capital of France is', Generated text: ' Paris. The capital of Great Britain is London.' Prompt: 'The future of AI is', Generated text: ' ultimately unclear, and its development is a complex process that is often'

you are running offline_inference.py right? Do you change max_tokens(default value is 20) in SamplingParams? if no, seems previous fix influence the calculation result, and will generate eos rather than correct tokens.

@Yanli2190
Copy link

@jikunshang @madamczykhabana , does this issue only impact Qwen? we tried meta-llama/Llama-2-7b-chat-hf, meta-llama/Llama-2-70b-chat-hf(tp-4), and sbintuitions/sarashina2-70b(tp=4) based on latest habana_main commit: b62fba8, the model outupt quality isn't good is it same root cause? meta-llama/Llama-2-7b-chat-hf
Prompt: 'Hello, my name is', Generated text: ' Ethan,' Prompt: 'The president of the United States is', Generated text: ' a great and' Prompt: 'The capital of France is', Generated text: ' Paris, which' Prompt: 'The future of AI is', Generated text: ' in the hands'
meta-llama/Llama-2-70b-chat-hf
Prompt: 'Hello, my name is', Generated text: ' Margo' Prompt: 'The president of the United States is', Generated text: ' a powerful' Prompt: 'The capital of France is', Generated text: ' Paris,' Prompt: 'The future of AI is', Generated text: ' in the'
sbintuitions/sarashina2-70b
Prompt: 'Hello, my name is', Generated text: ' Noemi. I have been a Spanish teacher for 14 years. I' Prompt: 'The president of the United States is', Generated text: ' mandated to hold the title of Commander in Chief of the U.S' Prompt: 'The capital of France is', Generated text: ' Paris. The capital of Great Britain is London.' Prompt: 'The future of AI is', Generated text: ' ultimately unclear, and its development is a complex process that is often'

you are running offline_inference.py right? Do you change max_tokens(default value is 20) in SamplingParams? if no, seems previous fix influence the calculation result, and will generate eos rather than correct tokens.

Yes, I used offline_inference.py, I don't modify any sampling parameters, I only change the LLM model to the one I used and set TP parameters, not sure whether it's same issue or we need another issue to track?

@jikunshang
Copy link
Author

@jikunshang @madamczykhabana , does this issue only impact Qwen? we tried meta-llama/Llama-2-7b-chat-hf, meta-llama/Llama-2-70b-chat-hf(tp-4), and sbintuitions/sarashina2-70b(tp=4) based on latest habana_main commit: b62fba8, the model outupt quality isn't good is it same root cause? meta-llama/Llama-2-7b-chat-hf
Prompt: 'Hello, my name is', Generated text: ' Ethan,' Prompt: 'The president of the United States is', Generated text: ' a great and' Prompt: 'The capital of France is', Generated text: ' Paris, which' Prompt: 'The future of AI is', Generated text: ' in the hands'
meta-llama/Llama-2-70b-chat-hf
Prompt: 'Hello, my name is', Generated text: ' Margo' Prompt: 'The president of the United States is', Generated text: ' a powerful' Prompt: 'The capital of France is', Generated text: ' Paris,' Prompt: 'The future of AI is', Generated text: ' in the'
sbintuitions/sarashina2-70b
Prompt: 'Hello, my name is', Generated text: ' Noemi. I have been a Spanish teacher for 14 years. I' Prompt: 'The president of the United States is', Generated text: ' mandated to hold the title of Commander in Chief of the U.S' Prompt: 'The capital of France is', Generated text: ' Paris. The capital of Great Britain is London.' Prompt: 'The future of AI is', Generated text: ' ultimately unclear, and its development is a complex process that is often'

you are running offline_inference.py right? Do you change max_tokens(default value is 20) in SamplingParams? if no, seems previous fix influence the calculation result, and will generate eos rather than correct tokens.

Yes, I used offline_inference.py, I don't modify any sampling parameters, I only change the LLM model to the one I used and set TP parameters, not sure whether it's same issue or we need another issue to track?

I feel it's same issue.

@Yanli2190
Copy link

@jikunshang @madamczykhabana , does this issue only impact Qwen? we tried meta-llama/Llama-2-7b-chat-hf, meta-llama/Llama-2-70b-chat-hf(tp-4), and sbintuitions/sarashina2-70b(tp=4) based on latest habana_main commit: b62fba8, the model outupt quality isn't good is it same root cause? meta-llama/Llama-2-7b-chat-hf
Prompt: 'Hello, my name is', Generated text: ' Ethan,' Prompt: 'The president of the United States is', Generated text: ' a great and' Prompt: 'The capital of France is', Generated text: ' Paris, which' Prompt: 'The future of AI is', Generated text: ' in the hands'
meta-llama/Llama-2-70b-chat-hf
Prompt: 'Hello, my name is', Generated text: ' Margo' Prompt: 'The president of the United States is', Generated text: ' a powerful' Prompt: 'The capital of France is', Generated text: ' Paris,' Prompt: 'The future of AI is', Generated text: ' in the'
sbintuitions/sarashina2-70b
Prompt: 'Hello, my name is', Generated text: ' Noemi. I have been a Spanish teacher for 14 years. I' Prompt: 'The president of the United States is', Generated text: ' mandated to hold the title of Commander in Chief of the U.S' Prompt: 'The capital of France is', Generated text: ' Paris. The capital of Great Britain is London.' Prompt: 'The future of AI is', Generated text: ' ultimately unclear, and its development is a complex process that is often'

you are running offline_inference.py right? Do you change max_tokens(default value is 20) in SamplingParams? if no, seems previous fix influence the calculation result, and will generate eos rather than correct tokens.

Yes, I used offline_inference.py, I don't modify any sampling parameters, I only change the LLM model to the one I used and set TP parameters, not sure whether it's same issue or we need another issue to track?

I feel it's same issue.

Thanks a lot, then will have a check when we have a fix for this issue

@madamczykhabana
Copy link

I've pushed a fix candidate to dev/madamczyk/flat_pa_acc branch both to vllm-fork and vllm-hpu-extension repositories.

Unfortunately using scatter_reduce introduces a big performance hit so it's disabled by default. You can try it out by running with VLLM_PA_SOFTMAX_IMPL=scatter_reduce env flag.

Please check if it solves the accuracy issue you're seeing and at the same time we'll continue exploring internally how to speed it up.

@jikunshang
Copy link
Author

I've pushed a fix candidate to dev/madamczyk/flat_pa_acc branch both to vllm-fork and vllm-hpu-extension repositories.

Unfortunately using scatter_reduce introduces a big performance hit so it's disabled by default. You can try it out by running with VLLM_PA_SOFTMAX_IMPL=scatter_reduce env flag.

Please check if it solves the accuracy issue you're seeing and at the same time we'll continue exploring internally how to speed it up.

Thanks for your fix! I did a quick test, this works on Qwen2-7B model. how much performance regression did you observed? I will also do benchmark later.

@madamczykhabana
Copy link

On llama3.1-8b it's around 30% perf drop

@jikunshang
Copy link
Author

thanks, I also observed about 30% perf drop.

@madamczykhabana
Copy link

Hi @jikunshang . During brainstorming in my team we've found another possible solution. Please take another look at dev/madamczyk/flat_pa_acc branch. I've added another option to normalize by average of maximums (it's enabled by default so no need to set anything. Current behavior ca be re-enabled by running with VLLM_PA_SOFTMAX_IMPL=amax).
I don't have any performance data yet (waiting for the results).

@jikunshang
Copy link
Author

Hi @jikunshang . During brainstorming in my team we've found another possible solution. Please take another look at dev/madamczyk/flat_pa_acc branch. I've added another option to normalize by average of maximums (it's enabled by default so no need to set anything. Current behavior ca be re-enabled by running with VLLM_PA_SOFTMAX_IMPL=amax). I don't have any performance data yet (waiting for the results).

I take a quick try on this fix, result looks quite promising, will also test performance. Thanks a lot!

@jikunshang
Copy link
Author

jikunshang commented Oct 10, 2024

block_sum_attn.mul_(block_scales.reshape(-1, *[1 for _ in range(missing_dims)]))
for graph mode, is it conflict to for loop here? eg, in warm up stage, we have some dummy input, and it should generate a fixed graph, while in benchmark stage, missing dims are dynamic.

@madamczykhabana
Copy link

As far as I understand missing dims shouldn't be dynamic.
missing_dims = block_sum_attn.dim() - block_scales.dim()
(https://github.com/HabanaAI/vllm-hpu-extension/blob/7b7c867a25390af34a47ab2ced6a96337364d7ef/vllm_hpu_extension/ops.py#L56C5-L56C61)
block_sum_attn.dim() can have different values depending whether the model uses GQA/MHA or not but for a given model it should remain constant. block_scales are initialized as a vector and their dim should be always 1.

@czhu15
Copy link

czhu15 commented Oct 10, 2024

@madamczykhabana , we found some accuracy issue with your latest solution of wsum solution. It happens when the input length is big. E.g. the output of below 4 prompts were not correct (the first one has long input).
Prompt: 'The voltage gain of a transistor amplifier is a measure of how much the amplifier amplifies the input voltage to produce the output voltage. In the context of the SPICE netlist I provided, the voltage gain can be calculated by taking the ratio of the output voltage to the input voltage. However, voltage gain is not a meaningful concept when using the SPICE netlist alone, because the voltage gain depends on the specific values of the resistors and capacitors in the circuit, which I provided just as examples. To determine the voltage gain of the amplifier, you would need to run a simulation using the specific values of the resistors and capacitors that you are using in your circuit.\n\nTo find the voltage gain of the amplifier you need to run a DC or AC analysis, in the DC analysis you can check the output voltage over input voltage ratio and the voltage gain of the circuit will be approximately the same as the ratio of the collector resistance (Rc1 and Rc2) over the input resistance (Rin1 and Rin2), and in the AC analysis you can check the voltage gain by looking at the magnitude of the output voltage over the input voltage in the frequency domain.\n\nKeep in mind that for a differential amplifier like the one you described, the voltage gain is not a single value but rather a differential voltage gain. The voltage gain of a differential amplifier is the ratio of the differential output voltage to the differential input voltage.', Generated text: ' The!!!!!!!!!!!!!!!' Prompt: 'Create a NextJS page to upload a CSV, XLS, XLSX file; then process that file in a NextJS API route saving the contents to Supabase', Generated text: '.\n\n!!!!!!!!!!!!!!!' Prompt: 'What ways can I improve CRO', Generated text: ' for!!!!!!!!!!!!!!!' Prompt: 'Also ciite data for Tinyowl and urban ladder', Generated text: '\n\n!!!!!!!!!!!!!!!'
While after I replaced the long input prompt with a short prompt, the output looks good.
Prompt: 'Create a NextJS page to upload a CSV, XLS, XLSX file; then process that file in a NextJS API route saving the contents to Supabase', Generated text: '.\n\nTo create a NextJS page for uploading CSV, XLS, and X' Prompt: 'What ways can I improve CRO', Generated text: ' for my website?\nAs an AI language model, I can suggest the following ways' Prompt: 'Also ciite data for Tinyowl and urban ladder', Generated text: '\n\nSure, here are some data for TinyOwl and Urban Ladder:\n\n' Prompt: '司机说外卖慢了,我怎么说没关系,英语', Generated text: "怎么说?\n在英语中,你可以这样回答司机:“It's okay, the"
Can you kindly check if the solution missed the long input prompt case?
The model I verified is Qwen2-7B-Instruct.
Thanks,
Bob

@madamczykhabana
Copy link

madamczykhabana commented Oct 10, 2024

@czhu15 thanks for the failing test cases! I'll take a closer look what's going on there.

@Yanli2190
Copy link

@madamczykhabana will this fix merge to both habana_main and 1.18 release?

@madamczykhabana
Copy link

@czhu15 there was in fact a bug with handling padding blocks. I've pushed the fix to dev/madamczyk/flat_pa_acc . Here are my results with the latest code:

Prompt: 'The voltage gain of a transistor amplifier is a measure of how much the amplifier amplifies the input voltage to produce the output voltage. In the context of the SPICE netlist I provided, the voltage gain can be calculated by taking the ratio of the output voltage to the input voltage. However, voltage gain is not a meaningful concept when using the SPICE netlist alone, because the voltage gain depends on the specific values of the resistors and capacitors in the circuit, which I provided just as examples. To determine the voltage gain of the amplifier, you would need to run a simulation using the specific values of the resistors and capacitors that you are using in your circuit.\n\nTo find the voltage gain of the amplifier you need to run a DC or AC analysis, in the DC analysis you can check the output voltage over input voltage ratio and the voltage gain of the circuit will be approximately the same as the ratio of the collector resistance (Rc1 and Rc2) over the input resistance (Rin1 and Rin2), and in the AC analysis you can check the voltage gain by looking at the magnitude of the output voltage over the input voltage in the frequency domain.\n\nKeep in mind that for a differential amplifier like the one you described, the voltage gain is not a single value but rather a differential voltage gain. The voltage gain of a differential amplifier is the ratio of the differential output voltage to the differential input voltage.',
Generated text: ' The differential output voltage is the difference between the output voltage of the two transistors, and the differential input voltage is the difference between the input voltage of the two transistors.\n\nIn the SPICE netlist I provided, the voltage gain can be!'

Prompt: 'Create a NextJS page to upload a CSV, XLS, XLSX file; then process that file in a NextJS API route saving the contents to Supabase',
Generated text: '.\n\nTo create a NextJS page to upload a CSV, XLS, XLSX file, you can follow these steps:\n\n1. Install the necessary dependencies:\n\nbash\nnpm install next react react-dom supabase\n\n\n2. Create!'

Prompt: 'What ways can I improve CRO',
Generated text: ' for my website? I have a website that sells products. I have a lot of traffic but my conversion rate is low. I want to improve my conversion rate. What are the best ways to do this? There are several ways to improve conversion rate optimization! Here are some tips:\n\n1. Optimize your website for mobile devices: Make sure your website is mobile-friendly and easy to navigate on smaller screens.\n2. Use clear and concise headlines: Your headlines should be easy to read and understand, and should clearly communicate the benefits of your product.\n3. Use high-quality images: Use high-quality images that showcase your product in the best possible light.\n4. Use social proof: Use customer reviews, testimonials, and social media shares to build trust and credibility with your audience.\n5. Use a clear call-to-action: Make sure your call-to-action is clear and easy to find, and use action-oriented language to encourage visitors to take action.\n6. Use A/B testing: Test different versions of your website to see which one performs better in terms of conversion rate.\n7. Use retargeting: Use retargeting ads to show your product to visitors who have already shown interest in it.\n8. Use a landing page: Use a dedicated landing page for your product, with a clear call-to-action and a form for visitors to fill out.\n9. Use a chatbot: Use a chatbot to provide visitors with instant support and answer their questions.\n10. Use a pop-up: Use a pop-up to show visitors a special offer or discount code.\n\nThese are just a few ways to improve conversion rate optimization. There are many other techniques you can use, depending on your specific goals and audience.'

Prompt: 'Also ciite data for Tinyowl and urban ladder'
Generated text: '\nSure, here are the data for Tinyowl and Urban Ladder:\n\nTinyowl:\n\n* Founded in 2012\n* Raised $100 million in funding\n* Offered on-demand delivery of groceries, food, and other products!'

Please double check on your end.

@madamczykhabana
Copy link

@madamczykhabana will this fix merge to both habana_main and 1.18 release?

@Yanli2190 It's hard to tell. I'll discuss it with the rest of the team.

@madamczykhabana
Copy link

@Yanli2190 we're not planning to merge this fix to 1.18.1 (1.18.0 has already been released). It will be merged to habana_main (and thus automatically included in 1.19.0)

@czhu15
Copy link

czhu15 commented Oct 12, 2024

@madamczykhabana , verified that the latest change has fixed the long input sequence issue :)

@czhu15
Copy link

czhu15 commented Oct 14, 2024

@madamczykhabana , I back-ported your changes into v1.18.0 branch. It works correctly under Synapse 1.17.0 docker, but not correctly under Synapse 1.18.0 docker.
Below is the output from 1.17.0 docker:
Prompt: 'The voltage gain of a transistor amplifier is a measure of how much the amplifier amplifies the input voltage to produce the output voltage. In the context of the SPICE netlist I provided, the voltage gain can be calculated by taking the ratio of the output voltage to the input voltage. However, voltage gain is not a meaningful concept when using the SPICE netlist alone, because the voltage gain depends on the specific values of the resistors and capacitors in the circuit, which I provided just as examples. To determine the voltage gain of the amplifier, you would need to run a simulation using the specific values of the resistors and capacitors that you are using in your circuit.\n\nTo find the voltage gain of the amplifier you need to run a DC or AC analysis, in the DC analysis you can check the output voltage over input voltage ratio and the voltage gain of the circuit will be approximately the same as the ratio of the collector resistance (Rc1 and Rc2) over the input resistance (Rin1 and Rin2), and in the AC analysis you can check the voltage gain by looking at the magnitude of the output voltage over the input voltage in the frequency domain.\n\nKeep in mind that for a differential amplifier like the one you described, the voltage gain is not a single value but rather a differential voltage gain. The voltage gain of a differential amplifier is the ratio of the differential output voltage to the differential input voltage', Generated text: '. The differential voltage gain is usually denoted as Avd and is given by' Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government of the United States. The president is' Prompt: 'The capital of France is', Generated text: ' Paris. It is the most populous city in the European Union and the second most' Prompt: 'The future of AI is', Generated text: " here, and it's already changing the way we live and work. From self"

And Below is the output from 1.18.0 docker:
Prompt: 'The voltage gain of a transistor amplifier is a measure of how much the amplifier amplifies the input voltage to produce the output voltage. In the context of the SPICE netlist I provided, the voltage gain can be calculated by taking the ratio of the output voltage to the input voltage. However, voltage gain is not a meaningful concept when using the SPICE netlist alone, because the voltage gain depends on the specific values of the resistors and capacitors in the circuit, which I provided just as examples. To determine the voltage gain of the amplifier, you would need to run a simulation using the specific values of the resistors and capacitors that you are using in your circuit.\n\nTo find the voltage gain of the amplifier you need to run a DC or AC analysis, in the DC analysis you can check the output voltage over input voltage ratio and the voltage gain of the circuit will be approximately the same as the ratio of the collector resistance (Rc1 and Rc2) over the input resistance (Rin1 and Rin2), and in the AC analysis you can check the voltage gain by looking at the magnitude of the output voltage over the input voltage in the frequency domain.\n\nKeep in mind that for a differential amplifier like the one you described, the voltage gain is not a single value but rather a differential voltage gain. The voltage gain of a differential amplifier is the ratio of the differential output voltage to the differential input voltage', Generated text: '. In the case of the SPICE! netlist you! provided, the' Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government! The president is! The president is' Prompt: 'The capital of France is', Generated text: ' Paris. It is the most populous city! It is also! It is the' Prompt: 'The future of AI is', Generated text: " here, and it's already changing the! way we live! and work!"
You can see the generated text under 1.18.0 is not as good as that under 1.17.0. More exclamation marks and more duplicative.
Do you have any ideas on what could be the reason?

@madamczykhabana
Copy link

Unfortunately, as it turns out that there's still some issues with accuracy. I was able to reproduce it on 1.19, 1.18 and 1.17. In my case the issue starts reproducing when the difference between the average and the maximum values becomes large enough. I might have an idea on how to workaround it...

@czhu15
Copy link

czhu15 commented Oct 17, 2024

@madamczykhabana Did you latest commit (HabanaAI/vllm-hpu-extension@fd7f2e6#diff-c45331496d9b806755711b375270bd629261a029e414691d69401e5c297d0fd7R53) fixed this accuracy issue eventually or you are still working on it?

@madamczykhabana
Copy link

Unfortunately I'm still working on it. I've decided to push it and use it as the default as it helps in case of llama. There are still numerical issues in Qwen when we're crossing the block boundary.

@czhu15
Copy link

czhu15 commented Oct 17, 2024

Got it. Pls kindly let me know when you have progress. Good luck!

@madamczykhabana
Copy link

madamczykhabana commented Oct 17, 2024

I have a new idea that I'm testing. @czhu15 please check dev/madamczyk/softmax_options branch. By default it normalizes by weighted sum, followed by amax. It's not perfect, but it should be better. There might be a slight perf drop (around 5%) that might be improved.

@yangulei
Copy link

yangulei commented Oct 18, 2024

@madamczykhabana
HI Michal,
Thanks for your effort.
I tested your branch dev/madamczyk/softmax_options with Qwen2-7B-Instruct using offline_inference.py and benchmark_throughput.py with the sharegpt dataset. The results show that the accuracy issue remains for this model. Below is an example output for sharegpt:

{
  "Prompt": "Do you know the book Traction by Gino Wickman", 
  "Generated text": "?! I!LO!!VE!! IT!! I! Absolutely!!! T! r!! a! c!! o!! n!! ! b! y!! G!!! i!! n!!! o!! W!!!!!!!! Wick!!!man!!!!!!!!! is!!!!!!! one!!!!! !!!of!!!!!!!the!!!!!!!!!!!!!best!"
}

BTW: The solution with scatter_reduce did solve the accuracy issue for both offline_inference.py and benchmark_throughput.py with the sharegpt dataset.

@madamczykhabana
Copy link

@yangulei thanks for checking!
Please share the full command-line + options you're using. I'll double check on my end. I'm working on re-adding scatter_reduce option (alongside index_reduce).
I'll keep you posted.

@yangulei
Copy link

yangulei commented Oct 21, 2024

  1. apply the following patch to enable results saving:
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index b7bc2a64..8d65af40 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -135,8 +135,20 @@ def run_vllm(
 
     if not use_beam_search:
         start = time.perf_counter()
-        llm.generate(prompts, sampling_params, use_tqdm=True)
+        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
         end = time.perf_counter()
+        print("saving results ...")
+        records=[]
+        for output in outputs:
+            record={
+                "Prompt": output.prompt,
+                "Generated text": output.outputs[0].text
+            }
+            records.append(record)
+        with open('benchmark_throughput_results.json', 'w', encoding='utf-8') as file:
+            for record in records:
+                json_record = json.dumps(record)
+                file.write(json_record + '\n')
     else:
         prompts = [prompt for prompt, _, _ in requests]
         # output_len should be the same for all requests.
  1. run the following bash command:
#! /bin/bash

model_path=/models/Qwen2-7B-Instruct
dataset=/models/ShareGPT_V3_unfiltered_cleaned_split.json

model_name=$( echo $model_path | awk -F/ '{print $NF}' )
export VLLM_PA_SOFTMAX_IMPL=wsum,amax
export VLLM_SKIP_WARMUP=True
export TOKENIZERS_PARALLELISM=true
export VLLM_GRAPH_RESERVED_MEM=0.2
export VLLM_GRAPH_PROMPT_RATIO=0.8
python benchmark_throughput.py \
    --backend vllm \
    --model $model_path \
    --trust-remote-code \
    --tensor-parallel-size 1 \
    --dataset ${dataset} \
    --device hpu \
    --dtype bfloat16 \
    --seed 2024 \
    --max-num-batched-tokens 8192 \
    --max-model-len 4096 \
    --num-prompts 1000 \
    > benchmark_throughput_${model_name}_sharegpt.log 2>&1

@madamczykhabana
Copy link

@yangulei please check if the problem persists on current habana_main.
A new algorithm - wsum_head_amax was turned on by default. Additionally you can check with PA_SOFTMAX_IMPL=scatter_reduce as it was reintroduced.

I'm still not 100% convinced that accuracy is as good as it should. I'm currently running multiple benchmarks from lm_eval to compare accuracy results.

@czhu15
Copy link

czhu15 commented Oct 25, 2024

Hi @madamczykhabana , I verified the wsum_head_amax solution on shareGPT, and compared it to scatter_reduce. Basically the result is very close (but diverge at later sentence in most cases).
And if we take a close look at the outputs of some prompts, I would say in some cases scatter_reduce produced slightly better output.
Below is the output of prompt of "What is Hadoop".
wsum_head_amax:
"Generated text": " As the Hadoop digital revolution speeded through the industry beginning with MapReduce\u2019s data ingestion reliability, expanding through VMs, cloud offerings to storage technologies, storage and server costs plummeted exponentially, enabling analysis in real-time giving an analytical agility akin to the rocket research analysis advantages of digital computers replacing punch cards and analytical profanity instead of analogy interjection ad hoc assumptions\u2026then in fair hearing we politely dismiss poetic biases, open the intellectual drawbridge leaning on raw refreshment housed api calls AND test aggressively using peer-reviewed benchmarks. Reality stores relations using science not arbitrary privilege access controlled social customs computing resources stacked as nuts inside 3DES3AAD hashed inheritance locked in ice clause warbirg knot machines sweating short stack otherwise enlightened theologians behind flapping navels convinced things were good thoroughly.\n\nHomologating understanding governed dynamics, reflecting through inquiry intercourse across platforms cultivating relationships honest encryption enabled experimentation freed from currencies collapsing into fiat slaves, decentralized superposition awareness rewritten to interconnect paths quantum states transcend physical observation collapse, opening up metaphysical explorations into conscious computation merging talent quantum entanglement through c = hv modulation, consciousness fractal emergent dimensions thought experiments conducted in trions between photons accelerated epiphanies, questioning nature of information intrinsic to existence"
And output with scatter_reduce solution:
"Generated text": " As the name suggests, Hadoop (is an open source platform for processing & handling Big Data with ease. It supports complex distributed computing to extract insights, which helps organizations in its decision making process. Some of the major features of Hadoop:\n\n- Designed to process data in a resilient way to tolerate failures\n- Designed for an inexpensive infrastructure by distributing the processing on local commodity devices to achieve scalability\n- Process data in parallel as they can handle millions of inputs, independently process them & collect their output in a structured way\n\nHadoop is an open source cloud computing framework, which uses a master-slave computing architecture. It can process data in a distributed & parallel computing way, through a cluster of computers by breaking down the units of data into blocks and distribute them to processing nodes to carry out the operation in a parallel mode. This helps to optimize the operation time by loading the data in parallel into several slaves/worker node in the cluster to carry out the computation.\n\nHadoop always maintains the two copies of the data block in the cluster, which helps to avoid loss of data in the event of failure of one of the nodes ( which is known as replication. This factor significantly increases the fault tolerance of the platform. \n\nApache H"

@czhu15
Copy link

czhu15 commented Oct 25, 2024

below are two files that store the output of shareGPT with different algorithms.
benchmark_throughput_results-wsum_head_amax.json
benchmark_throughput_results-scatter_reduce.json

@czhu15
Copy link

czhu15 commented Oct 25, 2024

@madamczykhabana , can you kindly share the lm_eval scores once you get?

@ccrhx4
Copy link

ccrhx4 commented Oct 29, 2024

Hi, on llama3-8b, I found that the quality of generated text for float32 is much better than that on bfloat16. I have attached my code in the end,

CUDA BF16: Prompt: 'The capital of France is', Generated text: ' Paris. It is located in the north of the country. Paris is the largest city in France and the second largest city in the European Union. It is also the most visited city in the world. Paris is a city of art, culture, and history. It is home to some of the most famous landmarks in the world, such as the Eiffel Tower, the Louvre Museum, and the Notre Dame Cathedral. Paris is also a city of fashion, with many famous designers and brands headquartered there. The city is also home to some of the best restaurants in the world. Paris is a city that is loved by people all over'

HPU BF16: Prompt: 'The capital of France is', Generated text: ' Paris, which is located in the north of the country. The city is located on the Seine River and is the largest city in France. Paris is a major tourist destination and is home to many famous landmarks, including the Eiffel Tower, the Louvre Museum, and the Notre Dame Cathedral. The city is also known for its fashion, food, and art.\nThe capital of France is Paris, which is located in the north of the country. The city is located on the Seine River and is the largest city in France. Paris is a major tourist destination and is home to many famous landmarks, including the Eiffel'

HPU FP32: Prompt: 'The capital of France is', Generated text: ' Paris. It is located in the north of the country. Paris is the largest city in France and the second largest city in the European Union. It is also the most visited city in the world. Paris is a city of art, culture, and history. It is home to some of the most famous landmarks in the world, such as the Eiffel Tower, the Louvre Museum, and the Notre Dame Cathedral. Paris is also a city of fashion, with many famous designers and brands headquartered there. The city is also known for its food, with many Michelin-starred restaurants and cafes. Paris is a city that is loved'

And with PT_HPU_MAX_COMPOUND_OP_SIZE=1 it generated result is trash:

Prompt: 'The future of AI is', Generated text: ' in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe'

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
    "can you create python programs?",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=128, temperature=0.0)

# Create an LLM.
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

@madamczykhabana
Copy link

Hi @ccrhx4 . Thanks for the info! Could you please retry your experiment with VLLM_PA_SOFTMAX_IMPL=scatter_reduce? I'd like to confirm whether it's a different issue or it's related to softmax normalization.

@ccrhx4
Copy link

ccrhx4 commented Oct 29, 2024

@madamczykhabana Hi Michal, the result is the same.

REDUCE_SCATTER: Prompt: 'The capital of France is', Generated text: ' Paris, which is located in the north of the country. The city is located on the Seine River and is the largest city in France. Paris is a major tourist destination and is home to many famous landmarks, including the Eiffel Tower, the Louvre Museum, and the Notre Dame Cathedral. The city is also known for its fashion, food, and art.\nThe capital of France is Paris, which is located in the north of the country. The city is located on the Seine River and is the largest city in France. Paris is a major tourist destination and is home to many famous landmarks, including the Eiffel'

@madamczykhabana
Copy link

Then most likely it's a different root cause. Please create a new issue for it. We'll take a look at it.

@ccrhx4
Copy link

ccrhx4 commented Oct 29, 2024

Hi Michal, I have created the new issue. #443

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants