Skip to content

Commit

Permalink
#13329: TTNN implementation of MNIST model
Browse files Browse the repository at this point in the history
  • Loading branch information
sabira-mcw committed Oct 28, 2024
1 parent 4a11a11 commit a3e1a7b
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 0 deletions.
32 changes: 32 additions & 0 deletions models/demos/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# MNIST

## Platforms

GS E150, WH N150, WH N300

## Introduction

The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks.

### Batch size: 4

Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 4

## How to Run

To run the demo for digit classification using the MNIST model, follow these instructions:

- Use the following command to run the MNIST model.
```
pytest models/demos/mnist/demo/demo.py::test_demo_dataset
```

## Inputs

The demo receives inputs from respective dataset MNIST.

## Additional Information

If you encounter issues when running the model, ensure that device has support for all required operations.

### Owner: [sabira-mcw](https://github.com/sabira-mcw)
78 changes: 78 additions & 0 deletions models/demos/mnist/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn

from torchvision import transforms, datasets
from loguru import logger

from torch.utils.data import DataLoader
from models.demos.mnist.reference.mnist import MnistModel
from models.demos.mnist.tt import tt_mnist

from ttnn.model_preprocessing import preprocess_model_parameters


def run_demo_dataset(device, batch_size, iterations, model_location_generator):
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist"))
model = MnistModel(state_dict)
model = model.eval()

parameters = preprocess_model_parameters(
initialize_model=lambda: model,
convert_to_ttnn=lambda *_: True,
device=device,
)
correct = 0
for iters in range(iterations):
dataloader = DataLoader(test_dataset, batch_size=batch_size)
x, labels = next(iter(dataloader))
dataset_predictions = []
ttnn_predictions = []
dataset_ttnn_correct = 0
x = ttnn.from_torch(x, dtype=ttnn.bfloat16, device=device)
tt_output = tt_mnist.mnist(device, batch_size, x, parameters)
tt_output = ttnn.to_torch(tt_output)
predicted_probabilities = torch.nn.functional.softmax(tt_output, dim=1)
_, predicted_label = torch.max(predicted_probabilities, 1)
tt_output = tt_output
for i in range(batch_size):
dataset_predictions.append(labels[i])
ttnn_predictions.append(predicted_label[i])
logger.info(f"Iter: {iters} Sample {i}:")
logger.info(f"Expected Label: {dataset_predictions[i]}")
logger.info(f"Predicted Label: {ttnn_predictions[i]}")

if dataset_predictions[i] == ttnn_predictions[i]:
dataset_ttnn_correct += 1
correct += 1
dataset_ttnn_accuracy = dataset_ttnn_correct / (batch_size)
logger.info(
f"ImageNet Inference Accuracy for iter {iters} of {batch_size} input samples : {dataset_ttnn_accuracy}"
)

accuracy = correct / (batch_size * iterations)
logger.info(f"ImageNet Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}")


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("iterations", [1])
def test_demo_dataset(
device,
batch_size,
iterations,
model_location_generator,
):
return run_demo_dataset(
device=device,
batch_size=batch_size,
iterations=iterations,
model_location_generator=model_location_generator,
)
30 changes: 30 additions & 0 deletions models/demos/mnist/reference/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch


class MnistModel(torch.nn.Module):
def __init__(self, state_dict):
super().__init__()

self.fc1 = torch.nn.Linear(784, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)

self.load_state_dict(state_dict)

def forward(self, x):
x = x.view(x.shape[0], -1)

x = self.fc1(x)
x = torch.nn.functional.relu(x)

x = self.fc2(x)
x = torch.nn.functional.relu(x)

x = self.fc3(x)
x = torch.nn.functional.relu(x)

return torch.nn.functional.softmax(x)
127 changes: 127 additions & 0 deletions models/demos/mnist/tests/test_perf_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import time
import pytest
import torch
from loguru import logger
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
from models.perf.perf_utils import prep_perf_report
from models.demos.mnist.tt import tt_mnist
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.mnist.reference.mnist import MnistModel
from models.utility_functions import is_grayskull, is_wormhole_b0
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report

transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root="./data", train=False, transform=None, download=True)


def get_expected_times(tt_mnist):
if is_grayskull():
return {
tt_mnist: (2.3, 0.0041),
}[tt_mnist]
elif is_wormhole_b0():
return {
tt_mnist: (3.3, 0.0045),
}[tt_mnist]


@pytest.mark.models_performance_bare_metal
@pytest.mark.models_performance_virtual_machine
@pytest.mark.parametrize(
"batch_size",
[4],
)
@pytest.mark.parametrize(
"tt_mnist",
[tt_mnist],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_performance_mnist(device, batch_size, tt_mnist, model_location_generator, reset_seeds):
state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist"))
model = MnistModel(state_dict)
model = model.eval()
disable_persistent_kernel_cache()
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
convert_to_ttnn=lambda *_: True,
device=device,
)
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
dataloader = DataLoader(test_dataset, batch_size=batch_size)
x, labels = next(iter(dataloader))

test_input = ttnn.from_torch(x, dtype=ttnn.bfloat16, device=device)
durations = []
for _ in range(2):
start = time.time()

ttnn_output = tt_mnist.mnist(
device=device,
x=test_input,
batch_size=batch_size,
parameters=parameters,
)
end = time.time()
durations.append(end - start)

inference_and_compile_time, *inference_times = durations
average_inference_time = sum(inference_times) / len(inference_times)
expected_compile_time, expected_inference_time = get_expected_times(tt_mnist)

prep_perf_report(
model_name="MNIST",
batch_size=batch_size,
inference_and_compile_time=inference_and_compile_time,
inference_time=average_inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
inference_time_cpu=0.0,
)

logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}")
logger.info(f"Inference time: {average_inference_time}")
logger.info(f"Inference times: {inference_times}")
logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}")


@pytest.mark.parametrize(
"batch_size",
[4],
)
@pytest.mark.models_device_performance_bare_metal
def test_perf_device_bare_metal(batch_size, reset_seeds):
subdir = "ttnn_mnist"
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 21202.73
elif is_wormhole_b0():
expected_perf = 27223.11

command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py::test_mnist"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
prep_device_perf_report(
model_name=f"tt_mnist{batch_size}",
batch_size=batch_size,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments="",
)
36 changes: 36 additions & 0 deletions models/demos/mnist/tt/tt_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch


def mnist(device, batch_size, x, parameters):
x = ttnn.reshape(x, (x.shape[0], -1))
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
x = ttnn.linear(
x,
parameters.fc1.weight,
bias=parameters.fc1.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
activation="relu",
)
x = ttnn.linear(
x,
parameters.fc2.weight,
bias=parameters.fc2.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
activation="relu",
)
x = ttnn.linear(
x,
parameters.fc3.weight,
bias=parameters.fc3.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
activation="relu",
)

x = ttnn.softmax(x)

return x
4 changes: 4 additions & 0 deletions tests/scripts/run_performance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ run_perf_models_other() {

env pytest -n auto models/demos/convnet_mnist/tests -m $test_marker

env pytest -n auto models/demos/mnist/tests/test_perf_mnist.py -m $test_marker

## Merge all the generated reports
env python models/perf/merge_perf_results.py
}
Expand Down Expand Up @@ -95,6 +97,8 @@ run_device_perf_models() {

env pytest models/demos/convnet_mnist/tests/ -m $test_marker

env pytest models/demos/mnist/tests -m $test_marker

if [ "$tt_arch" == "grayskull" ]; then
#TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with
#Model Device perf regression tests to make sure thy run on no-soft-reset BMs
Expand Down
3 changes: 3 additions & 0 deletions tests/scripts/single_card/run_single_card_demo_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ run_common_func_tests() {
# ConvNet Mnist
pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$?

#mnist
pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$?

return $fail
}

Expand Down
43 changes: 43 additions & 0 deletions tests/ttnn/integration_tests/mnist/test_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn
import pytest
from tests.ttnn.utils_for_testing import assert_with_pcc
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.mnist.reference.mnist import MnistModel
from models.demos.mnist.tt import tt_mnist
from torch.utils.data import DataLoader
from torchvision import transforms, datasets


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.parametrize(
"batch_size",
[4],
)
def test_mnist(reset_seeds, device, batch_size, model_location_generator):
state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist"))
model = MnistModel(state_dict)
model = model.eval()
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
dataloader = DataLoader(test_dataset, batch_size=batch_size)

x, labels = next(iter(dataloader))

torch_output = model(x)

parameters = preprocess_model_parameters(
initialize_model=lambda: model,
convert_to_ttnn=lambda *_: True,
device=device,
)
x = ttnn.from_torch(x, dtype=ttnn.bfloat16, device=device)

tt_output = tt_mnist.mnist(device, batch_size, x, parameters)

tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.99)

0 comments on commit a3e1a7b

Please sign in to comment.