diff --git a/models/demos/whisper/README.md b/models/demos/whisper/README.md new file mode 100644 index 00000000000..a1fe94ee1f8 --- /dev/null +++ b/models/demos/whisper/README.md @@ -0,0 +1,45 @@ +# Whisper Demo + +Demo showcasing Whisper running on Grayskull - e150 and Wormhole - n150, n300 using ttnn. + +## Introduction + +Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets. + +## Details + +The entry point to whisper model is `whisper` in `models/demos/whisper/tt/ttnn_optimized_functional_whisper.py` for optimized version.. The model picks up certain configs and weights from huggingface pretrained model. We have used openai/whisper-base version from huggingface as our reference. + +### Max Tokens: 32 + +Max Tokens determines the maximum number of input tokens processed by the model in a single pass durig transcription, optimizing performance and compatibility. It's recommended to set the max_tokens to 32 + +### Batch size: 8 + +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 8 + +## How to Run + +### Whisper For Audio Classification +Use `pytest --disable-warnings models/demos/whisper/demo/demo.py::test_demo_for_audio_classification[models.demos.whisper.tt.ttnn_optimized_functional_whisper-1-8-WHISPER_MEMORY_CONFIG0-sanchit-gandhi/whisper-medium-fleurs-lang-id-models/demos/whisper/demo/dataset/audio_classification]` to run the ttnn optimized functional whisper demo for audio classification. + +#### Our another demo is designed to run with `google/fleurs` for Audio classification + +Use `pytest --disable-warnings models/demos/whisper/demo/demo.py::test_demo_for_audio_classification_dataset` to run audio classification demo with dataset inputs. + +### Whisper For Conditional Generation + +Use `pytest --disable-warnings models/demos/whisper/demo/demo.py::test_demo_for_conditional_generation[models.demos.whisper.tt.ttnn_optimized_functional_whisper-8-32-WHISPER_MEMORY_CONFIG0-openai/whisper-tiny.en-models/demos/whisper/demo/dataset/conditional_generation-device_params0]` to run the ttnn optimized functional whisper demo for conditional generation. + +#### Our another demo is designed to run with `hf-internal-testing/librispeech_asr_dummy` for Conditional generation + +Use `pytest --disable-warnings models/demos/whisper/demo/demo.py::test_demo_for_conditional_generation_dataset` to run conditional generation demo with dataset inputs. + + +## Inputs + +Inputs by default are provided from `dataset/audio_classification` and `dataset/conditional_generation` folder. If you wish to change the inputs, provide a different path to demo. + +For demo with dataset, Inputs for Audio classification is taken from `google/fleurs` dataset and Inputs for Conditional generation is taken from `hf-internal-testing/librispeech_asr_dummy` dataset. + +### Owner: [kkeerthana0573](https://github.com/kkeerthana0573) diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/10116516891483200485.wav b/models/demos/whisper/demo/dataset/audio_classification/10116516891483200485.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/10116516891483200485.wav rename to models/demos/whisper/demo/dataset/audio_classification/10116516891483200485.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/140291826269534354.wav b/models/demos/whisper/demo/dataset/audio_classification/140291826269534354.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/140291826269534354.wav rename to models/demos/whisper/demo/dataset/audio_classification/140291826269534354.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/1689242038473278354.wav b/models/demos/whisper/demo/dataset/audio_classification/1689242038473278354.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/1689242038473278354.wav rename to models/demos/whisper/demo/dataset/audio_classification/1689242038473278354.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/17340315164505628698.wav b/models/demos/whisper/demo/dataset/audio_classification/17340315164505628698.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/17340315164505628698.wav rename to models/demos/whisper/demo/dataset/audio_classification/17340315164505628698.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/17659141715436566244.wav b/models/demos/whisper/demo/dataset/audio_classification/17659141715436566244.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/17659141715436566244.wav rename to models/demos/whisper/demo/dataset/audio_classification/17659141715436566244.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/17928171511082320095.wav b/models/demos/whisper/demo/dataset/audio_classification/17928171511082320095.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/17928171511082320095.wav rename to models/demos/whisper/demo/dataset/audio_classification/17928171511082320095.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/2086639904747050008.wav b/models/demos/whisper/demo/dataset/audio_classification/2086639904747050008.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/2086639904747050008.wav rename to models/demos/whisper/demo/dataset/audio_classification/2086639904747050008.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/622196158886216764.wav b/models/demos/whisper/demo/dataset/audio_classification/622196158886216764.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/622196158886216764.wav rename to models/demos/whisper/demo/dataset/audio_classification/622196158886216764.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/7043619860143829064.wav b/models/demos/whisper/demo/dataset/audio_classification/7043619860143829064.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/7043619860143829064.wav rename to models/demos/whisper/demo/dataset/audio_classification/7043619860143829064.wav diff --git a/models/experimental/functional_whisper/demo/dataset/audio_classification/9522084197299278725.wav b/models/demos/whisper/demo/dataset/audio_classification/9522084197299278725.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/audio_classification/9522084197299278725.wav rename to models/demos/whisper/demo/dataset/audio_classification/9522084197299278725.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/11150113890463037787.wav b/models/demos/whisper/demo/dataset/conditional_generation/11150113890463037787.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/11150113890463037787.wav rename to models/demos/whisper/demo/dataset/conditional_generation/11150113890463037787.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/1298409023920250606.wav b/models/demos/whisper/demo/dataset/conditional_generation/1298409023920250606.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/1298409023920250606.wav rename to models/demos/whisper/demo/dataset/conditional_generation/1298409023920250606.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17566024285835266239.wav b/models/demos/whisper/demo/dataset/conditional_generation/17566024285835266239.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17566024285835266239.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17566024285835266239.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17646385371758249908.wav b/models/demos/whisper/demo/dataset/conditional_generation/17646385371758249908.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17646385371758249908.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17646385371758249908.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17659141715436566244.wav b/models/demos/whisper/demo/dataset/conditional_generation/17659141715436566244.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17659141715436566244.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17659141715436566244.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17928171511082320095.wav b/models/demos/whisper/demo/dataset/conditional_generation/17928171511082320095.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17928171511082320095.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17928171511082320095.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/17938133003986293739.wav b/models/demos/whisper/demo/dataset/conditional_generation/17938133003986293739.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/17938133003986293739.wav rename to models/demos/whisper/demo/dataset/conditional_generation/17938133003986293739.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/2842775607363710885.wav b/models/demos/whisper/demo/dataset/conditional_generation/2842775607363710885.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/2842775607363710885.wav rename to models/demos/whisper/demo/dataset/conditional_generation/2842775607363710885.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/6757317816154782558.wav b/models/demos/whisper/demo/dataset/conditional_generation/6757317816154782558.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/6757317816154782558.wav rename to models/demos/whisper/demo/dataset/conditional_generation/6757317816154782558.wav diff --git a/models/experimental/functional_whisper/demo/dataset/conditional_generation/6969469525741631060.wav b/models/demos/whisper/demo/dataset/conditional_generation/6969469525741631060.wav similarity index 100% rename from models/experimental/functional_whisper/demo/dataset/conditional_generation/6969469525741631060.wav rename to models/demos/whisper/demo/dataset/conditional_generation/6969469525741631060.wav diff --git a/models/demos/whisper/demo/demo.py b/models/demos/whisper/demo/demo.py new file mode 100644 index 00000000000..37bd6182357 --- /dev/null +++ b/models/demos/whisper/demo/demo.py @@ -0,0 +1,543 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from datasets import load_dataset +from loguru import logger +from scipy.io import wavfile +import ttnn +from transformers import ( + AutoFeatureExtractor, + WhisperModel, + WhisperConfig, + AutoProcessor, + WhisperForConditionalGeneration, +) +from models.utility_functions import ( + disable_compilation_reports, + disable_persistent_kernel_cache, + profiler, +) +from models.demos.whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper +from models.generation_utils import get_logits_processor, pad_input_32 +from ttnn.model_preprocessing import preprocess_model_parameters + +import torch +import os +from os import listdir +from os.path import isfile, join + +from transformers import AutoFeatureExtractor, WhisperForAudioClassification +from datasets import load_dataset +from torchmetrics.text import WordErrorRate +from sklearn.metrics import accuracy_score + + +def load_input_paths(folder_path): + files = [os.path.join(folder_path, f) for f in listdir(folder_path) if isfile(join(folder_path, f))] + return files + + +def run_generate( + config, + input_embeds, + input_features, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask, + parameters, + ttnn_linear_weight, + device, + decoder_input_ids, + generation_config, + batch_size, + max_tokens, + whisper_memory_config, +): + logits_processor = get_logits_processor(decoder_input_ids, config) + decoder_start_values = generation_config.pad_token_id * torch.ones(batch_size, input_features.shape[1]).to( + torch.long + ) + eos_reached = torch.zeros(batch_size, dtype=torch.bool) + + profiler.start(f"inference_time") + for i in range(max_tokens): + ttnn_output = ttnn_model.whisper_for_conditional_generation( + config=config, + input_embeds=input_embeds, + decoder_hidden_states=decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + device=device, + ttnn_linear_weight=ttnn_linear_weight, + whisper_memory_config=whisper_memory_config, + ) + ttnn_output = ttnn.from_device(ttnn_output) + logits_to_torch = ttnn.to_torch(ttnn_output) + next_token_logits = logits_to_torch[:, i, :] + next_tokens_scores = logits_processor(input_features, next_token_logits) + next_tokens = torch.argmax(next_tokens_scores, dim=-1).unsqueeze(0) + + # Check if EOS token is generated for any sample in the batch and + # Setting subsequent next_tokens to config.pad_token_id if EOS token is reached. + eos_generated_flags = next_tokens == config.eos_token_id + eos_reached = eos_reached | eos_generated_flags.squeeze(0) + next_tokens[:, eos_reached] = config.pad_token_id + + if (i + 1) % 32 == 0: + decoder_input_ids = torch.cat([decoder_input_ids, decoder_start_values], dim=1) + + decoder_input_ids[:, i + 1] = next_tokens[:, None] + decoder_hidden_states, decoder_attention_mask = ttnn_model.preprocess_decoder_inputs( + config=config, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=parameters.decoder, + device=device, + ) + + if torch.all(next_tokens == config.eos_token_id): + break + + profiler.end(f"inference_time") + return decoder_input_ids + + +def run_demo_functional_whisper_for_audio_classification_inference( + device, model_name, input_path, ttnn_model, num_inputs, batch_size, whisper_memory_config +): + feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") + + model.eval() + input_data = load_input_paths(input_path) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=device, + ) + if len(input_data) < batch_size: + assert False, "batch_size exceeds number of audio files available in folder" + + batched_inputs = [] + for i in range(batch_size): + input_file_path = input_data[i] + samplerate, data = wavfile.read(input_file_path) + + inputs = feature_extractor( + data, + sampling_rate=samplerate, + return_tensors="pt", + ) + + input_features = inputs.input_features + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + config = model.config + input_embedding = ttnn_model.preprocess_encoder_inputs( + input_features=batched_inputs, + parameters=parameters.encoder, + device=device, + whisper_memory_config=whisper_memory_config, + ) + + out_logits = ttnn_model.whisper_for_audio_classification( + config=config, + inputs_embeds=input_embedding, + parameters=parameters, + device=device, + batch_size=batch_size, + whisper_memory_config=whisper_memory_config, + ) + + logits_torch = ttnn.to_torch(out_logits) + predicted_list = [] + for i in range(batch_size): + single_logits_torch = logits_torch[i].squeeze(0) + predicted_class_ids = torch.argmax(single_logits_torch).item() + predicted_label = model.config.id2label[predicted_class_ids] + logger.info(f"predicted_label: {predicted_label}") + predicted_list.append(predicted_label) + + return predicted_list + + +def run_demo_functional_whisper_for_conditional_generation_inference( + device, + reset_seeds, + batch_size, + model_name, + input_path, + ttnn_model, + max_tokens=32, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, +): + model = WhisperModel.from_pretrained(model_name).eval() + config = WhisperConfig.from_pretrained(model_name) + processor = AutoProcessor.from_pretrained(model_name, language="English", task="transcribe") + hf_reference_model = WhisperForConditionalGeneration.from_pretrained(model_name) + + linear_weight = hf_reference_model.proj_out.weight + ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + ttnn_linear_weight = ttnn.permute(ttnn_linear_weight, (1, 0)) + ttnn_linear_weight = ttnn.to_layout(ttnn_linear_weight, layout=ttnn.TILE_LAYOUT) + + feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) + input_data = load_input_paths(input_path) + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=device, + ) + + if len(input_data) < batch_size: + assert False, "batch_size exceeds number of audio files available in folder" + + for i in range(batch_size): + input_file_path = input_data[i] + samplerate, data = wavfile.read(input_file_path) + inputs = feature_extractor(data, sampling_rate=samplerate, return_tensors="pt") + dtype_to_use = torch.bfloat16 + input_features = inputs.input_features.type(dtype_to_use) + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) + + profiler.start(f"preprocessing_inputs") + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=parameters, + device=device, + whisper_memory_config=whisper_memory_config, + ) + profiler.end(f"preprocessing_inputs") + + generation_config = hf_reference_model.generation_config + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + whisper_memory_config=whisper_memory_config, + ) + + profiler.start(f"post_processing_output_to_string") + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) + profiler.end(f"post_processing_output_to_string") + + logger.info("Model Output") + logger.info(ttnn_transcription) + + measurements = { + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements, ttnn_transcription + + +def run_demo_functional_whisper_for_audio_classification_dataset( + device, reset_seeds, model_name, ttnn_model, batch_size, n_iterations, whisper_memory_config +): + feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) + model = WhisperForAudioClassification.from_pretrained(model_name) + + model.eval() + ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) + sample = iter(ds) + + reference_labels = [] + predicted_labels = [] + config = model.config + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=device, + ) + + for _ in range(n_iterations): + batch_input = [] + for i in range(batch_size): + s = next(sample) + inputs = feature_extractor(s["audio"]["array"], sampling_rate=16000, return_tensors="pt") + input_features = inputs.input_features.type(torch.bfloat16) + batch_input = input_features if i == 0 else torch.cat((batch_input, input_features), dim=0) + reference_labels.append(s["language"]) + + input_embedding = ttnn_model.preprocess_encoder_inputs( + input_features=batch_input, + parameters=parameters.encoder, + device=device, + whisper_memory_config=whisper_memory_config, + ) + + out_logits = ttnn_model.whisper_for_audio_classification( + config=config, + inputs_embeds=input_embedding, + parameters=parameters, + device=device, + batch_size=batch_size, + whisper_memory_config=whisper_memory_config, + ) + logits_torch = ttnn.to_torch(out_logits) + + for i in range(batch_size): + single_logits_torch = logits_torch[i].squeeze(0) + predicted_class_ids = torch.argmax(single_logits_torch).item() + predicted_label = model.config.id2label[predicted_class_ids] + predicted_labels.append(predicted_label) + + accuracy = accuracy_score(reference_labels, predicted_labels) + logger.info(f"Reference labels: {reference_labels}") + logger.info(f"Predicted labels: {predicted_labels}") + logger.info(f"Accuracy: {accuracy}") + return accuracy + + +def run_demo_functional_whisper_for_conditional_generation_dataset( + device, + reset_seeds, + model_name, + ttnn_model, + batch_size=1, + n_iterations=1, + max_tokens=32, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, +): + model = WhisperModel.from_pretrained(model_name).eval() + config = WhisperConfig.from_pretrained(model_name) + processor = AutoProcessor.from_pretrained(model_name, language="English", task="transcribe") + hf_reference_model = WhisperForConditionalGeneration.from_pretrained(model_name) + + linear_weight = hf_reference_model.proj_out.weight + ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + ttnn_linear_weight = ttnn.permute(ttnn_linear_weight, (1, 0)) + ttnn_linear_weight = ttnn.to_layout(ttnn_linear_weight, layout=ttnn.TILE_LAYOUT) + + feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = iter(ds) + batched_ground_truth_transcriptions = [] + + for _ in range(n_iterations): + for i in range(batch_size): + s = next(sample) + inputs = feature_extractor(s["audio"]["array"], sampling_rate=16000, return_tensors="pt") + ground_truth_transcriptions = s["text"] + dtype_to_use = torch.bfloat16 + input_features = inputs.input_features.type(dtype_to_use) + + batched_inputs = input_features if i == 0 else torch.cat((batched_inputs, input_features), dim=0) + + decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id + decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) + batched_decoder_input_ids = ( + decoder_input_ids if i == 0 else torch.cat((batched_decoder_input_ids, decoder_input_ids), dim=0) + ) + + batched_ground_truth_transcriptions.append(ground_truth_transcriptions) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.eval(), + convert_to_ttnn=ttnn_model.convert_to_ttnn, + custom_preprocessor=ttnn_model.custom_preprocessor, + device=device, + ) + + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=batched_inputs, + input_ids=batched_decoder_input_ids, + attention_mask=None, + parameters=parameters, + device=device, + whisper_memory_config=whisper_memory_config, + ) + + ttnn_output = run_generate( + config, + input_embeds, + batched_inputs, + ttnn_model, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + ttnn_linear_weight=ttnn_linear_weight, + device=device, + decoder_input_ids=batched_decoder_input_ids, + generation_config=hf_reference_model.generation_config, + batch_size=batch_size, + max_tokens=max_tokens, + whisper_memory_config=whisper_memory_config, + ) + ttnn_transcription = processor.batch_decode(ttnn_output, skip_special_tokens=True) + + logger.info("Model Output") + logger.info(ttnn_transcription) + + wer = WordErrorRate() + wer_scores = [] + for transcription, ground_truth in zip(ttnn_transcription, batched_ground_truth_transcriptions): + transcription = transcription.upper() + individual_wer_score = wer([transcription], [ground_truth]) + wer_scores.append(individual_wer_score) + logger.info(f"Individual Sample WER score: {individual_wer_score}") + + average_wer_score = sum(wer_scores) / len(wer_scores) + logger.info(f"Average WER score: {average_wer_score}") + accuracy = 1 - average_wer_score + logger.info(f"Accuracy: {accuracy}") + + return average_wer_score + + +@pytest.mark.parametrize( + "model_name, input_loc", + ((["sanchit-gandhi/whisper-medium-fleurs-lang-id", "models/demos/whisper/demo/dataset/audio_classification"]),), +) +@pytest.mark.parametrize( + ("ttnn_model", "num_inputs", "batch_size", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 1, 8, ttnn.DRAM_MEMORY_CONFIG),), +) +def test_demo_for_audio_classification( + device, + reset_seeds, + use_program_cache, + model_name, + input_loc, + ttnn_model, + num_inputs, + batch_size, + WHISPER_MEMORY_CONFIG, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + return run_demo_functional_whisper_for_audio_classification_inference( + device, + model_name=model_name, + input_path=input_loc, + ttnn_model=ttnn_model, + num_inputs=num_inputs, + batch_size=batch_size, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "model_name, input_loc", + ((["openai/whisper-tiny.en", "models/demos/whisper/demo/dataset/conditional_generation"]),), +) +@pytest.mark.parametrize( + ("ttnn_model", "batch_size", "max_tokens", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 8, 32, ttnn.L1_MEMORY_CONFIG),), +) +def test_demo_for_conditional_generation( + device, + reset_seeds, + use_program_cache, + model_name, + input_loc, + ttnn_model, + batch_size, + max_tokens, + WHISPER_MEMORY_CONFIG, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + return run_demo_functional_whisper_for_conditional_generation_inference( + device, + reset_seeds, + batch_size=batch_size, + model_name=model_name, + input_path=input_loc, + ttnn_model=ttnn_model, + max_tokens=max_tokens, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "model_name", + (["sanchit-gandhi/whisper-medium-fleurs-lang-id"]), +) +@pytest.mark.parametrize( + ("ttnn_model", "batch_size", "n_iterations", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 8, 1, ttnn.DRAM_MEMORY_CONFIG),), +) +def test_demo_for_audio_classification_dataset( + device, reset_seeds, use_program_cache, model_name, ttnn_model, batch_size, n_iterations, WHISPER_MEMORY_CONFIG +): + disable_persistent_kernel_cache() + disable_compilation_reports() + return run_demo_functional_whisper_for_audio_classification_dataset( + device, + reset_seeds, + model_name=model_name, + ttnn_model=ttnn_model, + batch_size=batch_size, + n_iterations=n_iterations, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "model_name", + (["openai/whisper-tiny.en"]), +) +@pytest.mark.parametrize( + ("ttnn_model", "batch_size", "n_iterations", "max_tokens", "WHISPER_MEMORY_CONFIG"), + ((ttnn_optimized_functional_whisper, 8, 1, 32, ttnn.L1_MEMORY_CONFIG),), +) +def test_demo_for_conditional_generation_dataset( + device, + reset_seeds, + use_program_cache, + model_name, + ttnn_model, + batch_size, + n_iterations, + max_tokens, + WHISPER_MEMORY_CONFIG, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + return run_demo_functional_whisper_for_conditional_generation_dataset( + device, + reset_seeds, + model_name=model_name, + ttnn_model=ttnn_model, + batch_size=batch_size, + n_iterations=n_iterations, + max_tokens=max_tokens, + whisper_memory_config=WHISPER_MEMORY_CONFIG, + ) diff --git a/models/experimental/functional_whisper/reference/torch_baseline_whisper.py b/models/demos/whisper/reference/torch_baseline_whisper.py similarity index 100% rename from models/experimental/functional_whisper/reference/torch_baseline_whisper.py rename to models/demos/whisper/reference/torch_baseline_whisper.py diff --git a/models/experimental/functional_whisper/reference/torch_functional_whisper.py b/models/demos/whisper/reference/torch_functional_whisper.py similarity index 100% rename from models/experimental/functional_whisper/reference/torch_functional_whisper.py rename to models/demos/whisper/reference/torch_functional_whisper.py diff --git a/models/demos/whisper/tests/test_perf_device_whisper.py b/models/demos/whisper/tests/test_perf_device_whisper.py new file mode 100644 index 00000000000..f8b812fdc32 --- /dev/null +++ b/models/demos/whisper/tests/test_perf_device_whisper.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from models.utility_functions import is_grayskull +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + + +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("batch_size", [8]) +def test_perf_device_bare_metal(device, batch_size, reset_seeds): + subdir = "ttnn_whisper_optimized_" + margin = 0.03 + num_iterations = 1 + + expected_perf = 13.38 if is_grayskull else 3.06 + command = ( + f"pytest tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py::test_ttnn_whisper" + ) + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_optimized_whisper_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/models/demos/whisper/tests/test_performance.py similarity index 53% rename from tests/ttnn/integration_tests/whisper/test_performance.py rename to models/demos/whisper/tests/test_performance.py index 288afa78719..c05c8e6194d 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/models/demos/whisper/tests/test_performance.py @@ -2,56 +2,47 @@ # SPDX-License-Identifier: Apache-2.0 -import pytest -from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper -from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig -from datasets import load_dataset +import time +import ttnn import torch -from ttnn.model_preprocessing import preprocess_model_parameters +import pytest +import transformers from loguru import logger -from models.utility_functions import is_wormhole_b0, is_blackhole +from datasets import load_dataset + +from models.utility_functions import is_grayskull from models.perf.perf_utils import prep_perf_report -import time -import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.whisper.tt import ttnn_optimized_functional_whisper +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report def get_expected_times(functional_whisper): return { - ttnn_functional_whisper: (11.7, 4.16), - ttnn_optimized_functional_whisper: (1.5, 1.35), + ttnn_optimized_functional_whisper: (10.47, 5.7) if is_grayskull else (15.84, 9.5), }[functional_whisper] -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Not tested on single WH") @pytest.mark.models_performance_bare_metal @pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("model_name", ["openai/whisper-base"]) -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize("sequence_size", [500]) -@pytest.mark.parametrize("functional_whisper", [ttnn_functional_whisper, ttnn_optimized_functional_whisper]) -def test_performance(device, use_program_cache, model_name, batch_size, sequence_size, functional_whisper): - config = WhisperConfig.from_pretrained(model_name) - +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("functional_whisper", [ttnn_optimized_functional_whisper]) +def test_performance(reset_seeds, use_program_cache, device, model_name, batch_size, functional_whisper): # Run TT Model - if functional_whisper == ttnn_functional_whisper: - tt_model_name = f"ttnn_{model_name}" - elif functional_whisper == ttnn_optimized_functional_whisper: - tt_model_name = f"ttnn_{model_name}_optimized" - else: - raise ValueError(f"Unknown functional_t5: {functional_whisper}") - - config = WhisperConfig.from_pretrained(model_name) - feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) + tt_model_name = f"ttnn_{model_name}_optimized" + config = transformers.WhisperConfig.from_pretrained(model_name) + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") + inputs = feature_extractor( + [ds[i]["audio"]["array"] for i in range(batch_size)], sampling_rate=16000, return_tensors="pt" + ) input_features = inputs.input_features - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - - attention_mask = None - - parameters = preprocess_model_parameters( - model_name=tt_model_name, - initialize_model=lambda: WhisperModel.from_pretrained(model_name).eval(), + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id + model = transformers.WhisperModel.from_pretrained(model_name).eval() + ttnn_parameters = preprocess_model_parameters( + initialize_model=lambda: model, convert_to_ttnn=functional_whisper.convert_to_ttnn, custom_preprocessor=functional_whisper.custom_preprocessor, device=device, @@ -63,27 +54,27 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence config=config, input_features=input_features, input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, + attention_mask=None, + parameters=ttnn_parameters, device=device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) - start = time.time() - tt_output = functional_whisper.whisper( + last_hidden_state = functional_whisper.whisper( config, + device, input_embeds, decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, - parameters=parameters, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) - tt_output = ttnn.to_torch(tt_output) + tt_output = ttnn.to_torch(last_hidden_state) end = time.time() - duration = end - start durations.append(duration) inference_and_compile_time, inference_time, *_ = durations - expected_compile_time, expected_inference_time = get_expected_times(functional_whisper) prep_perf_report( model_name=tt_model_name, @@ -95,7 +86,10 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence comments="", inference_time_cpu=0.0, ) - logger.info(f"Compile time: {inference_and_compile_time - inference_time}") logger.info(f"Inference time: {inference_time}") logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" diff --git a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py b/models/demos/whisper/tt/ttnn_functional_whisper.py similarity index 95% rename from models/experimental/functional_whisper/tt/ttnn_functional_whisper.py rename to models/demos/whisper/tt/ttnn_functional_whisper.py index 30b5fa712cf..8f1f2cb837d 100644 --- a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py +++ b/models/demos/whisper/tt/ttnn_functional_whisper.py @@ -198,8 +198,11 @@ def encoder_layer(config, hidden_states, *, parameters): return hidden_states -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight +def encoder(config, inputs_embeds, *, parameters, device): + weights = ttnn.to_torch(parameters.embed_positions.weight) + inputs_embeds = ttnn.to_torch(inputs_embeds) + hidden_states = torch.add(inputs_embeds, weights) + hidden_states = ttnn.from_torch(hidden_states, device=device, layout=ttnn.TILE_LAYOUT) hidden_states = dropout(hidden_states, p=0, training=False) for encoder_layer_parameter in parameters.layers: @@ -399,8 +402,8 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) +def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters, device): + encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder, device=device) last_hidden_state = decoder( config, decoder_hidden_states, @@ -411,6 +414,25 @@ def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attent return last_hidden_state +def whisper_for_conditional_generation( + config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters, device, ttnn_linear_weight +): + output = whisper( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + device=device, + ) + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): diff --git a/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py b/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py new file mode 100644 index 00000000000..172668664b9 --- /dev/null +++ b/models/demos/whisper/tt/ttnn_optimized_functional_whisper.py @@ -0,0 +1,672 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import transformers +from typing import Optional +from torch.nn import functional as F +from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias + + +WHISPER_DTYPE = ttnn.bfloat8_b + + +def dropout(hidden_states, p, training): + # ignored for inference + return hidden_states + + +# The split_query_key_value_and_split_heads requires the query to have the same volume as the key and values +# This is not the case however for whisper so we currently cannot swap out calculate_key_values below +# def calculate_key_values(config, query_states, key_value_states, *, parameters): +# fused_kv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias +# head_size = config.d_model // config.encoder_attention_heads +# batch_size, *_, _, two_times_hidden_size = fused_kv.shape.with_tile_padding() +# hidden_size = two_times_hidden_size // 2 +# encoder_attention_heads = hidden_size // head_size +# query_states, key_states, value_states = ttnn.transformer.split_query_key_value_and_split_heads( +# query_states, +# kv_input_tensor=fused_kv, +# num_heads=encoder_attention_heads, +# memory_config=WHISPER_MEMORY_CONFIG, +# ) +# key_states = ttnn.permute(key_states, (0, 1, 3, 2)) +# return query_states, key_states, value_states + + +def calculate_key_values(config, key_value_states, *, parameters, whisper_memory_config): + bsz, tgt_len, hidden_size = key_value_states.shape + bsz, tgt_len_padded, _ = key_value_states.shape.with_tile_padding() + head_size = hidden_size // config.encoder_attention_heads + + fused_qkv = ttnn.linear( + key_value_states, + parameters.weight, + bias=parameters.bias, + memory_config=whisper_memory_config, + ) + dtype = fused_qkv.dtype + device = fused_qkv.device() + + # fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT) + # fused_qkv = ttnn.from_device(fused_qkv) + # fused_qkv = ttnn.reshape(fused_qkv, (bsz, tgt_len, config.encoder_attention_heads, 2, head_size)) + # # Without Split: 0.84 pcc + # key_states = ttnn.reshape(fused_qkv, (bsz, tgt_len, config.encoder_attention_heads, head_size * 2))[..., :head_size] + # value_states = ttnn.reshape(fused_qkv, (bsz, tgt_len, config.encoder_attention_heads, head_size * 2))[..., head_size:] + + # key_states = ttnn.to_device(key_states, device) + # key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + # key_states = ttnn.permute(key_states, (0, 2, 3, 1)) + + # value_states = ttnn.to_device(value_states, device) + # value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + # value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + + fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.ROW_MAJOR_LAYOUT) + fused_qkv = ttnn.from_device(fused_qkv) + fused_qkv = ttnn.reshape(fused_qkv, (bsz, tgt_len, 2, config.encoder_attention_heads, head_size)) + fused_qkv = ttnn.to_layout(fused_qkv, layout=ttnn.TILE_LAYOUT) + fused_qkv = ttnn.to_device(fused_qkv, device=device) + + # #13672: Slice op Not supported for 5d tensors. + fused_qkv = ttnn.to_torch(fused_qkv) + key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] # + key_states = ttnn.from_torch(key_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + value_states = ttnn.from_torch(value_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + key_states = ttnn.permute(key_states, (0, 2, 3, 1)) + value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, head_size, tgt_len], + [bsz, config.encoder_attention_heads, head_size, tgt_len_padded], + ) + key_states = ttnn.reshape(key_states, shape=desired_shape) + + desired_shape = ttnn.Shape( + [bsz, config.encoder_attention_heads, tgt_len, head_size], + [bsz, config.encoder_attention_heads, tgt_len_padded, head_size], + ) + value_states = ttnn.reshape(value_states, shape=desired_shape) + + return key_states, value_states + + +def calculate_query_key_values(config, hidden_states, *, parameters, whisper_memory_config): + fused_qkv = ttnn.linear( + hidden_states, + parameters.weight, + bias=parameters.bias, + ) + + return ttnn.transformer.split_query_key_value_and_split_heads( + fused_qkv, memory_config=whisper_memory_config, num_heads=config.num_attention_heads + ) + + +def whisper_attention( + config, device, hidden_states, attention_mask, key_value_states=None, *, parameters, whisper_memory_config +): + head_size = config.d_model // config.encoder_attention_heads + scaling = head_size**-0.5 + bsz, *_, tgt_len, _ = hidden_states.shape + + is_cross_attention = key_value_states is not None + if is_cross_attention: + query_states = ttnn.linear( + hidden_states, + parameters.q_proj.weight, + bias=parameters.q_proj.bias, + memory_config=whisper_memory_config, + ) + query_states = ttnn.to_layout(query_states, layout=ttnn.ROW_MAJOR_LAYOUT) + query_states = ttnn.from_device(query_states) + query_states = ttnn.reshape(query_states, (bsz, tgt_len, config.encoder_attention_heads, head_size)) + query_states = ttnn.to_layout(query_states, layout=ttnn.TILE_LAYOUT) + query_states = ttnn.to_device(query_states, device=device) + query_states = ttnn.permute(query_states, (0, 2, 1, 3)) + key_states, value_states = calculate_key_values( + config, key_value_states, parameters=parameters.key_value, whisper_memory_config=whisper_memory_config + ) + else: + query_states, key_states, value_states = calculate_query_key_values( + config, hidden_states, parameters=parameters.query_key_value, whisper_memory_config=whisper_memory_config + ) + + query_states *= scaling + attn_weights = ttnn.matmul(query_states, key_states) + + if attention_mask is not None: + attn_weights = ttnn.add(attn_weights, attention_mask) + + # differences in ttnn.softmax vs torch.softmax cause the attn_weights to be slightly different + attn_weights = ttnn.softmax(attn_weights, dim=-1) + + attn_probs = dropout(attn_weights, p=0, training=False) + attn_output = ttnn.matmul(attn_probs, value_states, memory_config=whisper_memory_config) + + ttnn.deallocate(attn_probs) + ttnn.deallocate(attn_weights) + ttnn.deallocate(query_states) + + attn_output = ttnn.transformer.concatenate_heads(attn_output) + + attn_output = ttnn.linear( + attn_output, + parameters.out_proj.weight, + bias=parameters.out_proj.bias, + memory_config=whisper_memory_config, + ) + + return attn_output + + +def encoder_layer(config, device, hidden_states, *, parameters, whisper_memory_config): + residual = hidden_states + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states, + attention_mask=None, + parameters=parameters.self_attn, + whisper_memory_config=whisper_memory_config, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + residual = hidden_states + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = ttnn.linear( + hidden_states, + parameters.fc1.weight, + bias=parameters.fc1.bias, + ) + + hidden_states = ttnn.gelu(hidden_states, memory_config=whisper_memory_config) + hidden_states = dropout(hidden_states, p=0, training=False) + + hidden_states = ttnn.linear( + hidden_states, + parameters.fc2.weight, + bias=parameters.fc2.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + # if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): + # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +def encoder(config, device, inputs_embeds, *, parameters, whisper_memory_config): + hidden_states = ttnn.add(inputs_embeds, parameters.embed_positions.weight) + hidden_states = dropout(hidden_states, p=0, training=False) + + for encoder_layer_parameter in parameters.layers: + hidden_states = encoder_layer( + config, + device, + hidden_states, + parameters=encoder_layer_parameter, + whisper_memory_config=whisper_memory_config, + ) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + + return hidden_states + + +def make_causal_mask(input_ids_shape, dtype): + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.shape + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def decoder_layer( + config, device, hidden_states, attention_mask, encoder_hidden_states, *, parameters, whisper_memory_config +): + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.self_attn_layer_norm.weight, + bias=parameters.self_attn_layer_norm.bias, + memory_config=whisper_memory_config, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states=hidden_states, + attention_mask=attention_mask, + parameters=parameters.self_attn, + whisper_memory_config=whisper_memory_config, + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + # Cross-Attention Block + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.encoder_attn_layer_norm.weight, + bias=parameters.encoder_attn_layer_norm.bias, + ) + + hidden_states = whisper_attention( + config, + device, + hidden_states, + attention_mask=None, + key_value_states=encoder_hidden_states, + parameters=parameters.encoder_attn, + whisper_memory_config=whisper_memory_config, + ) + + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + + residual = hidden_states + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.final_layer_norm.weight, + bias=parameters.final_layer_norm.bias, + ) + + hidden_states = ttnn.linear( + hidden_states, parameters.fc1.weight, bias=parameters.fc1.bias, memory_config=whisper_memory_config + ) + hidden_states = ttnn.gelu(hidden_states, memory_config=whisper_memory_config) + hidden_states = dropout(hidden_states, p=0, training=False) + + hidden_states = ttnn.linear( + hidden_states, parameters.fc2.weight, bias=parameters.fc2.bias, memory_config=whisper_memory_config + ) + hidden_states = dropout(hidden_states, p=0, training=False) + hidden_states = ttnn.add(residual, hidden_states) + return hidden_states + + +def prepare_decoder_attention_mask(attention_mask, input_shape, input_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + if input_shape[-1] > 1: + combined_attention_mask = make_causal_mask(input_shape, input_embeds.dtype) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = expand_mask(attention_mask, input_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def decoder( + config, device, hidden_states, decoder_attention_mask, encoder_hidden_states, *, parameters, whisper_memory_config +): + hidden_states = dropout(hidden_states, p=0, training=False) + + for decoder_layer_parameter in parameters.layers: + hidden_states = decoder_layer( + config, + device, + hidden_states, + decoder_attention_mask, + encoder_hidden_states, + parameters=decoder_layer_parameter, + whisper_memory_config=whisper_memory_config, + ) + + hidden_states = ttnn.layer_norm( + hidden_states, + weight=parameters.layer_norm.weight, + bias=parameters.layer_norm.bias, + ) + + return hidden_states + + +def convert_to_ttnn(model, name): + return name not in [ + "encoder.conv1", + "encoder.conv2", + "decoder.embed_tokens", + "decoder.embed_positions", + ] + + +def preprocess_encoder_inputs(input_features, *, parameters, device, whisper_memory_config): + def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1): + return F.conv1d(input, weight, bias, stride, padding, dilation, groups) + + def ttnn_conv1d( + device, + tt_input_tensor, + weights, + conv_params, + bias, + *, + output_dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + math_fidelity=ttnn.MathFidelity.LoFi, + deallocate_activation=True, + act_block_h=32, + height_sharding=True, + use_shallow_conv_variant=False, + fp32_accum=False, + packer_l1_acc=False, + debug=False, + groups=1, + math_approx=False, + activation="", + reallocate_halo=False, + reshard=False, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, + ): + weights = ttnn.from_torch(weights, dtype=ttnn.float32) + bias = ttnn.from_torch(bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype=ttnn.float32) + + conv_config = ttnn.Conv1dConfig( + dtype=output_dtype, + weights_dtype=weights_dtype, + math_approx_mode_enabled=math_approx, + fp32_dest_acc_enabled=fp32_accum, + packer_l1_accum_enabled=packer_l1_acc, + activation=activation, + input_channels_alignment=(16 if use_shallow_conv_variant else 32), + deallocate_activation=deallocate_activation, + reallocate_halo_output=reallocate_halo, + act_block_h_override=act_block_h, + reshard_if_not_optimal=reshard, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), + math_fidelity=math_fidelity, + ) + + [tt_output_tensor_on_device, out_length, weights_device, bias_device] = ttnn.Conv1d( + input_tensor=tt_input_tensor, + weight_tensor=weights, + in_channels=tt_input_tensor.shape[-1], + out_channels=weights.shape[0], + device=device, + bias_tensor=bias, + kernel_size=3, + stride=conv_params[0], + padding=conv_params[1], + batch_size=tt_input_tensor.shape[0], + input_length=tt_input_tensor.shape[1], + conv_config=conv_config, + conv_op_cache={}, + debug=debug, + groups=groups, + ) + tt_output_tensor_on_device = ttnn.squeeze(tt_output_tensor_on_device, 0) + tt_output_tensor_on_device = ttnn.to_layout(tt_output_tensor_on_device, layout=ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor_on_device = ttnn.reshape( + tt_output_tensor_on_device, (tt_input_tensor.shape[0], out_length, tt_output_tensor_on_device.shape[-1]) + ) + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + + return tt_output_tensor + + if parameters.conv1.weight.shape[0] == 512: + input_features = ttnn.from_torch(input_features, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + input_features = ttnn.permute(input_features, (0, 2, 1)) + conv1 = ttnn_conv1d( + device, + input_features, + parameters.conv1.weight, + [1, 1], + parameters.conv1.bias, + ) + conv1 = ttnn.to_layout(conv1, ttnn.TILE_LAYOUT) + conv1 = ttnn.to_device(conv1, device) + conv1 = ttnn.permute(conv1, (0, 2, 1)) + + else: + conv1 = conv( + input_features.float(), + weight=parameters.conv1.weight, + bias=parameters.conv1.bias, + stride=1, + padding=1, + ) + conv1 = ttnn.from_torch(conv1, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + input_embeds = ttnn.gelu(conv1, memory_config=whisper_memory_config) + input_embeds = ttnn.to_layout(input_embeds, layout=ttnn.ROW_MAJOR_LAYOUT) + + # input_embeds = ttnn.permute(input_embeds, (0, 2, 1)) + input_embeds = ttnn.to_torch(input_embeds) + + # #13529 ttnn.conv1d throws OOM here. + # conv2 = ttnn_conv1d( + # device, + # input_embeds, + # parameters.conv2.weight, + # [2, 1], + # parameters.conv2.bias, + # ) + # conv2 = ttnn.to_layout(conv2, ttnn.TILE_LAYOUT) + # conv2 = ttnn.to_device(conv2, device) + # conv2 = ttnn.permute(conv2, (0, 2, 1)) + # input_embeds = ttnn.gelu(conv2, memory_config=whisper_memory_config) + + conv = conv( + input_embeds.float(), + weight=parameters.conv2.weight, + bias=parameters.conv2.bias, + stride=2, + padding=1, + ) + conv = ttnn.from_torch(conv, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + input_embeds = ttnn.gelu(conv, memory_config=whisper_memory_config) + input_embeds = ttnn.permute(input_embeds, (0, 2, 1)) + + return input_embeds + + +def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters, device): + input_shape = input_ids.size() + input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) + inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) + attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) + # ttnn cannot broadcast when adding on the batch or channel dimensions so this is a workaround + attention_mask = attention_mask.expand(-1, config.encoder_attention_heads, -1, -1) + + positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] + decoder_hidden_states = inputs_embeds + positions + + decoder_hidden_states = ttnn.from_torch( + decoder_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + return decoder_hidden_states, attention_mask + + +def preprocess_inputs(*, config, input_features, input_ids, attention_mask, parameters, device, whisper_memory_config): + input_embeds = preprocess_encoder_inputs( + input_features, parameters=parameters.encoder, device=device, whisper_memory_config=whisper_memory_config + ) + (decoder_hidden_states, attention_mask) = preprocess_decoder_inputs( + config, input_ids, attention_mask, parameters=parameters.decoder, device=device + ) + return input_embeds, decoder_hidden_states, attention_mask + + +def whisper( + config, + device, + encoder_hidden_states, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + whisper_memory_config, +): + encoder_hidden_states = encoder( + config, + device, + encoder_hidden_states, + parameters=parameters.encoder, + whisper_memory_config=whisper_memory_config, + ) + + last_hidden_state = decoder( + config, + device, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + parameters=parameters.decoder, + whisper_memory_config=whisper_memory_config, + ) + + return last_hidden_state + + +def whisper_for_audio_classification(config, inputs_embeds, *, parameters, device, batch_size, whisper_memory_config): + encoder_outputs = encoder( + config=config, + device=device, + inputs_embeds=inputs_embeds, + parameters=parameters.encoder, + whisper_memory_config=whisper_memory_config, + ) + hidden_states = ttnn.linear( + encoder_outputs, + parameters.projector.weight, + bias=parameters.projector.bias, + memory_config=whisper_memory_config, + ) + pooled_output = ttnn.mean(hidden_states, dim=-2, keepdim=True) + + logits = ttnn.linear( + pooled_output, + parameters.classifier.weight, + bias=parameters.classifier.bias, + memory_config=whisper_memory_config, + ) + return logits + + +def whisper_for_conditional_generation( + config, + input_embeds, + decoder_hidden_states, + decoder_attention_mask, + *, + parameters, + device, + ttnn_linear_weight, + whisper_memory_config, +): + output = whisper( + config=config, + device=device, + encoder_hidden_states=input_embeds, + decoder_hidden_states=decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=parameters, + whisper_memory_config=whisper_memory_config, + ) + + ttnn_output = ttnn.matmul( + output, + ttnn_linear_weight, + dtype=ttnn.bfloat16, + ) + return ttnn_output + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): + height, width = torch_model.k_proj.weight.shape + + if "encoder_attn" in name: + parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}} + preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0) + preprocessed_bias = torch.cat([torch.zeros(height), torch_model.v_proj.bias], dim=0) + parameters["key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + parameters["q_proj"]["weight"] = preprocess_linear_weight(torch_model.q_proj.weight, dtype=ttnn.bfloat16) + parameters["q_proj"]["bias"] = preprocess_linear_bias(torch_model.q_proj.bias, dtype=ttnn.bfloat16) + else: + parameters = {"query_key_value": {}, "out_proj": {}} + preprocessed_weight = torch.cat( + [torch_model.q_proj.weight, torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0 + ) + preprocessed_bias = torch.cat( + [torch_model.q_proj.bias, torch.zeros(height), torch_model.v_proj.bias], dim=0 + ) + parameters["query_key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + + parameters["out_proj"]["weight"] = preprocess_linear_weight(torch_model.out_proj.weight, dtype=ttnn.bfloat16) + parameters["out_proj"]["bias"] = preprocess_linear_bias(torch_model.out_proj.bias, dtype=ttnn.bfloat16) + + elif name == "encoder.embed_positions" and isinstance(torch_model, torch.nn.Embedding): + embeddings = torch_model.weight.unsqueeze(0).expand(8, -1, -1) + embeddings = ttnn.from_torch(embeddings, dtype=ttnn.bfloat16) + embeddings = ttnn.to_layout(embeddings, ttnn.TILE_LAYOUT) + parameters["weight"] = embeddings + + return parameters diff --git a/models/experimental/functional_whisper/README.md b/models/experimental/functional_whisper/README.md deleted file mode 100644 index 8a228b35e7e..00000000000 --- a/models/experimental/functional_whisper/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# ttnn_functional_whisper Demo - -## How to Run - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for audio classification. - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/audio_classification" models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for audio classification. - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_optimized_functional_whisper]` to run the ttnn optimized functional whisper demo for conditional generation. - -Use `pytest --disable-warnings --input-path="models/experimental/functional_whisper/demo/dataset/conditional_generation" models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation[1-models.experimental.functional_whisper.tt.ttnn_functional_whisper]` to run the ttnn functional whisper demo for conditional generation. - -Our another demo is designed to run with `google/fleurs` for Audio classification and `hf-internal-testing/librispeech_asr_dummy` for Conditional generation - -Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_audio_classification_dataset` to run audio classification demo with dataset input. - -Use `pytest --disable-warnings models/experimental/functional_whisper/demo/demo.py::test_demo_for_conditional_generation_dataset` to run conditional generation demo with dataset input. - -## Inputs - -Inputs by default are provided from `dataset/audio_classification` and `dataset/conditional_generation` folder. If you wish to change the inputs, provide a different path to demo. - -For demo with dataset,Inputs for Audio classification is taken from `google/fleurs` dataset and Inputs for Conditional generation is taken from `hf-internal-testing/librispeech_asr_dummy` dataset. - -## Details - -The entry point to whisper model is whisper in `models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py` for optimized version.(`models/experimental/functional_whisper/tt/ttnn_functional_whisper.py` for normal version). diff --git a/models/experimental/functional_whisper/demo/demo.py b/models/experimental/functional_whisper/demo/demo.py deleted file mode 100644 index ad0910f35d8..00000000000 --- a/models/experimental/functional_whisper/demo/demo.py +++ /dev/null @@ -1,396 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch -from datasets import load_dataset -from loguru import logger -from scipy.io import wavfile -import ttnn -from transformers import ( - AutoFeatureExtractor, - WhisperModel, - WhisperConfig, - AutoProcessor, - WhisperForConditionalGeneration, -) -from models.utility_functions import ( - disable_compilation_reports, - disable_persistent_kernel_cache, -) -from models.experimental.functional_whisper.tt import ttnn_functional_whisper, ttnn_optimized_functional_whisper -from models.generation_utils import get_logits_processor -from ttnn.model_preprocessing import preprocess_model_parameters - -import torch -import os -from os import listdir -from os.path import isfile, join - -from transformers import AutoFeatureExtractor, WhisperForAudioClassification -from datasets import load_dataset - - -def load_input_paths(folder_path): - files = [os.path.join(folder_path, f) for f in listdir(folder_path) if isfile(join(folder_path, f))] - return files - - -def pad_input_32(tensor, value): - len = tensor.shape[1] - - if len % 32 == 0: - return tensor - - padded_len = ((len // 32) + 1) * 32 - - pad_tensor = (value * torch.ones(tensor.shape[0], padded_len - len)).to(torch.long) - tensor = torch.cat([tensor, pad_tensor], dim=1) - - return tensor - - -def run_generate( - config, - input_embeds, - input_features, - ttnn_model, - decoder_hidden_states, - decoder_attention_mask, - parameters, - processor, - ttnn_linear_weight, - device, - generation_config, -): - input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - - logits_processor = get_logits_processor(input_ids, config) - - input_ids = pad_input_32(input_ids, config.pad_token_id).to(torch.long) - - decoder_start_values = generation_config.pad_token_id * torch.ones(1, 32).to(torch.long) - - for i in range(32): - output = ttnn_model.whisper( - config, - input_embeds, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - ) - output = output @ ttnn_linear_weight - - output = ttnn.from_device(output) - - logits_to_torch = ttnn.to_torch(output) - - next_token_logits = logits_to_torch[:, i, :] - - next_tokens_scores = logits_processor(input_features, next_token_logits) - next_tokens = torch.argmax(next_tokens_scores, dim=-1) - - if (i + 1) % 32 == 0: - input_ids = torch.cat([input_ids, decoder_start_values], dim=1) - - input_ids[:, i + 1] = next_tokens[:, None] - - decoder_hidden_states, decoder_attention_mask = ttnn_model.preprocess_decoder_inputs( - config=config, input_ids=input_ids, attention_mask=None, parameters=parameters.decoder, device=device - ) - - if next_tokens == config.eos_token_id: - break - logger.info(processor.batch_decode(input_ids, skip_special_tokens=True)[0]) - - ttnn_transcription = processor.batch_decode(input_ids, skip_special_tokens=True)[0] - - return ttnn_transcription - - -def run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs): - torch.manual_seed(1234) - - feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") - model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") - - model.eval() - input_data = load_input_paths(input_path) - - parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - ) - if len(input_data) < num_inputs: - assert False, "num_inputs exceeds number of audio files available in folder" - - for i in range(num_inputs): - input_file_path = input_data[i] - samplerate, data = wavfile.read(input_file_path) - - inputs = feature_extractor( - data, - sampling_rate=samplerate, - return_tensors="pt", - ) - - input_features = inputs.input_features - - config = model.config - input_embedding = ttnn_model.preprocess_encoder_inputs( - input_features=input_features, parameters=parameters.encoder, device=device - ) - - encoder_outputs = ttnn_model.encoder( - config=config, inputs_embeds=input_embedding, parameters=parameters.encoder - ) - - hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight) - hidden_states = ttnn.add(hidden_states, parameters.projector.bias) - - pooled_output = ttnn.mean(hidden_states, dim=-2) - - logits = ttnn.matmul(pooled_output, parameters.classifier.weight) - logits = ttnn.add(logits, parameters.classifier.bias) - - logits_torch = ttnn.to_torch(logits) - predicted_class_ids = torch.argmax(logits_torch).item() - predicted_label = model.config.id2label[predicted_class_ids] - - logger.info("predicted_label") - logger.info(predicted_label) - - -def run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs): - torch.manual_seed(0) - - model = WhisperModel.from_pretrained("openai/whisper-tiny.en").to(torch.bfloat16).eval() - - config = WhisperConfig.from_pretrained("openai/whisper-tiny.en") - - processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe") - hf_reference_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - linear_weight = hf_reference_model.proj_out.weight - - linear_weight = hf_reference_model.proj_out.weight - ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) - ttnn_linear_weight = ttnn.permute(ttnn_linear_weight, (1, 0)) - ttnn_linear_weight = ttnn.to_layout(ttnn_linear_weight, layout=ttnn.TILE_LAYOUT) - - feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") - input_data = load_input_paths(input_path) - parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - ) - - if len(input_data) < num_inputs: - assert False, "num_inputs exceeds number of audio files available in folder" - output_list = {} - for i in range(num_inputs): - input_file_path = input_data[i] - samplerate, data = wavfile.read(input_file_path) - inputs = feature_extractor(data, sampling_rate=samplerate, return_tensors="pt") - dtype_to_use = torch.bfloat16 - input_features = inputs.input_features.type(dtype_to_use) - - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) - - attention_mask = None - - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, - device=device, - ) - - generation_config = hf_reference_model.generation_config - ttnn_output = run_generate( - config, - input_embeds, - input_features, - ttnn_model, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - processor=processor, - ttnn_linear_weight=ttnn_linear_weight, - device=device, - generation_config=generation_config, - ) - logger.info("Model Output") - logger.info(ttnn_output) - output_list[i] = ttnn_output - for i in range(len(output_list)): - logger.info(f"output for input {i+1}") - logger.info(output_list[i]) - - -def run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device): - torch.manual_seed(1234) - - feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") - model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") - - model.eval() - - ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) - sample = next(iter(ds)) - - inputs = feature_extractor( - sample["audio"]["array"], - sampling_rate=sample["audio"]["sampling_rate"], - return_tensors="pt", - ) - - input_features = inputs.input_features - - logger.debug("Input audio language:") - logger.debug(sample["language"]) - - parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - ) - - config = model.config - input_embedding = ttnn_model.preprocess_encoder_inputs( - input_features=input_features, parameters=parameters.encoder, device=device - ) - - encoder_outputs = ttnn_model.encoder(config=config, inputs_embeds=input_embedding, parameters=parameters.encoder) - - hidden_states = ttnn.matmul(encoder_outputs, parameters.projector.weight) - hidden_states = ttnn.add(hidden_states, parameters.projector.bias) - - pooled_output = ttnn.mean(hidden_states, dim=-2) - - logits = ttnn.matmul(pooled_output, parameters.classifier.weight) - logits = ttnn.add(logits, parameters.classifier.bias) - - logits_torch = ttnn.to_torch(logits) - predicted_class_ids = torch.argmax(logits_torch).item() - predicted_label = model.config.id2label[predicted_class_ids] - - logger.info("predicted_label") - logger.info(predicted_label) - - -def run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device): - torch.manual_seed(0) - - model = WhisperModel.from_pretrained("openai/whisper-tiny.en").to(torch.bfloat16).eval() - - config = WhisperConfig.from_pretrained("openai/whisper-tiny.en") - - processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en", language="English", task="transcribe") - hf_reference_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - linear_weight = hf_reference_model.proj_out.weight - - linear_weight = hf_reference_model.proj_out.weight - ttnn_linear_weight = ttnn.from_torch(linear_weight, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) - ttnn_linear_weight = ttnn.permute(ttnn_linear_weight, (1, 0)) - ttnn_linear_weight = ttnn.to_layout(ttnn_linear_weight, layout=ttnn.TILE_LAYOUT) - - feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - dtype_to_use = torch.bfloat16 - input_features = inputs.input_features.type(dtype_to_use) - - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - decoder_input_ids = pad_input_32(decoder_input_ids, config.pad_token_id).to(torch.long) - - attention_mask = None - - parameters = preprocess_model_parameters( - initialize_model=lambda: model, - convert_to_ttnn=ttnn_model.convert_to_ttnn, - custom_preprocessor=ttnn_model.custom_preprocessor, - device=device, - ) - - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=parameters, - device=device, - ) - - generation_config = hf_reference_model.generation_config - ttnn_output = run_generate( - config, - input_embeds, - input_features, - ttnn_model, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=parameters, - processor=processor, - ttnn_linear_weight=ttnn_linear_weight, - device=device, - generation_config=generation_config, - ) - logger.info("Model Output") - logger.info(ttnn_output) - - -@pytest.mark.parametrize( - "ttnn_model", - (ttnn_optimized_functional_whisper, ttnn_functional_whisper), -) -@pytest.mark.parametrize( - "num_inputs", - ((1),), -) -def test_demo_for_audio_classification(input_path, ttnn_model, device, num_inputs): - disable_persistent_kernel_cache() - disable_compilation_reports() - return run_demo_functional_whisper_for_audio_classification_inference(input_path, ttnn_model, device, num_inputs) - - -@pytest.mark.parametrize( - "ttnn_model", - (ttnn_optimized_functional_whisper, ttnn_functional_whisper), -) -@pytest.mark.parametrize( - "num_inputs", - ((1),), -) -def test_demo_for_conditional_generation(input_path, ttnn_model, device, num_inputs): - disable_persistent_kernel_cache() - disable_compilation_reports() - return run_demo_functional_whisper_for_conditional_generation_inference(input_path, ttnn_model, device, num_inputs) - - -@pytest.mark.parametrize( - "ttnn_model", - (ttnn_optimized_functional_whisper, ttnn_functional_whisper), -) -def test_demo_for_audio_classification_dataset(ttnn_model, device): - disable_persistent_kernel_cache() - disable_compilation_reports() - return run_demo_functional_whisper_for_audio_classification_dataset(ttnn_model, device) - - -@pytest.mark.parametrize( - "ttnn_model", - (ttnn_functional_whisper, ttnn_optimized_functional_whisper), -) -def test_demo_for_conditional_generation_dataset(ttnn_model, device): - disable_persistent_kernel_cache() - disable_compilation_reports() - return run_demo_functional_whisper_for_conditional_generation_dataset(ttnn_model, device) diff --git a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py deleted file mode 100644 index 56b8c0054f6..00000000000 --- a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py +++ /dev/null @@ -1,441 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import transformers -import torch -from typing import Optional, Tuple - -from torch.nn import functional as F -from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias -import ttnn -from loguru import logger - -WHISPER_MEMORY_CONFIG = ttnn.DRAM_MEMORY_CONFIG -WHISPER_DTYPE = ttnn.bfloat8_b - - -def gelu(tensor): - return ttnn.gelu(tensor, memory_config=WHISPER_MEMORY_CONFIG) - - -def dropout(hidden_states, p, training): - # ignored for inference - return hidden_states - - -# The split_query_key_value_and_split_heads requires the query to have the same volume as the key and values -# This is not the case however for whisper so we currently cannot swap out calculate_key_values below -# def calculate_key_values(config, query_states, key_value_states, *, parameters): -# fused_kv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias -# head_size = config.d_model // config.encoder_attention_heads -# batch_size, *_, _, two_times_hidden_size = fused_kv.shape.with_tile_padding() -# hidden_size = two_times_hidden_size // 2 -# encoder_attention_heads = hidden_size // head_size -# query_states, key_states, value_states = ttnn.transformer.split_query_key_value_and_split_heads( -# query_states, -# kv_input_tensor=fused_kv, -# num_heads=encoder_attention_heads, -# memory_config=WHISPER_MEMORY_CONFIG, -# ) -# key_states = ttnn.permute(key_states, (0, 1, 3, 2)) -# return query_states, key_states, value_states - - -def calculate_key_values(config, key_value_states, *, parameters): - bsz, tgt_len, hidden_size = key_value_states.shape - bsz, tgt_len_padded, _ = key_value_states.shape.with_tile_padding() - head_size = hidden_size // config.encoder_attention_heads - - fused_qkv = key_value_states @ parameters.key_value.weight + parameters.key_value.bias - - dtype = fused_qkv.dtype - device = fused_qkv.device() - fused_qkv = ttnn.to_torch(fused_qkv) - fused_qkv = torch.reshape(fused_qkv, (bsz, tgt_len, 2, config.encoder_attention_heads, head_size)) - key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] - - key_states = ttnn.from_torch(key_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) - value_states = ttnn.from_torch(value_states, dtype=dtype, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) - - key_states = ttnn.permute(key_states, (0, 2, 3, 1)) - value_states = ttnn.permute(value_states, (0, 2, 1, 3)) - key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) - value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) - - desired_shape = ttnn.Shape( - [bsz, config.encoder_attention_heads, head_size, tgt_len], - [bsz, config.encoder_attention_heads, head_size, tgt_len_padded], - ) - key_states = ttnn.reshape(key_states, shape=desired_shape) - - desired_shape = ttnn.Shape( - [bsz, config.encoder_attention_heads, tgt_len, head_size], - [bsz, config.encoder_attention_heads, tgt_len_padded, head_size], - ) - value_states = ttnn.reshape(value_states, shape=desired_shape) - - return key_states, value_states - - -# The following functionis expected to replace calculate_query_key_values and split_query_key_value_and_split_heads below -# however the pcc is incorrect on the final layer unless we keep the original split_query_key_value_and_split_heads below -# def calculate_query_key_values(config, hidden_states, *, parameters): -# fused_qkv = hidden_states @ parameters.query_key_value.weight + parameters.query_key_value.bias -# head_size = config.d_model // config.encoder_attention_heads -# batch_size, *_, _, three_times_hidden_size = fused_qkv.shape.with_tile_padding() -# hidden_size = three_times_hidden_size // 3 -# encoder_attention_heads = hidden_size // head_size -# return ttnn.transformer.split_query_key_value_and_split_heads( -# fused_qkv, -# num_heads=encoder_attention_heads, -# memory_config=WHISPER_MEMORY_CONFIG, -# ) - - -def split_query_key_value_and_split_heads( - config, fused_qkv: ttnn.Tensor -) -> Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor]: - head_size = config.d_model // config.encoder_attention_heads - batch_size, *_, seq_length, _ = fused_qkv.shape - batch_size, *_, padded_seq_length, _ = fused_qkv.shape.with_tile_padding() - - query_states, key_states, value_states = ttnn.transformer.split_query_key_value_and_split_heads( - fused_qkv, num_heads=config.encoder_attention_heads - ) - - desired_shape = ttnn.Shape( - [batch_size, config.encoder_attention_heads, seq_length, head_size], - [batch_size, config.encoder_attention_heads, padded_seq_length, head_size], - ) - desired_key_shape = ttnn.Shape( - [batch_size, config.encoder_attention_heads, head_size, seq_length], - [batch_size, config.encoder_attention_heads, head_size, padded_seq_length], - ) - query_states = ttnn.reshape(query_states, shape=desired_shape) - key_states = ttnn.reshape(key_states, shape=desired_key_shape) - value_states = ttnn.reshape(value_states, shape=desired_shape) - return query_states, key_states, value_states - - -def calculate_query_key_values(config, hidden_states, *, parameters): - fused_qkv = hidden_states @ parameters.query_key_value.weight + parameters.query_key_value.bias - return split_query_key_value_and_split_heads(config, fused_qkv) - - -def whisper_attention(config, hidden_states, attention_mask, key_value_states=None, *, parameters): - head_size = config.d_model // config.encoder_attention_heads - scaling = head_size**-0.5 - bsz, *_, tgt_len, _ = hidden_states.shape - - is_cross_attention = key_value_states is not None - if is_cross_attention: - query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias - dtype = query_states.dtype - device = query_states.device() - query_states = ttnn.to_torch(query_states) - query_states = torch.reshape(query_states, (bsz, tgt_len, config.encoder_attention_heads, head_size)) - query_states = ttnn.from_torch(query_states, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) - query_states = ttnn.permute(query_states, (0, 2, 1, 3)) - key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters) - else: - query_states, key_states, value_states = calculate_query_key_values( - config, hidden_states, parameters=parameters - ) - - query_states *= scaling - - attn_weights = query_states @ key_states - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # differences in ttnn.softmax vs torch.softmax cause the attn_weights to be slightly different - attn_weights = ttnn.softmax(attn_weights, dim=-1, memory_config=WHISPER_MEMORY_CONFIG) - - attn_probs = dropout(attn_weights, p=0, training=False) - attn_output = attn_probs @ value_states - - attn_output = ttnn.transformer.concatenate_heads(attn_output) - attn_output = attn_output @ parameters.out_proj.weight + parameters.out_proj.bias - return attn_output - - -def encoder_layer(config, hidden_states, *, parameters): - residual = hidden_states - hidden_states = ttnn.layer_norm( - hidden_states, - weight=parameters.self_attn_layer_norm.weight, - bias=parameters.self_attn_layer_norm.bias, - memory_config=WHISPER_MEMORY_CONFIG, - ) - - hidden_states = whisper_attention(config, hidden_states, attention_mask=None, parameters=parameters.self_attn) - hidden_states = dropout(hidden_states, p=0, training=False) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = ttnn.layer_norm( - hidden_states, - weight=parameters.final_layer_norm.weight, - bias=parameters.final_layer_norm.bias, - memory_config=WHISPER_MEMORY_CONFIG, - ) - hidden_states = hidden_states @ parameters.fc1.weight + parameters.fc1.bias - hidden_states = gelu(hidden_states) - hidden_states = dropout(hidden_states, p=0, training=False) - hidden_states = hidden_states @ parameters.fc2.weight + parameters.fc2.bias - hidden_states = dropout(hidden_states, p=0, training=False) - hidden_states = residual + hidden_states - - # if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): - # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - return hidden_states - - -def encoder(config, inputs_embeds, *, parameters): - hidden_states = inputs_embeds + parameters.embed_positions.weight - hidden_states = dropout(hidden_states, p=0, training=False) - - for encoder_layer_parameter in parameters.layers: - hidden_states = encoder_layer(config, hidden_states, parameters=encoder_layer_parameter) - - hidden_states = ttnn.layer_norm( - hidden_states, - weight=parameters.layer_norm.weight, - bias=parameters.layer_norm.bias, - ) - return hidden_states - - -def make_causal_mask(input_ids_shape, dtype): - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) - - -def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.shape - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -def decoder_layer(config, hidden_states, attention_mask, encoder_hidden_states, *, parameters): - residual = hidden_states - hidden_states = ttnn.layer_norm( - hidden_states, - weight=parameters.self_attn_layer_norm.weight, - bias=parameters.self_attn_layer_norm.bias, - ) - - hidden_states = whisper_attention( - config, - hidden_states=hidden_states, - attention_mask=attention_mask, - parameters=parameters.self_attn, - ) - hidden_states = dropout(hidden_states, p=0, training=False) - hidden_states = residual + hidden_states - - # Cross-Attention Block - residual = hidden_states - hidden_states = ttnn.layer_norm( - hidden_states, - weight=parameters.encoder_attn_layer_norm.weight, - bias=parameters.encoder_attn_layer_norm.bias, - ) - - hidden_states = whisper_attention( - config, - hidden_states, - attention_mask=None, - key_value_states=encoder_hidden_states, - parameters=parameters.encoder_attn, - ) - - hidden_states = dropout(hidden_states, p=0, training=False) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = ttnn.layer_norm( - hidden_states, - weight=parameters.final_layer_norm.weight, - bias=parameters.final_layer_norm.bias, - ) - hidden_states = hidden_states @ parameters.fc1.weight + parameters.fc1.bias - hidden_states = gelu(hidden_states) - hidden_states = dropout(hidden_states, p=0, training=False) - hidden_states = hidden_states @ parameters.fc2.weight + parameters.fc2.bias - hidden_states = dropout(hidden_states, p=0, training=False) - hidden_states = residual + hidden_states - - return hidden_states - - -def prepare_decoder_attention_mask(attention_mask, input_shape, input_embeds): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - if input_shape[-1] > 1: - combined_attention_mask = make_causal_mask(input_shape, input_embeds.dtype) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = expand_mask(attention_mask, input_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -def decoder(config, hidden_states, decoder_attention_mask, encoder_hidden_states, *, parameters): - hidden_states = dropout(hidden_states, p=0, training=False) - - for decoder_layer_parameter in parameters.layers: - hidden_states = decoder_layer( - config, - hidden_states, - decoder_attention_mask, - encoder_hidden_states, - parameters=decoder_layer_parameter, - ) - - hidden_states = ttnn.layer_norm( - hidden_states, - weight=parameters.layer_norm.weight, - bias=parameters.layer_norm.bias, - ) - - return hidden_states - - -def convert_to_ttnn(model, name): - return name not in [ - "encoder.conv1", - "encoder.conv2", - "decoder.embed_tokens", - "decoder.embed_positions", - ] - - -def preprocess_encoder_inputs(input_features, *, parameters, device): - def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1): - return F.conv1d(input, weight, bias, stride, padding, dilation, groups) - - input_embeds = torch.nn.functional.gelu( - conv( - input_features, - weight=parameters.conv1.weight, - bias=parameters.conv1.bias, - padding=1, - ) - ) - input_embeds = torch.nn.functional.gelu( - conv( - input_embeds, - weight=parameters.conv2.weight, - bias=parameters.conv2.bias, - stride=2, - padding=1, - ) - ) - input_embeds = input_embeds.permute(0, 2, 1) - input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) - - return input_embeds - - -def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters, device): - input_shape = input_ids.size() - input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) - inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) - attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) - # ttnn cannot broadcast when adding on the batch or channel dimensions so this is a workaround - attention_mask = attention_mask.expand(-1, config.encoder_attention_heads, -1, -1) - - positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] - decoder_hidden_states = inputs_embeds + positions - - decoder_hidden_states = ttnn.from_torch( - decoder_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device - ) - attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) - - return decoder_hidden_states, attention_mask - - -def preprocess_inputs( - *, - config, - input_features, - input_ids, - attention_mask, - parameters, - device, -): - input_embeds = preprocess_encoder_inputs(input_features, parameters=parameters.encoder, device=device) - (decoder_hidden_states, attention_mask) = preprocess_decoder_inputs( - config, input_ids, attention_mask, parameters=parameters.decoder, device=device - ) - return input_embeds, decoder_hidden_states, attention_mask - - -def whisper(config, encoder_hidden_states, decoder_hidden_states, decoder_attention_mask, *, parameters): - encoder_hidden_states = encoder(config, encoder_hidden_states, parameters=parameters.encoder) - last_hidden_state = decoder( - config, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - parameters=parameters.decoder, - ) - return last_hidden_state - - -def custom_preprocessor(torch_model, name): - parameters = {} - if isinstance(torch_model, transformers.models.whisper.modeling_whisper.WhisperAttention): - height, width = torch_model.k_proj.weight.shape - - if "encoder_attn" in name: - parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}} - preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0) - preprocessed_bias = torch.cat([torch.zeros(height), torch_model.v_proj.bias], dim=0) - parameters["key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) - parameters["key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) - parameters["q_proj"]["weight"] = preprocess_linear_weight(torch_model.q_proj.weight, dtype=ttnn.bfloat16) - parameters["q_proj"]["bias"] = preprocess_linear_bias(torch_model.q_proj.bias, dtype=ttnn.bfloat16) - else: - parameters = {"query_key_value": {}, "out_proj": {}} - preprocessed_weight = torch.cat( - [torch_model.q_proj.weight, torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0 - ) - preprocessed_bias = torch.cat( - [torch_model.q_proj.bias, torch.zeros(height), torch_model.v_proj.bias], dim=0 - ) - parameters["query_key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) - parameters["query_key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) - - parameters["out_proj"]["weight"] = preprocess_linear_weight(torch_model.out_proj.weight, dtype=ttnn.bfloat16) - parameters["out_proj"]["bias"] = preprocess_linear_bias(torch_model.out_proj.bias, dtype=ttnn.bfloat16) - elif name == "encoder.embed_positions" and isinstance(torch_model, torch.nn.Embedding): - embeddings = ttnn.from_torch(torch_model.weight, dtype=ttnn.bfloat16) - embeddings = ttnn.to_layout(embeddings, ttnn.TILE_LAYOUT) - parameters["weight"] = embeddings - return parameters diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index c251fa4ccb3..7535f5eece3 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -25,7 +25,7 @@ run_perf_models_other() { env pytest models/demos/distilbert/tests/test_perf_distilbert.py -m $test_marker - env pytest -n auto tests/ttnn/integration_tests/whisper/test_performance.py -m $test_marker + env pytest -n auto models/demos/whisper/tests/test_performance.py -m $test_marker env pytest -n auto models/demos/metal_BERT_large_11/tests -m $test_marker @@ -95,6 +95,8 @@ run_device_perf_models() { env pytest models/demos/convnet_mnist/tests/ -m $test_marker + env pytest models/demos/whisper/tests/ -m $test_marker + if [ "$tt_arch" == "grayskull" ]; then #TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with #Model Device perf regression tests to make sure thy run on no-soft-reset BMs diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 0994d8fe24b..bd45506555e 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -43,6 +43,9 @@ run_common_func_tests() { # ConvNet Mnist pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$? + # Whisper + pytest --disable-warnings models/demos/whisper/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py index 30c91e7f8fe..656e150483d 100644 --- a/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper +from models.demos.whisper.reference import torch_functional_whisper import transformers from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset @@ -17,7 +17,7 @@ @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) @pytest.mark.parametrize("use_key_value_states", [False, True]) def test_whisper_attention(model_name, batch_size, sequence_size, use_key_value_states): @@ -52,7 +52,7 @@ def test_whisper_attention(model_name, batch_size, sequence_size, use_key_value_ @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) def test_encoder_layer(model_name, batch_size, sequence_size): torch.manual_seed(0) @@ -77,7 +77,7 @@ def test_encoder_layer(model_name, batch_size, sequence_size): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("feature_size", [80]) @pytest.mark.parametrize("sequence_length", [3000]) def test_encoder(model_name, batch_size, feature_size, sequence_length): @@ -106,7 +106,7 @@ def test_encoder(model_name, batch_size, feature_size, sequence_length): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) def test_decoder_layer(model_name, batch_size, sequence_size): torch.manual_seed(0) @@ -132,11 +132,11 @@ def test_decoder_layer(model_name, batch_size, sequence_size): config, torch_hidden_states, attention_mask, torch_encoder_hidden_states, parameters=parameters ) - assert_with_pcc(torch_output[0], output, 0.94) + assert_with_pcc(torch_output[0], output, 0.99) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) def test_decoder(model_name, batch_size, sequence_size): torch.manual_seed(0) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py index e6f02bf3203..c645e03984f 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py @@ -3,8 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper -from models.experimental.functional_whisper.tt import ttnn_functional_whisper +from models.demos.whisper.reference import torch_functional_whisper +from models.demos.whisper.tt import ttnn_functional_whisper import transformers from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset @@ -19,10 +19,9 @@ MODEL_NAME = "openai/whisper-base" -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) @pytest.mark.parametrize("use_key_value_states", [False, True]) def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states): @@ -81,13 +80,12 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_ output = ttnn.from_device(output) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.98) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size): torch.manual_seed(0) @@ -120,10 +118,9 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("feature_size", [80]) @pytest.mark.parametrize("sequence_length", [3000]) def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, sequence_length): @@ -167,17 +164,16 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque device=device, ) - output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder(config, ttnn_inputs_embeds, parameters=ttnn_parameters, device=device) output = ttnn.from_device(output) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.97) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size): torch.manual_seed(0) @@ -230,13 +226,12 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.97) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): torch.manual_seed(0) @@ -305,7 +300,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("ttnn_model", [ttnn_functional_whisper]) def test_ttnn_whisper(device, ttnn_model): torch.manual_seed(0) @@ -370,8 +364,9 @@ def test_ttnn_whisper(device, ttnn_model): decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, parameters=ttnn_parameters, + device=device, ) last_hidden_state = ttnn.from_device(last_hidden_state) last_hidden_state = ttnn.to_torch(last_hidden_state) - assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.9895) + assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.989) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index e6cea2f8870..61f04e5b7f8 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -2,30 +2,29 @@ # SPDX-License-Identifier: Apache-2.0 +import ttnn +import torch import pytest -from models.experimental.functional_whisper.reference import torch_functional_whisper -from models.experimental.functional_whisper.tt import ttnn_optimized_functional_whisper import transformers -from transformers import AutoFeatureExtractor, WhisperModel, WhisperConfig from datasets import load_dataset -import torch -import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import torch_random +from models.utility_functions import torch_random, is_grayskull from ttnn.model_preprocessing import preprocess_model_parameters -from models.utility_functions import is_wormhole_b0, is_blackhole +from models.demos.whisper.reference import torch_functional_whisper +from models.demos.whisper.tt import ttnn_optimized_functional_whisper MODEL_NAME = "openai/whisper-base" -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) @pytest.mark.parametrize("use_key_value_states", [False, True]) -def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states): - torch.manual_seed(0) +def test_whisper_attention( + device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states, reset_seeds +): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperAttention( embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout @@ -72,23 +71,24 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_ attention_mask = None output = ttnn_model.whisper_attention( config, + device, ttnn_hidden_states, attention_mask, key_value_states=ttnn_key_value_states, parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.98) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [1500]) -def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size): - torch.manual_seed(0) +def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config).eval() model = model @@ -113,20 +113,25 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device ) - output = ttnn_model.encoder_layer(config, ttnn_hidden_states, parameters=ttnn_parameters) + output = ttnn_model.encoder_layer( + config, + device, + ttnn_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("feature_size", [80]) @pytest.mark.parametrize("sequence_length", [3000]) -def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, sequence_length): - torch.manual_seed(0) +def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, sequence_length, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config).eval() model = model @@ -139,10 +144,6 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) - # torch_original_output = torch_functional_whisper.encoder_original( - # torch_input_features, parameters, embed_dim, num_heads - # ) - inputs_embeds = torch_functional_whisper.preprocess_encoder_inputs( input_features=torch_input_features, parameters=parameters, @@ -162,17 +163,24 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque input_features=torch_input_features, parameters=ttnn_parameters, device=device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) input_embeds = ttnn.to_layout(input_embeds, ttnn.TILE_LAYOUT) input_embeds = ttnn.to_device(input_embeds, device) - output = ttnn_model.encoder(config, input_embeds, parameters=ttnn_parameters) + output = ttnn_model.encoder( + config, + device, + input_embeds, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.968) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @@ -222,20 +230,25 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, device) output = ttnn_model.decoder_layer( - config, ttnn_hidden_states, ttnn_attention_mask, ttnn_encoder_hidden_states, parameters=ttnn_parameters + config, + device, + ttnn_hidden_states, + ttnn_attention_mask, + ttnn_encoder_hidden_states, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.L1_MEMORY_CONFIG, ) output = ttnn.to_torch(output) - assert_with_pcc(torch_output, output, 0.97) + assert_with_pcc(torch_output, output, 0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("sequence_size", [1500]) -def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): - torch.manual_seed(0) +def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size, reset_seeds): config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).eval() model = model @@ -244,7 +257,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) - # decoder_input_ids = torch.ones(1, 32).type(torch.int32) * config.decoder_start_token_id decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id attention_mask = None @@ -255,10 +267,6 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) - # torch_original_output = torch_functional_whisper.decoder_original( - # decoder_input_ids, attention_mask, torch_encoder_hidden_states, parameters, embed_dim, num_heads - # ) - (decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_decoder_inputs( decoder_input_ids, attention_mask, parameters=parameters ) @@ -291,32 +299,33 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): output = ttnn_model.decoder( config, + device=device, hidden_states=decoder_hidden_states, decoder_attention_mask=decoder_attention_mask, encoder_hidden_states=ttnn_encoder_hidden_states, parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, ) output = ttnn.to_torch(output) assert_with_pcc(torch_output, output, pcc=0.99) -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") -@pytest.mark.requires_fast_runtime_mode_off +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("ttnn_model", [ttnn_optimized_functional_whisper]) -def test_ttnn_whisper(tmp_path, device, ttnn_model): - torch.manual_seed(0) - model_name = "openai/whisper-base" - config = WhisperConfig.from_pretrained(model_name) - feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) +def test_ttnn_whisper(reset_seeds, device, batch_size, model_name, ttnn_model): + config = transformers.WhisperConfig.from_pretrained(model_name) + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") + inputs = feature_extractor( + [ds[i]["audio"]["array"] for i in range(batch_size)], sampling_rate=16000, return_tensors="pt" + ) input_features = inputs.input_features - decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id - - attention_mask = None + decoder_input_ids = torch.tensor([[1, 1]] * batch_size) * config.decoder_start_token_id - model = WhisperModel.from_pretrained(model_name).eval() + model = transformers.WhisperModel.from_pretrained(model_name).eval() parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -327,11 +336,11 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): (encoder_hidden_states, decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_inputs( input_features=input_features, input_ids=decoder_input_ids, - attention_mask=attention_mask, + attention_mask=None, parameters=parameters, ) - expected_last_hidden_state = torch_functional_whisper.whisper( + torch_last_hidden_state = torch_functional_whisper.whisper( config, encoder_hidden_states, decoder_hidden_states, @@ -346,24 +355,26 @@ def test_ttnn_whisper(tmp_path, device, ttnn_model): device=device, ) - with ttnn.tracer.trace(): - (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( - config=config, - input_features=input_features, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - parameters=ttnn_parameters, - device=device, - ) + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( + config=config, + input_features=input_features, + input_ids=decoder_input_ids, + attention_mask=None, + parameters=ttnn_parameters, + device=device, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) - last_hidden_state = ttnn_model.whisper( - config, - input_embeds, - decoder_hidden_states, - decoder_attention_mask=decoder_attention_mask, - parameters=ttnn_parameters, - ) - last_hidden_state = ttnn.to_torch(last_hidden_state) - ttnn.tracer.visualize(last_hidden_state, file_name=tmp_path / "whisper.svg") + last_hidden_state = ttnn_model.whisper( + config, + device, + input_embeds, + decoder_hidden_states, + decoder_attention_mask=decoder_attention_mask, + parameters=ttnn_parameters, + whisper_memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull else ttnn.L1_MEMORY_CONFIG, + ) + + last_hidden_state = ttnn.to_torch(last_hidden_state) - assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.964) + assert_with_pcc(torch_last_hidden_state, last_hidden_state, 0.97 if is_grayskull else 0.989)