Skip to content

Commit

Permalink
#13399: Add data parallel support for convnet mnist model
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneshkeerthivasanx committed Nov 7, 2024
1 parent 62255a8 commit ac45b6e
Show file tree
Hide file tree
Showing 10 changed files with 557 additions and 1 deletion.
1 change: 0 additions & 1 deletion models/demos/convnet_mnist/tt/convnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def convnet_mnist(
weights_dtype=ttnn.bfloat16,
math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
math_approx_mode_enabled=True,
fp32_dest_acc_enabled=False,
packer_l1_accum_enabled=False,
Expand Down
24 changes: 24 additions & 0 deletions models/demos/wormhole/convnet_mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Introduction

Convnet Mnist implements a Convolutions to classify handwritten digits from the MNIST dataset. The MNIST dataset contains grayscale images of handwritten digits (0-9), each of size 32x32 pixels.

# Platforms:
WH N300

## How to Run

To run the demo for digit classification using the MNIST model, follow these instructions:

- Use the following command to run the MNIST model.

```
pytest models/demos/wormhole/convnet_mnist/demo/demo.py
```

Maxpool and Softmax are used in torch inside the model.
ISSUES:
#12664 - [softmax](https://github.com/tenstorrent/tt-metal/issues/12664)
#12642 - [maxpool](https://github.com/tenstorrent/tt-metal/issues/12642)


### Owner: [vigneshkumarkeerthivasan](https://github.com/vigneshkeerthivasanx)
17 changes: 17 additions & 0 deletions models/demos/wormhole/convnet_mnist/convnet_mnist_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn


def custom_preprocessor(parameters, device):
parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device)
parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device)

parameters.fc1.weight = ttnn.to_device(parameters.fc1.weight, device)
parameters.fc1.bias = ttnn.to_device(parameters.fc1.bias, device)
parameters.fc2.weight = ttnn.to_device(parameters.fc2.weight, device)
parameters.fc2.bias = ttnn.to_device(parameters.fc2.bias, device)

return parameters
37 changes: 37 additions & 0 deletions models/demos/wormhole/convnet_mnist/convnet_mnist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import torchvision
import torchvision.transforms as transforms


def get_test_data(batch_size=64):
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.05,), std=(0.05,)),
]
)

test_dataset = torchvision.datasets.MNIST(
root="./data",
train=False,
download=True,
)

batch = []
images = []
outputs = []

for i in range(batch_size):
img, output = test_dataset[i]
tensor = transform(img).unsqueeze(0)
batch.append(tensor)
images.append(img)
outputs.append(output)

batch = torch.cat(batch)
return batch, images, outputs
87 changes: 87 additions & 0 deletions models/demos/wormhole/convnet_mnist/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn
import pytest

from pathlib import Path
from loguru import logger

from models.demos.wormhole.convnet_mnist.tt.convnet_mnist import (
convnet_mnist,
custom_preprocessor,
)
from models.demos.wormhole.convnet_mnist import convnet_mnist_preprocessing
from models.demos.wormhole.convnet_mnist.convnet_mnist_utils import get_test_data
from models.experimental.convnet_mnist.reference.convnet import ConvNet
from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import is_wormhole_b0, skip_for_grayskull


def model_location_generator(rel_path):
internal_weka_path = Path("/mnt/MLPerf")
has_internal_weka = (internal_weka_path / "bit_error_tests").exists()

if has_internal_weka:
return Path("/mnt/MLPerf") / rel_path
else:
return Path("/opt/tt-metal-models") / rel_path


@skip_for_grayskull()
@pytest.mark.parametrize(
"batch_size",
((16),),
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_convnet_mnist(mesh_device, batch_size, reset_seeds):
model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/")
state_dict = str(model_path / "convnet_mnist.pt")
state_dict = torch.load(state_dict)

test_input, images, output = get_test_data(batch_size)

model = ConvNet()
model.load_state_dict(state_dict)
model.eval()
torch_output = model(test_input)
batch_size = len(test_input)

inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)):
parameters = preprocess_model_parameters(
initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor
)
parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=mesh_device)

ttnn_input = torch.permute(test_input, (0, 2, 3, 1))
ttnn_input = ttnn.from_torch(
ttnn_input, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=inputs_mesh_mapper
)

ttnn_output = convnet_mnist(
input_tensor=ttnn_input,
device=mesh_device,
parameters=parameters,
mesh_mapper=inputs_mesh_mapper,
mesh_composer=output_mesh_composer,
)

ttnn_output = ttnn.to_torch(ttnn_output, mesh_composer=output_mesh_composer)

_, torch_predicted = torch.max(torch_output.data, -1)
_, ttnn_predicted = torch.max(ttnn_output.data, -1)

correct = 0
for i in range(batch_size):
if output[i] == ttnn_predicted[i]:
correct += 1
accuracy = correct / (batch_size)

logger.info(f" Accuracy for {batch_size} Samples : {accuracy}")
logger.info(f"torch_predicted {torch_predicted.squeeze()}")
logger.info(f"ttnn_predicted {ttnn_predicted.squeeze()}")
146 changes: 146 additions & 0 deletions models/demos/wormhole/convnet_mnist/tests/test_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
import time
from pathlib import Path

from loguru import logger
import ttnn
from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
from models.perf.perf_utils import prep_perf_report
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
from models.demos.wormhole.convnet_mnist.tt.convnet_mnist import (
convnet_mnist,
custom_preprocessor,
)
from models.demos.wormhole.convnet_mnist import convnet_mnist_preprocessing
from models.experimental.convnet_mnist.reference.convnet import ConvNet
from models.utility_functions import is_wormhole_b0, skip_for_grayskull


def get_expected_times(convnet_mnist):
return (15.0, 9.2)


def model_location_generator(rel_path):
internal_weka_path = Path("/mnt/MLPerf")
has_internal_weka = (internal_weka_path / "bit_error_tests").exists()

if has_internal_weka:
return Path("/mnt/MLPerf") / rel_path
else:
return Path("/opt/tt-metal-models") / rel_path


@skip_for_grayskull()
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"input_shape",
[
(2, 1, 32, 32),
],
)
def test_convnet_mnist(
mesh_device,
input_shape,
reset_seeds,
):
disable_persistent_kernel_cache()

model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/")
state_dict = str(model_path / "convnet_mnist.pt")
state_dict = torch.load(state_dict)

input_tensor = torch.randn(input_shape, dtype=torch.bfloat16)
batch_size = input_tensor.shape[0]
input_tensor = torch.permute(input_tensor, (0, 2, 3, 1))

model = ConvNet()
model.load_state_dict(state_dict)
model.eval()

inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)):
parameters = preprocess_model_parameters(
initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor
)
parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=mesh_device)

durations = []
for i in range(2):
start = time.time()
ttnn_input = ttnn.from_torch(
input_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=inputs_mesh_mapper
)

ttnn_output = convnet_mnist(
input_tensor=ttnn_input,
device=mesh_device,
parameters=parameters,
mesh_mapper=inputs_mesh_mapper,
mesh_composer=output_mesh_composer,
)
output = ttnn.from_device(ttnn_output)
output = ttnn.to_torch(output, mesh_composer=output_mesh_composer)
end = time.time()
durations.append(end - start)
enable_persistent_kernel_cache()

inference_and_compile_time, inference_time, *_ = durations

expected_compile_time, expected_inference_time = get_expected_times("convnet_mnist")
prep_perf_report(
model_name="convnet_mnist",
batch_size=batch_size,
inference_and_compile_time=inference_and_compile_time,
inference_time=inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
inference_time_cpu=0.0,
)

logger.info(f"Compile time: {inference_and_compile_time - inference_time}")
logger.info(f"Inference time: {inference_time}")
logger.info(f"Samples per second: {1 / inference_time * batch_size}")


@skip_for_grayskull()
@pytest.mark.parametrize(
"batch_size, expected_perf",
[
[1, 2885],
],
)
@pytest.mark.models_device_performance_bare_metal
def test_perf_device_bare_metal_convnet_mnist(batch_size, expected_perf):
subdir = "ttnn_convnet_mnist"
num_iterations = 1
margin = 0.03

command = f"pytest tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist_wh.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
prep_device_perf_report(
model_name=f"ttnn_convnet_mnist_wh_{batch_size}",
batch_size=batch_size,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments="",
)
Loading

0 comments on commit ac45b6e

Please sign in to comment.