Skip to content

[Question] how to use ptx_cp_async directly in tilelang #1493

@Dream-math

Description

@Dream-math

Required prerequisites

Questions

Hello everyone, I have a for loop which will copy global memory to shared memory like this:

            tn = T.get_thread_binding(0)
            warp_id = tn // 32
            lane_id = tn % 32
            warp_num = threads // 32
            start = bx * block_m
            # Main GEMM loop
            for k in T.Pipelined(T.ceildiv(d_model, block_k), num_stages=num_stages):
                # Load A tile
                for i in T.serial(T.ceildiv(block_m, warp_num)):
                    index = i * warp_num
                    tok = sorted_token_ids[start+index+warp_id]
                    if tok < num_valid_tokens_mul_top_k and tok >= 0:
                        for tt in T.vectorized(block_k // 32):
                            val = A[tok // top_k, k * block_k + lane_id * block_k // 32 + tt]
                            A_shared[index+warp_id, lane_id * block_k // 32 + tt] = val

I want to use copy async to replace this code, and I have tried use T.copy directly, but it didn't work, my platform is H20, I find there are ptx_cp_async and ptx_cp_async_bulk, so how to make this, I can't find any examples, thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions