diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 188bb46224..1fd981a27f 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -54,7 +54,7 @@ get_inference_float8_mem_sympy, get_inference_gemm_time_sympy, ) -from torchao.utils import is_MI300 +from torchao.utils import is_MI300, is_sm_at_least_100 @torch.no_grad() @@ -165,6 +165,67 @@ def do_matmul(A, B): return bf16_time_s, f8_time_s +def get_conv_equivalent_gemm_dims( + op_name: str, + batch: int, + in_channels: int, + out_channels: int, + kernel_size: int, + D: Optional[int], + H: int, + W: int, + stride: int = 1, + padding: int = 0, +): + """ + Get equivalent GEMM dimensions for a conv operation using analytical calculation. + + Conv operations can be expressed as implicit GEMM. This function computes + the equivalent GEMM dimensions without creating any tensors. + + Args: + op_name: "conv2d" or "conv3d" + batch: Batch size + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Kernel size (assumes square/cubic kernel) + D: Depth dimension (required for conv3d) + H: Height dimension + W: Width dimension + stride: Stride value + padding: Padding value + + Returns: + Tuple[int, int, int]: (gemm_M, gemm_K, gemm_N) + gemm_M: Number of output spatial positions (batch * spatial_output_size) + gemm_K: Size of each filter (in_channels * kernel_volume) + gemm_N: Number of filters (out_channels) + """ + if op_name == "conv2d": + # Output spatial dimensions + H_out = (H + 2 * padding - kernel_size) // stride + 1 + W_out = (W + 2 * padding - kernel_size) // stride + 1 + + gemm_M = batch * H_out * W_out + gemm_K = in_channels * kernel_size * kernel_size + gemm_N = out_channels + + elif op_name == "conv3d": + # Output spatial dimensions + D_out = (D + 2 * padding - kernel_size) // stride + 1 + H_out = (H + 2 * padding - kernel_size) // stride + 1 + W_out = (W + 2 * padding - kernel_size) // stride + 1 + + gemm_M = batch * D_out * H_out * W_out + gemm_K = in_channels * kernel_size * kernel_size * kernel_size + gemm_N = out_channels + + else: + raise ValueError(f"Unsupported op_name: {op_name}") + + return gemm_M, gemm_K, gemm_N + + def run( outfile: str, recipe_name: str, @@ -181,6 +242,8 @@ def run( H: Optional[int] = None, W: Optional[int] = None, kernel_size: Optional[int] = None, + stride: int = 1, + padding: int = 0, ): """ Args: @@ -189,16 +252,19 @@ def run( * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom` * `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN * `n_limit (optional)`: if specified, only runs `n_limit` iterations - # `save_profile_traces (optional)`: if True, saves profiling traces - # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm - # `op_name`: linear, conv2d or conv3d, decides which op to benchmark - # `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d - # `kernel_size`: kernel_size for conv3d / conv2d + * `save_profile_traces (optional)`: if True, saves profiling traces + * `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm + * `op_name`: linear, conv2d or conv3d, decides which op to benchmark + * `D`, `H`, `W`: spatial dimensions for conv3d / conv2d + * `kernel_size`: kernel_size for conv3d / conv2d + * `stride`: stride for conv ops (default: 1) + * `padding`: padding for conv ops (default: 0) """ _SUPPORTED_OPS = ["linear", "conv2d", "conv3d"] assert op_name in _SUPPORTED_OPS, ( f"Unsupported op: {op_name}, supported are: {_SUPPORTED_OPS}" ) + if op_name == "conv2d": assert H is not None and W is not None, ( "Expected D, H, W to be specified for conv2d" @@ -226,6 +292,8 @@ def run( ["MKN", f"{M} {K} {N}"], ["DHW", f"{D} {H} {W}"], ["kernel_size", kernel_size], + ["stride", stride], + ["padding", padding], ] print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) @@ -234,7 +302,9 @@ def run( M, K, N = sympy.symbols("M K N") - if op_name == "linear": + # Roofline model setup: linear uses M/K/N directly, conv uses equivalent + # implicit GEMM dimensions (computed per-iteration in the loop below) + if op_name in ("linear", "conv2d", "conv3d"): fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( M, K, @@ -260,20 +330,17 @@ def run( print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) print() else: - # TODO: enable roofline analysis for conv pass - # Note: roofline for conv2d/conv3d is not added yet, so most of the - # things for conv2d/conv3d we'll left out for now headers = [ - "fwd_M", - "fwd_K", - "fwd_N", + "fwd_M", # for conv: batch size + "fwd_K", # for conv: in_channels + "fwd_N", # for conv: out_channels "D", "H", "W", "kernel_size", - # roofline - gemm time (fwd + bwd, 3 gemms) + # roofline - gemm time (fwd + bwd, 3 gemms; for conv: using equivalent implicit gemm dims) "r_bf16_gemm_s", "r_fp8_gemm_s", # roofline - fp8 overhead time (by counting reads/writes in the ideal case) @@ -327,7 +394,6 @@ def run( b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 rb_bf16_gemm_ratio = -1 rb_fp8_gemm_ratio = -1 - if do_benchmarks: # TODO(future): make the bf16 gemm times exactly match the e2e # benchmarks, there is a slight deviation, probably related to gemm @@ -347,44 +413,105 @@ def run( rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s else: - # roofline analysis for conv2d/conv3d are not added yet - r_bf16_gemm_time_s = None - r_fp8_gemm_time_s = None + # For conv ops, compute equivalent GEMM dimensions + # M_val=batch, K_val=in_channels, N_val=out_channels + gemm_M, gemm_K, gemm_N = get_conv_equivalent_gemm_dims( + op_name=op_name, + batch=M_val, + in_channels=K_val, + out_channels=N_val, + kernel_size=kernel_size, + D=D, + H=H, + W=W, + stride=stride, + padding=padding, + ) - r_fp8_ovhd_time_s = None - r_fp8_gemm_and_ovhd_s = None - r_speedup = None + # use roofline model to estimate gemm time using equivalent GEMM dims + r_bf16_gemm_time_s = float( + bf16_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_gemm_time_s = float( + fp8_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_ovhd_time_s = float( + fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N) + ) + r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) - # real gemm benchmark time, also not added yet - # if enabled, also measured observed gemm time + # gemm benchmarks for conv not implemented, as conv uses implicit GEMM b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 - # gemm roofline ratio achieved in real benchmark rb_bf16_gemm_ratio = -1 rb_fp8_gemm_ratio = -1 b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 - if do_benchmarks: + # Check hardware requirements for conv operations + skip_conv_benchmarks = ( + do_benchmarks + and op_name in ("conv2d", "conv3d") + and not is_sm_at_least_100() + ) + + if skip_conv_benchmarks: + print( + f"WARNING: Skipping {op_name} benchmarks for shape ({M_val}, {K_val}, {N_val}). " + f"Float8 convolution requires SM 10.0+ (Blackwell/B100 GPUs). " + f"Current GPU: {torch.cuda.get_device_name(0)} with SM {torch.cuda.get_device_capability()}. " + f"Roofline model estimates are still valid." + ) + + if do_benchmarks and not skip_conv_benchmarks: # create the model if op_name == "conv2d": if not enable_fusion_modeling: m_orig = nn.Sequential( - nn.Conv2d(K_val, N_val, kernel_size, bias=False) + nn.Conv2d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) ).to(memory_format=torch.channels_last) else: m_orig = nn.Sequential( nn.ReLU(), - nn.Conv2d(K_val, N_val, kernel_size, bias=False), + nn.Conv2d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ), nn.ReLU(), ).to(memory_format=torch.channels_last) elif op_name == "conv3d": if not enable_fusion_modeling: m_orig = nn.Sequential( - nn.Conv3d(K_val, N_val, kernel_size, bias=False) + nn.Conv3d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) ).to(memory_format=torch.channels_last_3d) else: m_orig = nn.Sequential( nn.ReLU(), - nn.Conv3d(K_val, N_val, kernel_size, bias=False), + nn.Conv3d( + K_val, + N_val, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ), nn.ReLU(), ).to(memory_format=torch.channels_last_3d) else: