Skip to content

[Feature Request] Add predicate load optimization #1479

@LJC00118

Description

@LJC00118

Required prerequisites

  • I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)

Motivation

Currently in Tilelang, conditional loads require explicit branching:

if (condition) {
    data = load(addr);
} else {
    data = zero_init();
}

Add predicate load support to generate PTX predicated instructions directly:

setp.ge.s32 %p, %cond, 0;
@p ld.global.nc.v4.s32 {%r0, %r1, %r2, %r3}, [%ptr];

Motivation: Eliminate branch divergence penalties, generate register-friendly SASS code.

Solution

Add standard predicate load functions to Tilelang, transform if (cond) load() to predicate load:

__forceinline__ __device__ int4 ldg_with_gez_pred(const int4* ptr, const int& value) {
    int4 ret = make_int4(0, 0, 0, 0);
    asm volatile(
        "{\n\t"
        "  .reg .pred p;\n\t"
        "  setp.ge.s32 p, %5, 0;\n\t"
        "  @p ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];\n\t"
        "}"
        : "+r"(ret.x), "+r"(ret.y), "+r"(ret.z), "+r"(ret.w)
        : "l"(ptr), "r"(value)
        : "memory"
    );
    return ret;
}

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions