Skip to content

Conversation

@xwhzz
Copy link

@xwhzz xwhzz commented Sep 3, 2025

No description provided.

chengyupku and others added 30 commits June 3, 2025 10:32
* [Draft]Add grouped query attention fwd and bwd kernels

* remove profiling the performance from the check function
* [Draft]Add mamba chunk scan attention fwd kernel

* [Draft]Add mamba chunk scan attention fwd kernel

* [Draft]Add mamba chunk scan attention fwd kernel

* [Draft] add args.tune in mamba_chunk_scan test
* [Dev] Add mamba chunk state attention fwd kernel

* [Dev] Add mamba chunk state attention fwd kernel

* add type ignore

* add type ignore
* Add `type:ignore` to improve hint style

* Add MHA fwd/bwd kernels&test for BSHD layout

* Remove redundant code

* Improve naming

* Run yapf and ruff
* Add blocksparse_flash_attention fwd kernel

* Add blocksparse_flash_attention bwd kernel

* fix style and typo
- change all kernel class names from lowercase to camelCase (e.g. MLA_kernel to MLAKernel), and optimize code format for better readability. Update relevant documents to reflect these changes.
* [Draft]Add grouped query attention fwd and bwd kernels

* remove profiling the performance from the check function

* [Dev] Add bitnet prefill and decode kernels
* Add MHA decode kernel & update naming.

* Correct spelling.
* Add linear attention recurrent kernel.

* Run yapf and ruff

* Remove unnecessary pkg
* [Dev] Add GQA decode kernel and update naming

* [Style] Format code

* [Fix] Correct loop range of gqa_decode_split_ref using ceiling division
* Add utils.py

* Unify testing method.

* Adding padding to support seq_len indivisible by chunk_size.

* Fix typo in MHAKernel

* Fix lint

* Fix lint

* [fix] generate `do` without the requiring gradient

* [fix] generate `do` without the requiring gradient
…ls, [Fix] Resolve bugs in autotune of GQA kernels (tile-ai#22)

* [Dev] Add autotune support for GQA forward and backward passes

* [Fix] Resolve bugs in autotune of GQA kernels

* [Dev] Add autotune support for MHA decode, forward and backward kernels
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @xwhzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized vertical slash sparse attention mechanism, leveraging custom CUDA kernels and TileLang for superior performance. It includes a complete implementation, a new kernel abstraction, and extensive benchmarking infrastructure to validate its efficiency and correctness.

Highlights

  • New Sparse Attention Mechanism: Introduced a novel 'Vertical Slash Sparse Attention' mechanism, designed for efficient computation in deep learning models.
  • CUDA Kernel for Index Conversion: Added a custom CUDA kernel (vertical_slash_index.cu) and its C++ bindings (kernels.cpp) to dynamically convert and manage sparse attention indexes, optimizing data access patterns.
  • TileLang and Triton Integration: The new attention kernel is implemented using TileLang for performance optimization, with a Triton-based reference program for parity checking and benchmarking.
  • Performance Benchmarking: Comprehensive benchmarks demonstrate significant speedups (up to 1.8x) over baseline Triton implementations on NVIDIA H100 GPUs, showcasing improved TFlops and IO bandwidth.
  • Kernel Abstraction and Utilities: A new KernelBase abstract class provides a standardized interface for kernel implementations, complemented by enhanced utility functions for robust parity checks and detailed performance analysis.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a vertical slash sparse attention mechanism, including CUDA kernels, a TileLang implementation, and benchmark results. The changes involve adding new files for the kernel implementation, modifying the __init__.py files to include the new kernel, and adding a test file to evaluate the performance of the new attention mechanism.

o = mod(q, k, v, block_count, block_offset, column_count, column_index)
return o

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The backward function is currently a pass statement. Implement the backward pass for gradient calculation, or explicitly raise an error if gradients are not supported.

return kernel_func(block_M, block_N, num_stages)


@torch.compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a docstring to describe what this class does. This will improve readability and maintainability.

Comment on lines +231 to +232
class VerticalSlashSparseAttentionKernel(KernelBase):
map_dtype = {torch.float16: "float16", torch.bfloat16: "bfloat16"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a docstring to describe what this class does. This will improve readability and maintainability.

Comment on lines 487 to 494
def pytorch_ref_program(self, *args, **kwargs):

def attention_func(queries, keys, values, attention_mask):
attention_weights = torch.matmul(queries, keys.transpose(2, 3)) / math.sqrt(queries.size(-1))
attention_weights += attention_mask.to(queries.dtype) * torch.finfo(queries.dtype).min
attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(queries.dtype)
attention_output = torch.matmul(attention_weights, values)
return attention_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function pytorch_ref_program is defined but never called. Either remove it if it's not needed, or ensure it's integrated into the testing or usage of the class.

seq_len=SEQ_LEN,
head_dim=D_HEAD,
)
partity(kernel, q, k, v, block_count=block_count, block_offset=block_offset, column_count=column_count, column_index=column_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a try/except block around the partity call to catch potential assertion errors and provide more informative error messages, which is helpful for debugging.

)
partity(kernel, q, k, v, block_count=block_count, block_offset=block_offset, column_count=column_count, column_index=column_index)

perf = performance(kernel, [kernel.ref_program], q, k, v, block_count=block_count, block_offset=block_offset, column_count=column_count, column_index=column_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a try/except block around the performance call to catch potential errors and provide more informative error messages, which is helpful for debugging.

@xwhzz xwhzz changed the title [Example] Add vertical slash sparse attention and benchmark results [Dev] Add vertical slash sparse attention and benchmark results Sep 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants