5353from torchao .testing .training .roofline_utils import (
5454 get_inference_float8_mem_sympy ,
5555 get_inference_gemm_time_sympy ,
56- get_specs ,
57- BYTES_PER_EL_BF16 ,
58- BYTES_PER_EL_FLOAT8 ,
59- KERNEL_LAUNCH_OVERHEAD_SEC ,
6056)
6157from torchao .utils import is_MI300
6258
@@ -198,7 +194,10 @@ def get_conv_equivalent_gemm_dims(
198194 padding : int = 0 ,
199195):
200196 """
201- Get GEMM dimensions from unfold.
197+ Get equivalent GEMM dimensions for a conv operation.
198+
199+ Uses torch.nn.functional.unfold to derive the correct GEMM dimensions
200+ that correspond to the conv operation.
202201
203202 Args:
204203 op_name: "conv2d" or "conv3d"
@@ -235,10 +234,6 @@ def get_conv_equivalent_gemm_dims(
235234
236235 elif op_name == "conv3d" :
237236 x = torch .randn (batch , in_channels , D , H , W , device = device )
238-
239- # Note: torch.nn.Unfold only supports 4-D tensors
240- # (https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html)
241- # For 3D conv, reshape (B,C,D,H,W) -> (B*D,C,H,W) and unfold H,W
242237 B , C , D_in , H_in , W_in = x .shape
243238 x_reshaped = x .permute (0 , 2 , 1 , 3 , 4 ).reshape (B * D_in , C , H_in , W_in )
244239 unfolded = torch .nn .functional .unfold (
@@ -260,179 +255,6 @@ def get_conv_equivalent_gemm_dims(
260255 return gemm_M , gemm_K , gemm_N
261256
262257
263- def benchmark_im2col_unfold (
264- op_name : str ,
265- batch : int ,
266- in_channels : int ,
267- kernel_size : int ,
268- D : Optional [int ],
269- H : int ,
270- W : int ,
271- stride : int = 1 ,
272- padding : int = 0 ,
273- dtype = torch .bfloat16 ,
274- ):
275- """
276- Benchmark unfold operation.
277-
278- Args:
279- op_name: "conv2d" or "conv3d"
280- batch: Batch size
281- in_channels: Number of input channels
282- kernel_size: Kernel size
283- D: Depth dimension (required for conv3d)
284- H: Height dimension (required for conv2d/conv3d)
285- W: Width dimension (required for conv2d/conv3d)
286- stride: Stride value
287- padding: Padding value
288- dtype: Data type
289-
290- Returns:
291- Measured time in seconds
292- """
293- device = torch .device ("cuda" )
294-
295- _validate_conv_params (op_name , kernel_size , D , H , W )
296-
297- # Unfold doesn't support FP8; return -1 for unsupported dtypes
298- if dtype not in (torch .bfloat16 , torch .float16 , torch .float32 ):
299- return - 1
300-
301- # Create input tensor
302- if op_name == "conv2d" :
303- x = torch .randn (batch , in_channels , H , W , dtype = dtype , device = device )
304- elif op_name == "conv3d" :
305- x = torch .randn (batch , in_channels , D , H , W , dtype = dtype , device = device )
306- else :
307- raise ValueError (f"Unsupported op_name: { op_name } " )
308-
309- def _run_unfold ():
310- if op_name == "conv2d" :
311- return torch .nn .functional .unfold (
312- x , kernel_size = (kernel_size , kernel_size ), stride = stride , padding = padding
313- )
314- else : # conv3d: reshape to 4D since unfold only supports 4D
315- B , C , D_in , H_in , W_in = x .shape
316- x_reshaped = x .permute (0 , 2 , 1 , 3 , 4 ).reshape (B * D_in , C , H_in , W_in )
317- return torch .nn .functional .unfold (
318- x_reshaped , kernel_size = (kernel_size , kernel_size ), stride = stride , padding = padding
319- )
320-
321- # Warm up
322- for _ in range (2 ):
323- _ = _run_unfold ()
324- torch .cuda .synchronize ()
325-
326- # Benchmark
327- n_iter = 10
328- start = torch .cuda .Event (enable_timing = True )
329- end = torch .cuda .Event (enable_timing = True )
330-
331- start .record ()
332- for _ in range (n_iter ):
333- _ = _run_unfold ()
334- end .record ()
335- torch .cuda .synchronize ()
336-
337- return start .elapsed_time (end ) / 1000.0 / n_iter
338-
339-
340- def get_im2col_memory_overhead_sympy (
341- op_name : str ,
342- batch : int ,
343- in_channels : int ,
344- out_channels : int ,
345- kernel_size : int ,
346- D : Optional [int ],
347- H : int ,
348- W : int ,
349- stride : int = 1 ,
350- padding : int = 0 ,
351- dtype = torch .bfloat16 ,
352- gpu_name : Optional [str ] = None ,
353- ):
354- """
355- Calculate the memory overhead for im2col transformation in conv operations.
356-
357- Im2col unfolds the input tensor into a 2D matrix for efficient GEMM computation.
358- This involves:
359- 1. Reading the input tensor (batch × in_channels × spatial_dims)
360- 2. Writing the im2col matrix (output_spatial_positions × kernel_volume)
361-
362- The im2col matrix is typically much larger than the input due to overlapping
363- windows, especially with stride=1 and larger kernels.
364-
365- Args:
366- op_name: "conv2d" or "conv3d"
367- batch: Batch size
368- in_channels: Number of input channels
369- out_channels: Number of output channels
370- kernel_size: Kernel size
371- D: Depth dimension (required for conv3d)
372- H: Height dimension (required for conv2d/conv3d)
373- W: Width dimension (required for conv2d/conv3d)
374- stride: Stride value
375- padding: Padding value
376- dtype: Data type
377- gpu_name: GPU name for specs
378-
379- Returns:
380- sympy expression for im2col memory overhead in seconds
381- """
382- _validate_conv_params (op_name , kernel_size , D , H , W )
383- specs = get_specs (gpu_name )
384-
385- # Determine bytes per element based on dtype
386- if dtype == torch .bfloat16 :
387- bytes_per_el = BYTES_PER_EL_BF16
388- elif dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
389- bytes_per_el = BYTES_PER_EL_FLOAT8
390- else :
391- bytes_per_el = BYTES_PER_EL_BF16 # default
392-
393- if op_name == "conv2d" :
394-
395- # Input size
396- input_numel = batch * in_channels * H * W
397-
398- # Output spatial dimensions
399- H_out = (H - kernel_size + 2 * padding ) // stride + 1
400- W_out = (W - kernel_size + 2 * padding ) // stride + 1
401-
402- # Im2col matrix size: (batch * H_out * W_out) × (in_channels * kernel_size^2)
403- im2col_numel = batch * H_out * W_out * in_channels * kernel_size * kernel_size
404-
405- elif op_name == "conv3d" :
406- # Input size
407- input_numel = batch * in_channels * D * H * W
408-
409- # Output spatial dimensions
410- D_out = (D - kernel_size + 2 * padding ) // stride + 1
411- H_out = (H - kernel_size + 2 * padding ) // stride + 1
412- W_out = (W - kernel_size + 2 * padding ) // stride + 1
413-
414- # Im2col matrix size: (batch * D_out * H_out * W_out) × (in_channels * kernel_size^3)
415- im2col_numel = batch * D_out * H_out * W_out * in_channels * kernel_size * kernel_size * kernel_size
416-
417- else :
418- raise ValueError (f"Unsupported op_name: { op_name } " )
419-
420- # Memory traffic: read input + write im2col matrix
421- # Note: In practice, some implementations may avoid materializing the full im2col
422- # matrix, but we model the worst case here
423- bytes_read = input_numel * bytes_per_el
424- bytes_write = im2col_numel * bytes_per_el
425- total_bytes = bytes_read + bytes_write
426-
427- # Convert to time using memory bandwidth
428- im2col_time_s = total_bytes / specs ["peak_mem_bw_bytes_sec" ] / specs ["pct_achievable_mem_bw" ]
429-
430- # Account for kernel launch overhead
431- im2col_time_s = sympy .Max (im2col_time_s , KERNEL_LAUNCH_OVERHEAD_SEC )
432-
433- return im2col_time_s
434-
435-
436258def run (
437259 outfile : str ,
438260 recipe_name : str ,
@@ -513,7 +335,7 @@ def run(
513335 print ()
514336
515337 if op_name in ("conv2d" , "conv3d" ):
516- print (f"{ op_name } : GEMM dimensions from unfold, roofline from symbolic expressions " )
338+ print (f"{ op_name } : GEMM dimensions derived from conv params " )
517339 print ()
518340 elif op_name != "linear" :
519341 raise ValueError (f"Unsupported op_name: { op_name } " )
@@ -525,17 +347,11 @@ def run(
525347 # Roofline: GEMM time
526348 "r_bf16_gemm_s" , "r_fp8_gemm_s" ,
527349
528- # Roofline: im2col overhead
529- "r_im2col_bf16_s" , "r_im2col_fp8_s" ,
530-
531350 # Roofline: FP8 quantization overhead
532351 "r_fp8_ovhd_s" ,
533352
534- # Roofline: GEMM-only metrics
535- "r_fp8_gemm_and_ovhd_s" , "r_fp8_gemm_and_ovhd_spdp" ,
536-
537- # Roofline: Total (im2col + GEMM + quantization)
538- "r_bf16_total_s" , "r_fp8_total_s" , "r_fp8_total_spdp" ,
353+ # Roofline: Total (GEMM + quantization)
354+ "r_fp8_gemm_and_ovhd_s" , "r_fp8_speedup" ,
539355
540356 # Benchmarks: Direct GEMM
541357 "b_bf16_gemm_s" , "b_fp8_gemm_s" ,
@@ -572,11 +388,6 @@ def run(
572388 r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
573389 r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
574390
575- # Linear ops don't have im2col overhead
576- r_im2col_bf16_s , r_im2col_fp8_s = 0 , 0
577- r_bf16_total_s = r_bf16_gemm_time_s
578- r_fp8_total_s = r_fp8_gemm_and_ovhd_s
579- r_total_spdp = r_speedup
580391 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
581392 rb_bf16_gemm_ratio , rb_fp8_gemm_ratio = - 1 , - 1
582393
@@ -625,24 +436,11 @@ def run(
625436 fp8_ovhd_time_sympy .subs (M , gemm_M ).subs (K , gemm_K ).subs (N , gemm_N )
626437 )
627438
628- # Compute combined metrics
629439 r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
630440 r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
631441
632- # Roofline im2col overhead (theoretical)
633- r_im2col_bf16_s = float (get_im2col_memory_overhead_sympy (
634- op_name , M_val , K_val , N_val , kernel_size ,
635- D , H , W , stride = 1 , padding = 0 , dtype = torch .bfloat16
636- ))
637- r_im2col_fp8_s = r_im2col_bf16_s * 0.5
638-
639- # Roofline total: im2col + GEMM + quantization
640- r_bf16_total_s = r_bf16_gemm_time_s + r_im2col_bf16_s
641- r_fp8_total_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_im2col_fp8_s
642- r_total_spdp = r_bf16_total_s / r_fp8_total_s
643-
644- print (f" -> Im2col: BF16={ r_im2col_bf16_s * 1e6 :.2f} µs, FP8={ r_im2col_fp8_s * 1e6 :.2f} µs" )
645- print (f" -> Speedup: GEMM only={ r_speedup :.3f} x | Total={ r_total_spdp :.3f} x" )
442+ print (f" -> GEMM dims: M={ gemm_M } , K={ gemm_K } , N={ gemm_N } " )
443+ print (f" -> Speedup: { r_speedup :.3f} x" )
646444
647445 # GEMM benchmarks not yet implemented for conv ops
648446 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
@@ -768,18 +566,11 @@ def run(
768566 # Roofline: GEMM
769567 r_bf16_gemm_time_s ,
770568 r_fp8_gemm_time_s ,
771- # Roofline: im2col
772- r_im2col_bf16_s ,
773- r_im2col_fp8_s ,
774- # Roofline: FP8 quantization
569+ # Roofline: FP8 quantization overhead
775570 r_fp8_ovhd_time_s ,
776- # Roofline: GEMM-only
571+ # Roofline: Total ( GEMM + quantization)
777572 r_fp8_gemm_and_ovhd_s ,
778573 r_speedup ,
779- # Roofline: Total
780- r_bf16_total_s ,
781- r_fp8_total_s ,
782- r_total_spdp ,
783574 # Benchmarks: GEMM
784575 b_bf16_gemm_time_s ,
785576 b_fp8_gemm_time_s ,
0 commit comments