Skip to content

Conversation

@kurisu6912
Copy link
Collaborator

@kurisu6912 kurisu6912 commented Dec 19, 2025

In the previous version of LazyJIT, some constant/symblic are written outside of the function. This makes the global namespace crowded and name conflict.

In this revision, we move the type annotation inside function, which is simpler and clear:

import tilelang
import tilelang.language as T
import torch

@tilelang.lazy_jit
def gemm(
    A: T.ptr, B: T.ptr, C: T.ptr,
    out_dtype: T.dtype = T.float32,
    block_M: int = 128,
    block_N: int = 128,
    block_K: int = 32,
):
    M, N, K = T.const('M, N, K')

    A: T.Tensor[[M, K], T.float16]
    B: T.Tensor[[K, N], T.float16]
    C: T.Tensor[[M, N], out_dtype]

    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N)) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), A.dtype)
        B_shared = T.alloc_shared((block_K, block_N), B.dtype)
        C_local = T.alloc_fragment((block_M, block_N), out_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[bx * block_M, k * block_K], A_shared)
            T.copy(B[k * block_K, by * block_N], B_shared)
            T.gemm(A_shared, B_shared, C_local)
        T.copy(C_local, C[bx * block_M, by * block_N])

A = torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
B = torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
C = torch.randn(1024, 1024, dtype=torch.float32, device='cuda')

gemm(A, B, C)

TODO

  • Two-phase elaboration
  • Add support for T.empty and return values
  • Syntax sugar: remove T.ptr annotation completely
  • Add error report for Python positional-only arguments
  • Write comprehensive tests
  • Add examples

Summary by CodeRabbit

  • New Features

    • Dynamic-shape support and in-body shape/annotation patterns.
    • Lazy JIT workflow for on-demand, template-driven compilation.
    • Public side-effect helper exposed for scheduling.
  • Refactor

    • Simplified public signatures by moving type/shape declarations into function bodies.
    • Streamlined JIT/caching/logging and removed legacy annotation framework.
  • Documentation

    • Notebooks updated to demonstrate direct-call JIT and dynamic-shape usage.
  • Tests

    • Tests updated to validate lazy JIT flows and consolidated signature patterns.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions

This comment was marked as resolved.

@LeiWang1999 LeiWang1999 self-requested a review December 19, 2025 10:07
@kurisu6912 kurisu6912 marked this pull request as ready for review December 23, 2025 05:03
coderabbitai[bot]

This comment was marked as resolved.

coderabbitai[bot]

This comment was marked as resolved.

@kurisu6912
Copy link
Collaborator Author

kurisu6912 commented Dec 23, 2025

This PR:

  1. Add SideEffect function to the Python side, which allows the user to write some complex shape expressoin in the function body:

    def foo(A, B):
        N = T.dynamic('N')
        # In the previous version, this generates a LetStmt
        # But in this version, we check its side effects to disable LetStmt generation
        M = N * 2 + 1
        A: T.Tensor[[N], T.float32]
        B: T.Tensor[[M], T.float32]
  2. Two-phase elaboration:

    1. Phase 1: Generating a function with a dynamic/const shape placeholder
    2. Phase 2: Replace the const placeholder with the matched tensor shape/stride
      https://github.com/kurisu6912/tilelang/blob/71ed76bbbee5beb96a12f151be83a9b20bfd67f2/tilelang/language/v2/builder.py#L919-L933
  3. Function annotation heuristic: In both JIT and LazyJIT, if the function arguments are annotated inside the function body, we apply transforms to change it to match_buffer and add it to arguments

    def foo(A, B):
        A: T.Tensor[[128], T.float32]
        B: T.float32

    Is transformed to

    def foo(A: T.ptr, B: T.float32):
        A = T.match_buffer(A, [128], T.float32)
        B = B

@tile-ai tile-ai deleted a comment from coderabbitai bot Dec 23, 2025
@kurisu6912 kurisu6912 changed the title [LazyJIT] Move type annotation inside function [LazyJIT] Move Type Annotations to Function Body Dec 23, 2025
@coderabbitai

This comment was marked as resolved.

coderabbitai[bot]

This comment was marked as resolved.

coderabbitai[bot]

This comment was marked as resolved.

@kurisu6912

This comment was marked as resolved.

@kurisu6912

This comment was marked as resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant