From b9b422981b89a58eaeba34c2318a8e6462baae04 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Sun, 30 Nov 2025 20:47:41 -0500 Subject: [PATCH 1/3] support fp8 scaled_embedding_bag pattern match --- .../pt2e/test_x86inductor_fusion.py | 72 ++++++++++++ .../quantization/pt2e/inductor_passes/x86.py | 110 ++++++++++++++++++ 2 files changed, 182 insertions(+) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 520b5fbdfb..e477ac8d46 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -3047,6 +3047,78 @@ def test_fp8_q_attention_block(self): annotate_matmul=annotate_matmul, is_fp8=True ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), + reason="cpp kernels not built", + ) + def test_fp8_scaled_embedding_bag(self): + dtype = torch.float8_e4m3fn + + class FP8QDQEmbeddingBag(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight_scale = 2.0 + + def forward( + self, + weight, + input, + offsets=None, + ): + weight = ( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + ) + + return torch.nn.functional.embedding_bag( + input, + weight, + offsets, + mode="sum", + include_last_offset=True, + ) + + EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10] + EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024] + EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512] + EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32] + + EMBEDINGBAG_TEST_PARAMS = list( + itertools.product( + EMBEDINGBAG_MULTIHOT_SIZES, + EMBEDINGBAG_BAG_SIZES, + EMBEDINGBAG_VECTOR_SIZES, + EMBEDINGBAG_INDEX_DTYPES, + ) + ) + + for multi_hot, batch_size, vector_size, index_type in EMBEDINGBAG_TEST_PARAMS: + with torch.no_grad(): + mod = FP8QDQEmbeddingBag() + + weight = torch.randn((1000, vector_size)).to(dtype) + indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type) + offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to( + index_type + ) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["scaled_embedding_bag_matcher_count"], 1 + ) + + self._test_common( + mod, + (weight, indices, offsets), + matcher_check_fn, + ) + instantiate_parametrized_tests(TestPatternMatcher) if __name__ == "__main__": diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a0aef11541..2db27c7de9 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -3,6 +3,7 @@ import copy import functools import itertools +import operator from typing import Any import torch @@ -2851,6 +2852,113 @@ def _register_qlinear_binary_fusion(): ) +def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + pass_number=pass_number, + ) + def scaled_embedding_bag(match: Match, *args, **kwargs): + assert dtype in [torch.float32, torch.bfloat16] + + getitem_node = match.output_node() + embedding_bag_node = getitem_node.args[0] + assert embedding_bag_node.target is aten._embedding_bag_forward_only.default + + embedding_bag_weight_index = 0 + if dtype == torch.float32: + # pattern: embedding_bag -> dequant + dequant_node = embedding_bag_node.args[embedding_bag_weight_index] + else: + # pattern: embedding_bag -> to_bf16 -> dequant + weight_to_bf16_node = embedding_bag_node.args[embedding_bag_weight_index] + dequant_node = weight_to_bf16_node.args[0] + + assert dequant_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] + + # Weight QParams + qw, w_scale = kwargs["x"], kwargs["x_scale"] + + # Input Params + indices, offsets, mode, include_last_offset = ( + kwargs["indices"], + kwargs["offsets"], + kwargs["mode"], + kwargs["include_last_offset"], + ) + # only support fp32 output, next setp support more dtype + o_scale = 1.0 + + graph = match.graph + with graph.inserting_before(getitem_node): + new_args: tuple[Any, ...] = ( + qw, + indices, + offsets, + w_scale, + o_scale, + mode, + include_last_offset, + torch.float, + ) + + new_embedding_bag_node = graph.call_function( + torch.ops.torchao._scaled_embedding_bag.default, args=new_args + ) + + getitem_node.replace_all_uses_with(new_embedding_bag_node) + new_embedding_bag_node.meta.update(embedding_bag_node.meta) + + graph.erase_node(getitem_node) + graph.erase_node(embedding_bag_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + # Erase the dequant pattern + graph.erase_node(dequant_node) + + counters["inductor"]["scaled_embedding_bag_matcher_count"] += 1 + counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes) + + +def _generate_scaled_embedding_bag_patterns(dq_pattern): + embedding_bag_pattern = CallFunction( + torch.ops.aten._embedding_bag_forward_only.default, + dq_pattern, + KeywordArg("indices"), + KeywordArg("offsets"), + Arg(), + KeywordArg("mode"), + KeywordArg("sparse"), + Arg(), + KeywordArg("include_last_offset"), + ) + return CallFunction( + operator.getitem, + embedding_bag_pattern, + KeywordArg("item"), + ) + + +def _register_quantization_embeddingbag_pass(): + for dtype in [torch.float32, torch.bfloat16]: + _register_scaled_embedding_bag_pass( + _generate_scaled_embedding_bag_patterns( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=False, is_fp8=True + ), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + ), + pass_number=1, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 @@ -2874,6 +2982,8 @@ def _register_quantization_weight_pack_pass(): _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() + _register_quantization_embeddingbag_pass() + def quant_lift_up(module_graph: torch.fx.graph.Graph): """ From 49bc7387e3748264cf7138cecd042a288b7d037a Mon Sep 17 00:00:00 2001 From: shiyang-weng Date: Thu, 4 Dec 2025 15:10:37 +0800 Subject: [PATCH 2/3] Update torchao/quantization/pt2e/inductor_passes/x86.py Co-authored-by: Xia Weiwen --- torchao/quantization/pt2e/inductor_passes/x86.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 2db27c7de9..d5f6a22bed 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2889,7 +2889,7 @@ def scaled_embedding_bag(match: Match, *args, **kwargs): kwargs["mode"], kwargs["include_last_offset"], ) - # only support fp32 output, next setp support more dtype + # only support fp32 output, next step to support more dtype o_scale = 1.0 graph = match.graph From f6519e34c19476136e83da0293b1273298d78e35 Mon Sep 17 00:00:00 2001 From: shiyang-weng Date: Thu, 4 Dec 2025 15:24:13 +0800 Subject: [PATCH 3/3] refine code --- torchao/quantization/pt2e/inductor_passes/x86.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index d5f6a22bed..7f8a515fcd 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2981,8 +2981,7 @@ def _register_quantization_weight_pack_pass(): _register_qconv_binary_fusion() _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() - - _register_quantization_embeddingbag_pass() + _register_quantization_embeddingbag_pass() def quant_lift_up(module_graph: torch.fx.graph.Graph):