Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tracking Issue] Abstraction for sub-warp reduction in the IR. #55

Open
yzh119 opened this issue Oct 14, 2022 · 0 comments
Open

[Tracking Issue] Abstraction for sub-warp reduction in the IR. #55

yzh119 opened this issue Oct 14, 2022 · 0 comments

Comments

@yzh119
Copy link
Member

yzh119 commented Oct 14, 2022

In apache/tvm#10207 we introduce sub-warp reduction.

User can use 1 warp(32 threads) to perform eight 4-element aggregations in parallel:

@T.prim_func
def subwarpreduce(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [8, 4], dtype="float32")
    B = T.match_buffer(b, [8], dtype="float32")
    for o, i, j in T.grid(1, 8, 4):
        with T.block("red"):
            vi, vj = T.axis.remap("SR", [i, j])
            with T.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vj]
    

def test_subwarp_reduce_flag():
    mod = tvm.IRModule.from_expr(subwarpreduce)
    sch = tvm.tir.Schedule(mod["main"])
    blk = sch.get_block("red")
    o, i, j = sch.get_loops(blk)
    sch.bind(i, "threadIdx.y")
    sch.bind(j, "threadIdx.x")
    sch.bind(o, "blockIdx.x") 

However, in some cases, we have bound threadIdx.x with an extent of 32 in other blocks, which leads to contradictions.
Another way is to fuse i and j, then bind fused loop to threads in a warp:

@T.prim_func
def func(A: T.Buffer[(8, 4), "float32"], B: T.Buffer[(8,), "float32"]) -> None:
    # body
    # with T.block("root")
    for o in T.thread_binding(1, thread="blockIdx.x"):
        for i_j_fused in T.thread_binding(32, thread="threadIdx.x"):
            with T.block("red"):
                vi = T.axis.spatial(8, i_j_fused // 4)
                vj = T.axis.reduce(4, i_j_fused % 4)
                T.reads(A[vi, vj])
                T.writes(B[vi])
                with T.init():
                    B[vi] = T.float32(0)
                B[vi] = B[vi] + A[vi, vj]

Nevertheless, in this case, the LowerThreadAllReduce pass cannot recognize the sub-warp reduction structure and will emit code that reduces all threads in a warp together.

There should be an abstraction to denote sub-warp reduction in TensorIR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant