Skip to content

Commit 2779937

Browse files
quantizatioon: enable torch.Tensor.matmul quantization (#878)
1 parent 955bcff commit 2779937

File tree

3 files changed

+47
-3
lines changed

3 files changed

+47
-3
lines changed

intel_extension_for_pytorch/ao/quantization/_quantization_state_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
F.linear,
3838
torch._C._nn.linear,
3939
torch.matmul,
40+
torch.Tensor.matmul,
4041
F.embedding_bag,
4142
torch.embedding_bag,
4243
])
@@ -348,8 +349,8 @@ def iterate_and_apply_convert(
348349
args = torch.quantize_per_channel(args, scale, zp, ch_axis, dtype)
349350
args = args.dequantize()
350351
else:
351-
# white list, conv, linear, matmul, we alsy covert it's input to bflat16 firstly, and then inser q+dq
352-
if str(op) in conv_linear_ops + [str(torch.matmul)] + embedding_op or str(type(op)) in conv_linear_modules:
352+
# white list, conv, linear, matmul, we always convert it's input to bflat16 firstly, and then inser q+dq
353+
if str(op) in conv_linear_ops + [str(torch.matmul), str(torch.Tensor.matmul)] + embedding_op or str(type(op)) in conv_linear_modules:
353354
if torch.is_autocast_cpu_enabled() and core.get_autocast_dtype() == torch.bfloat16:
354355
if args.dtype == torch.float32:
355356
args = args.to(torch.bfloat16)

intel_extension_for_pytorch/ao/quantization/_recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
str(nn.SiLU), str(F.silu), str(torch.Tensor.sigmoid), str(torch.sigmoid), str(F.sigmoid), str(nn.Sigmoid), str(F.gelu), str(nn.GELU)]
1616
conv_gemm_ops = [str(F.conv2d), str(nn.Conv2d), str(F.conv3d), str(nn.Conv3d), str(torch.conv2d), str(torch.conv3d), \
1717
str(F.conv_transpose2d), str(torch.nn.ConvTranspose2d), str(F.conv_transpose3d), str(torch.nn.ConvTranspose3d),
18-
str(torch.conv_transpose2d), str(torch.conv_transpose2d), str(F.linear), str(nn.Linear), str(torch.matmul)]
18+
str(torch.conv_transpose2d), str(torch.conv_transpose2d), str(F.linear), str(nn.Linear), str(torch.matmul), str(torch.Tensor.matmul)]
1919
rnn_ops = [str(torch.nn.LSTM)]
2020

2121
# Those ops only support s8->s8 path, and also require the qscheme is per_tensor_symmetric.

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,49 @@ def forward(self, x, y):
10621062
self.assertFused(graph, ['aten::dequantize', 'aten::matmul', 'aten::div'])
10631063
self.checkPatterns(graph, patterns)
10641064

1065+
def test_bmm_method_bf16(self):
1066+
class M(nn.Module):
1067+
def __init__(self):
1068+
super(M, self).__init__()
1069+
1070+
def forward(self, x, y):
1071+
mm_res = x.matmul(y)
1072+
return mm_res
1073+
1074+
x = torch.randn(1, 16, 384, 64) * 0.1
1075+
y = torch.randn(1, 1, 64, 384) * 0.1
1076+
patterns = [
1077+
["aten::to", "aten::quantize_per_tensor"],
1078+
["aten::to", "aten::quantize_per_tensor"],
1079+
["aten::dequantize", "aten::to", "aten::matmul"],
1080+
]
1081+
m = M()
1082+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1083+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1084+
# single aten::to won't be rewritten by llga backend
1085+
self.assertFused(graph, ['aten::dequantize', 'aten::matmul'])
1086+
self.checkPatterns(graph, patterns)
1087+
1088+
def test_bmm_method_fp32(self):
1089+
class M(nn.Module):
1090+
def __init__(self):
1091+
super(M, self).__init__()
1092+
1093+
def forward(self, x, y):
1094+
mm_res = x.matmul(y)
1095+
return mm_res
1096+
1097+
x = torch.randn(1, 16, 384, 64) * 0.1
1098+
y = torch.randn(1, 1, 64, 384) * 0.1
1099+
patterns = [
1100+
["aten::dequantize", "aten::matmul"],
1101+
]
1102+
m = M()
1103+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1104+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1105+
self.assertFused(graph, ['aten::dequantize', 'aten::matmul'])
1106+
self.checkPatterns(graph, patterns)
1107+
10651108
def test_strided_bmm_div_int8_in_bf16_out(self):
10661109
class M(nn.Module):
10671110
def __init__(self):

0 commit comments

Comments
 (0)