diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 9b37f8c7b29..60892e20551 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -164,6 +164,14 @@ def add_llm_args(parser): default=None, nargs='+') + # cute dsl op configs + parser.add_argument('--use_cute_dsl_blockscaling_mm', + default=False, + action='store_true') + parser.add_argument('--use_cute_dsl_blockscaling_bmm', + default=False, + action='store_true') + return parser @@ -267,6 +275,8 @@ def setup_llm(args, **kwargs): trust_remote_code=args.trust_remote_code, gather_generation_logits=args.return_generation_logits, max_beam_width=args.max_beam_width, + use_cute_dsl_blockscaling_mm=args.use_cute_dsl_blockscaling_mm, + use_cute_dsl_blockscaling_bmm=args.use_cute_dsl_blockscaling_bmm, **kwargs) use_beam_search = args.max_beam_width > 1 diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index d497ace49b2..141e6e5053d 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -5,7 +5,7 @@ from tensorrt_llm.logger import logger -from ..._utils import get_sm_version +from ..._utils import get_sm_version, is_sm_100f from ...math_utils import pad_up from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) @@ -351,6 +351,8 @@ def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \ Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel + from ..cute_dsl_kernels.blackwell.blockwise_gemm.blockwise_gemm import \ + BlockwiseGemmKernel from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \ Sm100BlockScaledPersistentDenseGemmKernel from ..cute_dsl_kernels.blackwell.utils import make_ptr @@ -531,7 +533,7 @@ def forward( **kwargs, ) -> torch.Tensor: """ - Performs fp8 blockwise gemm operation using CuTe DSL. + Performs fp4 blockwise gemm operation using CuTe DSL. Args: inputs (List[torch.Tensor]): @@ -560,6 +562,7 @@ def forward( a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, alpha_tensor = inputs m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0] + # it is mk x nk # Allocate output tensor from UserBuffers or regular CUDA memory if self.to_userbuffers: @@ -2390,3 +2393,669 @@ def _( ): m, k = input.size(0), input.size(1) * 2 return torch.empty(m, k, dtype=output_dtype, device=input.device) + + @cute.jit + def permute(tensor, perm): + """ + General dimension permutation function. + Args: + tensor: Input tensor. + perm: A tuple indicating the new order of dimensions, e.g., (1,2,0) means permuting dimensions 0,1,2 to 1,2,0. + """ + layout = tensor.layout + shapes = cute.shape(layout) # Get original shape + strides = layout.stride # Get original strides + + # Rearrange shape and stride according to perm + new_shapes = tuple(shapes[p] for p in perm) + new_strides = tuple(strides[p] for p in perm) + + # Create new layout and tensor + new_layout = cute.make_layout(new_shapes, stride=new_strides) + return cute.make_tensor(tensor.iterator, new_layout) + + @cute.jit + def append_ones_wrapper(a: cute.Tensor): + a_layout = a.layout + a_layout = cute.append(a_layout, + cute.make_layout(1, stride=1), + up_to_rank=3) + new_a = cute.make_tensor(a.iterator, a_layout) + return new_a + + class CuteDSLFp8BlackwellLinear(TunableRunner): + kernel_cache = dict() + + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(2, 1, fp8_a_sf_m_shape), ), + ) + + def __init__(self): + super().__init__() + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + # Early exit: Check SM version - CuteDSL FP8 is only supported on Blackwell. + if not is_sm_100f(): + logger.debug( + f"CuteDSL: SM version {sm_version} is not supported. " + f"CuteDSL FP8 only supports SM 100 family. Skipping all tactics." + ) + return [] + + m = inputs[0].shape[0] + n = inputs[1].shape[0] + k = inputs[0].shape[1] + batch_size = 1 + # m,k + a_major = "k" + # n, k + b_major = "k" + # m, n + c_major = "n" + + use_2cta_instrs_candi = [False, True] + mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)] + cluster_shape_mn_candi = [ + (1, 1), + (1, 2), + (1, 4), + (2, 1), + (2, 2), + (2, 4), + (4, 1), + (4, 2), + (4, 4), + ] + return [ + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn) + for use_2cta_instrs in use_2cta_instrs_candi + for mma_tiler_mn in mma_tiler_mn_candi + for cluster_shape_mn in cluster_shape_mn_candi + if BlockwiseGemmKernel.can_implement( + cutlass.Float8E4M3FN, # ab_dtype, + cutlass.Float32, # acc_dtype, + cutlass.BFloat16, # c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + batch_size, + a_major, + b_major, + c_major, + ) + ] + + def forward( + self, + inputs: List[torch.Tensor], + tactic, + ) -> torch.Tensor: + """ + Performs fp8 blockwise (deepgemm like) operation using CuTe DSL. + + Args: + inputs (List[torch.Tensor]): + inputs[0]: Input tensor of shape (m, k), dtype: fp8. + inputs[1]: Weight tensor of shape (n, k), dtype: fp8. + inputs[2]: Input scale tensor of shape (k//128, m), dtype: fp32. + inputs[3]: Weight scale tensor of shape (n, k//128), dtype: fp32. + tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn). + + Returns: + torch.Tensor: Output tensor of shape (m, n), dtype: bf16. + """ + if isinstance(tactic, tuple): + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic + else: + # fallback to default tactic + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [ + False, + (128, 128), + (1, 1), + ] + a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs + m, n = a_tensor.shape[0], b_tensor.shape[0] + # TODO: analyze the logic of 'to_userbuffers' + c = torch.empty(*(m, n), dtype=torch.bfloat16, device="cuda") + + def torch_to_cutlass_dtype( + torch_dtype: torch.dtype) -> cutlass.DataType: + if torch_dtype == torch.float8_e4m3fn: + return cutlass.Float8E4M3FN + elif torch_dtype == torch.float8_e5m2: + return cutlass.Float8E5M2 + else: + raise ValueError(f"Unsupported dtype: {torch_dtype}") + + a_ptr = self.make_cute_dsl_global_pointer( + a_tensor, torch_to_cutlass_dtype(a_tensor.dtype), 16) + b_ptr = self.make_cute_dsl_global_pointer( + b_tensor, torch_to_cutlass_dtype(a_tensor.dtype), 16) + a_sf_ptr = self.make_cute_dsl_global_pointer( + a_sf_tensor, cutlass.Float32, 16) + b_sf_ptr = self.make_cute_dsl_global_pointer( + b_sf_tensor, cutlass.Float32, 16) + c_ptr = self.make_cute_dsl_global_pointer(c_tensor, + cutlass.BFloat16, 16) + + # get stream + torch_stream = torch.cuda.current_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + + cache_key = ( + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + ) + if cache_key not in self.__class__.kernel_cache: + gemm = BlockwiseGemmKernel( + cutlass.Float32, # acc_dtype, + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1]) + + compiled_gemm = cute.compile( + gemm.wrapper, + m, + n, + k, + 1, # batch + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_ptr, + max_active_clusters, + stream, + ) + self.__class__.kernel_cache[cache_key] = compiled_gemm + else: + compiled_gemm = self.__class__.kernel_cache[cache_key] + + # launch gemm kernel + compiled_gemm(mA, mB, mC, mSFA, mSFB, stream) + return c + + # a/b: fp8, scale: fp32, output: bf16 + @torch.library.custom_op("trtllm::cute_dsl_fp8_gemm_blackwell", + mutates_args=(), + device_types="cuda") + def cute_dsl_fp8_gemm_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + tuner = AutoTuner.get() + + cute_dsl_fp8_gemm_blackwell_runner = CuteDSLFp8BlackwellLinear() + _, best_tactic = tuner.choose_one( + "trtllm::cute_dsl_fp8_gemm_blackwell::gemm", + [cute_dsl_fp8_gemm_blackwell_runner], + CuteDSLFp8BlackwellLinear.tuning_config, + [input, weight, input_scale, weight_scale], + ) + return cute_dsl_fp8_gemm_blackwell_runner( + inputs=[input, weight, input_scale, weight_scale], + tactic=best_tactic, + ) + + @torch.library.register_fake("trtllm::cute_dsl_fp8_gemm_blackwell") + def _( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + ): + # [m, k] + shape = [i for i in mat_a.shape] + # [n, k] + shape[-1] = mat_b.shape[-2] + # output is fixed as bf16 + ret = mat_a.new_empty(shape, dtype=torch.bfloat16) + return ret + + class CuteDSLFp8BlackwellBmm(TunableRunner): + kernel_cache = dict() + + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 1, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(2, 1, fp8_a_sf_m_shape), ), + ) + + def __init__(self): + super().__init__() + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + + if not is_sm_100f(): + logger.debug( + f"CuteDSL: SM version {sm_version} is not supported. " + f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics." + ) + return [] + # [b, m, k] + batch_size, m, k = inputs[0].shape[0], inputs[0].shape[1], inputs[ + 0].shape[2] + # [b, n, k] + n = inputs[1].shape[1] + # m,k + a_major = "k" + # n, k + b_major = "k" + # m, n + c_major = "n" + + use_2cta_instrs_candi = [False, True] + mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)] + cluster_shape_mn_candi = [ + (1, 1), + (1, 2), + (1, 4), + (2, 1), + (2, 2), + (2, 4), + (4, 1), + (4, 2), + (4, 4), + ] + return [ + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn) + for use_2cta_instrs in use_2cta_instrs_candi + for mma_tiler_mn in mma_tiler_mn_candi + for cluster_shape_mn in cluster_shape_mn_candi + if BlockwiseGemmKernel.can_implement( + cutlass.Float8E4M3FN, # ab_dtype, + cutlass.Float32, # acc_dtype, + cutlass.BFloat16, # c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + batch_size, + a_major, + b_major, + c_major, + ) + ] + + def forward( + self, + inputs: List[torch.Tensor], + tactic, + ) -> None: + """ + Performs fp8 blockwise (deepgemm like) batched gemm operation using CuTe DSL. + + Args: + inputs (List[torch.Tensor]): + inputs[0]: Input tensor of shape (b, m, k), dtype: fp8. + inputs[1]: Weight tensor of shape (b, n, k), dtype: fp8. + inputs[2]: Input scale tensor of shape (b, pad_up(m, 4), ceil_div(k, 128)), dtype: fp32. + inputs[3]: Weight scale tensor of shape (B, Wn, Wk), dtype: fp32. + tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn). + + Returns: + torch.Tensor: Output tensor of shape (b, m, n), dtype: bf16. + """ + if isinstance(tactic, tuple): + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic + else: + # fallback to default tactic + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [ + False, + (128, 128), + (1, 1), + ] + + a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs + batch_size, m, k = a_tensor.shape[0], a_tensor.shape[ + 1], a_tensor.shape[2] + b_tensor.shape[1] + w_n, w_k = b_sf.shape[1], b_sf.shape[2] + + # """ + a_tmp = a.permute(1, 2, 0).view(torch.uint8) + b_tmp = b.permute(1, 2, 0).view(torch.uint8) + c_tmp = c.permute(1, 2, 0) + weight_scale_tmp = b_sf.permute(1, 2, 0) + + m_padded = pad_up(m, 4) + input_scale_tmp = a_sf[0:m_padded * w_k * batch_size] + input_scale_tmp = input_scale_tmp.reshape(batch_size, -1, m_padded) + input_scale_tmp = ( + input_scale_tmp[:batch_size, :w_k, :m].contiguous().permute( + 2, 1, 0)) + + mA = cute.runtime.from_dlpack( + a_tmp, assumed_align=16).mark_layout_dynamic(leading_dim=1) + mB = cute.runtime.from_dlpack( + b_tmp, assumed_align=16).mark_layout_dynamic(leading_dim=1) + mC = cute.runtime.from_dlpack( + c_tmp, assumed_align=16).mark_layout_dynamic(leading_dim=1) + mA.element_type = cutlass.Float8E4M3FN + mB.element_type = cutlass.Float8E4M3FN + + mSFA = cute.runtime.from_dlpack( + input_scale_tmp, + assumed_align=16).mark_layout_dynamic(leading_dim=0) + mSFB = cute.runtime.from_dlpack( + weight_scale_tmp, + assumed_align=16).mark_layout_dynamic(leading_dim=1) + # """ + + # a_tmp = a.view(torch.uint8) + # b_tmp = b.view(torch.uint8) + # mA = from_dlpack( + # a_tmp, assumed_align=16).mark_layout_dynamic(leading_dim=2) + # mB = from_dlpack( + # b_tmp, assumed_align=16).mark_layout_dynamic(leading_dim=2) + # mC = from_dlpack( + # c, assumed_align=16).mark_layout_dynamic(leading_dim=2) + # mA.element_type = cutlass.Float8E4M3FN + # mB.element_type = cutlass.Float8E4M3FN + + # # Note: mSFA is column major + # mSFA = from_dlpack( + # a_sf, assumed_align=16).mark_layout_dynamic(leading_dim=2) + # mSFB = from_dlpack( + # b_sf, assumed_align=16).mark_layout_dynamic(leading_dim=2) + + # get stream + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + cache_key = (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn) + if cache_key not in self.__class__.kernel_cache: + gemm = BlockwiseGemmKernel( + cutlass.Float32, # acc_dtype, + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1]) + + @cute.jit + def bmm_permute_wrapper( + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + a_sf: cute.Tensor, + b_sf: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + a = permute(a, (1, 2, 0)) + b = permute(b, (1, 2, 0)) + c = permute(c, (1, 2, 0)) + a_sf = permute(a_sf, (2, 1, 0)) + b_sf = permute(b_sf, (1, 2, 0)) + gemm(a, b, c, a_sf, b_sf, max_active_clusters, stream) + + compiled_gemm = cute.compile( + gemm, + # bmm_permute_wrapper, + mA, + mB, + mC, + mSFA, + mSFB, + max_active_clusters, + stream, + ) + self.__class__.kernel_cache[cache_key] = compiled_gemm + else: + compiled_gemm = self.__class__.kernel_cache[cache_key] + + # launch gemm kernel + compiled_gemm(mA, mB, mC, mSFA, mSFB, stream) + + # a/b: fp8, scale: fp32, out: bf16 + @torch.library.custom_op("trtllm::cute_dsl_fp8_bmm_blackwell", + mutates_args=(), + device_types="cuda") + def cute_dsl_fp8_bmm_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + out: torch.Tensor, + ) -> None: + tuner = AutoTuner.get() + + cute_dsl_fp8_bmm_blackwell_runner = CuteDSLFp8BlackwellBmm() + + _, best_tactic = tuner.choose_one( + "trtllm::cute_dsl_fp8_bmm_blackwell::gemm", + [cute_dsl_fp8_bmm_blackwell_runner], + CuteDSLFp8BlackwellBmm.tuning_config, + [input, weight, input_scale, weight_scale, out], + ) + cute_dsl_fp8_bmm_blackwell_runner( + inputs=[input, weight, input_scale, weight_scale, out], + tactic=best_tactic, + ) + + @torch.library.register_fake("trtllm::cute_dsl_fp8_bmm_blackwell") + def _( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + out: torch.Tensor, + ) -> None: + batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2] + n = mat_b.shape[1] + if out.dtype != torch.bfloat16: + assert False, "out.dtype != bf16" + if out.shape != (batch_size, m, n): + assert False, "out.shape != (batch_size, m, n)" + + # class CuteDSLFp8BlackwellGroupGemm(TunableRunner): + # kernel_cache = dict() + + # def __init__(self): + # super().__init__() + + # def get_valid_tactics( + # self, + # inputs: List[torch.Tensor], + # profile: OptimizationProfile, + # use_2cta_instrs: bool = False, + # mma_tiler_mn: Tuple[int, int] = (128, 128), + # cluster_shape_mn: Tuple[int, int] = (1, 1), + # **kwargs, + # ) -> List[int]: + # # [m, k] + # m, k = inputs[0].shape[0], inputs[0].shape[1] + # # [group_num, n, k] + # group_num, n, k = inputs[1].shape[0], inputs[1].shape[1], inputs[ + # 1].shape[2] + # # m,k + # a_major = "k" + # # n, k + # b_major = "k" + # # m, n + # c_major = "n" + # is_valid = BlockwiseContiguousGroupedGemmKernel.can_implement( + # cutlass.Float8E4M3FN, # ab_dtype,ab_dtype, + # cutlass.Float32, # acc_dtype, + # cutlass.BFloat16, # c_dtype, + # use_2cta_instrs, + # mma_tiler_mn, + # cluster_shape_mn, + # m, + # n, + # k, + # group_num, + # a_major, + # b_major, + # c_major, + # ) + # if is_valid: + # return [0] + # else: + # return [] + + # def forward( + # self, + # inputs: List[torch.Tensor], + # use_2cta_instrs: bool = False, + # mma_tiler_mn: Tuple[int, int] = (128, 128), + # cluster_shape_mn: Tuple[int, int] = (1, 1), + # tactic: int = -1, + # ) -> torch.Tensor: + # """Performs grouped fp8 gemm operation using CuTe DSL. + # :param a: Input tensor of shape (M, K) + # :type a: torch.Tensor, type: fp8 + # :param b: Weight tensor of shape (G, N, K) + # :type b: torch.Tensor, type: fp8 + # :param a_sf: Input scale tensor of shape (K//128, M). + # :type a_sf: torch.Tensor, type: fp32 + # :param b_sf: Weight scale tensor of shape (G, N//128, K//128) + # :type b_sf: torch.Tensor, type: fp32 + # :return: Output tensor of shape (M, N). Note: N/K should be divisible by 128. + # :rtype: torch.Tensor, type: bf16 + # """ + + # a, b, a_sf, b_sf, group_offset = inputs + # m, n = a.shape[0], b.shape[1] + # c = torch.empty(*(m, n), dtype=torch.bfloat16, device="cuda") + + # a_tmp = a.view(torch.uint8) + # b_tmp = b.view(torch.uint8) + + # mA = from_dlpack(a_tmp, + # assumed_align=16).mark_layout_dynamic(leading_dim=1) + # mB = from_dlpack(b_tmp, + # assumed_align=16).mark_layout_dynamic(leading_dim=2) + # mC = from_dlpack(c, assumed_align=16).mark_layout_dynamic(leading_dim=1) + # mA.element_type = cutlass.Float8E4M3FN + # mB.element_type = cutlass.Float8E4M3FN + + # mSFB = from_dlpack(b_sf, + # assumed_align=16).mark_layout_dynamic(leading_dim=2) + # mSFA = from_dlpack(a_sf, + # assumed_align=16).mark_layout_dynamic(leading_dim=1) + # group_offset_cute_tensor = from_dlpack( + # group_offset).mark_layout_dynamic() + + # # get stream + # torch_stream = torch.cuda.current_stream() + # stream = cuda.CUstream(torch_stream.cuda_stream) + + # cache_key = ( + # use_2cta_instrs, + # mma_tiler_mn, + # cluster_shape_mn, + # ) + # if cache_key not in CuteDSLFp8BlackwellGroupGemm.kernel_cache: + # gemm = BlockwiseContiguousGroupedGemmKernel( + # cutlass.Float32, # acc_dtype, + # use_2cta_instrs=use_2cta_instrs, + # mma_tiler_mn=mma_tiler_mn, + # cluster_shape_mn=cluster_shape_mn, + # ) + # # Compute max active clusters on current device + # hardware_info = cutlass.utils.HardwareInfo() + # max_active_clusters = hardware_info.get_max_active_clusters( + # cluster_shape_mn[0] * cluster_shape_mn[1]) + + # @cute.jit + # def group_gemm_permute_wrapper( + # a: cute.Tensor, + # b: cute.Tensor, + # c: cute.Tensor, + # a_sf: cute.Tensor, + # b_sf: cute.Tensor, + # group_offset: cute.Tensor, + # max_active_clusters: cutlass.Constexpr, + # stream: cuda.CUstream, + # ): + # a = append_ones_wrapper(a) + # c = append_ones_wrapper(c) + # b = permute(b, (1, 2, 0)) + # b_sf = permute(b_sf, (1, 2, 0)) + # a_sf = permute(a_sf, (1, 0)) + # a_sf = append_ones_wrapper(a_sf) + # gemm(a, b, c, a_sf, b_sf, group_offset, max_active_clusters, + # stream) + + # compiled_gemm = cute.compile( + # group_gemm_permute_wrapper, + # mA, + # mB, + # mC, + # mSFA, + # mSFB, + # group_offset_cute_tensor, + # max_active_clusters, + # stream, + # ) + # CuteDSLFp8BlackwellGroupGemm.kernel_cache[cache_key] = compiled_gemm + # else: + # compiled_gemm = CuteDSLFp8BlackwellGroupGemm.kernel_cache[cache_key] + + # # launch gemm kernel + # compiled_gemm(mA, mB, mC, mSFA, mSFB, group_offset_cute_tensor, stream) + # return c + + # # a/b: fp8, scale: fp32, out: bf16 + # @torch.library.custom_op("trtllm::cute_dsl_fp8_group_gemm_blackwell", + # mutates_args=(), + # device_types="cuda") + # def cute_dsl_fp8_group_gemm_blackwell( + # input: torch.Tensor, + # weight: torch.Tensor, + # input_scale: torch.Tensor, + # weight_scale: torch.Tensor, + # group_offset: torch.Tensor, + # ) -> torch.Tensor: + + # cute_dsl_fp8_group_gemm_blackwell_runner = CuteDSLFp8BlackwellGroupGemm() + # return cute_dsl_fp8_group_gemm_blackwell_runner( + # inputs=[input, weight, input_scale, weight_scale, group_offset], + # tactic=0, + # use_2cta_instrs=False, + # mma_tiler_mn=(128, 128), + # cluster_shape_mn=(1, 1), + # ) + + # @torch.library.register_fake("trtllm::cute_dsl_fp8_group_gemm_blackwell") + # def _( + # mat_a: torch.Tensor, + # mat_b: torch.Tensor, + # input_scale: torch.Tensor, + # weight_scale: torch.Tensor, + # group_offset: torch.Tensor, + # ) -> torch.Tensor: + # m, k = mat_a.shape[0], mat_a.shape[1] + # num_group, n, k = mat_b.shape[0], mat_b.shape[1], mat_b.shape[2] + # return mat_a.new_empty((m, n), dtype=torch.bfloat16) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/__init__.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py new file mode 100644 index 00000000000..0ee48134edd --- /dev/null +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py @@ -0,0 +1,2559 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This file is copied and modified from cutlass example https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py + +import math +from typing import Tuple, Type, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +""" +High-performance persistent blockwise dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell +architecture using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") +- Matrix B is NxKxL, L is batch dimension, B can be column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Each block will apply the scale factor A +- Each row will apply the scale factor B +- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA + operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final + - Type convert Final matrix to output type. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp8 (e4m3fn) + see detailed valid dtype combinations in below BlockwiseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128/256 +* Mma tiler N must be 128, align with the scaleB requirement +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class BlockwiseGemmKernel: + """This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, + e5m2) and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float8E4M3FN + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float16/BFloat16 + - Other data types are not supported for accuracy issues + + :note: Constraints: + - MMA tiler M must be 64/128/256 + - MMA tiler N must be 128 + - Cluster shape M must be multiple of 2 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = BlockwiseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2), + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell blockwise dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + # Set specialized warp ids + self.acc_update_warp_id = (0, 1, 2, 3) + self.epilog_warp_id = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.scale_warp_id = 10 + self.sched_warp_id = 11 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + ) + ) + self.num_regs_uniform_warps = 64 + self.num_regs_sched_warps = 64 + self.num_regs_epilogue_warps = 216 + self.num_regs_acc_update_warps = 216 + + # Set barrier for epilogue sync and tmem ptr sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 + * len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + # TMEM offset for final accumulator + self.tmem_final_offset = 384 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.scale_granularity_m = 1 + self.scale_granularity_n = 128 + self.scale_granularity_k = 128 + self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m + self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n + self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k + + if self.scale_k_per_tile != 1: + raise ValueError("scale_k_per_tile must be 1") + if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]: + raise ValueError("scale_m_per_tile must be cta_tile_m") + if self.scale_n_per_tile != 1: + raise ValueError("scale_n_per_tile must be 1") + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_scale_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sfa_dtype, + self.sfb_dtype, + self.scale_m_per_tile * self.scale_k_per_tile, + self.scale_n_per_tile * self.scale_k_per_tile, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.sfa_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_m_per_tile, + ), + ) + self.sfb_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_n, self.scale_n_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_n_per_tile, + ), + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type + self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if a.element_type is cutlass.Float32 else None), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if b.element_type is cutlass.Float32 else None), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition(cute.make_identity_layout(c.shape), self.epi_tile) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + tensor_sfa = cute.make_tensor( + sfa.iterator, + cute.make_layout( + ( + (self.scale_granularity_m, sfa.shape[0]), + (self.scale_granularity_k, sfa.shape[1]), + sfa.shape[2], + ), + stride=( + (0, sfa.layout.stride[0]), + (0, sfa.layout.stride[1]), + sfa.layout.stride[2], + ), + ), + ) + tensor_sfb = cute.make_tensor( + sfb.iterator, + cute.make_layout( + ( + (self.scale_granularity_n, sfb.shape[0]), + (self.scale_granularity_k, sfb.shape[1]), + sfb.shape[2], + ), + stride=( + (0, sfb.layout.stride[0]), + (0, sfb.layout.stride[1]), + sfb.layout.stride[2], + ), + ), + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + scale_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_scale_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2] + epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tensor_sfa, + tensor_sfb, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize mainloop scale_pipeline (barrier) and states + scale_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + scale_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + scale_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.scale_mbar_ptr.data_ptr(), + num_stages=self.num_scale_stage, + producer_group=scale_pipeline_producer_group, + consumer_group=scale_pipeline_consumer_group, + defer_sync=True, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize epilogue pipeline (barrier) and states + epi_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.acc_update_warp_id), + ) + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + epi_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.epi_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + defer_sync=True, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + defer_sync=True, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # coordinate + cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl)) + cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl)) + # (bM, bK, loopM, loopK, loopL) + cSFA = cute.local_tile( + cSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + cSFB = cute.local_tile( + cSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # scale viewed as C tensor + sSFA_view_as_C_layout = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + self.cta_tile_shape_mnk[1], + self.num_scale_stage, + ), + stride=((0, 1), 0, self.scale_m_per_tile), + ) + sSFB_view_as_C_layout = cute.make_layout( + ( + self.cta_tile_shape_mnk[0], + (self.scale_granularity_n, self.scale_n_per_tile), + self.num_scale_stage, + ), + stride=(0, (0, 1), self.scale_n_per_tile), + ) + sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout) + sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition global/shared tensor for TMA load A/B + # + # load scaleA/scaleB + atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=mSFA_mkl.element_type.width, + ) + tiled_copy_sfa = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + tiled_copy_sfb = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx) + thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl) + tAsSFA = thr_copy_sfa.partition_D(sSFA) + tAcSFA = thr_copy_sfa.partition_S(cSFA) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl) + tBsSFB = thr_copy_sfb.partition_D(sSFB) + tBcSFB = thr_copy_sfb.partition_S(cSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + # query next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # acquire tile info pipeline + tile_info_pipeline.producer_acquire(tile_info_producer_state) + + # store the tile info + cur_tile_coord = work_tile.tile_idx + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = cur_tile_coord[2] + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + # commit tile info pipeline + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized Scale load warp + # + if warp_idx == self.scale_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + scale_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # + # Prepare the mask for scaleA/scaleB + # + tApSFA = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros(cute.slice_(tAsSFA, (None, None, None, 0))).shape + ), + cutlass.Boolean, + ) + tBpSFB = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros(cute.slice_(tBsSFB, (None, None, None, 0))).shape + ), + cutlass.Boolean, + ) + + # Peek (try_wait) SCALE buffer empty + scale_producer_state.reset_count() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # + # Slice to per mma tile index + # + tAsSFA_pipe = cute.filter_zeros( + tAsSFA[(None, None, None, scale_producer_state.index)] + ) + tBsSFB_pipe = cute.filter_zeros( + tBsSFB[(None, None, None, scale_producer_state.index)] + ) + tAgSFA_k = cute.filter_zeros( + tAgSFA_mkl[ + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + tBgSFB_k = cute.filter_zeros( + tBgSFB_nkl[ + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + + tAcSFA_compact = cute.filter_zeros( + cute.slice_( + tAcSFA, + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + tBcSFB_compact = cute.filter_zeros( + cute.slice_( + tBcSFB, + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])): + tApSFA[((0, 0), i, (0, 0))] = cute.elem_less( + tAcSFA_compact[(i)][0], mSFA_mkl.shape[0] + ) + for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])): + tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less( + tBcSFB_compact[(i)][0], mSFB_nkl.shape[0] + ) + + # Conditionally wait for Scale buffer empty + scale_pipeline.producer_acquire(scale_producer_state, peek_scale_empty_status) + + # load scaleA/scaleB + cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA) + cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB) + + scale_pipeline.producer_commit(scale_producer_state) + + # Peek (try_wait) Scale buffer empty + scale_producer_state.advance() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait Scale buffer empty + # + scale_pipeline.producer_tail(scale_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state) + + # + # Mma mainloop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state, peek_acc_empty_status) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized acc update warps + # + if warp_idx <= self.acc_update_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps) + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc_base, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc, + tRT_tAcc_base, + ) = self.acc_update_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCtAcc_final, + tCgC, + sSFA_view_as_C, + sSFB_view_as_C, + epi_tile, + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + scale_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_scale_stage + ) + + epi_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # initialize the final accumulator + tTR_rAcc_final.fill(0.0) + + tTR_rSFA = cute.make_rmem_tensor( + cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + tTR_rSFB = cute.make_rmem_tensor( + cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + + scale_consumer_state.reset_count() + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait(scale_consumer_state) + + acc_consumer_state.reset_count() + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for scale buffer full + # + scale_pipeline.consumer_wait(scale_consumer_state, peek_scale_full_status) + + tTR_sSFA_slice = cute.slice_( + tTR_sSFA, + (None, None, None, 0, None, scale_consumer_state.index), + ) + tTR_sSFB_slice = cute.slice_( + tTR_sSFB, + (None, None, None, 0, None, scale_consumer_state.index), + ) + + scale_atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA) + cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB) + + # + # Wait for accumulator buffer full + # + + acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Update accumulator by scale factor in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Update accumulator by scale factor + # + tTR_rAcc_subtile = tTR_rAcc_final[(None, None, None, subtile_idx)] + tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)] + tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)] + + acc_vec = tTR_rAcc.load() + final_vec = tTR_rAcc_subtile.load() + scale_a = tTR_rSFA_subtile.load() + scale_b = tTR_rSFB_subtile.load() + scale = scale_a * scale_b + final_vec = acc_vec * scale + final_vec + tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) + + # + # Async arrive accumulator buffer empty + # + scale_pipeline.consumer_release(scale_consumer_state) + scale_consumer_state.advance() + + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) + + tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)] + tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc)) + + # + # Wait for epilogue buffer empty + # + epi_pipeline.producer_acquire(epi_producer_state) + + # copy the accumulator to tensor memory buffer + cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc) + cute.arch.fence_view_async_tmem_store() + + # + # Async arrive epilogue buffer full + # + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]: + cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps) + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + epi_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + num_prev_subtiles = cutlass.Int32(0) + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + mma_tile_coord_mnl[2], + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, epi_consumer_state.index)] + + # + # Wait for accumulator buffer full + # + epi_pipeline.consumer_wait(epi_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = num_prev_subtiles % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + epi_pipeline.consumer_release(epi_consumer_state) + epi_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def acc_update_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + tAcc_final: cute.Tensor, + gC_mnl: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epi_tile: cute.Tile, + ) -> Tuple[ + cute.TiledCopy, + cute.TiledCopy, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + ]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination). + Partition the scale factor tensor for related copy operations. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param tAcc_final: The final accumulator tensor to be copied and partitioned + :type tAcc_final: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param sSFA: The scale factor tensor for A + :type sSFA: cute.Tensor + :param sSFB: The scale factor tensor for B + :type sSFB: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + - tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results + - tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r + - tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r + - tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results + - tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t + :rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + tmem_load_atom = None + tmem_store_atom = None + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + else: + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + else: + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tAcc_final_epi = cute.flat_divide(tAcc_final[((None, None), 0, 0, None)], epi_tile) + + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)]) + tiled_copy_r2t = tcgen05.make_tmem_copy( + tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + sSFA_epi = cute.flat_divide(sSFA, epi_tile) + sSFB_epi = cute.flat_divide(sSFB, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi) + tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc_final_ = cute.make_rmem_tensor( + tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + tTR_rAcc_final = cute.group_modes(tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_)) + + tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi) + tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi) + # (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL) + tRT_rAcc_final_ = cute.make_rmem_tensor( + tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + # (R2T, R2T_M, R2T_N, (EPI_M, EPI_N)) + tRT_rAcc_final = cute.group_modes(tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_)) + + return ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc_final, + tRT_tAcc_final, + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sfa_dtype: Type[cutlass.Numeric], + sfb_dtype: Type[cutlass.Numeric], + sfa_count: int, + sfb_count: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6 + + # Default C stages + num_c_stage = 2 + + # Default ScaleA/B stages + num_scale_stage = 10 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage + sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage + scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024 + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {cutlass.Float32}: + is_valid = False + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + if mma_tiler_mn[1] not in (128,): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not BlockwiseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not BlockwiseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not BlockwiseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement + + @cute.jit + def wrapper( + self, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + l: cutlass.Constexpr, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + a_sf_ptr: cute.Pointer, + b_sf_ptr: cute.Pointer, + c_ptr: cute.Pointer, + max_active_clusters: cutlass.Constexpr, + current_stream: cuda.CUstream, + ): + """Executes the wrapped GEMM kernel with dynamically shaped tensors. + + Args: + m (int): The M dimension of the GEMM problem. + n (int): The N dimension of the GEMM problem. + k (int): The K dimension of the GEMM problem. + l (cutlass.Constexpr): The batch dimension (L) of the GEMM problem. + a_ptr (cute.Pointer): Pointer to the A tensor. + b_ptr (cute.Pointer): Pointer to the B tensor. + a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A. + b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B. + c_ptr (cute.Pointer): Pointer to the C tensor. + max_active_clusters (cutlass.Constexpr): Maximum number of active + clusters. + current_stream (cuda.CUstream): CUDA stream for the operation. + """ + + # m, k, l + a_tensor = cute.make_tensor( + a_ptr, + layout=cute.make_ordered_layout((m, k, l), order=(1, 0, 2)), + ) + # n, k, l + b_tensor = cute.make_tensor( + b_ptr, + layout=cute.make_ordered_layout( + (n, k, l), + order=(1, 0, 2), + ), + ) + # m, n, l + c_tensor = cute.make_tensor( + c_ptr, + layout=cute.make_ordered_layout( + (m, n, l), + order=(1, 0, 2), + ), + ) + # k//128, m, l + sfa_tensor = cute.make_tensor( + a_sf_ptr, + layout=cute.make_ordered_layout( + (math.ceil(k, 128), m, l), + order=(2, 1, 0), + ), + ) + # n//128, k//128, l + sfb_tensor = cute.make_tensor( + b_sf_ptr, + layout=cute.make_ordered_layout( + (math.ceil(n, 128), math.ceil(k, 128), l), + order=(1, 0, 2), + ), + ) + + self( + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + max_active_clusters, + current_stream, + ) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/continuous_grouped_gemm.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/continuous_grouped_gemm.py new file mode 100644 index 00000000000..4f287c41e6c --- /dev/null +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/continuous_grouped_gemm.py @@ -0,0 +1,2518 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This file is copied and modified from cutlass example https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/continuous_grouped_gemm.py + +import math +from typing import Tuple, Type, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +""" +High-performance persistent blockwise contiguous grouped dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture +using CUTE DSL. +- Matrix A is MxKx1, A can be row-major("K"), ValidM is composed of valid m in different groups +- Matrix B is NxKxL, B can be column-major("K"), L is grouped dimension +- Matrix C is MxNx1, C can be row-major("N"), ValidM is composed of valid m in different groups +- Each block will apply the scale factor SFA +- Each row will apply the scale factor SFB +- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB + +Matrix A/C Memory Layout Diagrams: + + ``` + Group 0 Group 1 Group 2 + -+---------+---------+---------+ + | | | | + K| ValidM0 | ValidM1 | ValidM2 | + | | | | + -+---------+---------+---------+ + |<- ValidM ->| + ``` + Note: the Group(L) dimension will be flatted into M dimension, and the rest Group(L) size is 1. + each ValidM will be aligned to 128. + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final + - Type convert Final matrix to output type. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/blockwise_gemm/contiguous_grouped_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 256,4096,4096,16 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/blockwise_gemm/contiguous_grouped_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 256,4096,4096,16 + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp8 (e4m3fn) + see detailed valid dtype combinations in below BlockwiseContiguousGroupedGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128/256 +* Mma tiler N must be 128, align with the scaleB requirement +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class BlockwiseContiguousGroupedGemmKernel: + """This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, e5m2) + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float8E4M3FN + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float16/BFloat16 + - Other data types are not supported for accuracy issues + + :note: Constraints: + - MMA tiler M must be 64/128/256 + - MMA tiler N must be 128 + - Cluster shape M must be multiple of 2 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = BlockwiseContiguousGroupedGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2), + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell blockwise dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + # Set specialized warp ids + self.acc_update_warp_id = (0, 1, 2, 3) + self.epilog_warp_id = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.scale_warp_id = 10 + self.sched_warp_id = 11 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + ) + ) + self.num_regs_uniform_warps = 64 + self.num_regs_sched_warps = 64 + self.num_regs_epilogue_warps = 216 + self.num_regs_acc_update_warps = 216 + + # Set barrier for epilogue sync and tmem ptr sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 + * len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + # TMEM offset for final accumulator + self.tmem_final_offset = 384 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.scale_granularity_m = 1 + self.scale_granularity_n = 128 + self.scale_granularity_k = 128 + self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m + self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n + self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k + + if self.scale_k_per_tile != 1: + raise ValueError("scale_k_per_tile must be 1") + if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]: + raise ValueError("scale_m_per_tile must be cta_tile_m") + if self.scale_n_per_tile != 1: + raise ValueError("scale_n_per_tile must be 1") + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_scale_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sfa_dtype, + self.sfb_dtype, + self.scale_m_per_tile * self.scale_k_per_tile, + self.scale_n_per_tile * self.scale_k_per_tile, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.sfa_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_m_per_tile, + ), + ) + self.sfb_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_n, self.scale_n_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_n_per_tile, + ), + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + gidx_mapping: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param gidx_mapping: Mapping from m index to group index + :type gidx_mapping: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type + self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if a.element_type is cutlass.Float32 else None), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if b.element_type is cutlass.Float32 else None), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition(cute.make_identity_layout(c.shape), self.epi_tile) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + tensor_sfa = cute.make_tensor( + sfa.iterator, + cute.make_layout( + ( + (self.scale_granularity_m, sfa.shape[0]), + (self.scale_granularity_k, sfa.shape[1]), + sfa.shape[2], + ), + stride=( + (0, sfa.layout.stride[0]), + (0, sfa.layout.stride[1]), + sfa.layout.stride[2], + ), + ), + ) + tensor_sfb = cute.make_tensor( + sfb.iterator, + cute.make_layout( + ( + (self.scale_granularity_n, sfb.shape[0]), + (self.scale_granularity_k, sfb.shape[1]), + sfb.shape[2], + ), + stride=( + (0, sfb.layout.stride[0]), + (0, sfb.layout.stride[1]), + sfb.layout.stride[2], + ), + ), + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + scale_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_scale_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2] + epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tensor_sfa, + tensor_sfb, + gidx_mapping, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + gidx_mapping: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize mainloop scale_pipeline (barrier) and states + scale_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + scale_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + scale_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.scale_mbar_ptr.data_ptr(), + num_stages=self.num_scale_stage, + producer_group=scale_pipeline_producer_group, + consumer_group=scale_pipeline_consumer_group, + defer_sync=True, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize epilogue pipeline (barrier) and states + epi_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.acc_update_warp_id), + ) + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + epi_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.epi_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + defer_sync=True, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + defer_sync=True, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # coordinate + cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl)) + cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl)) + # (bM, bK, loopM, loopK, loopL) + cSFA = cute.local_tile( + cSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + cSFB = cute.local_tile( + cSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # scale viewed as C tensor + sSFA_view_as_C_layout = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + self.cta_tile_shape_mnk[1], + self.num_scale_stage, + ), + stride=((0, 1), 0, self.scale_m_per_tile), + ) + sSFB_view_as_C_layout = cute.make_layout( + ( + self.cta_tile_shape_mnk[0], + (self.scale_granularity_n, self.scale_n_per_tile), + self.num_scale_stage, + ), + stride=(0, (0, 1), self.scale_n_per_tile), + ) + sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout) + sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition global/shared tensor for TMA load A/B + # + # load scaleA/scaleB + atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=mSFA_mkl.element_type.width, + ) + tiled_copy_sfa = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + tiled_copy_sfb = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx) + thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl) + tAsSFA = thr_copy_sfa.partition_D(sSFA) + tAcSFA = thr_copy_sfa.partition_S(cSFA) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl) + tBsSFB = thr_copy_sfb.partition_D(sSFB) + tBcSFB = thr_copy_sfb.partition_S(cSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + # query next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # acquire tile info pipeline + tile_info_pipeline.producer_acquire(tile_info_producer_state) + + # get the group info + cur_tile_coord = work_tile.tile_idx + gidx = 0 + if work_tile.is_valid_tile: + gidx = gidx_mapping[cur_tile_coord[0] * self.cta_tile_shape_mnk[0]] + + # store the tile info + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = gidx + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + # commit tile info pipeline + tile_info_pipeline.producer_commit(tile_info_producer_state) + + # advance to next tile + tile_info_producer_state.advance() + + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + gidx = gidx_mapping[cur_tile_coord[0] * self.cta_tile_shape_mnk[0]] + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, 0)] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + # + # Specialized Scale load warp + # + if warp_idx == self.scale_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + scale_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + gidx = gidx_mapping[cur_tile_coord[0] * self.cta_tile_shape_mnk[0]] + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # + # Prepare the mask for scaleA/scaleB + # + tApSFA = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros(cute.slice_(tAsSFA, (None, None, None, 0))).shape + ), + cutlass.Boolean, + ) + tBpSFB = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros(cute.slice_(tBsSFB, (None, None, None, 0))).shape + ), + cutlass.Boolean, + ) + + # Peek (try_wait) SCALE buffer empty + scale_producer_state.reset_count() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # + # Slice to per mma tile index + # + tAsSFA_pipe = cute.filter_zeros( + tAsSFA[(None, None, None, scale_producer_state.index)] + ) + tBsSFB_pipe = cute.filter_zeros( + tBsSFB[(None, None, None, scale_producer_state.index)] + ) + tAgSFA_k = cute.filter_zeros( + tAgSFA_mkl[ + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + 0, + ) + ] + ) + tBgSFB_k = cute.filter_zeros( + tBgSFB_nkl[ + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + + tAcSFA_compact = cute.filter_zeros( + cute.slice_( + tAcSFA, + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + 0, + ), + ) + ) + tBcSFB_compact = cute.filter_zeros( + cute.slice_( + tBcSFB, + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])): + tApSFA[((0, 0), i, (0, 0))] = cute.elem_less( + tAcSFA_compact[(i)][0], mSFA_mkl.shape[0] + ) + for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])): + tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less( + tBcSFB_compact[(i)][0], mSFB_nkl.shape[0] + ) + + # Conditionally wait for Scale buffer empty + scale_pipeline.producer_acquire(scale_producer_state, peek_scale_empty_status) + + # load scaleA/scaleB + cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA) + cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB) + + scale_pipeline.producer_commit(scale_producer_state) + + # Peek (try_wait) Scale buffer empty + scale_producer_state.advance() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait Scale buffer empty + # + scale_pipeline.producer_tail(scale_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # MMA warp don't care about gidx + gidx = 0 + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state) + + # + # Mma mainloop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state, peek_acc_empty_status) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized acc update warps + # + if warp_idx <= self.acc_update_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps) + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc_base, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc, + tRT_tAcc_base, + ) = self.acc_update_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCtAcc_final, + tCgC, + sSFA_view_as_C, + sSFB_view_as_C, + epi_tile, + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + scale_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_scale_stage + ) + + epi_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # Acc update warp don't care about gidx + gidx = 0 + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # initialize the final accumulator + tTR_rAcc_final.fill(0.0) + + tTR_rSFA = cute.make_rmem_tensor( + cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + tTR_rSFB = cute.make_rmem_tensor( + cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + + scale_consumer_state.reset_count() + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait(scale_consumer_state) + + acc_consumer_state.reset_count() + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for scale buffer full + # + scale_pipeline.consumer_wait(scale_consumer_state, peek_scale_full_status) + + tTR_sSFA_slice = cute.slice_( + tTR_sSFA, + (None, None, None, 0, None, scale_consumer_state.index), + ) + tTR_sSFB_slice = cute.slice_( + tTR_sSFB, + (None, None, None, 0, None, scale_consumer_state.index), + ) + + scale_atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA) + cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB) + + # + # Wait for accumulator buffer full + # + + acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Update accumulator by scale factor in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Update accumulator by scale factor + # + tTR_rAcc_subtile = tTR_rAcc_final[(None, None, None, subtile_idx)] + tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)] + tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)] + + acc_vec = tTR_rAcc.load() + final_vec = tTR_rAcc_subtile.load() + scale_a = tTR_rSFA_subtile.load() + scale_b = tTR_rSFB_subtile.load() + scale = scale_a * scale_b + final_vec = acc_vec * scale + final_vec + tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) + + # + # Async arrive accumulator buffer empty + # + scale_pipeline.consumer_release(scale_consumer_state) + scale_consumer_state.advance() + + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) + + tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)] + tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc)) + + # + # Wait for epilogue buffer empty + # + epi_pipeline.producer_acquire(epi_producer_state) + + # copy the accumulator to tensor memory buffer + cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc) + cute.arch.fence_view_async_tmem_store() + + # + # Async arrive epilogue buffer full + # + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]: + cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps) + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + epi_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # Epilogue warp don't care about gidx + gidx = 0 + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = gidx + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + num_prev_subtiles = cutlass.Int32(0) + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + 0, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, epi_consumer_state.index)] + + # + # Wait for accumulator buffer full + # + epi_pipeline.consumer_wait(epi_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = num_prev_subtiles % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + epi_pipeline.consumer_release(epi_consumer_state) + epi_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def acc_update_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + tAcc_final: cute.Tensor, + gC_mnl: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epi_tile: cute.Tile, + ) -> Tuple[ + cute.TiledCopy, + cute.TiledCopy, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + ]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination). + Partition the scale factor tensor for related copy operations. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param tAcc_final: The final accumulator tensor to be copied and partitioned + :type tAcc_final: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param sSFA: The scale factor tensor for A + :type sSFA: cute.Tensor + :param sSFB: The scale factor tensor for B + :type sSFB: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + - tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results + - tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r + - tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r + - tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results + - tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t + :rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + tmem_load_atom = None + tmem_store_atom = None + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + else: + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + else: + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tAcc_final_epi = cute.flat_divide(tAcc_final[((None, None), 0, 0, None)], epi_tile) + + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)]) + tiled_copy_r2t = tcgen05.make_tmem_copy( + tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + sSFA_epi = cute.flat_divide(sSFA, epi_tile) + sSFB_epi = cute.flat_divide(sSFB, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi) + tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc_final_ = cute.make_rmem_tensor( + tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + tTR_rAcc_final = cute.group_modes(tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_)) + + tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi) + tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi) + # (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL) + tRT_rAcc_final_ = cute.make_rmem_tensor( + tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + # (R2T, R2T_M, R2T_N, (EPI_M, EPI_N)) + tRT_rAcc_final = cute.group_modes(tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_)) + + return ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc_final, + tRT_tAcc_final, + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sfa_dtype: Type[cutlass.Numeric], + sfb_dtype: Type[cutlass.Numeric], + sfa_count: int, + sfb_count: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6 + + # Default C stages + num_c_stage = 2 + + # Default ScaleA/B stages + num_scale_stage = 10 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage + sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage + scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024 + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {cutlass.Float32}: + is_valid = False + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + if mma_tiler_mn[1] not in (128,): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + cluster_tiler_m = (cluster_shape_mn[0] // (2 if use_2cta_instrs else 1)) * mma_tiler_mn[0] + # Skip invalid cluster tiler shape since contiguous layout can't handle oob access + # The contiguous layout means the aligned data is stored in a contiguous manner. + # It can't handle runtime oob when alignment is not align with the tile_M, + # since the problem shape of TMA store can't be changed at runtime. + if cluster_tiler_m not in [64, 128]: + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not BlockwiseContiguousGroupedGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not BlockwiseContiguousGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not BlockwiseContiguousGroupedGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 148ec5e2e3f..4e318b7e4ba 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -122,6 +122,10 @@ class ModelConfig(Generic[TConfig]): extra_attrs: Dict = field(default_factory=dict, repr=False, init=False) + # cute dsl op configs + use_cute_dsl_blockscaling_mm: bool = False + use_cute_dsl_blockscaling_bmm: bool = False + _frozen: bool = field(default=False, init=False, repr=False) # If true, ONLY the vision encoder part of the full model is loaded/executed. diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 383ebf82961..0b64e076e26 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -235,6 +235,9 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim + self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm + self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm + qkv_shard_indices_mapping = { "q": (0, self.q_size * (2 if self.attn_output_gate else 1)), "k": @@ -260,7 +263,8 @@ def __init__( force_dynamic_quantization=config.force_dynamic_quantization, disable_deep_gemm=disable_deep_gemm, use_custom_cublas_mm=use_custom_cublas_mm, - fused_weight_shard_indices_mapping=qkv_shard_indices_mapping) + fused_weight_shard_indices_mapping=qkv_shard_indices_mapping, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) @@ -279,7 +283,8 @@ def __init__( allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, disable_deep_gemm=disable_deep_gemm, - use_custom_cublas_mm=use_custom_cublas_mm) + use_custom_cublas_mm=use_custom_cublas_mm, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend @@ -660,6 +665,7 @@ def fp8_block_scaling_bmm_out( mat2_scale: torch.Tensor, out: torch.Tensor, mat2_dequant: Optional[torch.Tensor] = None, + use_cute_dsl_blockscaling_bmm: bool = False, ) -> torch.Tensor: sm_version = get_sm_version() if sm_version == 90 or sm_version == 89 or sm_version == 120: @@ -673,7 +679,17 @@ def fp8_block_scaling_bmm_out( out.copy_(output) elif is_sm_100f(sm_version): - torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out) + if use_cute_dsl_blockscaling_bmm: + mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( + mat1) + torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8, + mat1_scale, mat2_scale, + out) + mat1_scale = None + else: + torch.bmm(mat1.transpose(0, 1), + mat2_dequant.transpose(1, 2), + out=out) else: raise NotImplementedError(f"SM{sm_version} is not supported") @@ -816,6 +832,9 @@ def __init__( quant_config = config.get_quant_config() self.quant_config = quant_config + self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm + self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm + if not self.is_lite: self.kv_a_proj_with_mqa = Linear( hidden_size, @@ -825,7 +844,8 @@ def __init__( quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank, eps=rms_norm_eps, @@ -841,7 +861,8 @@ def __init__( quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) else: self.kv_a_proj_with_mqa = Linear( hidden_size, @@ -851,7 +872,8 @@ def __init__( quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.q_proj = Linear( self.q_lora_rank, @@ -863,7 +885,8 @@ def __init__( quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.q_b_proj = self.q_proj self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank, @@ -880,7 +903,8 @@ def __init__( quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) # This parameter will view into self.kv_b_proj.weight after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. # Used in forward_absorption only @@ -912,7 +936,8 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, reduce_output=reduce_output, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: @@ -1048,7 +1073,7 @@ def create_weights(self): ), requires_grad=False, ) - if is_sm_100f(): + if is_sm_100f() and not self.use_cute_dsl_blockscaling_bmm: assert self.dtype == torch.bfloat16 self.k_b_proj_trans_dequant = nn.Parameter( torch.empty( @@ -1809,6 +1834,7 @@ def forward_absorption_generation( self.k_b_proj_trans_scale, q_nope_out, self.k_b_proj_trans_dequant, + self.use_cute_dsl_blockscaling_bmm, ), lambda: self.mqa.mla_rope_generation( fused_q, @@ -1887,6 +1913,7 @@ def forward_absorption_generation( self.v_b_proj_scale, attn_output.transpose(0, 1), self.v_b_proj_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -1942,6 +1969,7 @@ def forward_absorption_context( self.k_b_proj_trans_scale, q_nope_out, self.k_b_proj_trans_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -1997,6 +2025,7 @@ def forward_absorption_context( self.v_b_proj_scale, attn_output.transpose(0, 1), self.v_b_proj_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -2064,6 +2093,7 @@ def forward_sparse_mla_kvcache_bf16( self.k_b_proj_trans_scale, q_nope_out, self.k_b_proj_trans_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -2140,6 +2170,7 @@ def forward_sparse_mla_kvcache_bf16( self.v_b_proj_scale, attn_output.transpose(0, 1), self.v_b_proj_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 5c90a364934..799e3c2c1d8 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -740,10 +740,9 @@ def apply(self, module: Linear, input: torch.Tensor, if is_sm_100f(): if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm: - # TODO (@lmin): replace with cute_dsl gemm act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( input) - output = torch.ops.trtllm.fp8_block_scaling_gemm( + output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell( act_input_fp8, module.weight, act_input_sf, module.weight_scale) else: diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index dac655b1c33..c2f6a09a45b 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -300,6 +300,13 @@ def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]): return scale_shape * 2 +def fp8_a_sf_m_shape(input_shapes: List[List[int]]): + input_shape = input_shapes[0] + assert len(input_shape) == 2 or len(input_shape) == 3 + has_batch = len(input_shape) == 3 + return align(m, 4) if has_batch else m + + _enable_piecewise_cuda_graph = True diff --git a/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py b/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py index 5c60b772986..f78256f40e4 100644 --- a/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py +++ b/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py @@ -109,6 +109,51 @@ def test_fp8_block_scale_gemm(dtype, m, k, n): torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3) +@pytest.mark.skipif( + not isSM100Family(), + reason="The test is for Blackwell. Current SM is %d." % getSMVersion(), +) +@pytest.mark.parametrize( + "k, n", + [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), + (2048, 7168), (1024, 1024)], +) +@pytest.mark.parametrize( + "m", + [7, 64, 128, 4096], +) +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16], +) +def test_cute_dsl_fp8_block_scale_gemm(dtype, m, k, n): + + torch.random.manual_seed(0) + a = torch.randn((m, k), device='cuda', dtype=dtype) / k + b = torch.randn((n, k), device='cuda', dtype=dtype) / k + + act_a_fp8, act_a_sf = torch.ops.trtllm.fp8_quantize_1x128(a) + print(act_a_fp8.shape, act_a_sf.shape) + act_b_fp8, act_b_sf = per_block_cast_to_fp8(b) + print(act_b_fp8.shape, act_b_sf.shape) + + output_expected = a @ b.t() + + with autotune(): + cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell( + act_a_fp8, act_b_fp8, act_a_sf, act_b_sf) + + # test Cute DSL kernel + cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell( + act_a_fp8, act_b_fp8, act_a_sf, act_b_sf) + diff = calc_diff(cute_dsl_output, output_expected) + assert diff < 1e-3 + torch.testing.assert_close(cute_dsl_output, + output_expected, + atol=1e-3, + rtol=1e-3) + + @pytest.mark.skipif( getSMVersion() != 90 and getSMVersion() != 89 and getSMVersion() != 120, reason="The test is for Hopper and Ada only. Current SM is %d." % @@ -158,6 +203,60 @@ def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups): torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3) +@pytest.mark.skipif( + not isSM100Family(), + reason="The test is for Blackwell. Current SM is %d." % getSMVersion(), +) +@pytest.mark.parametrize( + "k, n", + [(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)], +) +@pytest.mark.parametrize( + "m", + [7, 64, 128], +) +@pytest.mark.parametrize( + "num_groups", + [4, 8, 16], +) +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16], +) +def test_cute_dsl_fp8_block_scale_bmm(dtype, m, k, n, num_groups): + + torch.random.manual_seed(0) + a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k + + a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a) + + print(a_fp8.shape, a_scales.shape) + + b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k + b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn) + b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128), + device='cuda', + dtype=torch.float) + + for i in range(num_groups): + b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i]) + + output_expected = torch.einsum('mgk,gnk->gmn', a, b) + output = torch.empty((num_groups, m, n), + device='cuda', + dtype=torch.bfloat16) + # tune + with autotune(): + torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, b_fp8, a_scales, + b_scales, output) + # run the tuned kernel + torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, b_fp8, a_scales, + b_scales, output) + diff = calc_diff(output, output_expected) + assert diff < 1e-3 + torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3) + + def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA, valsB, dqSfsB, quantizeOutput, tileSize): for mi in range(mM):