-
Notifications
You must be signed in to change notification settings - Fork 10
[Dev] Add vertical slash sparse attention and benchmark results #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ipts, and basic kernel implementations
…rmatting in test and kernel files
… include runtime requirements
…nel initialization
* [Draft]Add grouped query attention fwd and bwd kernels * remove profiling the performance from the check function
* [Draft]Add mamba chunk scan attention fwd kernel * [Draft]Add mamba chunk scan attention fwd kernel * [Draft]Add mamba chunk scan attention fwd kernel * [Draft] add args.tune in mamba_chunk_scan test
* [Dev] Add mamba chunk state attention fwd kernel * [Dev] Add mamba chunk state attention fwd kernel * add type ignore * add type ignore
* Add `type:ignore` to improve hint style * Add MHA fwd/bwd kernels&test for BSHD layout * Remove redundant code * Improve naming * Run yapf and ruff
* Add blocksparse_flash_attention fwd kernel * Add blocksparse_flash_attention bwd kernel * fix style and typo
- change all kernel class names from lowercase to camelCase (e.g. MLA_kernel to MLAKernel), and optimize code format for better readability. Update relevant documents to reflect these changes.
* [Draft]Add grouped query attention fwd and bwd kernels * remove profiling the performance from the check function * [Dev] Add bitnet prefill and decode kernels
* Add MHA decode kernel & update naming. * Correct spelling.
* Add linear attention recurrent kernel. * Run yapf and ruff * Remove unnecessary pkg
* [Dev] Add GQA decode kernel and update naming * [Style] Format code * [Fix] Correct loop range of gqa_decode_split_ref using ceiling division
* Add utils.py * Unify testing method. * Adding padding to support seq_len indivisible by chunk_size. * Fix typo in MHAKernel * Fix lint * Fix lint * [fix] generate `do` without the requiring gradient * [fix] generate `do` without the requiring gradient
…ls, [Fix] Resolve bugs in autotune of GQA kernels (tile-ai#22) * [Dev] Add autotune support for GQA forward and backward passes * [Fix] Resolve bugs in autotune of GQA kernels * [Dev] Add autotune support for MHA decode, forward and backward kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @xwhzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a highly optimized vertical slash sparse attention mechanism, leveraging custom CUDA kernels and TileLang for superior performance. It includes a complete implementation, a new kernel abstraction, and extensive benchmarking infrastructure to validate its efficiency and correctness.
Highlights
- New Sparse Attention Mechanism: Introduced a novel 'Vertical Slash Sparse Attention' mechanism, designed for efficient computation in deep learning models.
- CUDA Kernel for Index Conversion: Added a custom CUDA kernel (
vertical_slash_index.cu) and its C++ bindings (kernels.cpp) to dynamically convert and manage sparse attention indexes, optimizing data access patterns. - TileLang and Triton Integration: The new attention kernel is implemented using TileLang for performance optimization, with a Triton-based reference program for parity checking and benchmarking.
- Performance Benchmarking: Comprehensive benchmarks demonstrate significant speedups (up to 1.8x) over baseline Triton implementations on NVIDIA H100 GPUs, showcasing improved TFlops and IO bandwidth.
- Kernel Abstraction and Utilities: A new
KernelBaseabstract class provides a standardized interface for kernel implementations, complemented by enhanced utility functions for robust parity checks and detailed performance analysis.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a vertical slash sparse attention mechanism, including CUDA kernels, a TileLang implementation, and benchmark results. The changes involve adding new files for the kernel implementation, modifying the __init__.py files to include the new kernel, and adding a test file to evaluate the performance of the new attention mechanism.
| o = mod(q, k, v, block_count, block_offset, column_count, column_index) | ||
| return o | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return kernel_func(block_M, block_N, num_stages) | ||
|
|
||
|
|
||
| @torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class VerticalSlashSparseAttentionKernel(KernelBase): | ||
| map_dtype = {torch.float16: "float16", torch.bfloat16: "bfloat16"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
top/kernel/vs_sparse_attention.py
Outdated
| def pytorch_ref_program(self, *args, **kwargs): | ||
|
|
||
| def attention_func(queries, keys, values, attention_mask): | ||
| attention_weights = torch.matmul(queries, keys.transpose(2, 3)) / math.sqrt(queries.size(-1)) | ||
| attention_weights += attention_mask.to(queries.dtype) * torch.finfo(queries.dtype).min | ||
| attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(queries.dtype) | ||
| attention_output = torch.matmul(attention_weights, values) | ||
| return attention_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| seq_len=SEQ_LEN, | ||
| head_dim=D_HEAD, | ||
| ) | ||
| partity(kernel, q, k, v, block_count=block_count, block_offset=block_offset, column_count=column_count, column_index=column_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ) | ||
| partity(kernel, q, k, v, block_count=block_count, block_offset=block_offset, column_count=column_count, column_index=column_index) | ||
|
|
||
| perf = performance(kernel, [kernel.ref_program], q, k, v, block_count=block_count, block_offset=block_offset, column_count=column_count, column_index=column_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.