Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Able to access memory locations beyond those of the tensor when using tensor.accessor #1019

Closed
zafstojano opened this issue Oct 22, 2024 · 4 comments

Comments

@zafstojano
Copy link

Hello!

Is it normal behavior to access memory locations using tensor.accessor beyond those that are occupied by the tensor itself?

Following is a simple example of a kernel that instantiates an output tensor of size {5, 5, 5} and then tries to access an element out of the tensor bounds, i.e. {15, 20, 30}.

#include <stdint.h>
#include <stdlib.h>
#include <torch/torch.h>


torch::Tensor dummy_kernel_compute(const torch::Tensor& t_in) {
   // Init output tensor
   torch::Tensor t_out = torch::zeros({5, 5, 5}, torch::kFloat);
   auto t_out_accessor = t_out.accessor<float, 3>();
   
   // Simulate memory access out of the output tensor bounds (notice the {5, 5, 5} dimensions)
   t_out_accessor[15][20][30] = 1024.0;
   printf("\n===================================\n");
   printf("t_out_accessor[15][20][30] = %f", static_cast<float>(t_out_accessor[15][20][30]));
   printf("\n===================================\n");
   return t_out;
}

I would have expected this to cause an illegal memory access error, but it doesn't and the code runs without any errors. Following is a test script invoking the kernel:

import os 

os.environ["NEURON_RT_LOG_LEVEL"] = "INFO"
os.environ["NEURON_RT_GPSIMD_STDOUT_QUEUE_SIZE_BYTES"] = "65536"

import torch
import torch_xla.core.xla_model as xm
from torch_neuronx.xla_impl import custom_op

custom_op.load_library('./lib_custom_dummy_tensor.so')

device = xm.xla_device()
in_tensor = torch.ones((3, 3)).to(device).float()
out_tensor = torch.ops.my_ops.dummy_kernel(in_tensor)

print("Output tensor: ", out_tensor)

Which produces the following logs:

2024-Oct-22 19:35:52.128525 59798:59798  INFO   NRT:nrt_init                                Neuron Runtime 2.22.14.0 built on Sep  4 2024
2024-Oct-22 19:35:52.128554 59798:59798  INFO   NRT:nrt_init                                Found neuron driver: 2.18
2024-Oct-22 19:35:52.128630 59798:59798  INFO   NRT:nrt_allocate_neuron_cores               Instance info:: device count:1 cores per device:2 architecture:TRN
2024-Oct-22 19:35:52.128897 59798:59798  INFO   ENC:check_version                           Using CCOM 2.22.26.0-
2024-Oct-22 19:36:00.405574 59798:59798  INFO  TDRV:tdrv_init_one_mla_phase2                Configuring stochastic rounding seed for device: 0, nc 0. seed: 0
2024-Oct-22 19:36:00.423328 59798:59798  INFO  TDRV:tdrv_init_one_mla_phase2                Configuring stochastic rounding seed for device: 0, nc 1. seed: 0
2024-Oct-22 19:36:00.440958 59798:59798  INFO  TDRV:tdrv_init_one_mla_phase2                Initialized Device: 0 00:1e.0
2024-Oct-22 19:36:00.443606 59798:59798  INFO   NRT:nrt_infodump                            Neuron runtime information - please include in any support request:
2024-Oct-22 19:36:00.443614 59798:59798  INFO   NRT:nrt_infodump                            ------------->8------------[ cut here ]------------>8-------------
2024-Oct-22 19:36:00.443616 59798:59798  INFO   NRT:nrt_infodump                            NRT version: 2.22.14.0 (6e27b8d5b22dea0e0b8375517f4d8a009b6de5a8)
2024-Oct-22 19:36:00.443619 59798:59798  INFO   NRT:nrt_infodump                            Embedded FW version: 1.12.2.0 (f152b70c827a52701d6b9ee74ec7ff7a15971f7d)
2024-Oct-22 19:36:00.443632 59798:59798  INFO   NRT:nrt_infodump                            CCOM version: 2.22.26.0- (compat 48)
2024-Oct-22 19:36:00.443634 59798:59798  INFO   NRT:nrt_infodump                            Instance ID: i-0fec495110ce69aaf
2024-Oct-22 19:36:00.443637 59798:59798  INFO   NRT:nrt_infodump                            Cluster ID: N/A
2024-Oct-22 19:36:00.443640 59798:59798  INFO   NRT:nrt_infodump                            Kernel: Linux 5.15.0-1031-aws #35-Ubuntu SMP Fri Feb 10 02:07:18 UTC 2023
2024-Oct-22 19:36:00.443643 59798:59798  INFO   NRT:nrt_infodump                            Nodename: ip-172-31-66-134
2024-Oct-22 19:36:00.443652 59798:59798  INFO   NRT:nrt_infodump                            Driver version: 2.18.12.0

2024-Oct-22 19:36:00.443659 59798:59798  INFO   NRT:nrt_infodump                            Visible cores: 0, 1
2024-Oct-22 19:36:00.443661 59798:59798  INFO   NRT:nrt_infodump                            Environment:
2024-Oct-22 19:36:00.443669 59798:59798  INFO   NRT:nrt_infodump                                NEURON_RT_LOG_LEVEL=INFO
2024-Oct-22 19:36:00.443674 59798:59798  INFO   NRT:nrt_infodump                                NEURON_RT_GPSIMD_STDOUT_QUEUE_SIZE_BYTES=65536
2024-Oct-22 19:36:00.443681 59798:59798  INFO   NRT:nrt_infodump                                NEURON_LIBRARY_PATH=/opt/aws_neuronx_venv_pytorch/lib/python3.10/site-packages/libneuronxla/libneuronpjrt.so
2024-Oct-22 19:36:00.443686 59798:59798  INFO   NRT:nrt_infodump                                NEURON_RT_ROOT_COMM_ID=localhost:62182
2024-Oct-22 19:36:00.443688 59798:59798  INFO   NRT:nrt_infodump                                NEURON_INTERNAL_PJRT_C_API_VERSION=0.23
2024-Oct-22 19:36:00.443691 59798:59798  INFO   NRT:nrt_infodump                            -------------8<-----------[ cut to here ]-----------8<------------
Output tensor:  2024-10-22 19:36:00.000464:  59798  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-10-22 19:36:00.000464:  59798  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff --verbose=35
.
Compiler status PASS
2024-Oct-22 19:36:01.785813 59798:59798  INFO  NMGR:kmgr_load_nn_internal                   Loading NN: /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff version: 2.0, VNC range: 0/1
2024-Oct-22 19:36:01.785835 59798:59798  INFO  NMGR:enter_cache                             NEFF load cache - first worker, creating cache
2024-Oct-22 19:36:01.788818 59798:59798  INFO  NMGR:kmgr_load_nn_internal                   /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff mac_count: 0
2024-Oct-22 19:36:01.789057 59798:59798  INFO  TDRV:io_create_queues                        Vring: qPoolIO0 has 17/17 desc's, total tx/rx size: 564/564
2024-Oct-22 19:36:01.789347 59798:59798  INFO  TDRV:io_create_queues                        Vring: qSPIO0 has 17/17 desc's, total tx/rx size: 100/100
2024-Oct-22 19:36:01.791152 59798:59798  INFO  TDRV:kbl_model_add                           Added model: 1001 on: 0(nd0:nc0) count:1; mem_usage: before: (33900,33864) after (100850556,33864) 
2024-Oct-22 19:36:01.791174 59798:59798  INFO  TDRV:log_dev_mem                             ND 0:NC 0, current utilization:
        * total: 96.179MB
        * model code: 96.086MB
        * tensors: 44.000B
        * runtime: 1.062KB
        * dma rings: 94.000KB

2024-Oct-22 19:36:01.791301 59798:59798  INFO  NMGR:kmgr_load_nn_internal_v2                Loaded NN "/tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff" ID: 0, max in-fight: 2
2024-Oct-22 19:36:01.791318 59798:59798  INFO  NMGR:release_tmp_neff_cache                  NEFF load cache - cleared by the last worker
2024-Oct-22 19:36:01.791326 59798:59798  INFO  NMGR:kmgr_load_nn_post_metrics               NN: /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff, loaded successfully
2024-Oct-22 19:36:01.791339 59798:59798  INFO  NMGR:kmgr_load_nn_internal                   Loading NN: /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff version: 2.0, VNC range: 1/1
2024-Oct-22 19:36:01.791347 59798:59798  INFO  NMGR:enter_cache                             NEFF load cache - first worker, creating cache
2024-Oct-22 19:36:01.793961 59798:59798  INFO  NMGR:kmgr_load_nn_internal                   /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff mac_count: 0
2024-Oct-22 19:36:01.794119 59798:59798  INFO  TDRV:io_create_queues                        Vring: qPoolIO0 has 17/17 desc's, total tx/rx size: 564/564
2024-Oct-22 19:36:01.794396 59798:59798  INFO  TDRV:io_create_queues                        Vring: qSPIO0 has 17/17 desc's, total tx/rx size: 100/100
2024-Oct-22 19:36:01.796157 59798:59798  INFO  TDRV:kbl_model_add                           Added model: 1002 on: 1(nd0:nc1) count:1; mem_usage: before: (100850556,33864) after (100850556,100850520) 
2024-Oct-22 19:36:01.796176 59798:59798  INFO  TDRV:log_dev_mem                             ND 0:NC 1, current utilization:
        * total: 96.179MB
        * model code: 96.086MB
        * tensors: 8.000B
        * runtime: 1.062KB
        * dma rings: 94.000KB

2024-Oct-22 19:36:01.796284 59798:59798  INFO  NMGR:kmgr_load_nn_internal_v2                Loaded NN "/tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff" ID: 0, max in-fight: 2
2024-Oct-22 19:36:01.796299 59798:59798  INFO  NMGR:release_tmp_neff_cache                  NEFF load cache - cleared by the last worker
2024-Oct-22 19:36:01.796304 59798:59798  INFO  NMGR:kmgr_load_nn_post_metrics               NN: /tmp/ubuntu/neuroncc_compile_workdir/2b6d2da5-0f0a-4952-8064-210d622dab2b/model.MODULE_2051229325733854787+d7517139.neff, loaded successfully
2024-Oct-22 19:36:01.797272 59798:59898  INFO  TDRV:init_dma_queues_and_start_engines_v2    Configuring stochastic rounding: enabled=0
2024-Oct-22 19:36:01.799032 59798:59898  INFO  TDRV:pool_stdio_queue_consume_all_entries    Printing stdout from GPSIMD:

===================================
t_out_accessor[15][20][30] = 1024.000000
===================================

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='xla:0')

As you can see, I can successfully read and write to memory locations outside of the output tensor.

Is this behavior expected?

@Vie-Jay
Copy link
Contributor

Vie-Jay commented Oct 22, 2024

Having the team take a look...

@delongmeng-aws
Copy link

Took a look. I was able to reproduce this observation.

@delongmeng-aws
Copy link

We are further looking into the root cause of this issue and will provide update later.

@delongmeng-aws
Copy link

Hi @zafstojano,
Thanks for reporting this issue! After our investigation, we confirm that this is an expected behavior. Neuron doesn't perform bounds checking on the TensorAccessor, to maximize performance. The responsibility for ensuring that indices are within the valid range falls on the programmer. Note that this behavior is similar to that in PyTorch TensorAccessor. We will update our public documentation to reflect this clearly. We will close this issue, but feel free to reopen if you have any further questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants