Skip to content

Commit

Permalink
#4609: Deit gs demo benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeerthana0573 committed Jan 22, 2024
1 parent 7b9aae9 commit 8958ba3
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 0 deletions.
46 changes: 46 additions & 0 deletions models/experimental/deit/tests/demo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from PIL import Image
import torch
import os
import glob
from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES
from datasets import load_dataset


class InputExample(object):
def __init__(self, image, label=None):
self.image = image
self.label = label


def get_input(image_path):
img = Image.open(image_path)
return img


def get_label(image_path):
_, image_name = image_path.rsplit("/", 1)
image_name_exact, _ = image_name.rsplit(".", 1)
_, label_id = image_name_exact.rsplit("_", 1)
label = list(IMAGENET2012_CLASSES).index(label_id)
return label


def get_data(input_loc):
img_dir = input_loc + "/"
data_path = os.path.join(img_dir, "*G")
files = sorted(glob.glob(data_path))
examples = []
for f1 in files:
examples.append(
InputExample(
image=get_input(f1),
label=get_label(f1),
)
)
image_examples = examples

return image_examples
207 changes: 207 additions & 0 deletions models/experimental/deit/tests/test_perf_accuracy_deit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
import numpy as np

from loguru import logger
from datasets import load_dataset
from transformers import AutoImageProcessor, DeiTForImageClassificationWithTeacher

from models.experimental.deit.tt.deit_for_image_classification_with_teacher import (
deit_for_image_classification_with_teacher,
)
from models.utility_functions import (
disable_persistent_kernel_cache,
enable_persistent_kernel_cache,
torch_to_tt_tensor_rm,
profiler,
)
from models.perf.perf_utils import prep_perf_report
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.experimental.deit.tests.demo_utils import get_data


BATCH_SIZE = 1


def run_perf_deit(
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
iterations,
model_location_generator,
):
disable_persistent_kernel_cache()
first_key = "first_iter"
second_key = "second_iter"
third_key = "third_iter"
cpu_key = "ref_key"
comments = "distilled-patch16-wteacher"

sample_image = hf_cat_image_sample_input

image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
HF_model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
input = image_processor(sample_image, return_tensors="pt")

tt_input = torch_to_tt_tensor_rm(input["pixel_values"], device, put_on_device=False)
tt_model_with_teacher = deit_for_image_classification_with_teacher(device)

with torch.no_grad():
profiler.start(cpu_key)
logits = HF_model(**input).logits
profiler.end(cpu_key)

profiler.start(first_key)
tt_output = tt_model_with_teacher(tt_input)[0]
tt_lib.device.Synchronize(device)
profiler.end(first_key)
del tt_output

enable_persistent_kernel_cache()

profiler.start(second_key)
tt_output = tt_model_with_teacher(tt_input)[0]
tt_lib.device.Synchronize(device)
profiler.end(second_key)
del tt_output

iteration = 0
logger.info("ImageNet-1k validation Dataset")
if iterations <= 50:
input_loc = str(model_location_generator("sample_data"))
else:
input_loc = str(model_location_generator("ImageNet_data"))
image_examples = get_data(input_loc)
reference_labels = []
predicted_labels = []

weka_is_on = True
if len(image_examples) == 0:
weka_is_on = False
files_raw = iter(load_dataset("imagenet-1k", split="validation", use_auth_token=True, streaming=True))
image_examples = []
sample_count = BATCH_SIZE * iterations
for _ in range(sample_count):
image_examples.append(next(files_raw))

profiler.start(third_key)
while iteration < iterations:
if weka_is_on:
input_image = image_examples[iteration].image
else:
input_image = image_examples[iteration]["image"]
if input_image.mode == "L":
input_image = input_image.convert(mode="RGB")

inputs = image_processor(images=input_image, return_tensors="pt")
tt_inputs = torch_to_tt_tensor_rm(inputs["pixel_values"], device, put_on_device=False)

tt_output_with_teacher = tt_model_with_teacher(tt_inputs)[0]
tt_output_with_teacher = tt_to_torch_tensor(tt_output_with_teacher).squeeze(0)[:, 0, :]

prediction = tt_output_with_teacher.argmax(-1).item()
predicted_labels.append(prediction)

if weka_is_on:
reference_labels.append(image_examples[iteration].label)
else:
reference_labels.append(image_examples[iteration]["label"])

iteration += 1

predicted_labels = np.array(predicted_labels)
reference_labels = np.array(reference_labels)
accuracy = np.mean(predicted_labels == reference_labels)

logger.info("Accuracy")
logger.info(accuracy)
tt_lib.device.Synchronize(device)
profiler.end(third_key)

first_iter_time = profiler.get(first_key)
second_iter_time = profiler.get(second_key)
third_iter_time = profiler.get(third_key)

cpu_time = profiler.get(cpu_key)
compile_time = first_iter_time - second_iter_time

prep_perf_report(
model_name="deit",
batch_size=BATCH_SIZE,
inference_and_compile_time=first_iter_time,
inference_time=second_iter_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comments,
inference_time_cpu=cpu_time,
)

logger.info(f"deit {comments} inference time: {second_iter_time}")
logger.info(f"deit {comments} compile time: {compile_time}")
logger.info(f"deit inference for {iterations} Samples: {third_iter_time}")


@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"expected_inference_time, expected_compile_time, iterations",
(
(
1.8,
18,
50,
),
),
)
def test_perf_bare_metal(
use_program_cache,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
iterations,
model_location_generator,
):
run_perf_deit(
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
iterations,
model_location_generator,
)


@pytest.mark.models_performance_virtual_machine
@pytest.mark.parametrize(
"expected_inference_time, expected_compile_time, iterations",
(
(
2.0,
19.5,
50,
),
),
)
def test_perf_virtual_machine(
use_program_cache,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
iterations,
model_location_generator,
):
run_perf_deit(
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
iterations,
model_location_generator,
)

0 comments on commit 8958ba3

Please sign in to comment.