Skip to content

Commit bde089d

Browse files
authored
Fix matmul post scalar op fusion (#1257)
* fix matmul post scalar fusion * rename related codes from div to mul
1 parent ac8b947 commit bde089d

File tree

8 files changed

+182
-21
lines changed

8 files changed

+182
-21
lines changed

csrc/jit/cpu/kernels/Mha.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ at::Tensor dil_transfree_vit_mha(
266266
value.resize_({batchSize, sequenceSize, head_num, head_size})
267267
.transpose_(1, 2);
268268

269-
bmm_impl(query, key, qk, ideep::attr_t(), {}, 1.f / dim_per_head);
269+
bmm_impl(query, key, qk, ideep::attr_t(), {}, dim_per_head);
270270
qk = dil_softmax_(qk, softmax_dim, dtype);
271271

272272
auto output = dil_mha_matmul_trans(qk, value);

csrc/jit/passes/graph_rewrite.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,16 +251,19 @@ void FuseMatmulDivOrMul(std::shared_ptr<Graph>& graph) {
251251
return (%r) )";
252252
std::string fused_matmul_mul = R"(
253253
graph(%x, %y, %z):
254-
%ones : float = prim::Constant[value=1.0]()
255-
%z_ = aten::div(%ones, %z)
256-
%r = ipex::matmul_div(%x, %y, %z_)
254+
%r = ipex::matmul_mul(%x, %y, %z)
257255
return (%r) )";
258256
std::string fused_matmul_mul_with_out = R"(
259257
graph(%x, %y, %z, %out):
260-
%ones : float = prim::Constant[value=1.0]()
261-
%z_ = aten::div(%ones, %z)
262-
%r = ipex::matmul_div(%x, %y, %out, %z_)
258+
%r = ipex::matmul_mul(%x, %y, %out, %z)
263259
return (%r) )";
260+
auto filter_scalar = [](const Match& match,
261+
const std::unordered_map<std::string, Value*>& vmap) {
262+
Node* node = match.anchor;
263+
auto target_value = node->input(1);
264+
return utils::is_scalar(target_value);
265+
};
266+
264267
for (auto const& it : div_ops) {
265268
at::jit::TemplateEnv env;
266269
env.s("div_op", it);
@@ -281,7 +284,7 @@ void FuseMatmulDivOrMul(std::shared_ptr<Graph>& graph) {
281284
aten_mul_pattern.format(env), fused_matmul_mul);
282285
rewriter.RegisterRewritePattern(
283286
aten_mul_pattern_with_out.format(env), fused_matmul_mul_with_out);
284-
rewriter.runOnGraph(graph);
287+
rewriter.runOnGraph(graph, filter_scalar);
285288
}
286289
}
287290

@@ -309,11 +312,16 @@ void PostScalarDivOrMul(std::shared_ptr<Graph>& graph) {
309312
%qk = aten::matmul(%q, %k)
310313
%r = aten::mul(%qk, %scale)
311314
return (%r) )";
312-
315+
auto filter_scalar = [](const Match& match,
316+
const std::unordered_map<std::string, Value*>& vmap) {
317+
Node* node = match.anchor;
318+
auto target_value = node->input(0)->node()->input(1);
319+
return utils::is_scalar(target_value);
320+
};
313321
SubgraphRewriter rewriter;
314322
rewriter.RegisterRewritePattern(div_matmul, matmul_div);
315323
rewriter.RegisterRewritePattern(mul_matmul, matmul_mul);
316-
rewriter.runOnGraph(graph);
324+
rewriter.runOnGraph(graph, filter_scalar);
317325
}
318326

319327
// MHA fusion covers aten::softmax, ipex::softmax and ipex::softmax_:
@@ -495,7 +503,8 @@ void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph) {
495503
// This constant fill value could be either 0-dim tensor or just a
496504
// scalar
497505
auto fill_value_node = qk_node->input(2)->node();
498-
if (fill_value_node->kind() != prim::Constant) {
506+
if (fill_value_node->kind() != prim::Constant ||
507+
!utils::is_scalar(qk_node->input(2))) {
499508
return false;
500509
}
501510

csrc/jit/passes/graph_rewrite_mha.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ void FusedTransFreeMha(std::shared_ptr<Graph>& graph) {
299299
%key_ = aten::select(%qkv2, %select_dim, %key_select)
300300
%value = aten::select(%qkv2, %select_dim, %value_select)
301301
%key = aten::transpose(%key_, %trans_a, %trans_b)
302-
%bmm1 = ipex::matmul_div(%query, %key, %scale)
302+
%bmm1 = ipex::matmul_mul(%query, %key, %scale)
303303
%smx = ipex::softmax(%bmm1, %trans_b, %dtype)
304304
%bmm2 = aten::matmul(%smx, %value)
305305
%context_layer = aten::transpose(%bmm2, %key_select, %value_select)

csrc/jit/passes/register_dnnl_jit_ops.cpp

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,88 @@ torch::jit::RegisterOperators op({
873873
},
874874
aliasAnalysisFromSchema()),
875875

876+
Operator(
877+
"ipex::matmul_mul(Tensor left, Tensor right, Tensor(a!) out_opt, Tensor "
878+
"mul_input) -> Tensor(a!)",
879+
[](const Node* node) -> Operation {
880+
return [](Stack* stack) {
881+
auto mul_tensor = std::move(peek(stack, 3, 4).toTensor());
882+
auto mul_input_data = mul_tensor.item();
883+
// divide mul_input to reuse dil_matmul_div function
884+
auto div_input_data = 1.0f / mul_input_data.to<float>();
885+
auto result = dil_matmul_div(
886+
(std::move(peek(stack, 0, 4))).toTensor(),
887+
(std::move(peek(stack, 1, 4))).toTensor(),
888+
toOptionalTensor(std::move(peek(stack, 2, 4))),
889+
div_input_data);
890+
drop(stack, 4);
891+
torch::jit::pack(stack, std::move(result));
892+
return 0;
893+
};
894+
},
895+
aliasAnalysisFromSchema()),
896+
897+
Operator(
898+
"ipex::matmul_mul(Tensor left, Tensor right, Tensor(a!) out_opt, Scalar "
899+
"mul_input) -> Tensor(a!)",
900+
[](const Node* node) -> Operation {
901+
return [](Stack* stack) {
902+
// divide mul_input to reuse dil_matmul_div function
903+
auto div_input_data =
904+
1.0f / (std::move(peek(stack, 3, 4))).toScalar().to<float>();
905+
auto result = dil_matmul_div(
906+
(std::move(peek(stack, 0, 4))).toTensor(),
907+
(std::move(peek(stack, 1, 4))).toTensor(),
908+
toOptionalTensor(std::move(peek(stack, 2, 4))),
909+
div_input_data);
910+
drop(stack, 4);
911+
torch::jit::pack(stack, std::move(result));
912+
return 0;
913+
};
914+
},
915+
aliasAnalysisFromSchema()),
916+
917+
Operator(
918+
"ipex::matmul_mul(Tensor left, Tensor right, Tensor mul_input) -> "
919+
"Tensor",
920+
[](const Node* node) -> Operation {
921+
return [](Stack* stack) {
922+
auto mul_tensor = (std::move(peek(stack, 2, 3))).toTensor();
923+
auto mul_input_data = mul_tensor.item();
924+
// divide mul_input to reuse dil_matmul_div function
925+
auto div_input_data = 1.0f / mul_input_data.to<float>();
926+
auto result = dil_matmul_div(
927+
(std::move(peek(stack, 0, 3))).toTensor(),
928+
(std::move(peek(stack, 1, 3))).toTensor(),
929+
at::Tensor(),
930+
div_input_data);
931+
drop(stack, 3);
932+
torch::jit::pack(stack, std::move(result));
933+
return 0;
934+
};
935+
},
936+
aliasAnalysisFromSchema()),
937+
938+
Operator(
939+
"ipex::matmul_mul(Tensor left, Tensor right, Scalar mul_input) -> "
940+
"Tensor",
941+
[](const Node* node) -> Operation {
942+
return [](Stack* stack) {
943+
// divide mul_input to reuse dil_matmul_div function
944+
auto div_input_data =
945+
1.0f / (std::move(peek(stack, 2, 3))).toScalar().to<float>();
946+
auto result = dil_matmul_div(
947+
(std::move(peek(stack, 0, 3))).toTensor(),
948+
(std::move(peek(stack, 1, 3))).toTensor(),
949+
at::Tensor(),
950+
div_input_data);
951+
drop(stack, 3);
952+
torch::jit::pack(stack, std::move(result));
953+
return 0;
954+
};
955+
},
956+
aliasAnalysisFromSchema()),
957+
876958
Operator(
877959
"ipex::bmm_add(Tensor input, Tensor batch1, Tensor batch2, Scalar alpha) -> "
878960
"Tensor",
@@ -944,7 +1026,7 @@ torch::jit::RegisterOperators op({
9441026
auto scale_tensor = std::move(peek(stack, 4, 7).toTensor());
9451027
auto scale_data = scale_tensor.item();
9461028
// divide scale to reuse dil_mha_scores_calc function
947-
auto div_scale_data = 1 / scale_data.to<float>();
1029+
auto div_scale_data = 1.0f / scale_data.to<float>();
9481030
auto result = dil_mha_scores_calc(
9491031
peek(stack, 0, 7).toTensor(),
9501032
peek(stack, 1, 7).toTensor(),
@@ -964,7 +1046,7 @@ torch::jit::RegisterOperators op({
9641046
"Scalar scale, int softmax_dim, ScalarType ? dtype) -> Tensor",
9651047
[](Stack& stack) {
9661048
// divide scale to reuse dil_mha_scores_calc function
967-
auto div_scale_data = 1 / peek(stack, 4, 7).toScalar().to<float>();
1049+
auto div_scale_data = 1.0f / peek(stack, 4, 7).toScalar().to<float>();
9681050
auto result = dil_mha_scores_calc(
9691051
peek(stack, 0, 7).toTensor(),
9701052
peek(stack, 1, 7).toTensor(),

csrc/jit/passes/utils.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,25 @@ bool is_contiguous(c10::TensorTypePtr tensor) {
171171
return is_contiguous;
172172
}
173173

174+
// Check if the target IValue is a scalar or a 0-dim scalar tensor
175+
bool is_scalar(torch::jit::Value* target_value) {
176+
if (!toIValue(target_value).has_value()) {
177+
return false;
178+
}
179+
if (toIValue(target_value).value().isScalar()) {
180+
return true;
181+
} else if (toIValue(target_value).value().isTensor()) {
182+
auto target_tensor_dim = target_value->type()->cast<TensorType>()->dim();
183+
if (!target_tensor_dim.has_value()) {
184+
return false;
185+
} else {
186+
return target_tensor_dim.value() == 0 ? true : false;
187+
}
188+
} else {
189+
return false;
190+
}
191+
}
192+
174193
} // namespace utils
175194
} // namespace graph_rewrite
176195
} // namespace jit

csrc/jit/passes/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ supported_non_unary_post_op_fusion_set();
3030
bool is_channelslast(c10::TensorType tensor);
3131
// Check if the memory format of the tensor is Contiguous
3232
bool is_contiguous(c10::TensorTypePtr tensor);
33-
33+
// Check if the target IValue is a scalar or a 0-dim scalar tensor
34+
bool is_scalar(torch::jit::Value* target_value);
3435
} // namespace utils
3536
} // namespace graph_rewrite
3637
} // namespace jit

tests/cpu/test_jit.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,28 @@ def forward(self, x):
846846
else:
847847
return mm_res.div_(torch.ones(mm_res_shape,dtype=x.dtype)+1)
848848

849+
class MatmulMul(nn.Module):
850+
def __init__(self, mul_scalar=False, with_out=False):
851+
super(MatmulMul, self).__init__()
852+
self.with_out = with_out
853+
self.mul_scalar = mul_scalar
854+
def forward(self, x):
855+
mm_res = None
856+
y = torch.transpose(x, -1, -2).contiguous()
857+
mm_res_shape = x.size()[:-1] + (y.size()[-1:])
858+
if not self.mul_scalar:
859+
x = x * (torch.ones([1],dtype=x.dtype) + 1)
860+
if self.with_out:
861+
mm_res = torch.randn(mm_res_shape, dtype=x.dtype)
862+
mm_res = torch.matmul(x, y, out=mm_res)
863+
else:
864+
mm_res = torch.matmul(x, y)
865+
if self.mul_scalar:
866+
mm_res = mm_res * 0.125
867+
else:
868+
mm_res = mm_res * (torch.ones([1],dtype=x.dtype) + 1)
869+
return mm_res
870+
849871
class TransposedMatmulDiv(nn.Module):
850872
def __init__(self):
851873
super(TransposedMatmulDiv, self).__init__()
@@ -1282,7 +1304,6 @@ def _test_output_bf16(self, base_model, x, kind_in_graph=None, kind_not_in_graph
12821304
#bf16, jit trace path
12831305
trace_graph = trace_fused_model.graph_for(x3)
12841306
fused_tresult = trace_fused_model(x3)
1285-
12861307
self.assertEqual(fused_tresult, result, prec=prec)
12871308
self.assertEqual(fused_tresult.dtype, torch.bfloat16)
12881309

@@ -3337,9 +3358,37 @@ def fn(input, weight, bias):
33373358
self.assertEqual(scripted_fn(input, weight, bias), result)
33383359
self.assertEqual(traced_fn(input, weight, bias), result)
33393360

3340-
def test_matmul_div(self):
3361+
def test_matmul_div_or_mul(self):
33413362
inputs = [torch.randn(10, 3, 4), torch.randn(3, 4)]
33423363
for x in inputs:
3364+
self._test_output(
3365+
MatmulMul(mul_scalar=True, with_out=False),
3366+
x,
3367+
kind_in_graph="ipex::matmul_mul",
3368+
kind_not_in_graph=None)
3369+
self._test_output(
3370+
MatmulMul(mul_scalar=True, with_out=True),
3371+
x,
3372+
kind_in_graph="ipex::matmul_mul",
3373+
kind_not_in_graph=None)
3374+
self._test_output(
3375+
MatmulMul(mul_scalar=False, with_out=True),
3376+
x,
3377+
kind_in_graph=None,
3378+
kind_not_in_graph="ipex::matmul_mul")
3379+
self._test_output_bf16(
3380+
MatmulMul(mul_scalar=True, with_out=False),
3381+
x.to(torch.bfloat16),
3382+
kind_in_graph="ipex::matmul_mul",
3383+
kind_not_in_graph=None,
3384+
prec=5e-2)
3385+
self._test_output_bf16(
3386+
MatmulMul(mul_scalar=True, with_out=True),
3387+
x.to(torch.bfloat16),
3388+
kind_in_graph="ipex::matmul_mul",
3389+
kind_not_in_graph=None,
3390+
prec=5e-2)
3391+
33433392
self._test_output(
33443393
MatmulDivOutplace(div_scalar=True, with_out=True),
33453394
x,
@@ -3466,14 +3515,14 @@ def test_transposed_matmuldiv(self):
34663515
fused_mod = traced_mod.graph_for(x1[i], y1[j])
34673516
out = traced_mod(x1[i], y1[j])
34683517
expected = model(x1[i], y1[j])
3469-
self.assertTrue(any(n.kind() == "ipex::matmul_div" for n in fused_mod.nodes()))
3518+
self.assertTrue(any(n.kind() == "ipex::matmul_mul" for n in fused_mod.nodes()))
34703519
self.assertEqual(out, expected, prec=1e-4)
34713520
with torch.cpu.amp.autocast(), torch.no_grad():
34723521
traced_mod = torch.jit.trace(model, (x1[i].bfloat16(), y1[j].bfloat16()))
34733522
fused_mod = traced_mod.graph_for(x1[i].bfloat16(), y1[j].bfloat16())
34743523
out = traced_mod(x1[i].bfloat16(), y1[j].bfloat16())
34753524
expected = model(x1[i].bfloat16(), y1[j].bfloat16())
3476-
self.assertTrue(any(n.kind() == "ipex::matmul_div" for n in fused_mod.nodes()))
3525+
self.assertTrue(any(n.kind() == "ipex::matmul_mul" for n in fused_mod.nodes()))
34773526
self.assertEqual(out, expected, prec=1e-1)
34783527

34793528
def test_bmm_add(self):

tests/cpu/test_mha.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_transfree_mha_bf16(self):
131131
for _ in range(2):
132132
mha_jit = mha_ipex(mat, mask_base)
133133
vit_mha_jit = vit_mha_ipex(mat)
134-
134+
135135
mha_ref = mha_model(mat, mask_base)
136136
vit_mha_ref = vit_mha_model(mat)
137137

@@ -141,6 +141,7 @@ def test_transfree_mha_bf16(self):
141141
mha_graph = mha_ipex.graph_for(mat, mask_base)
142142
vit_mha_graph = vit_mha_ipex.graph_for(mat)
143143

144+
144145
self.assertTrue(any(n.kind() == "ipex::transfree_mha" for n in mha_graph.nodes()))
145146
self.assertTrue(any(n.kind() == "ipex::transfree_vit_mha" for n in vit_mha_graph.nodes()))
146147

@@ -336,7 +337,7 @@ def test_fake_mha_fp32(self):
336337
fake_mha_jit.append(fake_mha_ipex[i](mat))
337338
fake_mha_ref.append(fake_mha_model[i](mat))
338339
fake_mha_graph = fake_mha_ipex[i].graph_for(mat)
339-
self.assertTrue(any(n.kind() == "ipex::matmul_div" for n in fake_mha_graph.nodes()))
340+
self.assertTrue(any(n.kind() == "ipex::matmul_mul" for n in fake_mha_graph.nodes()))
340341
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU]) as p:
341342
fake_mha_ipex[i](mat)
342343
if i == 6:

0 commit comments

Comments
 (0)