Skip to content

Commit

Permalink
[CUDA] Do not emit vector load on unaligned base offset (#9731)
Browse files Browse the repository at this point in the history
* [CUDA] Do not emit vector load on unaligned base offset

* fix alignment check condition

* black

* improve test

* improve the vectorization condtion to avoid error in yolo5 (thanks to vinx13)

* replace coeff != 1 check by coeff % lane == 0
  • Loading branch information
masahi authored Dec 14, 2021
1 parent b4d595c commit 1f5f3c9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 27 deletions.
13 changes: 13 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/
#include "codegen_c.h"

#include <tvm/arith/analyzer.h>

#include <cctype>
#include <iomanip>

Expand Down Expand Up @@ -710,8 +712,19 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
} else {
ICHECK(is_one(op->predicate)) << "predicated load is not supported";

bool can_vector_load = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
const RampNode* ramp = op->index.as<RampNode>();
ICHECK(ramp);
arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
// The condition: {k * coeff + base} divisible by the alignment for any k
if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) {
can_vector_load = true;
}
}

if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
Expand Down
78 changes: 51 additions & 27 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from tvm import te
import numpy as np
from tvm import topi
import unittest
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing
import pytest

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
Expand Down Expand Up @@ -995,29 +994,54 @@ def test_unrolled_vectorization():
tvm.testing.assert_allclose(c_np, N * np.ones((N, N)))


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_try_unaligned_vector_load():
def get_compute(N, C_N, offset):
A = te.placeholder((N,), name="A", dtype="float16")
C = te.compute((C_N,), lambda i: A[i + offset], name="C")
return N, C_N, A, C

def get_compute_unaligned():
return get_compute(3, 2, 1)

def get_compute_aligned():
return get_compute(4, 2, 2)

def build(A, C, N, C_N):
s = te.create_schedule(C.op)
oi, ii = s[C].split(C.op.axis[0], factor=2)
s[C].bind(oi, te.thread_axis("threadIdx.x"))
s[C].vectorize(ii) # BUG: misalignment

tgt = tvm.target.Target(target="cuda", host="llvm")
dev = tvm.device(tgt.kind.name, 0)
f = tvm.build(s, [A, C], tgt, name="foo")
kernel_source = f.imported_modules[0].get_source()

a_data = np.arange(0, N).astype(A.dtype)
a = tvm.nd.array(a_data, dev)
c = tvm.nd.array(np.zeros(C_N, dtype=C.dtype), dev)
f(a, c)

return a_data, c.numpy(), kernel_source

N, C_N, A, C = get_compute_unaligned()
a_data, c, kernel_source = build(A, C, N, C_N)
# (uint1*)(A + (1)) is invalid
assert "A + (1)" not in kernel_source

expected = a_data[1 : C_N + 1]
assert np.allclose(c, expected), f"expected={expected}\nactual={c}"

N, C_N, A, C = get_compute_aligned()
a_data, c, kernel_source = build(A, C, N, C_N)
# (uint1*)(A + (2)) is a valid vector load
assert "A + (2)" in kernel_source

expected = a_data[2 : C_N + 2]
assert np.allclose(c, expected), f"expected={expected}\nactual={c}"


if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_bf16_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int4()
test_cuda_make_int8()
test_cuda_inf_nan()
test_cuda_shuffle()
test_vectorized_casts()
test_cuda_reduction_binding()
test_crossthread_reduction1()
test_crossthread_reduction2()
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()
test_cuda_mix_threaded_and_normal_reduction()
test_cuda_floordiv_with_vectorization()
test_cuda_floormod_with_vectorization()
test_vectorized_intrin1()
test_vectorized_intrin2()
test_vectorized_popcount()
test_cuda_vectorize_load_permute_pad()
test_vectorized_cooperative_fetching_x()
test_vectorized_cooperative_fetching_xy()
test_unrolled_vectorization()
pytest.main([__file__])

0 comments on commit 1f5f3c9

Please sign in to comment.