Skip to content

Commit 736b6f4

Browse files
[mxfp8 moe training] parallelize along col blocks in scale blocked format kernel for groups along K
stack-info: PR: #3416, branch: danielvegamyhre/stack/85
1 parent a6dbf45 commit 736b6f4

File tree

9 files changed

+1069
-19
lines changed

9 files changed

+1069
-19
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
import os
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from torch.utils.cpp_extension import load
15+
from tqdm import tqdm
16+
17+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
18+
from torchao.prototype.moe_training.kernels.mxfp8 import (
19+
triton_mx_block_rearrange_2d_K_groups,
20+
)
21+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
22+
triton_mx_block_rearrange_2d_K_groups_naive,
23+
)
24+
from torchao.prototype.moe_training.utils import generate_jagged_offs
25+
from torchao.utils import is_sm_at_least_100
26+
27+
# Build CUDA kernel directly using torch.utils.cpp_extension.load
28+
mxfp8_cuda = None
29+
try:
30+
if not is_sm_at_least_100():
31+
raise RuntimeError("CUDA kernel requires SM100+ GPU (Blackwell or later)")
32+
33+
# Get the kernel source directory
34+
KERNEL_DIR = os.path.join(
35+
os.path.dirname(os.path.abspath(__file__)),
36+
"..",
37+
"..",
38+
"..",
39+
"..",
40+
"torchao",
41+
"csrc",
42+
"cuda",
43+
"mx_kernels",
44+
)
45+
KERNEL_DIR = os.path.normpath(KERNEL_DIR)
46+
47+
print("Compiling CUDA kernel...")
48+
mxfp8_cuda = load(
49+
name="mx_block_rearrange_2d_K_groups",
50+
sources=[
51+
os.path.join(KERNEL_DIR, "mxfp8_extension.cpp"),
52+
os.path.join(KERNEL_DIR, "mxfp8_cuda.cu"),
53+
os.path.join(KERNEL_DIR, "mx_block_rearrange_2d_K_groups.cu"),
54+
],
55+
extra_cuda_cflags=[
56+
"-O3",
57+
"--use_fast_math",
58+
"-std=c++17",
59+
"-gencode=arch=compute_100,code=sm_100",
60+
"-gencode=arch=compute_120,code=compute_120",
61+
],
62+
extra_cflags=["-O3", "-std=c++17"],
63+
verbose=True,
64+
)
65+
print("✓ CUDA kernel compilation successful!")
66+
except (ImportError, RuntimeError) as e:
67+
print(f"⚠ CUDA kernel not available: {e}")
68+
print("The benchmark will only run 'naive' and 'parallel' Triton versions.\n")
69+
70+
device = torch.device("cuda")
71+
72+
# Needed since changing args to function causes recompiles
73+
torch._dynamo.config.cache_size_limit = 1000
74+
75+
76+
@dataclass(frozen=True)
77+
class ExperimentConfig:
78+
input_shape: tuple[int]
79+
num_groups: int
80+
version: str # "naive" or "parallel"
81+
82+
83+
@dataclass(frozen=True)
84+
class ExperimentResult:
85+
time_us: float
86+
mem_bw_gbps: float
87+
88+
89+
@dataclass(frozen=True)
90+
class Experiment:
91+
config: ExperimentConfig
92+
result: ExperimentResult
93+
94+
95+
def get_configs() -> List[ExperimentConfig]:
96+
# Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups.
97+
block_size = 32
98+
input_shapes = [
99+
(5120, 16384 // block_size),
100+
(5120, 131072 // block_size),
101+
(8192, 16384 // block_size),
102+
(8192, 131072 // block_size),
103+
(7168, 16384 // block_size),
104+
(7168, 131072 // block_size),
105+
(2048, 16384 // block_size),
106+
(2048, 131072 // block_size),
107+
]
108+
num_groups = [8]
109+
versions = ["naive", "parallel", "cuda"]
110+
111+
configs = []
112+
for shape, groups, version in itertools.product(
113+
input_shapes,
114+
num_groups,
115+
versions,
116+
):
117+
configs.append(
118+
ExperimentConfig(
119+
input_shape=shape,
120+
num_groups=groups,
121+
version=version,
122+
)
123+
)
124+
return configs
125+
126+
127+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
128+
input_shape, num_groups, version = (
129+
config.input_shape,
130+
config.num_groups,
131+
config.version,
132+
)
133+
input_tensor = torch.randint(
134+
low=0,
135+
high=256,
136+
size=input_shape,
137+
dtype=torch.uint8,
138+
device=device,
139+
)
140+
141+
M, Kg = input_shape
142+
block_size = 32
143+
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)
144+
145+
# Select which kernel to benchmark based on version
146+
if version == "naive":
147+
kernel_fn = triton_mx_block_rearrange_2d_K_groups_naive
148+
elif version == "parallel":
149+
kernel_fn = triton_mx_block_rearrange_2d_K_groups
150+
elif version == "cuda":
151+
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups
152+
else:
153+
raise ValueError(f"Unknown version: {version}")
154+
155+
# Run kernel to get output shape
156+
out_scales = kernel_fn(
157+
input_tensor,
158+
input_group_offsets,
159+
)
160+
161+
# Benchmark the kernel
162+
assert input_tensor.is_contiguous()
163+
time_us = benchmark_cuda_function_in_microseconds(
164+
kernel_fn,
165+
input_tensor,
166+
input_group_offsets,
167+
)
168+
169+
# Calculate memory bandwidth
170+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
171+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
172+
173+
read_bytes = input_tensor.numel() * bytes_per_input_el
174+
write_bytes = out_scales.numel() * bytes_per_output_el
175+
176+
mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6)
177+
178+
return ExperimentResult(
179+
time_us=time_us,
180+
mem_bw_gbps=mem_bw_gbps,
181+
)
182+
183+
184+
def print_results(experiments: List[Experiment]):
185+
# Group experiments by input shape
186+
shapes_dict = {}
187+
for exp in experiments:
188+
shape_key = exp.config.input_shape
189+
if shape_key not in shapes_dict:
190+
shapes_dict[shape_key] = {}
191+
shapes_dict[shape_key][exp.config.version] = exp.result
192+
193+
headers = [
194+
"kernel_version",
195+
"input_shape",
196+
"time_us",
197+
"mem_bw_gbps",
198+
"fastest_version",
199+
]
200+
201+
rows = []
202+
for shape, versions in shapes_dict.items():
203+
# Find fastest version for this shape
204+
fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0]
205+
206+
# Add rows for each version
207+
for version, result in versions.items():
208+
rows.append(
209+
[
210+
version,
211+
f"({shape[0]}, {shape[1]})",
212+
f"{result.time_us:.2f}",
213+
round(result.mem_bw_gbps, 3),
214+
fastest_version,
215+
]
216+
)
217+
218+
print(tabulate(rows, headers=headers))
219+
220+
221+
def main():
222+
torch.random.manual_seed(123)
223+
configs = get_configs()
224+
results = []
225+
for config in tqdm(configs):
226+
result = run_experiment(config)
227+
results.append(Experiment(config=config, result=result))
228+
229+
# Use Tabulate to print results
230+
print_results(results)
231+
232+
233+
if __name__ == "__main__":
234+
main()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ def get_extensions():
702702
mxfp8_sources = [
703703
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
704704
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
705+
os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_K_groups.cu"),
705706
]
706707

707708
# Only add the extension if the source files exist AND we are building for sm100

test/prototype/moe_training/test_kernels.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,62 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
354354
# Check quantized values
355355
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
356356
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
357+
358+
359+
@pytest.mark.skipif(
360+
not is_sm_at_least_100(),
361+
reason="MXFP8 requires CUDA capability 10.0 or greater",
362+
)
363+
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
364+
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])
365+
@pytest.mark.parametrize("n_groups", [1, 4, 8, 16])
366+
def test_cuda_mx_block_rearrange_2d_K_groups(
367+
m: int,
368+
total_k: int,
369+
n_groups: int,
370+
):
371+
"""
372+
Test CUDA kernel for mx_block_rearrange_2d_K_groups against Triton reference.
373+
This kernel rearranges E8M0 scales to block-scaled swizzle format for cuBLAS Tmem.
374+
"""
375+
from torchao.prototype import mxfp8_cuda
376+
377+
device = "cuda"
378+
block_size = 32
379+
input_data = torch.randn(m, total_k, device=device)
380+
381+
e8m0_scales, _ = to_mx(
382+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
383+
)
384+
385+
# Generate group end offsets along total_K, then divide by block_size to get scale group end offsets
386+
input_group_offsets = generate_jagged_offs(
387+
n_groups, total_k, multiple_of=block_size, device=device
388+
)
389+
scale_group_offsets = input_group_offsets // block_size
390+
391+
# Triton reference implementation
392+
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
393+
e8m0_scales,
394+
scale_group_offsets,
395+
)
396+
397+
# CUDA kernel implementation
398+
cuda_out_scales = mxfp8_cuda.mx_block_rearrange_2d_K_groups(
399+
e8m0_scales.view(torch.uint8),
400+
scale_group_offsets,
401+
)
402+
403+
# Check that outputs match
404+
assert torch.equal(triton_out_scales, cuda_out_scales.view(torch.float8_e8m0fnu)), (
405+
"CUDA and Triton blocked scales not equal"
406+
)
407+
408+
# Verify output shape
409+
expected_rows = ((m + 127) // 128) * 128 # Padded to multiple of 128
410+
expected_cols = (
411+
e8m0_scales.size(1) + n_groups * 4
412+
) # Original cols + padding per group
413+
assert cuda_out_scales.shape == (expected_rows, expected_cols), (
414+
f"Output shape mismatch: expected {(expected_rows, expected_cols)}, got {cuda_out_scales.shape}"
415+
)

0 commit comments

Comments
 (0)