-
Notifications
You must be signed in to change notification settings - Fork 588
Add logic for block-scaled tensors with GEMM swizzled scales #2486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
d274220 to
52ce3a4
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
4925b63 to
1de4b5e
Compare
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (5)
-
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py, line 273-277 (link)syntax: return statement should return a tuple, not a dictionary
-
transformer_engine/pytorch/csrc/common.h, line 389-390 (link)style: Float8E8M0 maps to kByte - this might cause confusion since kByte is used for both DType::kByte and DType::kFloat8E8M0. Is there a specific reason Float8E8M0 maps to kByte instead of having its own PyTorch scalar type?
-
transformer_engine/common/include/transformer_engine/swizzle.h, line 7-8 (link)syntax: Header comment incorrectly states this is 'cast.h' and describes casting functions, but this is 'swizzle.h' for swizzle operations
-
transformer_engine/common/cast/dispatch/quantize.cuh, line 150-157 (link)logic: Forward quantization always uses GEMM_READY format regardless of tensor's with_gemm_swizzled_scales field, while backward quantization respects it (lines 294-303). This inconsistency could lead to scale format mismatches. Should forward quantization also check output_tensor->with_gemm_swizzled_scales like the backward path does?
-
transformer_engine/pytorch/distributed.py, line 1082-1084 (link)logic: Bug:
quantizer(out)is called whenquantizerisNone. This will cause aTypeError: 'NoneType' object is not callableat runtime.
65 files reviewed, 5 comments
Description
All of the supported block-scaled tensor formats (MXFP8, NVFP4, DSv3 FP8) have two ways of ordering their scaling factors:
The core infrastructure handles this in an ad hoc way, blindly assuming that the "right" scale ordering is used for the different operations. The PyTorch infrastructure only supports MXFP8 and NVFP4 scales are in compact order, although DSv3 FP8 does have awareness of "compact" and "GEMM-ready" formats. This situation makes it hard to implement fused kernels that can bypass the swizzle kernel.
This PR adds a
with_gemm_swizzled_scalesfield in the C++ tensor class so that the core infrastructure can distinguish between the different scale orderings. It also adds this field in the PyTorch quantized tensor classes, and exposes aoptimize_for_gemmoption in the quantizer so that we can create tensors that do not need communication or checkpointing. Finally, it rips out all the DSv3 FP8 infrastructure for the compact format, which is no longer necessary.Progress
Add option to pre-swizzle weightsCloses #2446.
Type of change
Changes
Please list the changes introduced in this PR:
optimize_for_gemmoption in PyTorch quantizerChecklist: