Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

neuron cc error with a simple model #1011

Open
aavbsouza opened this issue Oct 11, 2024 · 0 comments
Open

neuron cc error with a simple model #1011

aavbsouza opened this issue Oct 11, 2024 · 0 comments

Comments

@aavbsouza
Copy link

I am trying to start using the neuron sdk, first with small toy examples. However even on the simplest model I am encountering compile errors. For instance:

import os
import time

import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

# XLA imports
import torch_xla.core.xla_model as xm

# Declare 3-layer MLP for MNIST dataset
class MLP(nn.Module):
  def __init__(self, input_size = 28 * 28, output_size = 10, layers = [120, 84]):
      super(MLP, self).__init__()
      self.fc1 = nn.Linear(input_size, layers[0])
      self.fc2 = nn.Linear(layers[0], layers[1])
      self.fc3 = nn.Linear(layers[1], output_size)

  def forward(self, x):
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return F.log_softmax(x, dim=1)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

# Global constants
EPOCHS = 4
WARMUP_STEPS = 2
BATCH_SIZE = 32

# Load MNIST train dataset
train_dataset = mnist.MNIST(root='./MNIST_DATA_train',
                            train=True, download=True, transform=ToTensor())

def main():
    # os.environ["NEURON_USE_EAGER_DEBUG_MODE"] = "1"
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    device = "xla"
    model = Net().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = torch.nn.NLLLoss()

 # Run the training loop
    print('----------Training ---------------')
    model.train()
    for epoch in range(EPOCHS):
        start = time.time()
        for idx, (train_x, train_label) in enumerate(train_loader):
            optimizer.zero_grad()
            train_x = train_x.to(device)
            train_label = train_label.to(device)
            output = model(train_x)
            loss = loss_fn(output, train_label)
            loss.backward()
            optimizer.step()
            xm.mark_step()
if __name__ == "__main__":
    main()

I am getting this compiler error:

root@0be38a63d4f3:/workspaces/neuron-container# python3 mnist_simple.py 
----------Training ---------------
2024-10-11 18:59:06.000915:  170396  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-10-11 18:59:06.000917:  170396  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/no-user/neuroncc_compile_workdir/2fc2cd8e-a860-46c4-bcd5-98e3520cb0e7/model.MODULE_15325567888101316750+d7517139.hlo_module.pb --output /tmp/no-user/neuroncc_compile_workdir/2fc2cd8e-a860-46c4-bcd5-98e3520cb0e7/model.MODULE_15325567888101316750+d7517139.neff --verbose=35
.root = neuronxcc/starfish/penguin/targets/sunda/Tiling.py
root = neuronxcc/starfish/penguin/targets/sunda
root = neuronxcc/starfish/penguin/targets
root = neuronxcc/starfish/penguin
root = neuronxcc/starfish

2024-10-11 18:59:09.000869:  170396  ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/no-user/neuroncc_compile_workdir/2fc2cd8e-a860-46c4-bcd5-98e3520cb0e7/model.MODULE_15325567888101316750+d7517139.hlo_module.pb', '--output', '/tmp/no-user/neuroncc_compile_workdir/2fc2cd8e-a860-46c4-bcd5-98e3520cb0e7/model.MODULE_15325567888101316750+d7517139.neff', '--verbose=35']: 2024-10-11T18:59:09Z [TEN404] Internal tensorizer error: SundaSizeTiling:Exceed tonga size without loop? - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.

2024-10-11 18:59:09.000870:  170396  ERROR ||NEURON_CC_WRAPPER||: Compilation failed for /tmp/no-user/neuroncc_compile_workdir/2fc2cd8e-a860-46c4-bcd5-98e3520cb0e7/model.MODULE_15325567888101316750+d7517139.hlo_module.pb after 0 retries.
Traceback (most recent call last):
  File "/workspaces/neuron-container/mnist_simple.py", line 85, in <module>
    main()
  File "/workspaces/neuron-container/mnist_simple.py", line 83, in main
    xm.mark_step()
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 969, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: INTERNAL: RunNeuronCCImpl: error condition error != 0: <class 'subprocess.CalledProcessError'>: Command '['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/no-user/neuroncc_compile_workdir/2fc2cd8e-a860-46c4-bcd5-98e3520cb0e7/model.MODULE_15325567888101316750+d7517139.hlo_module.pb', '--output', '/tmp/no-user/neuroncc_compile_workdir/2fc2cd8e-a860-46c4-bcd5-98e3520cb0e7/model.MODULE_15325567888101316750+d7517139.neff', '--verbose=35']' returned non-zero exit status 70.

I am using the latest DLC container with neuron 2.20. I am using podman. I am able to run some of the aws neuron samples from their git repository.

thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant