-
Notifications
You must be signed in to change notification settings - Fork 588
[common] Add support for cuBLASLt GEMM for GroupedTensor #2502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci L0 |
Greptile Summary
Important Files Changed
Confidence score: 4/5
Sequence DiagramsequenceDiagram
participant User
participant API as nvte_grouped_gemm
participant Validator as validate_grouped_gemm_inputs
participant Selector as select_grouped_operand
participant Kernel as setup_grouped_gemm_kernel
participant cuBLAS as cublasLtMatmul
participant HandleMgr as cublasHandleManager
User->>API: "Call nvte_grouped_gemm(transa, transb, alpha, A, B, beta, C, D, workspace_setup, workspace_cublas, stream)"
API->>Validator: "validate_grouped_gemm_inputs(A, B, C, D, alpha, beta)"
Validator-->>API: "Validation complete"
API->>Selector: "select_grouped_operand(A, transa, is_A=true)"
Selector-->>API: "A_sel (data pointer, dtype, transpose flag)"
API->>Selector: "select_grouped_operand(B, transb, is_A=false)"
Selector-->>API: "B_sel (data pointer, dtype, transpose flag)"
API->>API: "GroupedGemmSetupWorkspace::from_buffers(workspace_ptr, num_tensors)"
API->>Kernel: "launch_grouped_gemm_setup(workspace, A_sel, B_sel, C, D, alpha, beta, num_tensors, stream)"
Note over Kernel: "Populates pointer arrays and M/N/K dimensions for each matrix in group"
Kernel-->>API: "Setup arrays populated"
API->>HandleMgr: "GetHandle()"
HandleMgr-->>API: "cublasLtHandle_t"
API->>API: "init_matrix_layouts(descA, descB, descC, descD, workspace, A_sel, B_sel, D, num_tensors)"
API->>API: "init_matmul_desc(matmulDesc, op_A, op_B)"
API->>API: "set_fp8_scale_pointers(matmulDesc, A_sel, B_sel)"
API->>API: "select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m, avg_n, avg_k)"
API->>cuBLAS: "cublasLtMatmul(handle, matmulDesc, alpha_ptrs, A_ptrs, B_ptrs, beta_ptrs, C_ptrs, D_ptrs, algo, workspace, stream)"
Note over cuBLAS: "D = alpha * op(A) @ op(B) + beta * C for each matrix group"
cuBLAS-->>API: "Grouped GEMM complete"
API-->>User: "Return"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
tests/cpp/operator/test_grouped_gemm.cu, line 75 (link)logic: missing columnwise_data in move assignment
-
tests/cpp/operator/test_grouped_gemm.cu, line 336-337 (link)logic: kSameFirst and kSameLast test cases use identical shape tuples. Should kSameFirst and kSameLast have different shape patterns to properly test the respective scenarios?
-
tests/cpp/operator/test_grouped_gemm.cu, line 367-378 (link)logic: missing case for InputCase::kFP8Delayed
-
transformer_engine/common/include/transformer_engine/gemm.h, line 265-266 (link)style: The
avg_m,avg_n,avg_kparameters are not documented in the function commentWhat do these average dimension parameters represent and how should they be computed?
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
4 files reviewed, 4 comments
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 266 (link)syntax: Documentation incorrectly states
alpha[i]andbeta[i]. The implementation uses a singlealphaandbetavalue for all matrices in the group (batch stride is 1 ininit_matmul_descat cublaslt_gemm.cu:1404), not per-matrix scaling.
4 files reviewed, 1 comment
- Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci |
| NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, | ||
| const int64_t *avg_n, const int64_t *avg_k); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The average sizes seem like advanced configs that would be better to leave out of the top-level API. Can we move them inside NVTEMatmulConfig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was suggestion of @ptrendx - customer may potentially want to use them if they know something more about the shapes. For example if there are multiple tensors of with k dimension D and one tensor with K dimension equal to 1, then it is potentially true that telling cublas that avg dim = D will result in better performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should definitely expose these options, but it would be better to put them in NVTEMatmulConfig rather than the top-level function signature. If you look at nvte_cublas_gemm:
TransformerEngine/transformer_engine/common/include/transformer_engine/gemm.h
Lines 109 to 112 in 97a09c2
| void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, | |
| NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, | |
| NVTETensor workspace, bool accumulate, bool use_split_accumulator, | |
| int math_sm_count, cudaStream_t stream); |
You see advanced options like
pre_gelu_out and use_split_accumulator that are only needed for specialized cases. It's even worse for use_split_accumulator, since that's only relevant for FP8 on Hopper and now it's sitting uselessly in the API forever. Compare with the v2 API, which is much closer to the original GEMM API:TransformerEngine/transformer_engine/common/include/transformer_engine/gemm.h
Lines 131 to 133 in 97a09c2
| void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, | |
| const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, | |
| NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream); |
The idea of NVTEMatmulConfig is to hold these advanced, kernel-specific options in a way where we can add or deprecate them easily without breaking API changes.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 256 (link)syntax: Documentation mentions non-existent
configparameter - this parameter is not in the function signature on line 276-280 -
tests/cpp/operator/test_grouped_gemm.cu, line 141 (link)logic: Alignment calculation appears incorrect - multiplying 16 bytes by 8 bits gives 128 bits, not the intended 16-byte alignment
6 files reviewed, 2 comments
|
/te-ci L0 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 260-279 (link)syntax: Missing
configparameter in implementationThe function signature includes an
NVTEMatmulConfig configparameter in the declaration (line 276), but the actual implementation incublaslt_grouped_gemm.cu:498doesn't accept this parameter. The call site in tests (line 449-462 oftest_grouped_gemm.cu) also omits it, passing0instead.Either add the parameter to the implementation or remove it from the declaration.
-
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 592 (link)logic: cuBLAS version mismatch in error message
Error message mentions "cuBLAS 13.2+" but the compile-time guard checks for
CUBLAS_VERSION >= 130100(line 29), which is cuBLAS 13.1. Also, header comment references cuBLAS 13.1 vs 13.2 inconsistently. -
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 472-473 (link)logic: Potential null pointer dereference when C is NULL
When
Cis NULL anduse_null_c=true,C->data.dptrandC->dtype()will cause a segfault. The code setsinputC = outputDwhen C is NULL (line 525), but this happens afterlaunch_grouped_gemm_setupis called, whereCis still NULL. -
tests/cpp/operator/test_grouped_gemm.cu, line 95-102 (link)style: Workspace size calculation doesn't match implementation
Test calculates
4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes(6 pointer arrays total), but the implementation inGroupedGemmSetupWorkspace::from_buffersexpects exactly 6 pointer arrays + 3 int arrays. The calculation is correct but the comment formatting makes it unclear. Consider:6 * ptr_bytes + 3 * int_bytes.
7 files reviewed, 4 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 237-268 (link)syntax: Version inconsistency: note section mentions cuBLAS13.1+ but requirements section specifies13.2+. Need to align these versions consistently.
Which cuBLAS version is actually required - 13.1 or 13.2?
-
tests/cpp/operator/test_grouped_gemm.cu, line 341-344 (link)logic: Transpose logic appears inverted - for
transa=true, A should be transposed so input shape should be (K,M) to produce effective (M,K) for GEMM. Is the tensor shape logic correct for transpose operations? Typically transa=true means the input A matrix needs to be transposed during the operation.
7 files reviewed, 2 comments
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
tests/cpp/operator/test_grouped_gemm.cu, line 485 (link)syntax: incorrect version check - should be
130100not130200The API requires cuBLAS 13.1+ (version 130100), but this conditional check uses 130200. This mismatch means tests will be skipped even on cuBLAS 13.1.
8 files reviewed, 1 comment
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci |
Description
Adds
nvte_grouped_gemmAPI using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes. A GPU kernel (setup_grouped_gemm_kernel) convertsNVTEGroupedTensorformat (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K).New API
Computes
D = alpha * op(A) @ op(B) + beta * Cfor groups of matrices with potentially different shapes.Type of change
Changes
GroupedGemmSetupWorkspacestruct for cuBLAS workspace layouttest_grouped_gemm.cucomparing againstnvte_multi_tensor_gemm(FP8/BF16, various shapes and transpose layouts)Checklist: