diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 520b5fbdfb..e477ac8d46 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -3047,6 +3047,78 @@ def test_fp8_q_attention_block(self): annotate_matmul=annotate_matmul, is_fp8=True ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), + reason="cpp kernels not built", + ) + def test_fp8_scaled_embedding_bag(self): + dtype = torch.float8_e4m3fn + + class FP8QDQEmbeddingBag(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight_scale = 2.0 + + def forward( + self, + weight, + input, + offsets=None, + ): + weight = ( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + ) + + return torch.nn.functional.embedding_bag( + input, + weight, + offsets, + mode="sum", + include_last_offset=True, + ) + + EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10] + EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024] + EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512] + EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32] + + EMBEDINGBAG_TEST_PARAMS = list( + itertools.product( + EMBEDINGBAG_MULTIHOT_SIZES, + EMBEDINGBAG_BAG_SIZES, + EMBEDINGBAG_VECTOR_SIZES, + EMBEDINGBAG_INDEX_DTYPES, + ) + ) + + for multi_hot, batch_size, vector_size, index_type in EMBEDINGBAG_TEST_PARAMS: + with torch.no_grad(): + mod = FP8QDQEmbeddingBag() + + weight = torch.randn((1000, vector_size)).to(dtype) + indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type) + offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to( + index_type + ) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["scaled_embedding_bag_matcher_count"], 1 + ) + + self._test_common( + mod, + (weight, indices, offsets), + matcher_check_fn, + ) + instantiate_parametrized_tests(TestPatternMatcher) if __name__ == "__main__": diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a0aef11541..7f8a515fcd 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -3,6 +3,7 @@ import copy import functools import itertools +import operator from typing import Any import torch @@ -2851,6 +2852,113 @@ def _register_qlinear_binary_fusion(): ) +def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + pass_number=pass_number, + ) + def scaled_embedding_bag(match: Match, *args, **kwargs): + assert dtype in [torch.float32, torch.bfloat16] + + getitem_node = match.output_node() + embedding_bag_node = getitem_node.args[0] + assert embedding_bag_node.target is aten._embedding_bag_forward_only.default + + embedding_bag_weight_index = 0 + if dtype == torch.float32: + # pattern: embedding_bag -> dequant + dequant_node = embedding_bag_node.args[embedding_bag_weight_index] + else: + # pattern: embedding_bag -> to_bf16 -> dequant + weight_to_bf16_node = embedding_bag_node.args[embedding_bag_weight_index] + dequant_node = weight_to_bf16_node.args[0] + + assert dequant_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] + + # Weight QParams + qw, w_scale = kwargs["x"], kwargs["x_scale"] + + # Input Params + indices, offsets, mode, include_last_offset = ( + kwargs["indices"], + kwargs["offsets"], + kwargs["mode"], + kwargs["include_last_offset"], + ) + # only support fp32 output, next step to support more dtype + o_scale = 1.0 + + graph = match.graph + with graph.inserting_before(getitem_node): + new_args: tuple[Any, ...] = ( + qw, + indices, + offsets, + w_scale, + o_scale, + mode, + include_last_offset, + torch.float, + ) + + new_embedding_bag_node = graph.call_function( + torch.ops.torchao._scaled_embedding_bag.default, args=new_args + ) + + getitem_node.replace_all_uses_with(new_embedding_bag_node) + new_embedding_bag_node.meta.update(embedding_bag_node.meta) + + graph.erase_node(getitem_node) + graph.erase_node(embedding_bag_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + # Erase the dequant pattern + graph.erase_node(dequant_node) + + counters["inductor"]["scaled_embedding_bag_matcher_count"] += 1 + counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes) + + +def _generate_scaled_embedding_bag_patterns(dq_pattern): + embedding_bag_pattern = CallFunction( + torch.ops.aten._embedding_bag_forward_only.default, + dq_pattern, + KeywordArg("indices"), + KeywordArg("offsets"), + Arg(), + KeywordArg("mode"), + KeywordArg("sparse"), + Arg(), + KeywordArg("include_last_offset"), + ) + return CallFunction( + operator.getitem, + embedding_bag_pattern, + KeywordArg("item"), + ) + + +def _register_quantization_embeddingbag_pass(): + for dtype in [torch.float32, torch.bfloat16]: + _register_scaled_embedding_bag_pass( + _generate_scaled_embedding_bag_patterns( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=False, is_fp8=True + ), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + ), + pass_number=1, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 @@ -2873,6 +2981,7 @@ def _register_quantization_weight_pack_pass(): _register_qconv_binary_fusion() _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() + _register_quantization_embeddingbag_pass() def quant_lift_up(module_graph: torch.fx.graph.Graph):