Skip to content

Commit b7b9359

Browse files
jianan-guJiong Gong
andauthored
Add add broadcast checks for einsum+add fusion kernel (#885)
* add runtime check * refine code * refine code * refine code * add corner case from alphafold2 * Update intel_extension_for_pytorch/csrc/jit/cpu/kernels/Einsum.cpp Co-authored-by: Jiong Gong <jiong.gong@intel.com> * refine code Co-authored-by: Jiong Gong <jiong.gong@intel.com>
1 parent 1717c44 commit b7b9359

File tree

4 files changed

+112
-9
lines changed

4 files changed

+112
-9
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Einsum.cpp

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,67 @@ namespace cpu {
2020
using at::IntArrayRef;
2121
using at::Tensor;
2222

23+
//! function: is_add_broadcast_supported_by_onednn
24+
/*!
25+
* This is a workaround checking since oneDNN is not well supported
26+
* matmul+binary_add fusion with all kinds of add input broadcast dims;
27+
* Depending the add input broadcast dims, oneDNN matmul+binary_add will go into
28+
* ref path in some cases; Here we add this function checking to map those
29+
* verified supported cases, and fallback those unsupported cases;
30+
*
31+
* The verified supported cases use following oneDNN non_broadcast_mask:
32+
* 2D: oneDNN non_broadcast_mask = {0, 2, 3}
33+
* 3D: oneDNN non_broadcast_mask = {0, 2, 4, 5, 7}
34+
* 4D: oneDNN non_broadcast_mask = {0, 2, 8, 9, 13, 15}
35+
*
36+
* For example:
37+
* For 4D tensors, left has shape [8, 2, 4, 6] and right has shape [8, 2, 6, 4],
38+
* so matmul shape is [8, 2, 4, 4], and post_add_tensor has shape [8, 1, 1, 4].
39+
* Therefore, the according non_broadcast_mask is 9, which is supported.
40+
*
41+
* \param left: the left operand of matmul
42+
* \param right: the right operand of matmul
43+
* \param post_add_tensor: the post add input tensor
44+
* \return: whether the post add input is supported for broadcast by oneDNN for
45+
* matmul+binary_add fusion
46+
*/
47+
bool is_add_broadcast_supported_by_onednn(
48+
const at::Tensor& left,
49+
const at::Tensor& right,
50+
const at::Tensor& post_add_tensor) {
51+
auto non_broadcast_mask = 0;
52+
for (int i = 0; i < left.dim(); i++) {
53+
if (post_add_tensor.size(i) != 1) {
54+
if (i == left.dim() - 1) {
55+
non_broadcast_mask +=
56+
post_add_tensor.size(i) == right.size(i) ? 1 << i : 0;
57+
} else {
58+
non_broadcast_mask +=
59+
post_add_tensor.size(i) == left.size(i) ? 1 << i : 0;
60+
}
61+
}
62+
}
63+
if (left.dim() == 4) {
64+
if (non_broadcast_mask == 0 || non_broadcast_mask == 2 ||
65+
non_broadcast_mask == 8 || non_broadcast_mask == 9 ||
66+
non_broadcast_mask == 13 || non_broadcast_mask == 15) {
67+
return true;
68+
}
69+
} else if (left.dim() == 3) {
70+
if (non_broadcast_mask == 0 || non_broadcast_mask == 2 ||
71+
non_broadcast_mask == 4 || non_broadcast_mask == 5 ||
72+
non_broadcast_mask == 7) {
73+
return true;
74+
}
75+
} else if (left.dim() == 2) {
76+
if (non_broadcast_mask == 0 || non_broadcast_mask == 2 ||
77+
non_broadcast_mask == 3) {
78+
return true;
79+
}
80+
}
81+
82+
return false;
83+
}
2384
//! function: sumproduct_pair
2485
/*!
2586
*
@@ -266,14 +327,19 @@ static Tensor sumproduct_pair(
266327
left = left.permute(lpermutation).reshape(left_shape);
267328
right = right.permute(rpermutation).reshape(right_shape);
268329

269-
// Tensor result = at::bmm(left, right);
270-
auto _input = arg.is_contiguous() ? arg : arg.contiguous();
271-
ideep::tensor onednn_input = itensor_view_from_dense(_input);
272-
auto op_attr = ideep::attr_t::fuse_binary(
273-
dnnl::algorithm::binary_add, onednn_input.get_desc());
274-
Tensor result =
275-
bmm_impl(left, right, at::Tensor(), op_attr, {onednn_input}, 1.0f);
276-
330+
// now we do the computation
331+
Tensor result;
332+
if (is_add_broadcast_supported_by_onednn(left, right, arg)) {
333+
auto _input = arg.is_contiguous() ? arg : arg.contiguous();
334+
ideep::tensor onednn_input = itensor_view_from_dense(_input);
335+
auto op_attr = ideep::attr_t::fuse_binary(
336+
dnnl::algorithm::binary_add, onednn_input.get_desc());
337+
result = bmm_impl(left, right, at::Tensor(), op_attr, {onednn_input}, 1.0f);
338+
} else {
339+
result = at::matmul(left, right);
340+
auto f_alpha = alpha.to<float>();
341+
result = result + f_alpha * arg;
342+
}
277343
result = result.view(out_size).permute(opermutation);
278344

279345
// finally squeeze summed dimensions if desired
@@ -666,6 +732,7 @@ at::Tensor einsum_binary(
666732
const c10::List<at::Tensor>& operands,
667733
const at::Tensor& add_arg,
668734
const c10::Scalar& alpha) {
735+
IPEX_RECORD_FUNCTION("dil_einsum_binary", c10::ArrayRef<c10::IValue>({}));
669736
auto prepare_res = einsum_prepare(equation, operands);
670737
bool has_zero_size_dim = std::get<0>(prepare_res);
671738
auto out_size = std::get<1>(prepare_res);

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Einsum.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,9 @@ at::Tensor einsum_binary(
3030
const at::Tensor& input,
3131
const c10::Scalar& alpha);
3232

33+
bool is_add_broadcast_supported_by_onednn(
34+
const at::Tensor& left,
35+
const at::Tensor& right,
36+
const at::Tensor& post_add_tensor);
3337
} // namespace cpu
3438
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_einsum.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ auto ipex_einsum_filter =
1616
auto equation =
1717
getIValue("equation", match_vmap, vmap).value().toStringView();
1818
int num_ops = std::count(equation.begin(), equation.end(), ',') + 1;
19-
if (num_ops != 2)
19+
if (num_ops != 2) {
2020
return false; // only process the 2 operands
21+
}
2122
return true;
2223
};
2324

tests/cpu/test_jit.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,6 +2813,37 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
28132813
model = EinsumAdd(("ij,j"))
28142814
_test_fp32(model, input1, input2, bias)
28152815

2816+
bias = torch.randn(1, 4, 49, 49)
2817+
input1 = torch.randn(8, 4, 49, 32)
2818+
input2 = torch.randn(8, 4, 49, 32)
2819+
model_from_vit = EinsumAdd('bhid,bhjd->bhij')
2820+
_test_fp32(model_from_vit, input1, input2, bias)
2821+
2822+
bias = torch.randn(1, 1, 49, 49)
2823+
input1 = torch.randn(8, 6, 49, 32)
2824+
input2 = torch.randn(8, 6, 49, 32)
2825+
model_from_vit_v2 = EinsumAdd('bhid,bhjd->bhij')
2826+
_test_fp32(model_from_vit_v2, input1, input2, bias)
2827+
2828+
bias = torch.randn(8, 1, 1, 49)
2829+
input1 = torch.randn(8, 6, 49, 32)
2830+
input2 = torch.randn(8, 6, 49, 32)
2831+
model_from_vit_alphafold2_v1 = EinsumAdd('bhid,bhjd->bhij')
2832+
_test_fp32(model_from_vit_alphafold2_v1, input1, input2, bias)
2833+
2834+
bias = torch.randn(1, 1, 32)
2835+
input1 = torch.randn( 6, 50, 32)
2836+
input2 = torch.randn( 32, 32)
2837+
model_from_vit_alphafold2_v2 = EinsumAdd('bsh,ho->bso')
2838+
_test_fp32(model_from_vit_alphafold2_v2, input1, input2, bias)
2839+
2840+
bias = torch.randn(6, 1, 50)
2841+
input1 = torch.randn( 6, 50, 32)
2842+
input2 = torch.randn( 6, 32, 50)
2843+
model_from_vit_alphafold2_v3 = EinsumAdd('bsh,bho->bso')
2844+
_test_fp32(model_from_vit_alphafold2_v3, input1, input2, bias)
2845+
2846+
28162847
def test_ipex_softmax(self):
28172848
self._test_output(
28182849
AtenSoftmaxRepalce(),

0 commit comments

Comments
 (0)