From 3bc5d371fae80ada9fd0690aef0f9a2050729edc Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 3 Dec 2025 03:10:33 +0000 Subject: [PATCH 1/6] Add conv roofline --- .../float8/float8_inference_roofline.py | 490 ++++++++++++++---- torchao/testing/training/roofline_utils.py | 12 + 2 files changed, 405 insertions(+), 97 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 188bb46224..3c705e1095 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -53,10 +53,30 @@ from torchao.testing.training.roofline_utils import ( get_inference_float8_mem_sympy, get_inference_gemm_time_sympy, + get_specs, + BYTES_PER_EL_BF16, + BYTES_PER_EL_FLOAT8, + KERNEL_LAUNCH_OVERHEAD_SEC, ) from torchao.utils import is_MI300 +def _validate_conv_params( + op_name: str, + kernel_size: Optional[int], + D: Optional[int], + H: Optional[int], + W: Optional[int], +): + """Validate conv operation parameters.""" + if op_name == "conv2d": + assert H is not None and W is not None, "H and W required for conv2d" + assert kernel_size is not None, "kernel_size required for conv2d" + elif op_name == "conv3d": + assert D is not None and H is not None and W is not None, "D, H, W required for conv3d" + assert kernel_size is not None, "kernel_size required for conv3d" + + @torch.no_grad() def get_gpu_kernel_time(m, x, trace_filename=None): # warm up @@ -165,6 +185,254 @@ def do_matmul(A, B): return bf16_time_s, f8_time_s +def get_conv_equivalent_gemm_dims( + op_name: str, + batch: int, + in_channels: int, + out_channels: int, + kernel_size: int, + D: Optional[int], + H: int, + W: int, + stride: int = 1, + padding: int = 0, +): + """ + Get GEMM dimensions from unfold. + + Args: + op_name: "conv2d" or "conv3d" + batch: Batch size + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Kernel size + D: Depth dimension (required for conv3d) + H: Height dimension (required for conv2d/conv3d) + W: Width dimension (required for conv2d/conv3d) + stride: Stride value + padding: Padding value + + Returns: + Tuple[int, int, int]: (gemm_M, gemm_K, gemm_N) + gemm_M: Number of output spatial positions + gemm_K: Size of each filter (in_channels * kernel volume) + gemm_N: Number of filters (out_channels) + """ + device = torch.device("cuda") + + _validate_conv_params(op_name, kernel_size, D, H, W) + + if op_name == "conv2d": + x = torch.randn(batch, in_channels, H, W, device=device) + unfolded = torch.nn.functional.unfold( + x, kernel_size=(kernel_size, kernel_size), + stride=stride, padding=padding + ) + batch_out, K, L = unfolded.shape + gemm_M = batch_out * L + gemm_K = K + gemm_N = out_channels + + elif op_name == "conv3d": + x = torch.randn(batch, in_channels, D, H, W, device=device) + + # Note: torch.nn.Unfold only supports 4-D tensors + # (https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html) + # For 3D conv, reshape (B,C,D,H,W) -> (B*D,C,H,W) and unfold H,W + B, C, D_in, H_in, W_in = x.shape + x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(B * D_in, C, H_in, W_in) + unfolded = torch.nn.functional.unfold( + x_reshaped, kernel_size=(kernel_size, kernel_size), + stride=stride, padding=padding + ) + + D_out = (D - kernel_size + 2 * padding) // stride + 1 + _, K_2d, L_2d = unfolded.shape + + # GEMM dimensions: account for depth in K + gemm_K = K_2d * kernel_size # C * kernel_size³ + gemm_M = B * D_out * L_2d + gemm_N = out_channels + + else: + raise ValueError(f"Unsupported op_name: {op_name}") + + return gemm_M, gemm_K, gemm_N + + +def benchmark_im2col_unfold( + op_name: str, + batch: int, + in_channels: int, + kernel_size: int, + D: Optional[int], + H: int, + W: int, + stride: int = 1, + padding: int = 0, + dtype=torch.bfloat16, +): + """ + Benchmark unfold operation. + + Args: + op_name: "conv2d" or "conv3d" + batch: Batch size + in_channels: Number of input channels + kernel_size: Kernel size + D: Depth dimension (required for conv3d) + H: Height dimension (required for conv2d/conv3d) + W: Width dimension (required for conv2d/conv3d) + stride: Stride value + padding: Padding value + dtype: Data type + + Returns: + Measured time in seconds + """ + device = torch.device("cuda") + + _validate_conv_params(op_name, kernel_size, D, H, W) + + # Unfold doesn't support FP8; return -1 for unsupported dtypes + if dtype not in (torch.bfloat16, torch.float16, torch.float32): + return -1 + + # Create input tensor + if op_name == "conv2d": + x = torch.randn(batch, in_channels, H, W, dtype=dtype, device=device) + elif op_name == "conv3d": + x = torch.randn(batch, in_channels, D, H, W, dtype=dtype, device=device) + else: + raise ValueError(f"Unsupported op_name: {op_name}") + + def _run_unfold(): + if op_name == "conv2d": + return torch.nn.functional.unfold( + x, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding + ) + else: # conv3d: reshape to 4D since unfold only supports 4D + B, C, D_in, H_in, W_in = x.shape + x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(B * D_in, C, H_in, W_in) + return torch.nn.functional.unfold( + x_reshaped, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding + ) + + # Warm up + for _ in range(2): + _ = _run_unfold() + torch.cuda.synchronize() + + # Benchmark + n_iter = 10 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(n_iter): + _ = _run_unfold() + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / 1000.0 / n_iter + + +def get_im2col_memory_overhead_sympy( + op_name: str, + batch: int, + in_channels: int, + out_channels: int, + kernel_size: int, + D: Optional[int], + H: int, + W: int, + stride: int = 1, + padding: int = 0, + dtype=torch.bfloat16, + gpu_name: Optional[str] = None, +): + """ + Calculate the memory overhead for im2col transformation in conv operations. + + Im2col unfolds the input tensor into a 2D matrix for efficient GEMM computation. + This involves: + 1. Reading the input tensor (batch × in_channels × spatial_dims) + 2. Writing the im2col matrix (output_spatial_positions × kernel_volume) + + The im2col matrix is typically much larger than the input due to overlapping + windows, especially with stride=1 and larger kernels. + + Args: + op_name: "conv2d" or "conv3d" + batch: Batch size + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Kernel size + D: Depth dimension (required for conv3d) + H: Height dimension (required for conv2d/conv3d) + W: Width dimension (required for conv2d/conv3d) + stride: Stride value + padding: Padding value + dtype: Data type + gpu_name: GPU name for specs + + Returns: + sympy expression for im2col memory overhead in seconds + """ + _validate_conv_params(op_name, kernel_size, D, H, W) + specs = get_specs(gpu_name) + + # Determine bytes per element based on dtype + if dtype == torch.bfloat16: + bytes_per_el = BYTES_PER_EL_BF16 + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + bytes_per_el = BYTES_PER_EL_FLOAT8 + else: + bytes_per_el = BYTES_PER_EL_BF16 # default + + if op_name == "conv2d": + + # Input size + input_numel = batch * in_channels * H * W + + # Output spatial dimensions + H_out = (H - kernel_size + 2 * padding) // stride + 1 + W_out = (W - kernel_size + 2 * padding) // stride + 1 + + # Im2col matrix size: (batch * H_out * W_out) × (in_channels * kernel_size^2) + im2col_numel = batch * H_out * W_out * in_channels * kernel_size * kernel_size + + elif op_name == "conv3d": + # Input size + input_numel = batch * in_channels * D * H * W + + # Output spatial dimensions + D_out = (D - kernel_size + 2 * padding) // stride + 1 + H_out = (H - kernel_size + 2 * padding) // stride + 1 + W_out = (W - kernel_size + 2 * padding) // stride + 1 + + # Im2col matrix size: (batch * D_out * H_out * W_out) × (in_channels * kernel_size^3) + im2col_numel = batch * D_out * H_out * W_out * in_channels * kernel_size * kernel_size * kernel_size + + else: + raise ValueError(f"Unsupported op_name: {op_name}") + + # Memory traffic: read input + write im2col matrix + # Note: In practice, some implementations may avoid materializing the full im2col + # matrix, but we model the worst case here + bytes_read = input_numel * bytes_per_el + bytes_write = im2col_numel * bytes_per_el + total_bytes = bytes_read + bytes_write + + # Convert to time using memory bandwidth + im2col_time_s = total_bytes / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"] + + # Account for kernel launch overhead + im2col_time_s = sympy.Max(im2col_time_s, KERNEL_LAUNCH_OVERHEAD_SEC) + + return im2col_time_s + + def run( outfile: str, recipe_name: str, @@ -196,23 +464,10 @@ def run( # `kernel_size`: kernel_size for conv3d / conv2d """ _SUPPORTED_OPS = ["linear", "conv2d", "conv3d"] - assert op_name in _SUPPORTED_OPS, ( - f"Unsupported op: {op_name}, supported are: {_SUPPORTED_OPS}" - ) - if op_name == "conv2d": - assert H is not None and W is not None, ( - "Expected D, H, W to be specified for conv2d" - ) - assert kernel_size is not None, ( - "Expected kernel_size to be specified for conv2d" - ) - elif op_name == "conv3d": - assert D is not None and H is not None and W is not None, ( - "Expected D, H, W to be specified for conv3d" - ) - assert kernel_size is not None, ( - "Expected kernel_size to be specified for conv3d" - ) + assert op_name in _SUPPORTED_OPS, f"Unsupported op: {op_name}, supported: {_SUPPORTED_OPS}" + + if op_name in ("conv2d", "conv3d"): + _validate_conv_params(op_name, kernel_size, D, H, W) config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -234,68 +489,62 @@ def run( M, K, N = sympy.symbols("M K N") - if op_name == "linear": - fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( - M, - K, - N, - recipe_name, - # TODO(future): also enable fusion modeling here + # Create symbolic roofline expressions (same for linear and conv) + fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( + M, K, N, recipe_name, + ) + bf16_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.bfloat16, None + ) + + if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")): + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float4_e2m1fn_x2, recipe_name ) - bf16_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.bfloat16, None + else: + gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float8_e4m3fn, gemm_recipe_name ) - - if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")): - fp8_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.float4_e2m1fn_x2, recipe_name - ) - else: - gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None - fp8_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.float8_e4m3fn, gemm_recipe_name - ) - print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) - print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) - print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) + + print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) + print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) + print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) + print() + + if op_name in ("conv2d", "conv3d"): + print(f"{op_name}: GEMM dimensions from unfold, roofline from symbolic expressions") print() - else: - # TODO: enable roofline analysis for conv - pass + elif op_name != "linear": + raise ValueError(f"Unsupported op_name: {op_name}") - # Note: roofline for conv2d/conv3d is not added yet, so most of the - # things for conv2d/conv3d we'll left out for now headers = [ - "fwd_M", - "fwd_K", - "fwd_N", - "D", - "H", - "W", - "kernel_size", - # roofline - gemm time (fwd + bwd, 3 gemms) - "r_bf16_gemm_s", - "r_fp8_gemm_s", - # roofline - fp8 overhead time (by counting reads/writes in the ideal case) + # Shape parameters + "fwd_M", "fwd_K", "fwd_N", "D", "H", "W", "kernel_size", + + # Roofline: GEMM time + "r_bf16_gemm_s", "r_fp8_gemm_s", + + # Roofline: im2col overhead + "r_im2col_bf16_s", "r_im2col_fp8_s", + + # Roofline: FP8 quantization overhead "r_fp8_ovhd_s", - # roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid) - "r_fp8_gemm_and_ovhd_s", - "r_fp8_gemm_and_ovhd_spdp", - # benchmarks - gemm time (fwd + bwd, 3 gemms) - "b_bf16_gemm_s", - "b_fp8_gemm_s", - # benchmarks - e2e LNLinearSigmoid time fwd + bwd - "b_bf16_e2e_s", - "b_fp8_e2e_s", - # note that e2e speedup is not the same as the roofline speedup: - # 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time) - # 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid) - # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple - # we don't break them out and don't have a roofline for them. - "b_fp8_e2e_spdp", - # how well benchmarked gemms match roofline predicted gemms - "rb_bf16_gemm_ratio", - "rb_fp8_gemm_ratio", + + # Roofline: GEMM-only metrics + "r_fp8_gemm_and_ovhd_s", "r_fp8_gemm_and_ovhd_spdp", + + # Roofline: Total (im2col + GEMM + quantization) + "r_bf16_total_s", "r_fp8_total_s", "r_fp8_total_spdp", + + # Benchmarks: Direct GEMM + "b_bf16_gemm_s", "b_fp8_gemm_s", + + # Benchmarks: End-to-end + "b_bf16_e2e_s", "b_fp8_e2e_s", "b_fp8_e2e_spdp", + + # Roofline vs benchmark ratios + "rb_bf16_gemm_ratio", "rb_fp8_gemm_ratio", ] results = [] @@ -322,11 +571,14 @@ def run( ) r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) - - # if enabled, also measured observed gemm time + + # Linear ops don't have im2col overhead + r_im2col_bf16_s, r_im2col_fp8_s = 0, 0 + r_bf16_total_s = r_bf16_gemm_time_s + r_fp8_total_s = r_fp8_gemm_and_ovhd_s + r_total_spdp = r_speedup b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 - rb_bf16_gemm_ratio = -1 - rb_fp8_gemm_ratio = -1 + rb_bf16_gemm_ratio, rb_fp8_gemm_ratio = -1, -1 if do_benchmarks: # TODO(future): make the bf16 gemm times exactly match the e2e @@ -346,21 +598,58 @@ def run( rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s - else: - # roofline analysis for conv2d/conv3d are not added yet - r_bf16_gemm_time_s = None - r_fp8_gemm_time_s = None - - r_fp8_ovhd_time_s = None - r_fp8_gemm_and_ovhd_s = None - r_speedup = None - - # real gemm benchmark time, also not added yet - # if enabled, also measured observed gemm time + elif op_name in ("conv2d", "conv3d"): + # Get GEMM dimensions from unfold + stride, padding = 1, 0 + gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims( + op_name=op_name, + batch=M_val, + in_channels=K_val, + out_channels=N_val, + kernel_size=kernel_size, + D=D, + H=H, + W=W, + stride=stride, + padding=padding, + ) + + # Use pre-computed symbolic expressions (created upfront) + r_bf16_gemm_time_s = float( + bf16_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_gemm_time_s = float( + fp8_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_ovhd_time_s = float( + fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + + # Compute combined metrics + r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) + + # Roofline im2col overhead (theoretical) + r_im2col_bf16_s = float(get_im2col_memory_overhead_sympy( + op_name, M_val, K_val, N_val, kernel_size, + D, H, W, stride=1, padding=0, dtype=torch.bfloat16 + )) + r_im2col_fp8_s = r_im2col_bf16_s * 0.5 + + # Roofline total: im2col + GEMM + quantization + r_bf16_total_s = r_bf16_gemm_time_s + r_im2col_bf16_s + r_fp8_total_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_im2col_fp8_s + r_total_spdp = r_bf16_total_s / r_fp8_total_s + + print(f" -> Im2col: BF16={r_im2col_bf16_s*1e6:.2f} µs, FP8={r_im2col_fp8_s*1e6:.2f} µs") + print(f" -> Speedup: GEMM only={r_speedup:.3f}x | Total={r_total_spdp:.3f}x") + + # GEMM benchmarks not yet implemented for conv ops b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 - # gemm roofline ratio achieved in real benchmark - rb_bf16_gemm_ratio = -1 - rb_fp8_gemm_ratio = -1 + rb_bf16_gemm_ratio, rb_fp8_gemm_ratio = -1, -1 + + else: + raise ValueError(f"Unsupported op_name: {op_name}") b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 if do_benchmarks: @@ -476,22 +765,29 @@ def run( H, W, kernel_size, - # roofline - gemm + # Roofline: GEMM r_bf16_gemm_time_s, r_fp8_gemm_time_s, - # roofline - fp8 overhead + # Roofline: im2col + r_im2col_bf16_s, + r_im2col_fp8_s, + # Roofline: FP8 quantization r_fp8_ovhd_time_s, - # roofline - gemm + overhead, and speedup + # Roofline: GEMM-only r_fp8_gemm_and_ovhd_s, r_speedup, - # benchmarks - gemm + # Roofline: Total + r_bf16_total_s, + r_fp8_total_s, + r_total_spdp, + # Benchmarks: GEMM b_bf16_gemm_time_s, b_fp8_gemm_time_s, - # benchmarks - e2e, and speedup + # Benchmarks: e2e b_bf16_e2e_time_s, b_fp8_e2e_time_s, b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20), - # gemm ratios + # Roofline vs benchmark ratios rb_bf16_gemm_ratio, rb_fp8_gemm_ratio, ] diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index bf234b3717..83e4b516cb 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -27,6 +27,18 @@ # which would hit about 2.2k GBPS on Meta's H100 variant "pct_achievable_mem_bw": 0.92, }, + "NVIDIA H200": { + # H200 has same compute as H100 but more memory and higher bandwidth + # https://www.nvidia.com/en-us/data-center/h200/, divide by 2 because no sparsity + "bf16_peak_tops": 989e12, + "fp8_peak_tops": 1979e12, + # 4.8 TB per second for H200 (double the standard H100) + "peak_mem_bw_bytes_sec": 4.8e12, + # copy from H100 + "pct_achievable_gemm_tops": 0.78, + # copy from H100 + "pct_achievable_mem_bw": 0.92, + }, "NVIDIA B200": { # https://resources.nvidia.com/en-us-blackwell-architecture, page 19, # divide by 2 because no sparsity From 79cdaecf9f63b2b174e3bcfe0b5a55ac9feabff0 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 3 Dec 2025 06:12:24 +0000 Subject: [PATCH 2/6] Add conv roofline --- .../float8/float8_inference_roofline.py | 250 ++---------------- 1 file changed, 23 insertions(+), 227 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 3c705e1095..2655ada10c 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -53,10 +53,6 @@ from torchao.testing.training.roofline_utils import ( get_inference_float8_mem_sympy, get_inference_gemm_time_sympy, - get_specs, - BYTES_PER_EL_BF16, - BYTES_PER_EL_FLOAT8, - KERNEL_LAUNCH_OVERHEAD_SEC, ) from torchao.utils import is_MI300 @@ -198,7 +194,10 @@ def get_conv_equivalent_gemm_dims( padding: int = 0, ): """ - Get GEMM dimensions from unfold. + Get equivalent GEMM dimensions for a conv operation. + + Uses torch.nn.functional.unfold to derive the correct GEMM dimensions + that correspond to the conv operation. Args: op_name: "conv2d" or "conv3d" @@ -235,10 +234,6 @@ def get_conv_equivalent_gemm_dims( elif op_name == "conv3d": x = torch.randn(batch, in_channels, D, H, W, device=device) - - # Note: torch.nn.Unfold only supports 4-D tensors - # (https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html) - # For 3D conv, reshape (B,C,D,H,W) -> (B*D,C,H,W) and unfold H,W B, C, D_in, H_in, W_in = x.shape x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(B * D_in, C, H_in, W_in) unfolded = torch.nn.functional.unfold( @@ -260,179 +255,6 @@ def get_conv_equivalent_gemm_dims( return gemm_M, gemm_K, gemm_N -def benchmark_im2col_unfold( - op_name: str, - batch: int, - in_channels: int, - kernel_size: int, - D: Optional[int], - H: int, - W: int, - stride: int = 1, - padding: int = 0, - dtype=torch.bfloat16, -): - """ - Benchmark unfold operation. - - Args: - op_name: "conv2d" or "conv3d" - batch: Batch size - in_channels: Number of input channels - kernel_size: Kernel size - D: Depth dimension (required for conv3d) - H: Height dimension (required for conv2d/conv3d) - W: Width dimension (required for conv2d/conv3d) - stride: Stride value - padding: Padding value - dtype: Data type - - Returns: - Measured time in seconds - """ - device = torch.device("cuda") - - _validate_conv_params(op_name, kernel_size, D, H, W) - - # Unfold doesn't support FP8; return -1 for unsupported dtypes - if dtype not in (torch.bfloat16, torch.float16, torch.float32): - return -1 - - # Create input tensor - if op_name == "conv2d": - x = torch.randn(batch, in_channels, H, W, dtype=dtype, device=device) - elif op_name == "conv3d": - x = torch.randn(batch, in_channels, D, H, W, dtype=dtype, device=device) - else: - raise ValueError(f"Unsupported op_name: {op_name}") - - def _run_unfold(): - if op_name == "conv2d": - return torch.nn.functional.unfold( - x, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding - ) - else: # conv3d: reshape to 4D since unfold only supports 4D - B, C, D_in, H_in, W_in = x.shape - x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(B * D_in, C, H_in, W_in) - return torch.nn.functional.unfold( - x_reshaped, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding - ) - - # Warm up - for _ in range(2): - _ = _run_unfold() - torch.cuda.synchronize() - - # Benchmark - n_iter = 10 - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - for _ in range(n_iter): - _ = _run_unfold() - end.record() - torch.cuda.synchronize() - - return start.elapsed_time(end) / 1000.0 / n_iter - - -def get_im2col_memory_overhead_sympy( - op_name: str, - batch: int, - in_channels: int, - out_channels: int, - kernel_size: int, - D: Optional[int], - H: int, - W: int, - stride: int = 1, - padding: int = 0, - dtype=torch.bfloat16, - gpu_name: Optional[str] = None, -): - """ - Calculate the memory overhead for im2col transformation in conv operations. - - Im2col unfolds the input tensor into a 2D matrix for efficient GEMM computation. - This involves: - 1. Reading the input tensor (batch × in_channels × spatial_dims) - 2. Writing the im2col matrix (output_spatial_positions × kernel_volume) - - The im2col matrix is typically much larger than the input due to overlapping - windows, especially with stride=1 and larger kernels. - - Args: - op_name: "conv2d" or "conv3d" - batch: Batch size - in_channels: Number of input channels - out_channels: Number of output channels - kernel_size: Kernel size - D: Depth dimension (required for conv3d) - H: Height dimension (required for conv2d/conv3d) - W: Width dimension (required for conv2d/conv3d) - stride: Stride value - padding: Padding value - dtype: Data type - gpu_name: GPU name for specs - - Returns: - sympy expression for im2col memory overhead in seconds - """ - _validate_conv_params(op_name, kernel_size, D, H, W) - specs = get_specs(gpu_name) - - # Determine bytes per element based on dtype - if dtype == torch.bfloat16: - bytes_per_el = BYTES_PER_EL_BF16 - elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - bytes_per_el = BYTES_PER_EL_FLOAT8 - else: - bytes_per_el = BYTES_PER_EL_BF16 # default - - if op_name == "conv2d": - - # Input size - input_numel = batch * in_channels * H * W - - # Output spatial dimensions - H_out = (H - kernel_size + 2 * padding) // stride + 1 - W_out = (W - kernel_size + 2 * padding) // stride + 1 - - # Im2col matrix size: (batch * H_out * W_out) × (in_channels * kernel_size^2) - im2col_numel = batch * H_out * W_out * in_channels * kernel_size * kernel_size - - elif op_name == "conv3d": - # Input size - input_numel = batch * in_channels * D * H * W - - # Output spatial dimensions - D_out = (D - kernel_size + 2 * padding) // stride + 1 - H_out = (H - kernel_size + 2 * padding) // stride + 1 - W_out = (W - kernel_size + 2 * padding) // stride + 1 - - # Im2col matrix size: (batch * D_out * H_out * W_out) × (in_channels * kernel_size^3) - im2col_numel = batch * D_out * H_out * W_out * in_channels * kernel_size * kernel_size * kernel_size - - else: - raise ValueError(f"Unsupported op_name: {op_name}") - - # Memory traffic: read input + write im2col matrix - # Note: In practice, some implementations may avoid materializing the full im2col - # matrix, but we model the worst case here - bytes_read = input_numel * bytes_per_el - bytes_write = im2col_numel * bytes_per_el - total_bytes = bytes_read + bytes_write - - # Convert to time using memory bandwidth - im2col_time_s = total_bytes / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"] - - # Account for kernel launch overhead - im2col_time_s = sympy.Max(im2col_time_s, KERNEL_LAUNCH_OVERHEAD_SEC) - - return im2col_time_s - - def run( outfile: str, recipe_name: str, @@ -449,6 +271,9 @@ def run( H: Optional[int] = None, W: Optional[int] = None, kernel_size: Optional[int] = None, + stride: int = 1, + padding: int = 0, + verbose: bool = False, ): """ Args: @@ -457,11 +282,13 @@ def run( * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom` * `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN * `n_limit (optional)`: if specified, only runs `n_limit` iterations - # `save_profile_traces (optional)`: if True, saves profiling traces - # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm - # `op_name`: linear, conv2d or conv3d, decides which op to benchmark - # `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d - # `kernel_size`: kernel_size for conv3d / conv2d + * `save_profile_traces (optional)`: if True, saves profiling traces + * `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm + * `op_name`: linear, conv2d or conv3d, decides which op to benchmark + * `D`, `H`, `W`: spatial dimensions for conv3d / conv2d + * `kernel_size`: kernel_size for conv3d / conv2d + * `stride`: stride for conv ops (default: 1) + * `padding`: padding for conv ops (default: 0) """ _SUPPORTED_OPS = ["linear", "conv2d", "conv3d"] assert op_name in _SUPPORTED_OPS, f"Unsupported op: {op_name}, supported: {_SUPPORTED_OPS}" @@ -481,6 +308,8 @@ def run( ["MKN", f"{M} {K} {N}"], ["DHW", f"{D} {H} {W}"], ["kernel_size", kernel_size], + ["stride", stride], + ["padding", padding], ] print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) @@ -513,7 +342,7 @@ def run( print() if op_name in ("conv2d", "conv3d"): - print(f"{op_name}: GEMM dimensions from unfold, roofline from symbolic expressions") + print(f"{op_name}: GEMM dimensions derived from conv params") print() elif op_name != "linear": raise ValueError(f"Unsupported op_name: {op_name}") @@ -525,17 +354,11 @@ def run( # Roofline: GEMM time "r_bf16_gemm_s", "r_fp8_gemm_s", - # Roofline: im2col overhead - "r_im2col_bf16_s", "r_im2col_fp8_s", - # Roofline: FP8 quantization overhead "r_fp8_ovhd_s", - # Roofline: GEMM-only metrics - "r_fp8_gemm_and_ovhd_s", "r_fp8_gemm_and_ovhd_spdp", - - # Roofline: Total (im2col + GEMM + quantization) - "r_bf16_total_s", "r_fp8_total_s", "r_fp8_total_spdp", + # Roofline: Total (GEMM + quantization) + "r_fp8_gemm_and_ovhd_s", "r_fp8_speedup", # Benchmarks: Direct GEMM "b_bf16_gemm_s", "b_fp8_gemm_s", @@ -572,11 +395,6 @@ def run( r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) - # Linear ops don't have im2col overhead - r_im2col_bf16_s, r_im2col_fp8_s = 0, 0 - r_bf16_total_s = r_bf16_gemm_time_s - r_fp8_total_s = r_fp8_gemm_and_ovhd_s - r_total_spdp = r_speedup b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 rb_bf16_gemm_ratio, rb_fp8_gemm_ratio = -1, -1 @@ -599,8 +417,6 @@ def run( rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s elif op_name in ("conv2d", "conv3d"): - # Get GEMM dimensions from unfold - stride, padding = 1, 0 gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims( op_name=op_name, batch=M_val, @@ -625,24 +441,11 @@ def run( fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) ) - # Compute combined metrics r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) - # Roofline im2col overhead (theoretical) - r_im2col_bf16_s = float(get_im2col_memory_overhead_sympy( - op_name, M_val, K_val, N_val, kernel_size, - D, H, W, stride=1, padding=0, dtype=torch.bfloat16 - )) - r_im2col_fp8_s = r_im2col_bf16_s * 0.5 - - # Roofline total: im2col + GEMM + quantization - r_bf16_total_s = r_bf16_gemm_time_s + r_im2col_bf16_s - r_fp8_total_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_im2col_fp8_s - r_total_spdp = r_bf16_total_s / r_fp8_total_s - - print(f" -> Im2col: BF16={r_im2col_bf16_s*1e6:.2f} µs, FP8={r_im2col_fp8_s*1e6:.2f} µs") - print(f" -> Speedup: GEMM only={r_speedup:.3f}x | Total={r_total_spdp:.3f}x") + print(f" -> GEMM dims: M={gemm_M}, K={gemm_K}, N={gemm_N}") + print(f" -> Speedup: {r_speedup:.3f}x") # GEMM benchmarks not yet implemented for conv ops b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 @@ -768,18 +571,11 @@ def run( # Roofline: GEMM r_bf16_gemm_time_s, r_fp8_gemm_time_s, - # Roofline: im2col - r_im2col_bf16_s, - r_im2col_fp8_s, - # Roofline: FP8 quantization + # Roofline: FP8 quantization overhead r_fp8_ovhd_time_s, - # Roofline: GEMM-only + # Roofline: Total (GEMM + quantization) r_fp8_gemm_and_ovhd_s, r_speedup, - # Roofline: Total - r_bf16_total_s, - r_fp8_total_s, - r_total_spdp, # Benchmarks: GEMM b_bf16_gemm_time_s, b_fp8_gemm_time_s, From f4c7a6edbd5ff2ae6c12172d37f72f1c2dc2105c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 3 Dec 2025 18:08:33 +0000 Subject: [PATCH 3/6] updates --- .../float8/float8_inference_roofline.py | 219 ++++++++---------- 1 file changed, 102 insertions(+), 117 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 2655ada10c..1b1b32a8db 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -57,22 +57,6 @@ from torchao.utils import is_MI300 -def _validate_conv_params( - op_name: str, - kernel_size: Optional[int], - D: Optional[int], - H: Optional[int], - W: Optional[int], -): - """Validate conv operation parameters.""" - if op_name == "conv2d": - assert H is not None and W is not None, "H and W required for conv2d" - assert kernel_size is not None, "kernel_size required for conv2d" - elif op_name == "conv3d": - assert D is not None and H is not None and W is not None, "D, H, W required for conv3d" - assert kernel_size is not None, "kernel_size required for conv3d" - - @torch.no_grad() def get_gpu_kernel_time(m, x, trace_filename=None): # warm up @@ -219,8 +203,6 @@ def get_conv_equivalent_gemm_dims( """ device = torch.device("cuda") - _validate_conv_params(op_name, kernel_size, D, H, W) - if op_name == "conv2d": x = torch.randn(batch, in_channels, H, W, device=device) unfolded = torch.nn.functional.unfold( @@ -244,7 +226,6 @@ def get_conv_equivalent_gemm_dims( D_out = (D - kernel_size + 2 * padding) // stride + 1 _, K_2d, L_2d = unfolded.shape - # GEMM dimensions: account for depth in K gemm_K = K_2d * kernel_size # C * kernel_size³ gemm_M = B * D_out * L_2d gemm_N = out_channels @@ -273,7 +254,6 @@ def run( kernel_size: Optional[int] = None, stride: int = 1, padding: int = 0, - verbose: bool = False, ): """ Args: @@ -291,10 +271,24 @@ def run( * `padding`: padding for conv ops (default: 0) """ _SUPPORTED_OPS = ["linear", "conv2d", "conv3d"] - assert op_name in _SUPPORTED_OPS, f"Unsupported op: {op_name}, supported: {_SUPPORTED_OPS}" - - if op_name in ("conv2d", "conv3d"): - _validate_conv_params(op_name, kernel_size, D, H, W) + assert op_name in _SUPPORTED_OPS, ( + f"Unsupported op: {op_name}, supported are: {_SUPPORTED_OPS}" + ) + + if op_name == "conv2d": + assert H is not None and W is not None, ( + "Expected D, H, W to be specified for conv2d" + ) + assert kernel_size is not None, ( + "Expected kernel_size to be specified for conv2d" + ) + elif op_name == "conv3d": + assert D is not None and H is not None and W is not None, ( + "Expected D, H, W to be specified for conv3d" + ) + assert kernel_size is not None, ( + "Expected kernel_size to be specified for conv3d" + ) config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -318,56 +312,69 @@ def run( M, K, N = sympy.symbols("M K N") - # Create symbolic roofline expressions (same for linear and conv) - fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( - M, K, N, recipe_name, - ) - bf16_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.bfloat16, None - ) - - if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")): - fp8_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.float4_e2m1fn_x2, recipe_name + if op_name == "linear": + fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( + M, + K, + N, + recipe_name, + # TODO(future): also enable fusion modeling here ) - else: - gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None - fp8_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.float8_e4m3fn, gemm_recipe_name + bf16_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.bfloat16, None ) - - print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) - print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) - print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) - print() - - if op_name in ("conv2d", "conv3d"): - print(f"{op_name}: GEMM dimensions derived from conv params") + + if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")): + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float4_e2m1fn_x2, recipe_name + ) + else: + gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float8_e4m3fn, gemm_recipe_name + ) + print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) + print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) + print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) print() - elif op_name != "linear": - raise ValueError(f"Unsupported op_name: {op_name}") + else: + # TODO: enable roofline analysis for conv + pass + + # Note: roofline for conv2d/conv3d is not added yet, so most of the + # things for conv2d/conv3d we'll left out for now headers = [ - # Shape parameters - "fwd_M", "fwd_K", "fwd_N", "D", "H", "W", "kernel_size", - - # Roofline: GEMM time - "r_bf16_gemm_s", "r_fp8_gemm_s", - - # Roofline: FP8 quantization overhead + "fwd_M", + "fwd_K", + "fwd_N", + "D", + "H", + "W", + "kernel_size", + # roofline - gemm time (fwd + bwd, 3 gemms) + "r_bf16_gemm_s", + "r_fp8_gemm_s", + # roofline - fp8 overhead time (by counting reads/writes in the ideal case) "r_fp8_ovhd_s", - - # Roofline: Total (GEMM + quantization) - "r_fp8_gemm_and_ovhd_s", "r_fp8_speedup", - - # Benchmarks: Direct GEMM - "b_bf16_gemm_s", "b_fp8_gemm_s", - - # Benchmarks: End-to-end - "b_bf16_e2e_s", "b_fp8_e2e_s", "b_fp8_e2e_spdp", - - # Roofline vs benchmark ratios - "rb_bf16_gemm_ratio", "rb_fp8_gemm_ratio", + # roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid) + "r_fp8_gemm_and_ovhd_s", + "r_fp8_gemm_and_ovhd_spdp", + # benchmarks - gemm time (fwd + bwd, 3 gemms) + "b_bf16_gemm_s", + "b_fp8_gemm_s", + # benchmarks - e2e LNLinearSigmoid time fwd + bwd + "b_bf16_e2e_s", + "b_fp8_e2e_s", + # note that e2e speedup is not the same as the roofline speedup: + # 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time) + # 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid) + # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple + # we don't break them out and don't have a roofline for them. + "b_fp8_e2e_spdp", + # how well benchmarked gemms match roofline predicted gemms + "rb_bf16_gemm_ratio", + "rb_fp8_gemm_ratio", ] results = [] @@ -394,10 +401,11 @@ def run( ) r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) - - b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 - rb_bf16_gemm_ratio, rb_fp8_gemm_ratio = -1, -1 + # if enabled, also measured observed gemm time + b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 + rb_bf16_gemm_ratio = -1 + rb_fp8_gemm_ratio = -1 if do_benchmarks: # TODO(future): make the bf16 gemm times exactly match the e2e # benchmarks, there is a slight deviation, probably related to gemm @@ -416,43 +424,20 @@ def run( rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s - elif op_name in ("conv2d", "conv3d"): - gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims( - op_name=op_name, - batch=M_val, - in_channels=K_val, - out_channels=N_val, - kernel_size=kernel_size, - D=D, - H=H, - W=W, - stride=stride, - padding=padding, - ) - - # Use pre-computed symbolic expressions (created upfront) - r_bf16_gemm_time_s = float( - bf16_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) - ) - r_fp8_gemm_time_s = float( - fp8_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) - ) - r_fp8_ovhd_time_s = float( - fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) - ) - - r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s - r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) - - print(f" -> GEMM dims: M={gemm_M}, K={gemm_K}, N={gemm_N}") - print(f" -> Speedup: {r_speedup:.3f}x") - - # GEMM benchmarks not yet implemented for conv ops - b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 - rb_bf16_gemm_ratio, rb_fp8_gemm_ratio = -1, -1 - else: - raise ValueError(f"Unsupported op_name: {op_name}") + # roofline analysis for conv2d/conv3d are not added yet + r_bf16_gemm_time_s = None + r_fp8_gemm_time_s = None + r_fp8_ovhd_time_s = None + r_fp8_gemm_and_ovhd_s = None + r_speedup = None + + # real gemm benchmark time, also not added yet + # if enabled, also measured observed gemm time + b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 + # gemm roofline ratio achieved in real benchmark + rb_bf16_gemm_ratio = -1 + rb_fp8_gemm_ratio = -1 b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 if do_benchmarks: @@ -460,23 +445,23 @@ def run( if op_name == "conv2d": if not enable_fusion_modeling: m_orig = nn.Sequential( - nn.Conv2d(K_val, N_val, kernel_size, bias=False) + nn.Conv2d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False) ).to(memory_format=torch.channels_last) else: m_orig = nn.Sequential( nn.ReLU(), - nn.Conv2d(K_val, N_val, kernel_size, bias=False), + nn.Conv2d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False), nn.ReLU(), ).to(memory_format=torch.channels_last) elif op_name == "conv3d": if not enable_fusion_modeling: m_orig = nn.Sequential( - nn.Conv3d(K_val, N_val, kernel_size, bias=False) + nn.Conv3d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False) ).to(memory_format=torch.channels_last_3d) else: m_orig = nn.Sequential( nn.ReLU(), - nn.Conv3d(K_val, N_val, kernel_size, bias=False), + nn.Conv3d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False), nn.ReLU(), ).to(memory_format=torch.channels_last_3d) else: @@ -568,22 +553,22 @@ def run( H, W, kernel_size, - # Roofline: GEMM + # roofline - gemm r_bf16_gemm_time_s, r_fp8_gemm_time_s, - # Roofline: FP8 quantization overhead + # roofline - fp8 overhead r_fp8_ovhd_time_s, - # Roofline: Total (GEMM + quantization) + # roofline - gemm + overhead, and speedup r_fp8_gemm_and_ovhd_s, r_speedup, - # Benchmarks: GEMM + # benchmarks - gemm b_bf16_gemm_time_s, b_fp8_gemm_time_s, - # Benchmarks: e2e + # benchmarks - e2e, and speedup b_bf16_e2e_time_s, b_fp8_e2e_time_s, b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20), - # Roofline vs benchmark ratios + # gemm ratios rb_bf16_gemm_ratio, rb_fp8_gemm_ratio, ] From 828cb02a60ab4bc163c0c3c07ed72ecdd4ada88f Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 3 Dec 2025 21:04:22 +0000 Subject: [PATCH 4/6] updates --- .../float8/float8_inference_roofline.py | 155 +++++++++++------- 1 file changed, 95 insertions(+), 60 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 1b1b32a8db..41ea16d321 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -178,61 +178,51 @@ def get_conv_equivalent_gemm_dims( padding: int = 0, ): """ - Get equivalent GEMM dimensions for a conv operation. - - Uses torch.nn.functional.unfold to derive the correct GEMM dimensions - that correspond to the conv operation. - + Get equivalent GEMM dimensions for a conv operation using analytical calculation. + + Conv operations can be expressed as implicit GEMM. This function computes + the equivalent GEMM dimensions without creating any tensors. + Args: op_name: "conv2d" or "conv3d" batch: Batch size in_channels: Number of input channels out_channels: Number of output channels - kernel_size: Kernel size + kernel_size: Kernel size (assumes square/cubic kernel) D: Depth dimension (required for conv3d) - H: Height dimension (required for conv2d/conv3d) - W: Width dimension (required for conv2d/conv3d) + H: Height dimension + W: Width dimension stride: Stride value padding: Padding value - + Returns: Tuple[int, int, int]: (gemm_M, gemm_K, gemm_N) - gemm_M: Number of output spatial positions - gemm_K: Size of each filter (in_channels * kernel volume) + gemm_M: Number of output spatial positions (batch * spatial_output_size) + gemm_K: Size of each filter (in_channels * kernel_volume) gemm_N: Number of filters (out_channels) """ - device = torch.device("cuda") - if op_name == "conv2d": - x = torch.randn(batch, in_channels, H, W, device=device) - unfolded = torch.nn.functional.unfold( - x, kernel_size=(kernel_size, kernel_size), - stride=stride, padding=padding - ) - batch_out, K, L = unfolded.shape - gemm_M = batch_out * L - gemm_K = K + # Output spatial dimensions + H_out = (H + 2 * padding - kernel_size) // stride + 1 + W_out = (W + 2 * padding - kernel_size) // stride + 1 + + gemm_M = batch * H_out * W_out + gemm_K = in_channels * kernel_size * kernel_size gemm_N = out_channels - + elif op_name == "conv3d": - x = torch.randn(batch, in_channels, D, H, W, device=device) - B, C, D_in, H_in, W_in = x.shape - x_reshaped = x.permute(0, 2, 1, 3, 4).reshape(B * D_in, C, H_in, W_in) - unfolded = torch.nn.functional.unfold( - x_reshaped, kernel_size=(kernel_size, kernel_size), - stride=stride, padding=padding - ) - - D_out = (D - kernel_size + 2 * padding) // stride + 1 - _, K_2d, L_2d = unfolded.shape - - gemm_K = K_2d * kernel_size # C * kernel_size³ - gemm_M = B * D_out * L_2d + # Output spatial dimensions + D_out = (D + 2 * padding - kernel_size) // stride + 1 + H_out = (H + 2 * padding - kernel_size) // stride + 1 + W_out = (W + 2 * padding - kernel_size) // stride + 1 + + gemm_M = batch * D_out * H_out * W_out + gemm_K = in_channels * kernel_size * kernel_size * kernel_size gemm_N = out_channels - + else: raise ValueError(f"Unsupported op_name: {op_name}") - + return gemm_M, gemm_K, gemm_N @@ -312,7 +302,9 @@ def run( M, K, N = sympy.symbols("M K N") - if op_name == "linear": + # Roofline model setup: linear uses M/K/N directly, conv uses equivalent + # implicit GEMM dimensions (computed per-iteration in the loop below) + if op_name in ("linear", "conv2d", "conv3d"): fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( M, K, @@ -338,21 +330,17 @@ def run( print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) print() else: - # TODO: enable roofline analysis for conv pass - # Note: roofline for conv2d/conv3d is not added yet, so most of the - # things for conv2d/conv3d we'll left out for now - headers = [ - "fwd_M", - "fwd_K", - "fwd_N", + "fwd_M", # for conv: batch size + "fwd_K", # for conv: in_channels + "fwd_N", # for conv: out_channels "D", "H", "W", "kernel_size", - # roofline - gemm time (fwd + bwd, 3 gemms) + # roofline - gemm time (fwd + bwd, 3 gemms; for conv: using equivalent implicit gemm dims) "r_bf16_gemm_s", "r_fp8_gemm_s", # roofline - fp8 overhead time (by counting reads/writes in the ideal case) @@ -425,17 +413,36 @@ def run( rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s else: - # roofline analysis for conv2d/conv3d are not added yet - r_bf16_gemm_time_s = None - r_fp8_gemm_time_s = None - r_fp8_ovhd_time_s = None - r_fp8_gemm_and_ovhd_s = None - r_speedup = None - - # real gemm benchmark time, also not added yet - # if enabled, also measured observed gemm time + # For conv ops, compute equivalent GEMM dimensions + # M_val=batch, K_val=in_channels, N_val=out_channels + gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims( + op_name=op_name, + batch=M_val, + in_channels=K_val, + out_channels=N_val, + kernel_size=kernel_size, + D=D, + H=H, + W=W, + stride=stride, + padding=padding, + ) + + # use roofline model to estimate gemm time using equivalent GEMM dims + r_bf16_gemm_time_s = float( + bf16_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_gemm_time_s = float( + fp8_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_ovhd_time_s = float( + fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) + + # gemm benchmarks for conv not implemented, as conv uses implicit GEMM b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 - # gemm roofline ratio achieved in real benchmark rb_bf16_gemm_ratio = -1 rb_fp8_gemm_ratio = -1 @@ -445,23 +452,51 @@ def run( if op_name == "conv2d": if not enable_fusion_modeling: m_orig = nn.Sequential( - nn.Conv2d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False) + nn.Conv2d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) ).to(memory_format=torch.channels_last) else: m_orig = nn.Sequential( nn.ReLU(), - nn.Conv2d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False), + nn.Conv2d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ), nn.ReLU(), ).to(memory_format=torch.channels_last) elif op_name == "conv3d": if not enable_fusion_modeling: m_orig = nn.Sequential( - nn.Conv3d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False) + nn.Conv3d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) ).to(memory_format=torch.channels_last_3d) else: m_orig = nn.Sequential( nn.ReLU(), - nn.Conv3d(K_val, N_val, kernel_size, stride=stride, padding=padding, bias=False), + nn.Conv3d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ), nn.ReLU(), ).to(memory_format=torch.channels_last_3d) else: From e815fd47432bc966e69bcd7da9dd499953547f67 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 3 Dec 2025 14:07:17 -0800 Subject: [PATCH 5/6] updates --- .../float8/float8_inference_roofline.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 41ea16d321..1fd981a27f 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -54,7 +54,7 @@ get_inference_float8_mem_sympy, get_inference_gemm_time_sympy, ) -from torchao.utils import is_MI300 +from torchao.utils import is_MI300, is_sm_at_least_100 @torch.no_grad() @@ -447,7 +447,22 @@ def run( rb_fp8_gemm_ratio = -1 b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 - if do_benchmarks: + # Check hardware requirements for conv operations + skip_conv_benchmarks = ( + do_benchmarks + and op_name in ("conv2d", "conv3d") + and not is_sm_at_least_100() + ) + + if skip_conv_benchmarks: + print( + f"WARNING: Skipping {op_name} benchmarks for shape ({M_val}, {K_val}, {N_val}). " + f"Float8 convolution requires SM 10.0+ (Blackwell/B100 GPUs). " + f"Current GPU: {torch.cuda.get_device_name(0)} with SM {torch.cuda.get_device_capability()}. " + f"Roofline model estimates are still valid." + ) + + if do_benchmarks and not skip_conv_benchmarks: # create the model if op_name == "conv2d": if not enable_fusion_modeling: From 30dc79344ee008db46be0ad19fcf08acee5c40e1 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 3 Dec 2025 18:23:03 -0800 Subject: [PATCH 6/6] minor fixes --- torchao/testing/training/roofline_utils.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index 83e4b516cb..bf234b3717 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -27,18 +27,6 @@ # which would hit about 2.2k GBPS on Meta's H100 variant "pct_achievable_mem_bw": 0.92, }, - "NVIDIA H200": { - # H200 has same compute as H100 but more memory and higher bandwidth - # https://www.nvidia.com/en-us/data-center/h200/, divide by 2 because no sparsity - "bf16_peak_tops": 989e12, - "fp8_peak_tops": 1979e12, - # 4.8 TB per second for H200 (double the standard H100) - "peak_mem_bw_bytes_sec": 4.8e12, - # copy from H100 - "pct_achievable_gemm_tops": 0.78, - # copy from H100 - "pct_achievable_mem_bw": 0.92, - }, "NVIDIA B200": { # https://resources.nvidia.com/en-us-blackwell-architecture, page 19, # divide by 2 because no sparsity