-
Notifications
You must be signed in to change notification settings - Fork 10
[BugFix]: avoid closure-captured mutable shapes in flash_attn and sparse_mla kernel #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor
Are you sure you want to change the base?
Conversation
…d sparse_mla kernels
Summary of ChangesHello @RMLYC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements a critical bug fix to enhance the stability and reliability of the AutoTuner. Previously, mutable shape objects were inadvertently captured by closures within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix an AutoTuner crash by moving mutable shape definitions into the JIT scope and converting them to immutable tuples. The changes across flash_attn/fwd.py and flash_attn/bwd.py correctly implement this fix. However, in top/kernels/deepseek_mla/sparse_mla.py, while the shape definitions have been moved, they are still defined as lists instead of tuples. I've added a comment with a suggestion to change them to tuples to ensure the fix is consistently applied and effective.
|
Lines 26–31: The comment is too long and should be reformatted into multiple lines. We should also add lint / pre-commit checks to enforce line-length limits. Regarding the content, please split this into a concise high-level comment followed by short, focused notes or TODOs. For example:
|
Which file? Currently, the code uses pre-commit for formatting checks. |
sparse_mla.py |
done |
|
| "you should fix the logic involving CP0 (cp_rank == 0) " | ||
| "to make sure q with pos < KV_Stride - 1 is masked. " | ||
| "(Or you may ignore this handling if NaN in these q's outputs " | ||
| "does not affect other results; this was reported by wangding.)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion message is overly verbose, mixes multiple concerns, and includes informal wording and personal attribution, which makes it hard to read and not suitable for production code.
Assertions should clearly state the invariant being checked, not embed long explanations, speculative handling, or historical context. This logic also encodes a non-obvious contract between q_start_index_s, kv_stride, and CP0 masking that should be documented explicitly rather than hidden inside an assert string.
I suggest:
- Make the assertion concise and factual. Keep the assert focused on what must be true, not why;
- Move the explanation into a comment (or docstring). Explain why this invariant exists using structured comments, not an assert message.
# When q_start_index_s != 0, CP0 is expected to mask queries with
# position < kv_stride - 1. Otherwise, short CP lengths may produce
# invalid (NaN) outputs for early queries.
if q_start_index_s != 0:
assert q_start_index_s > kv_stride, (
"q_start_index_s must be greater than kv_stride when non-zero"
)
| assert kv_group == 1, 'here we solve the heads padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' | ||
| # print(f'padded_H = {padded_H}, heads = {heads}') | ||
| assert kv_group == 1, ("Automatic heads padding is only supported when kv_group == 1. " | ||
| "Otherwise handle Q and Output copies with an appropriate mask.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion message mixes a policy decision (“only supported”) with implementation guidance (“handle Q and Output copies with a mask”). This makes the intent unclear and puts design instructions inside an assert. The invariant should be stated succinctly, while the rationale and future work should be documented separately.
suggested rewrite:
# Automatic head padding is only valid when kv_group == 1.
# For kv_group > 1, Q and output tensors must be masked explicitly
# to handle padded heads correctly.
# TODO: Support automatic head padding for kv_group > 1
# by applying appropriate masks during Q and output copies.
if padded_H != heads:
assert kv_group == 1, (
"Automatic head padding requires kv_group == 1"
)
| The behavior may change if CP0 logic is fixed (there is a compiler-related | ||
| bug affecting that logic). | ||
| """ | ||
| assert dim == tilelang.math.next_power_of_2( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original docstring mixes background, assumptions, speculative fixes, and historical context in a single paragraph, which makes it hard to scan.
suggested rewrite:
"""
Implements sparse attention.
Note:
- The first (kv_stride - 1) token outputs may be NaN due to current CP0 handling.
- These outputs are typically unused and assumed to be safe to ignore.
- During the backward pass, care must be taken to prevent NaN propagation
(e.g., in expressions like `delta = out * dout`).
- A common mitigation is to zero out invalid outputs before backward
computation.
This behavior is a known limitation and may change once the CP0 logic
is corrected (currently affected by a compiler-related issue).
"""
Uh oh!
There was an error while loading. Please reload this page.