Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/quantization/pt2e/test_x86inductor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3047,6 +3047,78 @@ def test_fp8_q_attention_block(self):
annotate_matmul=annotate_matmul, is_fp8=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
def test_fp8_scaled_embedding_bag(self):
dtype = torch.float8_e4m3fn

class FP8QDQEmbeddingBag(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight_scale = 2.0

def forward(
self,
weight,
input,
offsets=None,
):
weight = (
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
tensor=weight.data,
scale=torch.tensor([self.weight_scale]),
output_dtype=torch.float,
)
)

return torch.nn.functional.embedding_bag(
input,
weight,
offsets,
mode="sum",
include_last_offset=True,
)

EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024]
EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512]
EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32]

EMBEDINGBAG_TEST_PARAMS = list(
itertools.product(
EMBEDINGBAG_MULTIHOT_SIZES,
EMBEDINGBAG_BAG_SIZES,
EMBEDINGBAG_VECTOR_SIZES,
EMBEDINGBAG_INDEX_DTYPES,
)
)

for multi_hot, batch_size, vector_size, index_type in EMBEDINGBAG_TEST_PARAMS:
with torch.no_grad():
mod = FP8QDQEmbeddingBag()

weight = torch.randn((1000, vector_size)).to(dtype)
indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type)
offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(
index_type
)

def matcher_check_fn():
self.assertEqual(
counters["inductor"]["scaled_embedding_bag_matcher_count"], 1
)

self._test_common(
mod,
(weight, indices, offsets),
matcher_check_fn,
)


instantiate_parametrized_tests(TestPatternMatcher)
if __name__ == "__main__":
Expand Down
110 changes: 110 additions & 0 deletions torchao/quantization/pt2e/inductor_passes/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import functools
import itertools
import operator
from typing import Any

import torch
Expand Down Expand Up @@ -2851,6 +2852,113 @@ def _register_qlinear_binary_fusion():
)


def _register_scaled_embedding_bag_pass(pattern, pass_number, dtype=torch.float32):
@register_freezing_graph_pattern(
pattern,
pass_number=pass_number,
)
def scaled_embedding_bag(match: Match, *args, **kwargs):
assert dtype in [torch.float32, torch.bfloat16]

getitem_node = match.output_node()
embedding_bag_node = getitem_node.args[0]
assert embedding_bag_node.target is aten._embedding_bag_forward_only.default

embedding_bag_weight_index = 0
if dtype == torch.float32:
# pattern: embedding_bag -> dequant
dequant_node = embedding_bag_node.args[embedding_bag_weight_index]
else:
# pattern: embedding_bag -> to_bf16 -> dequant
weight_to_bf16_node = embedding_bag_node.args[embedding_bag_weight_index]
dequant_node = weight_to_bf16_node.args[0]

assert dequant_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
]

# Weight QParams
qw, w_scale = kwargs["x"], kwargs["x_scale"]

# Input Params
indices, offsets, mode, include_last_offset = (
kwargs["indices"],
kwargs["offsets"],
kwargs["mode"],
kwargs["include_last_offset"],
)
# only support fp32 output, next setp support more dtype
o_scale = 1.0

graph = match.graph
with graph.inserting_before(getitem_node):
new_args: tuple[Any, ...] = (
qw,
indices,
offsets,
w_scale,
o_scale,
mode,
include_last_offset,
torch.float,
)

new_embedding_bag_node = graph.call_function(
torch.ops.torchao._scaled_embedding_bag.default, args=new_args
)

getitem_node.replace_all_uses_with(new_embedding_bag_node)
new_embedding_bag_node.meta.update(embedding_bag_node.meta)

graph.erase_node(getitem_node)
graph.erase_node(embedding_bag_node)
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
# Erase the dequant pattern
graph.erase_node(dequant_node)

counters["inductor"]["scaled_embedding_bag_matcher_count"] += 1
counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes)


def _generate_scaled_embedding_bag_patterns(dq_pattern):
embedding_bag_pattern = CallFunction(
torch.ops.aten._embedding_bag_forward_only.default,
dq_pattern,
KeywordArg("indices"),
KeywordArg("offsets"),
Arg(),
KeywordArg("mode"),
KeywordArg("sparse"),
Arg(),
KeywordArg("include_last_offset"),
)
return CallFunction(
operator.getitem,
embedding_bag_pattern,
KeywordArg("item"),
)


def _register_quantization_embeddingbag_pass():
for dtype in [torch.float32, torch.bfloat16]:
_register_scaled_embedding_bag_pass(
_generate_scaled_embedding_bag_patterns(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(
is_tensor_overload=False, is_fp8=True
),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
),
pass_number=1,
dtype=dtype,
) # pass_number=0 to run before weight prepack


@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
Expand All @@ -2874,6 +2982,8 @@ def _register_quantization_weight_pack_pass():
_register_qlinear_unary_fusion()
_register_qlinear_binary_fusion()

_register_quantization_embeddingbag_pass()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it fail on ARM? If so, put it in the if branch above instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



def quant_lift_up(module_graph: torch.fx.graph.Graph):
"""
Expand Down