Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/quantization/pt2e/test_x86inductor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3047,6 +3047,78 @@ def test_fp8_q_attention_block(self):
annotate_matmul=annotate_matmul, is_fp8=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
def test_fp8_scaled_embedding_bag(self):
dtype = torch.float8_e4m3fn

class FP8QDQEmbeddingBag(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight_scale = 2.0

def forward(
self,
weight,
input,
offsets=None,
):
weight = (
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
tensor=weight.data,
scale=torch.tensor([self.weight_scale]),
output_dtype=torch.float,
)
)

return torch.nn.functional.embedding_bag(
input,
weight,
offsets,
mode="sum",
include_last_offset=True,
)

EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024]
EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512]
EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32]

EMBEDINGBAG_TEST_PARAMS = list(
itertools.product(
EMBEDINGBAG_MULTIHOT_SIZES,
EMBEDINGBAG_BAG_SIZES,
EMBEDINGBAG_VECTOR_SIZES,
EMBEDINGBAG_INDEX_DTYPES,
)
)

for multi_hot, batch_size, vector_size, index_type in EMBEDINGBAG_TEST_PARAMS:
with torch.no_grad():
mod = FP8QDQEmbeddingBag()

weight = torch.randn((1000, vector_size)).to(dtype)
indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type)
offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(
index_type
)

def matcher_check_fn():
self.assertEqual(
counters["inductor"]["scaled_embedding_bag_matcher_count"], 1
)

self._test_common(
mod,
(weight, indices, offsets),
matcher_check_fn,
)


instantiate_parametrized_tests(TestPatternMatcher)
if __name__ == "__main__":
Expand Down
109 changes: 109 additions & 0 deletions torchao/quantization/pt2e/inductor_passes/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import functools
import itertools
import operator
from typing import Any

import torch
Expand Down Expand Up @@ -2851,6 +2852,113 @@ def _register_qlinear_binary_fusion():
)


def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float32):
@register_freezing_graph_pattern(
pattern,
pass_number=pass_number,
)
def scaled_embedding_bag(match: Match, *args, **kwargs):
assert dtype in [torch.float32, torch.bfloat16]

getitem_node = match.output_node()
embedding_bag_node = getitem_node.args[0]
assert embedding_bag_node.target is aten._embedding_bag_forward_only.default

embedding_bag_weight_index = 0
if dtype == torch.float32:
# pattern: embedding_bag -> dequant
dequant_node = embedding_bag_node.args[embedding_bag_weight_index]
else:
# pattern: embedding_bag -> to_bf16 -> dequant
weight_to_bf16_node = embedding_bag_node.args[embedding_bag_weight_index]
dequant_node = weight_to_bf16_node.args[0]

assert dequant_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
]

# Weight QParams
qw, w_scale = kwargs["x"], kwargs["x_scale"]

# Input Params
indices, offsets, mode, include_last_offset = (
kwargs["indices"],
kwargs["offsets"],
kwargs["mode"],
kwargs["include_last_offset"],
)
# only support fp32 output, next step to support more dtype
o_scale = 1.0

graph = match.graph
with graph.inserting_before(getitem_node):
new_args: tuple[Any, ...] = (
qw,
indices,
offsets,
w_scale,
o_scale,
mode,
include_last_offset,
torch.float,
)

new_embedding_bag_node = graph.call_function(
torch.ops.torchao._scaled_embedding_bag.default, args=new_args
)

getitem_node.replace_all_uses_with(new_embedding_bag_node)
new_embedding_bag_node.meta.update(embedding_bag_node.meta)

graph.erase_node(getitem_node)
graph.erase_node(embedding_bag_node)
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
# Erase the dequant pattern
graph.erase_node(dequant_node)

counters["inductor"]["scaled_embedding_bag_matcher_count"] += 1
counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes)


def _generate_scaled_embedding_bag_patterns(dq_pattern):
embedding_bag_pattern = CallFunction(
torch.ops.aten._embedding_bag_forward_only.default,
dq_pattern,
KeywordArg("indices"),
KeywordArg("offsets"),
Arg(),
KeywordArg("mode"),
KeywordArg("sparse"),
Arg(),
KeywordArg("include_last_offset"),
)
return CallFunction(
operator.getitem,
embedding_bag_pattern,
KeywordArg("item"),
)


def _register_quantization_embeddingbag_pass():
for dtype in [torch.float32, torch.bfloat16]:
_register_scaled_embedding_bag_pass(
_generate_scaled_embedding_bag_patterns(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(
is_tensor_overload=False, is_fp8=True
),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
),
pass_number=1,
dtype=dtype,
) # pass_number=0 to run before weight prepack


@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
Expand All @@ -2873,6 +2981,7 @@ def _register_quantization_weight_pack_pass():
_register_qconv_binary_fusion()
_register_qlinear_unary_fusion()
_register_qlinear_binary_fusion()
_register_quantization_embeddingbag_pass()


def quant_lift_up(module_graph: torch.fx.graph.Graph):
Expand Down
Loading