-
Notifications
You must be signed in to change notification settings - Fork 357
improve amd examples & bugfixs #1420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…nd clarity (tile-ai#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices.
…ed flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters.
… example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance.
- Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability.
- Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization.
- Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization. - Added new CI workflows for AMD and documentation publishing. - Updated various requirements files to include necessary dependencies. - Introduced new test cases and examples for better coverage and functionality. - Refactored existing code for improved readability and maintainability.
- Introduced `example_amd_flash_attn_bwd.py` for backward attention computation using TileLang. - Added `test.sh` script to facilitate running the new example with specified parameters. - Enhanced the overall structure and organization of the example for better clarity and usability.
- Reduced the number of threads and `num_split_q` options for improved performance. - Adjusted `panel_size` options to streamline configuration settings.
- Introduced a new example script `example_amd_flash_attn_bwd.py` demonstrating the forward and backward operations of Flash Attention using TileLang. - Implemented JIT-compiled functions for both forward and backward passes, including preprocessing and postprocessing steps. - Added a main function to facilitate testing and benchmarking of the attention mechanism with configurable parameters. - Included reference implementation for validation against PyTorch's attention mechanism. This addition enhances the examples directory by providing a comprehensive guide for users to understand and utilize Flash Attention in their applications.
- Updated `example_amd_flash_attn_bwd.py` to include more comprehensive testing features for the Flash Attention implementation. - Improved the main function to allow for better parameter configuration and benchmarking. - Added validation checks against PyTorch's attention mechanism to ensure accuracy and reliability of the example. This update aims to provide users with a more robust tool for understanding and utilizing Flash Attention in their applications.
- Updated file name from `intrin_rule_hip.cc` to `intrin_rule_cuda.cc` to reflect the change in focus from HIP to CUDA intrinsic rules. - Adjusted include paths for better organization and clarity in the code structure.
…installation - Removed the installation of `flash_attn==2.5.8` to streamline the CI process. - Added a step to uninstall `torch`, `torchvision`, and `torchaudio` prior to installing pre-release versions, ensuring compatibility and reducing potential conflicts.
…rd example - Eliminated the allocation of shared memory for `dv_shared` and `dk_shared` in `example_amd_flash_attn_bwd.py` to streamline memory usage and improve performance. - This change focuses on optimizing the backward pass implementation by reducing unnecessary memory overhead.
- Eliminated the step to uninstall `torch`, `torchvision`, and `torchaudio` in the AMD CI workflow, as it is no longer required for the installation of pre-release versions. - This change simplifies the CI process and reduces potential overhead during package management.
- Updated the return statement to use std::string for concatenation in the case of 16-bit types, improving code clarity. - Added a null check for the CallNode pointer in DispatchHIPWarpActiveMask to enhance robustness and prevent potential dereferencing issues.
- Adjusted the formatting of TVM_REGISTER_OP calls for better readability by aligning method chaining. - No functional changes were made; this update focuses on code style improvements to enhance maintainability.
- Renamed the file from `intrin_rule_cuda.cc` to `intrin_rule_hip.cc` to accurately reflect the focus on HIP intrinsic rules. - Updated the file documentation to clarify its purpose as related to HIP rather than CUDA.
…test_tilelang_gemm_mfma_intrinsic.py`
…r fragments for Q and V, optimizing LDS usage as per specifications. Update related GEMM calls to reflect changes in memory management.
…dding support for k_pack configurations. Update thread configurations and memory management to optimize performance.
…umping functionality. Update memory management for shared and register fragments, and improve tensor comparison for backward pass verification with adjusted tolerances.
…tensor. Update atomic addition overload in common.h to accept lvalue references, improving type safety. Remove unused body_sr function from gemm.h to streamline code.
…on and dumping functionality. Introduce new functions for kernel artifact management, including sanitization of filenames and generation of assembly code. Update backward pass logic to utilize partial gradients and improve memory management for tensor operations.
…rations, enhancing memory efficiency. Adjust kernel configurations and ensure type safety with T.cast for gradient calculations. Refactor tensor initialization in main function to align with new data types.
…ry usage Introduce a new static method, body_sr, in gemm.h to enhance the GEMM operation by utilizing shared memory for efficient data handling. This method includes detailed logic for loading matrices A and B, performing computations, and leveraging warp-level operations. Additionally, a new gemm_sr function is added to facilitate the invocation of body_sr, improving the overall structure and performance of the GEMM implementation.
…ions for Q and V tensors. Update tensor copying and GEMM operations to improve memory efficiency and align with new fragment handling logic.
…el source handling - Removed unnecessary fragment allocation in `flashattn_bwd` to streamline memory usage. - Enhanced GEMM order for better performance by computing `dV` before `dP`, reusing fragments effectively. - Introduced separate saving and loading for host and device kernel sources in the autotuner, improving compatibility and clarity. - Added error handling for loading host and device sources, with a fallback mechanism for backward compatibility with wrapped kernels.
…mputation in flash attention backward pass - Added comprehensive validation checks for configurations in `get_bwd_configs` to prevent divide-by-zero errors. - Integrated Delta computation directly into the `flashattn_bwd` kernel, eliminating the need for a separate kernel launch and enhancing data locality. - Updated function calls in the main execution flow to reflect the changes in Delta handling.
…tations - Updated comments for clarity and consistency in `example_amd_flash_attn_bwd.py` and `example_amd_flash_attn_fwd.py`. - Improved variable naming for better readability, changing `bx` to `bx_loop_var`. - Streamlined memory allocation and kernel execution flow by integrating Delta computation directly into the main kernel. - Adjusted batch and head parameters in the forward pass example for better testing flexibility.
- Translated comments to improve clarity for non-English speakers. - Adjusted memory allocation and kernel execution flow, including changes to block sizes and fragment allocations for better performance. - Integrated Delta computation directly into the main kernel, enhancing data locality and reducing overhead. - Updated the main execution flow to utilize a new reduction kernel for accumulating gradients, improving efficiency in the backward pass.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughReworks AMD FlashAttention forward/backward to use register fragments and per-tile dQ partials with a reduction kernel, adds HIP/HSACO/ASM artifact dumping and CLI dump_dir, changes AtomicAdd to take an lvalue reference, adds an SR GEMM path, and enables -ffast-math via pass_config. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host (main)
participant Fwd as FlashAttn FWD Kernel
participant Bwd as FlashAttn BWD Kernel
participant Reduce as dQ Reduce Kernel
participant Dumper as Artifact Dumper
Host->>Fwd: launch forward (Q -> Q_fragment regs)
Fwd->>Dumper: dump forward artifacts (HIP/HSACO/asm)
Fwd-->>Host: return activations/state
Host->>Bwd: launch backward (per-tile → dQ_partial, dK, dV)
Bwd->>Dumper: dump backward artifacts
Bwd-->>Host: return per-tile partial grads
Host->>Reduce: launch flashattn_bwd_reduce_dq (accumulate dQ_partial → dQ)
Reduce->>Dumper: dump reduce artifacts
Reduce-->>Host: return final dQ
Host->>Dumper: finalize artifact writes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/amd/example_amd_flash_attn_bwd.py (1)
622-628: Default parameter mismatch between function signature and CLI.The
main()function defaults (batch=1, heads=8) differ from CLI argument defaults (batch=4, heads=16). This inconsistency may confuse developers who callmain()directly without arguments.Align the defaults:
def main(batch: int = 1, - heads: int = 8, + heads: int = 16, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1, dump_dir: str | None = None):Or update CLI defaults to match the function signature.
Also applies to: 792-794
tilelang/autotuner/param.py (1)
274-341: Fix cache-load gating: use explicitNonechecks and handle empty-list params.
- Correctness:
kernel_paramscan be an empty list. Changeand kernel_paramstoand kernel_params is not Noneto allow valid cache hits with zero parameters.- Correctness: avoid truthiness-based fallback checks. Change
if not host_kernel_source or not device_kernel_source:to explicitis Nonechecks to distinguish "not loaded" from "legitimately empty."- Major risk: The wrapped-source split on
"\n\n"assumes the delimiter never appears within device or host source. Since the wrapped format isdevice_kernel_source + "\n\n" + host_kernel_source, any internal newlines at source boundaries could corrupt the split. Treat un-splittable wrapped sources as a cache miss rather than guessing; optionally add a TODO to document the actual, unambiguous delimiter expected byget_kernel_source().- if not host_kernel_source or not device_kernel_source: + if host_kernel_source is None or device_kernel_source is None: try: wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) if os.path.exists(wrapped_kernel_path): @@ - # Try to split wrapped source (format: device_source + "\n\n" + host_source) - if "\n\n" in wrapped_source: - parts = wrapped_source.split("\n\n", 1) - device_kernel_source = parts[0] - host_kernel_source = parts[1] if len(parts) > 1 else "" - else: - # If no separator, assume it's all device source - device_kernel_source = wrapped_source - host_kernel_source = "" - except Exception as e: + # Format: device_source + "\n\n" + host_source + if "\n\n" in wrapped_source: + parts = wrapped_source.split("\n\n", 1) + device_kernel_source = parts[0] + host_kernel_source = parts[1] + else: + host_kernel_source = None + device_kernel_source = None + except Exception as e: # noqa: BLE001 (best-effort cache read) logger.error(f"Error loading wrapped kernel source code from disk: {e}") @@ - if host_kernel_source is not None and device_kernel_source is not None and kernel_params: + if host_kernel_source is not None and device_kernel_source is not None and kernel_params is not None: return JITKernel.from_database(
🧹 Nitpick comments (5)
examples/amd/example_amd_flash_attn_bwd.py (4)
92-130: Redundantarchcheck after guaranteed assignment.At line 108, the condition
if arch:is redundant because lines 98-100 ensurearchis always assigned a value when it was initiallyNone. Theget_rocm_archfunction returns a default"gfx900"if detection fails.- if arch: - cmd.append(f"--offload-arch={arch}") + cmd.append(f"--offload-arch={arch}")
174-184: Early return prevents independent assembly generation.If HSACO compilation fails (line 176), the function returns early and skips assembly generation. Since
_compile_hip_to_asmcompiles directly from source (not from HSACO), assembly generation could still succeed independently.Consider removing the early return to allow assembly generation to proceed independently:
except Exception as err: # noqa: BLE001 print(f"[TileLang] HIP compilation failed for {hip_path}: {err}") - return
538-558: Reduction kernel may have suboptimal memory access pattern.The reduction loop iterates sequentially over
num_tileswith a parallel inner loop overdim. For largenum_tiles, this results in strided memory access across the tile dimension. Consider whether a different parallelization strategy (e.g., parallel reduction across tiles) would improve performance.For now, this implementation is correct and should work. Performance optimization can be deferred if profiling shows this kernel is a bottleneck.
752-769: Tensor allocation inside benchmark loop affects measurement accuracy.The
dQ_partial_benchallocation (lines 757-759) occurs inside the timed benchmark loop, adding memory allocation overhead to each iteration. This makes the TileLang benchmark slightly unfair compared to the reference, which doesn't have equivalent allocations inside its loop.Consider pre-allocating
dQ_partial_benchbefore the timing loop or accepting this overhead as part of the realistic usage pattern:def run_complete_fwd_bwd(): o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) delta_tl_bench = bwd_prep(o_tl_bench, dO) - dQ_partial_bench = torch.zeros((batch, num_kv_tiles, seq_len, heads, dim), - dtype=torch.float16, - device=device) dK_bench = torch.zeros_like(k, dtype=torch.float32) dV_bench = torch.zeros_like(v, dtype=torch.float32)And pre-allocate outside the benchmark function, or document that this overhead is intentionally included.
tilelang/autotuner/param.py (1)
177-222: Don’t silently writeNonehost/device sources; also address Ruff BLE001 (except Exception).
Right nowget_host_source()/get_device_source()returningNone(or non-str) will fall into the broadexceptand just warn, leaving partially-populated caches.# Save host and device kernel source separately try: # Try to get host and device source from adapter adapter = kernel.adapter if hasattr(adapter, 'get_host_source') and hasattr(adapter, 'get_device_source'): host_source = adapter.get_host_source() device_source = adapter.get_device_source() + if host_source is None or device_source is None: + raise ValueError("adapter returned None host/device source") host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) if verbose: logger.debug(f"Saving host kernel source code to file: {host_kernel_path}") with open(host_kernel_path, "w") as f: f.write(host_source) @@ with open(device_kernel_path, "w") as f: f.write(device_source) @@ - except Exception as e: + except Exception as e: # noqa: BLE001 (best-effort cache write) logger.warning(f"Error saving host/device kernel source code to disk: {e}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/amd/example_amd_flash_attn_bwd.py(13 hunks)examples/amd/example_amd_flash_attn_fwd.py(5 hunks)src/op/math.cc(1 hunks)src/tl_templates/hip/common.h(1 hunks)src/tl_templates/hip/gemm.h(2 hunks)tilelang/autotuner/param.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-03T06:24:11.411Z
Learnt from: Rachmanino
Repo: tile-ai/tilelang PR: 1175
File: src/op/math.cc:44-52
Timestamp: 2025-11-03T06:24:11.411Z
Learning: In tilelang's `src/op/math.cc`, the `infinity_op` function uses `std::numeric_limits<float>::infinity()` as a placeholder for all float types (including float64 and bfloat16). The codegen layer (PrintConst:Inf) handles the correct infinity value based on the dtype field of the FloatImm node, so the specific C++ template argument doesn't matter.
Applied to files:
src/op/math.cc
🧬 Code graph analysis (4)
examples/amd/example_amd_flash_attn_fwd.py (5)
tilelang/language/allocate.py (1)
alloc_fragment(59-70)tilelang/language/fill.py (2)
fill(13-46)clear(49-73)tilelang/tileop/gemm/gemm_base.py (1)
policy(119-120)tilelang/language/loop.py (1)
Parallel(11-31)tilelang/language/copy.py (1)
copy(13-101)
tilelang/autotuner/param.py (2)
tilelang/jit/adapter/tvm_ffi.py (4)
get_host_source(299-303)get_device_source(305-309)from_database(267-297)func(195-262)tilelang/jit/adapter/cython/adapter.py (1)
from_database(148-207)
examples/amd/example_amd_flash_attn_bwd.py (5)
tilelang/contrib/rocm.py (2)
find_rocm_path(264-283)get_rocm_arch(230-261)tilelang/contrib/hipcc.py (1)
compile_hip(19-96)tilelang/language/allocate.py (2)
alloc_fragment(59-70)alloc_shared(27-42)tilelang/env.py (1)
enable(106-108)tilelang/language/loop.py (2)
Pipelined(56-93)Parallel(11-31)
src/tl_templates/hip/gemm.h (1)
src/tl_templates/cuda/gemm_sm90.h (1)
gemm_sr(328-336)
🪛 Ruff (0.14.8)
tilelang/autotuner/param.py
220-220: Do not catch blind exception: Exception
(BLE001)
291-291: Do not catch blind exception: Exception
(BLE001)
312-312: Do not catch blind exception: Exception
(BLE001)
examples/amd/example_amd_flash_attn_bwd.py
122-122: subprocess call: check for execution of untrusted input
(S603)
130-130: Avoid specifying long messages outside the exception class
(TRY003)
154-154: Probable use of insecure hash functions in hashlib: sha1
(S324)
🔇 Additional comments (17)
src/op/math.cc (1)
57-63: HIP intrinsic registration fortl.infinityis correct.The
"hip.FLowerIntrinsic"attribute key matches the pattern used throughout the codebase. The lowering pipeline insrc/transform/lower_intrin.ccconstructs attribute queries astarget + ".FLowerIntrinsic"(where target is "hip"), andsrc/target/intrin_rule_hip.ccconsistently uses"hip.FLowerIntrinsic"across all HIP intrinsic registrations. The registration mirrors the CUDA setup appropriately.examples/amd/example_amd_flash_attn_fwd.py (8)
110-113: LGTM! While loop structure is correct.The manual loop variable initialization and while condition properly implement the split-Q parallelism strategy, allowing multiple thread blocks to process different Q blocks concurrently.
125-137: Excellent refactor to use register-based fragments for Q.Moving Q from shared memory to register-based fragments (
Q_fragment) reduces shared memory pressure and access latency, which should improve overall kernel performance. The copy operation is correctly updated to target the new fragment.
165-165: Correctly updated GEMM to use Q_fragment.The GEMM call is properly updated to use the new register-based
Q_fragment, completing the refactor consistently.
173-175: LGTM! Scaling operation is correctly placed.The attention scores (
acc_s) are scaled by1/sqrt(dim)after the GEMM and before the softmax computation, which is the standard Flash Attention approach.
182-186: Excellent numerical stability improvement.The explicit check for
-infinityinm_prev[i]avoids computingexp(-inf - m_i[i]), which would be 0 anyway. Settingscale_factor[i] = 0.0directly is both more numerically stable and more efficient. This correctly handles the case where no valid attention scores existed in previous blocks.
193-196: Numerical stability improvement for masked positions.Similar to the scale factor handling, this avoids computing
exp(-inf - m_i[i])for masked attention positions (e.g., positions violating causal constraints). Setting these directly to 0.0 is both correct and more efficient.
214-214: Correct loop increment for split-Q parallelism.The increment
current_bx + num_split_qproperly implements the split-Q strategy, where thread blocks with differentb_splitvalues process interleaved Q blocks in parallel.
253-254: Note: CLI defaults updated for larger test configuration.The default batch size increased from 1 to 2 and heads from 8 to 16, providing a more representative workload for testing and benchmarking.
src/tl_templates/hip/common.h (1)
113-117: Consider constraining the reference overload to prevent ambiguity with pointer lvalues.Both
AtomicAdd(T1* ...)andAtomicAdd(T1& ...)are viable for pointer lvalue arguments (e.g.,float* p; AtomicAdd(p, v)), creating template overload ambiguity. While no problematic calls exist in the current codebase, this could cause issues with future code.Suggested approaches:
- Add
static_assert(!std::is_pointer<T1>::value, "...")to the reference overload with#include <type_traits>.- Remove the pointer overload entirely to match the CUDA design, which provides only a reference overload.
examples/amd/example_amd_flash_attn_bwd.py (5)
247-260: LGTM! Fragment-based Q storage improves register utilization.The change from shared memory to fragment storage for Q (
Q_fragment) is a good optimization for the forward pass, allowing better register reuse during the QK GEMM operation.
370-394: LGTM! Delta preprocessing kernel is correct.The preprocessing kernel correctly computes
Delta = sum(O * dO)per sequence position, which is needed for the backward pass softmax gradient computation.
477-509: Clear GEMM ordering optimization with good comments.The reordered GEMM sequence (dV first, then dP) with explicit comments explaining the rationale improves code maintainability. The
ds_computeallocation inside the loop avoids layout conflicts as noted in the comment.
561-620: LGTM! Comprehensive debug and benchmark utilities.The
debug_tensor_comparisonfunction provides thorough diagnostics including NaN/Inf detection, relative differences, and position of maximum error. The benchmark function follows standard warmup/repeat patterns.
449-452: No action needed — thecontinuestatement is in a regular Python loop, not a TileLang kernel loop.Lines 449–452 show a standard Python
forloop overrange(groups)with a Pythoncontinuestatement. The TileLang pipelined loop (for k in T.Pipelined(...)) begins at line 454, after thecontinue. The condition skips entire iterations of the Python loop whenq_head >= heads, which is valid Python control flow. This has no TileLang-specific semantics concerns.src/tl_templates/hip/gemm.h (1)
379-387:gemm_sr(...)wrapper matches the CUDA entry-point shape—nice parity.The wrapper cleanly mirrors the CUDA
gemm_srpattern (delegating intoGemmTensorOp::body_sr), which should keep specialization plumbing consistent across backends.tilelang/autotuner/param.py (1)
27-28: Good: explicit host/device cache filenames make the cache layout clearer.
| if not dump_dir: | ||
| dump_dir = os.path.join(os.getcwd(), "tilelang_kernel_dumps") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kernel artifacts are always dumped even without explicit user request.
When dump_dir is None (the default), lines 633-634 set it to "tilelang_kernel_dumps", causing artifacts to be written to the current working directory. This may be unexpected behavior for users who didn't explicitly request artifact dumping.
Consider keeping dump_dir as None to disable dumping by default:
- if not dump_dir:
- dump_dir = os.path.join(os.getcwd(), "tilelang_kernel_dumps")
+ # dump_dir remains None if not specified, skipping artifact dumpingOr alternatively, make the default behavior explicit in the help text and require an opt-in flag like --dump to enable artifact generation.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/amd/example_amd_flash_attn_bwd.py around lines 633-634, the code
unconditionally sets dump_dir to os.path.join(os.getcwd(),
"tilelang_kernel_dumps") when dump_dir is None, causing kernel artifacts to be
dumped by default; change behavior so dumping is opt-in: either leave dump_dir
as None when not provided (so no dumps are written), or add a CLI flag (e.g.,
--dump) and only set dump_dir when that flag is present, and update the help
text to document the opt-in behavior.
| static TL_DEVICE void body_sr(A_type *A_shared, B_type *B_local, | ||
| C_type *C_local) { | ||
| auto tid = threadIdx.x; | ||
| auto warp_id = tid / warp_size; | ||
| auto warp_n = warp_id / block_row_warps; | ||
| auto warp_m = warp_id % block_row_warps; | ||
| auto warp_row_tiles = warp_rows * micro_size_x; | ||
| auto warp_col_tiles = warp_cols * micro_size_y; | ||
|
|
||
| auto lane_id = tid % warp_size; | ||
| auto tx = lane_id; | ||
|
|
||
| constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size; | ||
| constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size; | ||
| constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size; | ||
|
|
||
| constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile; | ||
| constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile; | ||
|
|
||
| A_type A_local[warp_rows * kPack * local_size_a]; | ||
|
|
||
| for (int ki = 0; ki < inner_k; ki++) { | ||
| // Fetch A from shared memory into register | ||
| for (int i = 0; i < warp_rows; i++) { | ||
| const auto l = warp_m * warp_row_tiles + i * micro_size_x; | ||
| const auto r = ki * (kPack * micro_size_k); | ||
| for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) { | ||
| if constexpr (TransposeA) { | ||
| auto [row, col] = reverse_index_map_transposed(lane_id, local_id); | ||
| A_local[i * kPack * local_size_a + local_id] = | ||
| A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>( | ||
| r + row, l + col)]; | ||
| } else { | ||
| auto [row, col] = reverse_index_map(lane_id, local_id); | ||
| A_local[i * kPack * local_size_a + local_id] = | ||
| A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>( | ||
| l + row, r + col)]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Compute | ||
| for (int kp = 0; kp < kPack; kp++) { | ||
| for (int i = 0; i < warp_rows; ++i) { | ||
| for (int j = 0; j < warp_cols; ++j) { | ||
| auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j); | ||
| // B is already in register, layout: [inner_k][warp_cols][kPack][local_size_b] | ||
| auto b_ptr = ((B_type *)B_local) + | ||
| (ki * warp_cols * kPack + j * kPack + kp) * vec_size; | ||
| // A is loaded from shared memory, layout: [warp_rows][kPack][local_size_a] | ||
| auto a_ptr = ((A_type *)A_local) + (i * kPack + kp) * vec_size; | ||
|
|
||
| // Use the trait to select the correct MFMA instruction, either fp8, | ||
| // fp16 or bf16 currently | ||
| MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix newly-introduced unused locals in body_sr (can fail under -Werror).
body_sr declares tx, local_size_b, local_size_c, and last_dim_b but never uses them (notably last_dim_b is new here). If your HIP build treats warnings as errors, this will break compilation.
static TL_DEVICE void body_sr(A_type *A_shared, B_type *B_local,
C_type *C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
auto warp_m = warp_id % block_row_warps;
auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y;
auto lane_id = tid % warp_size;
- auto tx = lane_id;
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
- constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
- constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;
constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile;
- constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile;If you want to keep symmetry with other paths for readability, consider [[maybe_unused]] instead.
🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 293 to 352, body_sr declares unused
locals tx, local_size_b, local_size_c, and last_dim_b which can trigger -Werror;
remove the unused variables or mark them [[maybe_unused]] (preferred to retain
symmetry/readability) so they no longer produce unused-variable warnings,
ensuring compilation succeeds under warning-as-error builds.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# First, let's look at the file structure and find the gemm_sr function signature
cd /tmp/repo 2>/dev/null || cd .
find . -name "gemm.h" -path "*/hip/*" 2>/dev/null | head -5Repository: tile-ai/tilelang
Length of output: 89
🏁 Script executed:
#!/bin/bash
# Search for gemm_sr calls and instantiations
rg -n "gemm_sr\s*<" -C 3 --max-count=20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Search for gemm_sr function calls with parentheses
rg -n "gemm_sr\s*\(" -C 3 --max-count=20Repository: tile-ai/tilelang
Length of output: 13571
🏁 Script executed:
#!/bin/bash
# Search for B_local allocations and usage
rg -n "B_local" -C 2 --max-count=30Repository: tile-ai/tilelang
Length of output: 50373
🏁 Script executed:
#!/bin/bash
# Search for body_sr calls
rg -n "body_sr\s*\(" -C 3 --max-count=20Repository: tile-ai/tilelang
Length of output: 2498
🏁 Script executed:
#!/bin/bash
# Let's look at the actual file to understand the gemm_sr and body_sr signature
cat -n src/tl_templates/hip/gemm.h | sed -n '250,290p'Repository: tile-ai/tilelang
Length of output: 2071
🏁 Script executed:
#!/bin/bash
# Also check the context around lines 379-387 mentioned at the end
cat -n src/tl_templates/hip/gemm.h | sed -n '370,395p'Repository: tile-ai/tilelang
Length of output: 1030
🏁 Script executed:
#!/bin/bash
# Find test cases and usage of gemm_sr
rg -n "gemm_sr" testing/python/ --max-count=30 -A 3 -B 1Repository: tile-ai/tilelang
Length of output: 4441
🏁 Script executed:
#!/bin/bash
# Look at the actual body_sr code around the indexing
cat -n src/tl_templates/hip/gemm.h | sed -n '240,290p'Repository: tile-ai/tilelang
Length of output: 2538
🏁 Script executed:
#!/bin/bash
# Check allocation and sizing of B_local more carefully
cat -n src/tl_templates/hip/gemm.h | sed -n '230,260p'Repository: tile-ai/tilelang
Length of output: 1457
🏁 Script executed:
#!/bin/bash
# Get the full body_sr implementation - it should be the static void body_sr starting around line 293
cat -n src/tl_templates/hip/gemm.h | sed -n '293,360p'Repository: tile-ai/tilelang
Length of output: 3321
🏁 Script executed:
#!/bin/bash
# Look at the run_gemm_sr implementation to see how B_region/B_local is allocated
cat -n testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py | sed -n '326,378p'Repository: tile-ai/tilelang
Length of output: 1520
🏁 Script executed:
#!/bin/bash
# Find matmul_sr definition in the test file
rg -n "def matmul_sr" testing/python/tilelibrary/ -A 50 --max-count=1 | head -80Repository: tile-ai/tilelang
Length of output: 5161
🏁 Script executed:
#!/bin/bash
# Also check if there's a tilelang source file that defines this
find . -name "*.py" -path "*/tilelang/*" -exec grep -l "def matmul_sr" {} \;Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Look for gemm_v2 definition
rg -n "def gemm_v2|gemm_v2\s*=" tilelang/language/ --max-count=5 -A 3Repository: tile-ai/tilelang
Length of output: 279
🏁 Script executed:
#!/bin/bash
# Also search for where gemm_sr is actually instantiated/used
rg -n "T.gemm_v2" tilelang/ --max-count=10 -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Look at the gemm_v2 implementation
cat -n tilelang/language/gemm.py | sed -n '129,180p'Repository: tile-ai/tilelang
Length of output: 955
🏁 Script executed:
#!/bin/bash
# Check if there are any direct usages of gemm_sr in C++ test files or examples
rg -n "gemm_sr\s*<" src/ --max-count=20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Check if body_sr is only an internal implementation or if there's external usage
rg -n "body_sr" src/tl_templates/ -B 2 -A 2 | head -60Repository: tile-ai/tilelang
Length of output: 2922
Document the gemm_sr contract: pB must be sized [inner_k][warp_cols][kPack][vec_size] (and aligned).
body_sr indexes B_local as (ki * warp_cols * kPack + j * kPack + kp) * vec_size (lines 340-341), requiring callers to provide pB containing all inner_k slices, not just a single fragment. A mismatch causes out-of-bounds reads and silent data corruption.
Actionable fix:
- Add a doc comment on
gemm_srdescribing the requiredpBlayout and size (similar to the existing comment at line 339). - Consider using
alignas(8)on the caller side forpBsince MFMA helpers perform vectorized loads.
Also applies to: gemm_rs (lines 369-377) has the analogous body_rs with similar layout expectations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
408-516: InconsistentGemmWarpPolicyusage between forward and backward examplesWhile the forward example uses the directly-imported
GemmWarpPolicy.FullRow, the backward example usesT.GemmWarpPolicy.FullRow. Both are valid (the namespace does exposeGemmWarpPolicy), but for consistency with the forward example and to match the direct import at line 19, useGemmWarpPolicy.FullRowwithout theT.prefix throughout:- T.gemm(K_shared, - q_shared, - qkT, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(K_shared, + q_shared, + qkT, + transpose_B=True, + policy=GemmWarpPolicy.FullRow) - T.gemm(P_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) + T.gemm(P_cast, do_shared, dv, policy=GemmWarpPolicy.FullRow) - T.gemm(V_fragment, - do_shared, - dP, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) + T.gemm(V_fragment, + do_shared, + dP, + transpose_B=True, + policy=GemmWarpPolicy.FullRow) - T.gemm(ds_compute, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + T.gemm(ds_compute, q_shared, dk, policy=GemmWarpPolicy.FullRow)
♻️ Duplicate comments (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
622-635: Dumping is enabled by default (writes into CWD) — make it opt-in
if not dump_dir: dump_dir = os.path.join(os.getcwd(), "tilelang_kernel_dumps")(Line 633-635) forces artifact dumping even when the user didn’t request it. This matches the earlier review concern. Prefer leavingdump_dir=Noneunless explicitly set via--dump_dir(or add an explicit--dumpflag).- if not dump_dir: - dump_dir = os.path.join(os.getcwd(), "tilelang_kernel_dumps") + # Keep dumping opt-in: only dump when --dump_dir is explicitly provided.
🧹 Nitpick comments (3)
examples/amd/example_amd_flash_attn_bwd.py (3)
75-129: hipcc subprocess: add a timeout + tighten option typing (deadisinstance(options, str)branch)
_compile_hip_to_asm()usessubprocess.run(...)(Line 120-126) without a timeout; a hung compiler will stall the example. Alsooptionsis annotatedlist[str] | None(Line 92-93) but later checksisinstance(options, str)(Line 112-115), which should be unreachable—consider simplifying to list-only handling.
131-183: Artifact digest: prefer sha256/blake2 (sha1 flagged) and keep dumping strictly opt-in*
Usinghashlib.sha1(...)for a filename digest (Line 152) is functionally fine but gets flagged by security tooling; considersha256orblake2bto avoid noise. Also, since this writes source/HSACO/asm to disk, it should remain opt-in viadump_dir(see separate comment on main’s defaulting behavior).
538-559: Reduction kernel correctness OK; performance may be dominated by serial num_tiles loop
flashattn_bwd_reduce_dqaccumulates overnum_tilesin a simple loop (Line 550-554). This is straightforward, but likely bandwidth/latency heavy for largenum_tiles. If this is meant beyond debugging, consider a more parallel reduction strategy.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/amd/example_amd_flash_attn_bwd.py(13 hunks)examples/amd/example_amd_flash_attn_fwd.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/amd/example_amd_flash_attn_bwd.py (3)
tilelang/contrib/hipcc.py (1)
compile_hip(19-91)tilelang/language/allocate.py (1)
alloc_fragment(71-84)tilelang/language/copy.py (1)
copy(14-95)
🪛 Ruff (0.14.8)
examples/amd/example_amd_flash_attn_bwd.py
120-120: subprocess call: check for execution of untrusted input
(S603)
128-128: Avoid specifying long messages outside the exception class
(TRY003)
152-152: Probable use of insecure hash functions in hashlib: sha1
(S324)
🔇 Additional comments (5)
examples/amd/example_amd_flash_attn_fwd.py (4)
11-23: Autotune tensor supply changed, but GPU-only supply helper is now unused — verify autotune still benchmarks on-device
Line 78 removed the explicit supply program, butsupply_tensors_gpu()(Line 11-23) strongly suggests this was needed to force GPU tensors. Please verify@tilelang.autotune(..., cache_input_tensors=True)still creates/keeps benchmark tensors on the intended device under ROCm/HIP (and doesn’t silently fall back to CPU tensors or trigger host↔device copies).Also applies to: 78-78
123-126: bx_loop_var While-loop looks consistent; sanity-check coverage and termination with num_split_q stride
Thebx_loop_var = b_splitinit (Line 123-125) +bx_loop_var = current_bx + num_split_qstep (Line 220) is a reasonable “grid-stride loop” over Q blocks. Just ensure your intended mapping is: eachb_splitlane coversb_split, b_split+num_split_q, ...and thatnum_split_qisn’t tuned so large that it hurts occupancy/coverage for smallnum_q_blocks.Also applies to: 135-137, 220-220
138-152: Q in fragment/register path is coherent; dataflow updates look internally consistent
Switching Q toQ_fragment(Line 138-152) and feeding it into GEMM (Line 170-177) matches the stated intent and reduces LDS pressure.Also applies to: 170-177
187-203: Softmax stability guards are a solid improvement (m_prev=-inf / acc_s=-inf paths)
Them_prev == -inf → scale_factor=0path (Line 187-193) andacc_s == -inf → 0 else exp(...)(Line 198-203) are sensible protections againstexp(-inf - -inf)style NaNs.examples/amd/example_amd_flash_attn_bwd.py (1)
245-259: Forward-in-bwd kernel updates mirror the fwd example and look consistent
The Q→Q_fragmentswitch (Line 245-259) and GEMM input update (Line 275-282) match the forward example’s intent.Also applies to: 275-282
| dQ_partial: T.Tensor([batch, T.ceildiv(seq_len, block_M), seq_len, heads, | ||
| dim], dtype), | ||
| dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file structure and the specific lines mentioned
wc -l examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 106
🏁 Script executed:
# Check the tensor definition at lines 416-418
sed -n '410,425p' examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 1200
🏁 Script executed:
# Check the allocation in main() around lines 686-689
sed -n '680,695p' examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 731
🏁 Script executed:
# Check the allocation in benchmarking around lines 750-753
sed -n '745,760p' examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 717
🏁 Script executed:
# Look for default values (batch, seq_len, block_M, heads, dim)
grep -n "batch\|seq_len\|block_M\|heads\|dim" examples/amd/example_amd_flash_attn_bwd.py | head -40Repository: tile-ai/tilelang
Length of output: 2437
🏁 Script executed:
# Search for where parameters are defined/initialized with values
sed -n '1,100p' examples/amd/example_amd_flash_attn_bwd.py | grep -E "batch|seq_len|heads|dim|="Repository: tile-ai/tilelang
Length of output: 1682
🏁 Script executed:
# Check the main function and parameter initialization
sed -n '650,690p' examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 1826
🏁 Script executed:
# Look for argparse or default parameter definitions
grep -n "argparse\|default\|parse_args\|batch\|seq_len\|heads\|dim" examples/amd/example_amd_flash_attn_bwd.py | tail -50Repository: tile-ai/tilelang
Length of output: 3418
🏁 Script executed:
# Check the beginning of main or where defaults are set
sed -n '600,650p' examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 1521
🏁 Script executed:
# Verify that both allocations are truly separate tensors (not reused)
sed -n '680,700p' examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 957
🏁 Script executed:
# Check the benchmarking function to see if it reuses the tensor or creates a new one
sed -n '740,765p' examples/amd/example_amd_flash_attn_bwd.pyRepository: tile-ai/tilelang
Length of output: 1128
🏁 Script executed:
# Let's calculate actual memory with realistic parameters
python3 << 'EOF'
import math
# CLI defaults
batch_cli = 4
heads_cli = 16
seq_len = 4096
dim_cli = 64
# Function defaults
batch_func = 1
heads_func = 8
dim_func = 128
# block_M options
block_M_values = [32, 64, 128, 256]
print("Memory calculation for dQ_partial (float16 = 2 bytes per element)")
print("=" * 70)
print("\nWith CLI defaults (batch=4, heads=16, seq_len=4096, dim=64):")
for block_M in block_M_values:
num_tiles = math.ceil(seq_len / block_M)
size_bytes = batch_cli * num_tiles * seq_len * heads_cli * dim_cli * 2
size_mb = size_bytes / (1024 * 1024)
size_gb = size_bytes / (1024 * 1024 * 1024)
print(f" block_M={block_M:3d}: {num_tiles:4d} tiles → {size_mb:7.1f} MB ({size_gb:.3f} GB) per allocation")
print("\nWith function defaults (batch=1, heads=8, seq_len=4096, dim=128):")
for block_M in block_M_values:
num_tiles = math.ceil(seq_len / block_M)
size_bytes = batch_func * num_tiles * seq_len * heads_func * dim_func * 2
size_mb = size_bytes / (1024 * 1024)
size_gb = size_bytes / (1024 * 1024 * 1024)
print(f" block_M={block_M:3d}: {num_tiles:4d} tiles → {size_mb:7.1f} MB ({size_gb:.3f} GB) per allocation")
print("\nWorse case with larger batch (batch=8, heads=16, seq_len=4096, dim=128):")
for block_M in [32, 64]: # smallest block_M values
num_tiles = math.ceil(seq_len / block_M)
size_bytes = 8 * num_tiles * seq_len * heads_cli * dim_func * 2
size_mb = size_bytes / (1024 * 1024)
size_gb = size_bytes / (1024 * 1024 * 1024)
print(f" block_M={block_M:3d}: {num_tiles:4d} tiles → {size_mb:7.1f} MB ({size_gb:.3f} GB) per allocation")
print("\nTwo allocations total with CLI defaults (block_M=32):")
num_tiles = math.ceil(seq_len / 32)
size_bytes_one = batch_cli * num_tiles * seq_len * heads_cli * dim_cli * 2
size_bytes_two = size_bytes_one * 2
print(f" Total for both allocations: {size_bytes_two / (1024**3):.2f} GB")
EOFRepository: tile-ai/tilelang
Length of output: 146
dQ_partial allocation scales dangerously with batch and dim
dQ_partial has shape [batch, num_tiles, seq_len, heads, dim] where num_tiles = ceildiv(seq_len, block_M) (lines 416-418). With float16 dtype, this is allocated twice: once in main() (lines 686-689) and again during benchmarking (lines 750-753).
Memory usage grows linearly with batch and dim. With default CLI parameters (batch=4, heads=16, seq_len=4096, dim=64, block_M=32), each allocation is ~256 MB. However, with larger batch or dim values (e.g., batch=8, dim=128), a single allocation can exceed 1 GB, and the dual allocation during benchmarking compounds this. Consider (a) reducing default problem sizes, (b) gating this path behind a flag, or (c) using a less memory-intensive accumulation strategy.
🤖 Prompt for AI Agents
In examples/amd/example_amd_flash_attn_bwd.py around lines 416-418, dQ_partial
is allocated with shape [batch, ceildiv(seq_len, block_M), seq_len, heads, dim]
and is created twice (once in main and again during benchmarking), causing large
memory use; fix by avoiding the duplicate large allocation and using a
memory-efficient accumulation: either (1) allocate dQ_partial once and reuse the
same buffer for benchmarking (move allocation to a shared scope and pass the
same reference), or (2) change the algorithm to accumulate dQ in-place per tile
(emit partial results into the final dQ buffer rather than storing full
dQ_partial), and add a CLI toggle (e.g., --use-partial-accum) to gate the
high-memory path while reducing default problem sizes if needed. Ensure any new
flag defaults to the low-memory behavior and clean up the second allocation site
so only one large buffer exists or none if using in-place accumulation.
| loop_end_k = ( | ||
| T.ceildiv(q_block_offset + | ||
| block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
find . -type f -name "*flash*" 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 422
🏁 Script executed:
find . -type d -name "amd" 2>/dev/null | head -10Repository: tile-ai/tilelang
Length of output: 123
🏁 Script executed:
git ls-files | grep amd | head -20Repository: tile-ai/tilelang
Length of output: 617
🏁 Script executed:
find . -name "example_amd_flash_attn_fwd.py" 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 104
🏁 Script executed:
rg "ceildiv.*block_N" -l 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 9224
🏁 Script executed:
rg "T.copy" -l | head -20Repository: tile-ai/tilelang
Length of output: 959
🏁 Script executed:
rg "T\.copy" -A 5 -B 5 | head -100Repository: tile-ai/tilelang
Length of output: 10180
🏁 Script executed:
ls -la examples/amd/ 2>/dev/null || echo "Directory does not exist"Repository: tile-ai/tilelang
Length of output: 337
🏁 Script executed:
rg "class.*Tile|def.*T\." --type py | head -40Repository: tile-ai/tilelang
Length of output: 5203
🏁 Script executed:
rg "\"copy\"|'copy'" --type py -B 3 -A 10 | head -80Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
rg "\[.*:.*\+.*block" --type py | head -40Repository: tile-ai/tilelang
Length of output: 6360
🏁 Script executed:
rg "guard|predicate|mask.*tile|mask.*load" --type py | head -30Repository: tile-ai/tilelang
Length of output: 2916
🏁 Script executed:
rg "kv_idx.*block_N|kv_idx.*\+" --type py -B 2 -A 5 | head -60Repository: tile-ai/tilelang
Length of output: 2812
🏁 Script executed:
git ls-files "*amd*"Repository: tile-ai/tilelang
Length of output: 617
🏁 Script executed:
grep -r "T\.copy" . --include="*.md" --include="*.py" --include="*.rst" 2>/dev/null | head -20Repository: tile-ai/tilelang
Length of output: 2426
🏁 Script executed:
git log --oneline 2>/dev/null | head -20Repository: tile-ai/tilelang
Length of output: 112
🏁 Script executed:
rg "kv_idx" --type py -B 3 -A 3 | head -80Repository: tile-ai/tilelang
Length of output: 2720
🏁 Script executed:
sed -n '145,170p' examples/amd/example_amd_flash_attn_fwd.pyRepository: tile-ai/tilelang
Length of output: 1277
🏁 Script executed:
sed -n '1,50p' examples/amd/example_amd_flash_attn_fwd.pyRepository: tile-ai/tilelang
Length of output: 1926
🏁 Script executed:
grep -n "loop_end_k\|block_N" examples/amd/example_amd_flash_attn_fwd.pyRepository: tile-ai/tilelang
Length of output: 1281
🏁 Script executed:
rg "def copy" --type py | grep -v test | head -20Repository: tile-ai/tilelang
Length of output: 145
🏁 Script executed:
cat docs/programming_guides/language_basics.md | head -200Repository: tile-ai/tilelang
Length of output: 6311
🏁 Script executed:
sed -n '1,100p' tilelang/language/tir/op.pyRepository: tile-ai/tilelang
Length of output: 2904
🏁 Script executed:
rg "class Tensor" tilelang/language/ --type py -A 30 | head -100Repository: tile-ai/tilelang
Length of output: 6014
🏁 Script executed:
cat tilelang/language/tir/op.py | grep -A 50 "class Buffer"Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
rg "T\.min|min\(.*seq_len|min\(.*loop" examples/flash_attention/ --type py | head -20Repository: tile-ai/tilelang
Length of output: 1690
🏁 Script executed:
sed -n '159,164p' examples/amd/example_amd_flash_attn_fwd.pyRepository: tile-ai/tilelang
Length of output: 405
🏁 Script executed:
ls -la tilelang/transform/ 2>/dev/null || find tilelang -type d -name "*transform*"Repository: tile-ai/tilelang
Length of output: 505
🏁 Script executed:
rg "buffer_store.*predicate\|predicate.*buffer" tilelang/ --type py | head -20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
find tilelang -name "*.py" | xargs grep -l "\"copy\"" | xargs grep -B 5 -A 20 "def copy\|\"copy\"\:"Repository: tile-ai/tilelang
Length of output: 44
🏁 Script executed:
grep -n "def copy" tilelang/language/copy.py -A 50 | head -100Repository: tile-ai/tilelang
Length of output: 2544
🏁 Script executed:
grep -n "LegalizeSafeMemoryAccess\|guard\|bounds" tilelang/language/transform/ -r --include="*.py" | head -20Repository: tile-ai/tilelang
Length of output: 121
🏁 Script executed:
rg "LegalizeSafeMemoryAccess" tilelang/ --type pyRepository: tile-ai/tilelang
Length of output: 389
🏁 Script executed:
find tilelang -name "*legalize*" -o -name "*safe*" -type f | head -10Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
grep -l "safe\|guard\|bound" tilelang/transform/*.pyRepository: tile-ai/tilelang
Length of output: 136
🏁 Script executed:
cat docs/compiler_internals/tensor_checks.mdRepository: tile-ai/tilelang
Length of output: 14884
🏁 Script executed:
ls testing/python/amd/Repository: tile-ai/tilelang
Length of output: 160
🏁 Script executed:
sed -n '168,178p' examples/flash_attention/example_mha_fwd_bshd.pyRepository: tile-ai/tilelang
Length of output: 576
Tail-tile bounds: add explicit iteration limit for last K/V tile
loop_end_k is computed via ceildiv(seq_len, block_N) (Lines 153-156), causing the loop to iterate beyond actual sequence data when seq_len % block_N != 0. The subsequent T.copy(K[bz, kv_idx:kv_idx+block_N, ...], K_shared) attempts to read past seq_len (e.g., indices 96–127 from a 100-element dimension). While Python slicing clips OOB indices, the copied data becomes incomplete (4 elements instead of expected 32), leaving K_shared tail uninitialized. Add explicit bounds: use loop_end_k = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) (per example_mha_fwd_bshd.py pattern) to prevent partial-tile copies into full-sized shared buffers. Also, the causal case incorrectly uses ceildiv(q_block_offset + block_M, block_N) instead of the bounded min-form.
🤖 Prompt for AI Agents
examples/amd/example_amd_flash_attn_fwd.py around lines 153 to 156: the computed
loop_end_k allows iterating past the actual sequence length for tail K/V tiles
causing partial slices to be copied into full-sized shared buffers and leaving
tails uninitialized; change the causal branch to clamp the
ceildiv(q_block_offset + block_M, block_N) with the ceildiv of (bx+1)*block_M
using T.min (matching example_mha_fwd_bshd.py pattern) so loop_end_k =
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) when
causal, otherwise keep T.ceildiv(seq_len, block_N), ensuring full-sized shared
copies never read past seq_len.
| def main(batch: int = 1, | ||
| heads: int = 8, | ||
| seq_len: int = 4096, | ||
| dim: int = 128, | ||
| is_causal: bool = False, | ||
| groups: int = 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defaults mismatch: main() defaults differ from CLI defaults
main(batch=1, heads=8, ...) (Line 224-226) conflicts with argparse defaults batch=2, heads=16 (Line 257-258). Consider aligning them to avoid confusion when calling main() programmatically vs via CLI.
Also applies to: 257-263
🤖 Prompt for AI Agents
In examples/amd/example_amd_flash_attn_fwd.py around lines 224-229 (and also
affecting lines 257-263), the default parameter values in main(batch=1, heads=8,
...) do not match the argparse defaults (batch=2, heads=16); update them to be
consistent by either changing main()'s default arguments to match the CLI
defaults or centralizing the defaults as module-level constants and using those
constants both in the main signature and when configuring argparse, so the
programmatic and CLI entry points share identical defaults.
…ash attention examples * Simplified function signatures and argument formatting in `example_amd_flash_attn_bwd.py` and `example_amd_flash_attn_fwd.py`. * Enhanced code clarity by adjusting spacing and line breaks in tensor copy operations. * Updated argument parsing in both examples for consistency in style. * Minor comments added in `gemm.h` to clarify layout descriptions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
examples/amd/example_amd_flash_attn_fwd.py (1)
12-23: Consider removing unusedsupply_tensors_gpufunction.This function is no longer referenced after the autotune decorator was updated (line 79) to use
cache_input_tensors=True. If it's not used elsewhere, consider removing it to reduce code clutter.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/amd/example_amd_flash_attn_bwd.py(13 hunks)examples/amd/example_amd_flash_attn_fwd.py(5 hunks)src/tl_templates/hip/gemm.h(2 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/tl_templates/hip/gemm.h
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/tl_templates/hip/gemm.h
🧬 Code graph analysis (2)
src/tl_templates/hip/gemm.h (1)
src/tl_templates/cuda/gemm_sm90.h (1)
gemm_sr(328-336)
examples/amd/example_amd_flash_attn_bwd.py (2)
tilelang/contrib/rocm.py (2)
find_rocm_path(267-286)get_rocm_arch(231-264)tilelang/contrib/hipcc.py (1)
compile_hip(19-91)
🪛 Ruff (0.14.8)
examples/amd/example_amd_flash_attn_bwd.py
117-117: subprocess call: check for execution of untrusted input
(S603)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
149-149: Probable use of insecure hash functions in hashlib: sha1
(S324)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (9)
src/tl_templates/hip/gemm.h (2)
293-354: LGTM: SR computation path follows established patterns.The
body_srimplementation correctly mirrors the structure ofbody_rs, loading A from shared memory and B from registers, then performing MFMA accumulation. The logic is sound and consistent with other compute paths.Note: The unused variable warnings and layout documentation concerns are already covered in previous review comments.
381-389: LGTM: Entry point follows established conventions.The
gemm_srfunction follows the same pattern asgemm_rsand correctly delegates toGemmTensorOp::body_sr.examples/amd/example_amd_flash_attn_fwd.py (2)
139-149: LGTM: Q_fragment migration is clean and well-documented.The shift from Q_shared to Q_fragment is clearly documented (line 139 comment) and consistently applied throughout the kernel. This aligns with the SR computation pattern introduced in the HIP GEMM templates.
Also applies to: 167-167
184-198: LGTM: Improved numerical stability with explicit infinity checks.The explicit checks for
-infinityprevent NaN propagation fromexp(-inf - val)and correctly set scale factors to 0.0. This is more robust than relying on implicit IEEE 754 behavior.examples/amd/example_amd_flash_attn_bwd.py (5)
79-88: LGTM: Include path retrieval is clean.The
_get_hip_include_optionsfunction correctly retrieves template and Composable Kernel paths from the environment and formats them as compiler options.
90-126: LGTM: Assembly compilation helper is sound.The
_compile_hip_to_asmfunction correctly uses temporary directories, properly constructs hipcc commands, and handles errors appropriately. The static analysis warnings (S603, TRY003) are false positives for this use case.
128-179: LGTM: Kernel artifact dumping is robust and well-structured.The
dump_kernel_artifactsfunction handles edge cases gracefully (missing kernel, empty source), uses appropriate error handling for compilation failures, and provides informative output. Thesha1usage (static analysis S324) is appropriate here since it's only for generating a digest for filenames, not for security purposes.
444-447: Verify: Group iteration uses Pythonrangeinstead of TileLang parallel construct.The group loop at lines 444-447 uses Python's
range, meaning iterations execute sequentially rather than in parallel. For smallgroupsvalues (1-4), this is likely acceptable, but ifgroupscan be large, consider whether this should use a TileLang parallel construct or if the kernel grid should be extended to parallelize over groups.Is the sequential group iteration intentional, or should this be parallelized for larger group counts?
524-546: LGTM: Reduction kernel correctly accumulates partial gradients.The
flashattn_bwd_reduce_dqkernel properly sums dQ_partial across tiles with accumulation in float32 for numerical stability. The kernel grid parallelizes over (batch, heads, seq_len) while the tile dimension is reduced sequentially, which is the appropriate choice.
|
@Alex4210987 Thanks! how many speed up we can achieve from those improvement? |
…nd update HIP compile callback * Added support for fast math optimizations in `example_amd_flash_attn_bwd.py` by updating the `@tilelang.jit` decorator. * Modified the `tilelang_callback_hip_compile` function to accept a `pass_config` parameter and conditionally enable fast math during compilation. * Updated the `LibraryGenerator` class to include fast math options in the compilation command. * Ensured consistency in fast math handling across different components of the codebase.
…tionality * Introduced fast math support in the LibraryGenerator class, aligning behavior with CUDA through pass_configs. * Added methods to get and set source and library paths, enhancing the flexibility of the LibraryGenerator. * Ensured proper handling of fast math options during HIP compilation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/engine/lower.py (1)
113-146: Fix CI lint:targetis unused intilelang_callback_hip_compile.
Ruff flags Line 113 (ARG001). Iftargetmust remain in the callback signature, rename to_targetor add a targeted noqa.- def tilelang_callback_hip_compile(code, target, pass_config=None): + def tilelang_callback_hip_compile(code, _target, pass_config=None): # noqa: ARG001tilelang/jit/adapter/libgen.py (1)
109-131: Guard againstself.pass_configs is Nonebefore.get(...)in HIP fast-math block.
Ifassign_pass_configs()wasn’t called, this will raise at runtime.- if self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH): + cfg = self.pass_configs or {} + if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH): deprecated_warning( "TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7", ) - enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, True) + enable_fast_math = not cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH, True) else: - enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False) + enable_fast_math = cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH, False)
♻️ Duplicate comments (4)
examples/amd/example_amd_flash_attn_bwd.py (4)
410-410: dQ_partial memory usage concern acknowledged.As flagged in the previous review, the
dQ_partialtensor with shape[batch, num_tiles, seq_len, heads, dim]can consume significant memory. This issue is tracked in the past review comments.
621-622: Default dump_dir behavior creates unwanted artifacts.As flagged in the previous review, when
dump_dirisNone, it's unconditionally set to a default path, causing artifacts to be written even when the user didn't request them. This issue is tracked in past review comments.
674-674: Duplicate dQ_partial allocation increases memory footprint.As flagged in the previous review,
dQ_partialis allocated here and again at line 736 during benchmarking, potentially doubling memory usage. This issue is tracked in past review comments.
736-740: Duplicate dQ_partial allocation in benchmarking.This is the second allocation of
dQ_partial(first at line 674). The memory concern was flagged in the previous review and is tracked there.
🧹 Nitpick comments (4)
tilelang/engine/lower.py (1)
124-141: Inconsistent use of.valueacross pass_config access patterns.
PassConfigKeyis astr-enum (inherits fromstr, Enum), soPassConfigKey.TL_ENABLE_FAST_MATHandPassConfigKey.TL_ENABLE_FAST_MATH.valueboth evaluate to the same string. However,lower.pyuses explicit.value(lines 126, 128, 131) whilelibgen.pyaccesses enum instances directly without.value(lines 63, 69, 71, 111, 117, 119). While both patterns work identically due to str-enum behavior, unify to one pattern for clarity—either consistently use.valueor consistently use enum instances directly.examples/amd/example_amd_flash_attn_bwd.py (3)
90-126: Verify subprocess call security in assembly generation.The
subprocess.runcall at line 117 constructs a command fromarchandoptionsparameters. Whilearchcomes fromhipcc.get_rocm_arch()andoptionsfrom_get_hip_include_options()(which should be trusted paths), ensure that any future modifications don't introduce untrusted input into this command construction.
242-252: Replace Chinese comment with English.The comment at line 242 is in Chinese. For international collaboration and maintainability, please use English.
Apply this diff:
- # Forward: Q在register里, K/V在LDS里 + # Forward: Q in register fragments, K/V in shared memory (LDS)
417-417: Replace Chinese comment with English.The comment at line 417 is in Chinese. For international collaboration, please use English.
Apply this diff:
- # Backward: K在shared里, V在register里, Q/QT/dO/dOT在LDS里 + # Backward: K in shared memory, V in register fragments, Q/dO in shared memory (LDS)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/amd/example_amd_flash_attn_bwd.py(15 hunks)tilelang/engine/lower.py(2 hunks)tilelang/jit/adapter/libgen.py(3 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
tilelang/engine/lower.py
🧬 Code graph analysis (2)
examples/amd/example_amd_flash_attn_bwd.py (1)
tilelang/contrib/hipcc.py (1)
compile_hip(19-91)
tilelang/jit/adapter/libgen.py (4)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-154)tilelang/utils/deprecated.py (1)
deprecated_warning(1-10)tilelang/jit/adapter/ctypes/adapter.py (2)
srcpath(275-277)libpath(280-282)tilelang/jit/adapter/cython/adapter.py (2)
srcpath(361-363)libpath(366-368)
🪛 GitHub Actions: CI
tilelang/engine/lower.py
[error] 1-1: Command failed with exit code 1 while running: python -m compileall -q -f tilelang
tilelang/jit/adapter/libgen.py
[error] 189-189: Command failed: python -m compileall -q -f tilelang. IndentationError: unexpected indent (libgen.py, line 189)
🪛 Ruff (0.14.8)
tilelang/engine/lower.py
113-113: Unused function argument: target
(ARG001)
examples/amd/example_amd_flash_attn_bwd.py
117-117: subprocess call: check for execution of untrusted input
(S603)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
149-149: Probable use of insecure hash functions in hashlib: sha1
(S324)
tilelang/jit/adapter/libgen.py
189-189: Unexpected indentation
(invalid-syntax)
🔇 Additional comments (13)
examples/amd/example_amd_flash_attn_bwd.py (13)
128-179: LGTM! Well-structured artifact dumping utility.The function properly handles edge cases (None checks, missing source), creates directories safely, and provides clear error messages for compilation failures. The separation of concerns (HIP source, HSACO, assembly) is clean.
182-182: LGTM! Fast math optimization enabled.Enabling
TL_ENABLE_FAST_MATHvia pass_configs is appropriate for this performance-critical kernel.
243-270: LGTM! Q moved to register fragments for better performance.The change from
Q_sharedtoQ_fragmentmoves Q into register storage, which should improve memory access patterns and performance. The usage inT.copy(line 252) andT.gemm(line 270) is consistent with this change.
353-353: LGTM! Preprocess kernel configuration.The jit decorator correctly specifies
out_idx=[2]for the Delta output tensor and enables fast math.
380-380: LGTM! Backward kernel configuration.Fast math enabled for performance optimization. Note that
out_idxis not specified here because the kernel writes to multiple output tensors (dQ_partial, dK, dV) that are provided as arguments.
414-419: LGTM! Backward kernel restructured for per-tile processing.The kernel iteration structure changed to
(head_kv, ceildiv(seq_len, block_M), batch)to process K/V tiles, and V is now stored inV_fragment(registers) instead of shared memory, consistent with the forward pass optimization.
444-495: LGTM! Per-group iteration and optimized GEMM ordering.The per-group iteration logic (lines 444-447) correctly handles grouped query attention by iterating over each Q head group for the current K/V head. The GEMM reordering (dV → dP → dK → dQ) is well-documented and improves V_fragment reuse since V is already in registers.
524-545: LGTM! Reduction kernel correctly accumulates dQ_partial.The reduction kernel properly accumulates per-tile partial gradients into the final
dQ_outtensor. The use ofaccum_dtype(float) for accumulation ensures numerical stability during the summation.
609-616: LGTM! Main function signature extended with dump_dir parameter.The addition of the optional
dump_dirparameter allows users to specify where kernel artifacts should be saved, improving flexibility for debugging and analysis workflows.
654-685: LGTM! Kernel artifact dumping integrated at all stages.The
dump_kernel_artifactscalls are appropriately placed after each kernel compilation (forward, preprocess, backward, reduce, postprocess), providing comprehensive debugging capabilities whendump_diris specified.
671-682: LGTM! dQ reduction properly integrated.The reduction workflow is correctly implemented: allocate
dQ_partial, run backward kernel to populate it, then usereduce_kernelto accumulate tiles intodQ_accum. Thenum_kv_tilescalculation matches the backward kernel's tile dimension.
768-770: Increased default parameters compound memory usage.The defaults changed from
batch=1, heads=8tobatch=4, heads=16. Combined with thedQ_partialmemory usage flagged in previous reviews, these larger defaults may cause out-of-memory issues on GPUs with limited memory.Consider reducing these defaults or documenting the minimum GPU memory requirements in the help text or README.
774-777: LGTM! CLI argument for dump_dir properly integrated.The
dump_dirargument is correctly added to the CLI parser and passed tomain(), allowing users to optionally specify where to save kernel artifacts.
* Removed unnecessary getter and setter methods for source and library paths in the LibraryGenerator class. * Streamlined the set_src_path method to improve code clarity and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
examples/amd/example_amd_flash_attn_fwd.py (1)
151-151: Verify: Tail-tile bounds handling for non-divisible sequence lengths.When
seq_len % block_N != 0,loop_end_k = T.ceildiv(q_block_offset + block_M, block_N)may iterate beyond valid sequence data. The subsequentT.copy(K[bz, kv_idx:kv_idx+block_N, ...], K_shared)could read pastseq_len(e.g., indices 96–127 from a 100-element dimension).Consider adding an explicit bound similar to the pattern used in other flash attention examples:
loop_end_k = T.min(T.ceildiv(seq_len, block_N), T.ceildiv(q_block_offset + block_M, block_N))for the causal case.Run the following script to check if this pattern is used elsewhere:
#!/bin/bash # Search for T.min usage with loop_end bounds in flash attention examples rg -n "loop_end.*T\.min|T\.min.*ceildiv.*seq_len" examples/ --type=py -C3examples/amd/example_amd_flash_attn_bwd.py (1)
621-622: Consider: Make artifact dumping opt-in by default.The current code unconditionally sets
dump_dirto"tilelang_kernel_dumps"when it'sNone, causing artifacts to be written even when users don't explicitly request them. This could be unexpected behavior.Consider one of these approaches:
- Keep
dump_dirasNonewhen not specified (no dumping by default)- Add a separate
--enable-dumpflag and only setdump_dirwhen that flag is present- Update the help text to clearly document that artifacts are dumped by default
🔎 Suggested fix to make dumping opt-in:
- if not dump_dir: - dump_dir = os.path.join(os.getcwd(), "tilelang_kernel_dumps") + # Keep dump_dir as None if not specified - dumping is opt-in
🧹 Nitpick comments (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
736-740: Note: Duplicate dQ_partial allocation in benchmarking.A new
dQ_partial_benchis allocated here for benchmarking. This is the second large allocation mentioned in past reviews. While necessary for isolated benchmarking, users should be aware this doubles memory usage during the benchmark phase.Consider documenting this in comments or adding a note about memory usage in the help text.
Related to past review comment about dual allocation.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/amd/example_amd_flash_attn_bwd.py(15 hunks)examples/amd/example_amd_flash_attn_fwd.py(5 hunks)src/tl_templates/hip/common.h(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tl_templates/hip/common.h
🧰 Additional context used
🧬 Code graph analysis (2)
examples/amd/example_amd_flash_attn_bwd.py (4)
tilelang/tileop/base.py (1)
GemmWarpPolicy(5-185)tilelang/jit/kernel.py (1)
kernel_source(644-645)tilelang/contrib/hipcc.py (1)
compile_hip(19-91)tilelang/language/allocate.py (2)
alloc_fragment(72-85)alloc_shared(40-55)
examples/amd/example_amd_flash_attn_fwd.py (1)
tilelang/autotuner/tuner.py (1)
autotune(676-771)
🪛 Ruff (0.14.8)
examples/amd/example_amd_flash_attn_bwd.py
117-117: subprocess call: check for execution of untrusted input
(S603)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
149-149: Probable use of insecure hash functions in hashlib: sha1
(S324)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (18)
examples/amd/example_amd_flash_attn_fwd.py (5)
79-79: LGTM: Autotune decorator updated.The removal of
supply_progparameter aligns with the updated autotune API. Thecache_input_tensors=Truesetting is appropriate for this use case.
124-127: LGTM: Loop control refactored for clarity.Introducing
bx_loop_varas an explicit loop variable improves code readability and makes the loop progression more explicit.
139-140: LGTM: Q migrated to register fragments.The change from
Q_sharedtoQ_fragmentmoves Q data to registers, which should improve performance by reducing shared memory pressure and improving data locality. The comment clearly documents this design choice.Also applies to: 149-149, 167-167
184-189: LGTM: Numerical stability improvements.The explicit checks for
-infinitybefore computing exponentials and scale factors prevent undefined behavior and improve numerical stability. This is defensive coding that handles edge cases in the softmax computation correctly.Also applies to: 195-198
221-221: LGTM: Default parameters now consistent.The CLI defaults (
batch=2, heads=16) now match the programmatic defaults in themain()function signature, resolving the previous mismatch. This ensures consistent behavior whether the script is called via CLI or imported and called directly.Also applies to: 248-249
examples/amd/example_amd_flash_attn_bwd.py (13)
1-11: LGTM: New imports for artifact dumping.The added imports support the new kernel artifact dumping functionality (HIP source, HSACO, assembly). All imports are standard library or already-used dependencies.
128-179: LGTM: Kernel artifact dumping implementation.The
dump_kernel_artifactsfunction provides useful debugging capabilities by saving HIP source, HSACO binaries, and assembly output. The error handling is appropriate, catching exceptions and logging them without crashing the main execution flow.
182-182: LGTM: Fast-math optimization enabled.Adding
pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}to the JIT decorator enables aggressive floating-point optimizations. This is appropriate for flash attention kernels where slight numerical differences are acceptable for significant performance gains.Also applies to: 353-353, 380-380, 506-506, 524-524
242-243: LGTM: Q migrated to register fragments in forward pass.Consistent with the forward-only file, Q is now stored in
Q_fragment(registers) instead of shared memory, improving performance by reducing memory traffic.Also applies to: 252-252, 270-270
410-410: Note: dQ_partial tensor shape requires careful memory management.The tensor shape
[batch, T.ceildiv(seq_len, block_M), seq_len, heads, dim]introduces an additional dimension for per-tile partial gradients. This is allocated twice (lines 674 and 736), which was flagged in past reviews as potentially consuming significant memory.The new reduce kernel (lines 524-546) accumulates these partials, which is the correct approach for parallelizing dQ computation across tiles. However, users should be aware of the memory implications with large batch sizes or sequence lengths.
Related to past review comment about memory usage.
414-419: LGTM: Backward kernel grid and data placement updated.The kernel grid is now
(head_kv, T.ceildiv(seq_len, block_M), batch), which parallelizes over K/V heads and tiles. The data placement (K in shared, V in fragment) is documented and appropriate for the backward pass computation pattern.
444-496: LGTM: Backward computation restructured for efficiency.The computation flow has been optimized:
- Outer loop over groups handles multi-query attention correctly
- GEMM order optimized to compute dV first (reusing V_fragment in registers)
- Per-tile dQ accumulation into
dQ_partialenables parallelization- Causal masking applied correctly to P_acc
The comments clearly document the optimized GEMM ordering rationale.
524-546: LGTM: dQ reduction kernel implementation.The new
flashattn_bwd_reduce_dqkernel correctly accumulates per-tile dQ partials into the final dQ gradient. The implementation:
- Iterates over all tiles
- Accumulates in float32 for numerical accuracy
- Handles the shape transformation correctly
This kernel is necessary for the parallelized backward pass approach.
609-617: LGTM: Main signature extended with dump_dir parameter.Adding
dump_diras an optional parameter withNonedefault is a clean API design. This allows callers to opt into artifact dumping.
654-654: LGTM: Kernel artifacts dumped at key points.The strategic placement of
dump_kernel_artifactscalls captures all major kernel stages:
- Forward pass
- Backward preprocess
- Backward main kernel
- dQ reduction
- Backward postprocess
This provides comprehensive debugging visibility.
Also applies to: 659-659, 669-669, 681-681, 685-685
671-682: LGTM: dQ_partial allocation and reduction integrated.The flow correctly:
- Allocates
dQ_partialbased on the backward kernel'sblock_Mconfiguration- Passes it to the backward kernel for per-tile accumulation
- Uses the reduce kernel to sum partials into final dQ
The implementation is correct and necessary for the parallelized approach.
768-770: LGTM: CLI defaults updated and dump_dir argument added.The CLI defaults are now
batch=4, heads=16, dim=64, which are more realistic test configurations. Thedump_dirargument is properly integrated with clear help text.Note: The main() function signature defaults (
batch=1, heads=8, dim=128) differ from CLI defaults. This is acceptable if intentional, but consider documenting why they differ (e.g., CLI defaults optimized for quick validation).Also applies to: 774-774, 777-777
75-126: Review subprocess usage and hash function implementation.The
_compile_hip_to_asmfunction usessubprocess.run()with a list-based command, which safely prevents shell injection. Whileasm_pathis derived from an externaldump_dirparameter, the list-based invocation and filename sanitization prevent command injection.SHA-1 is acceptable in non-security contexts for non-sensitive identifiers and checksums, making its use for generating 8-character filename digests appropriate here.
Summary by CodeRabbit
New Features
Improvements
Refactoring
✏️ Tip: You can customize this high-level summary in your review settings.