@@ -20,6 +20,67 @@ namespace cpu {
2020using at::IntArrayRef;
2121using at::Tensor;
2222
23+ // ! function: is_add_broadcast_supported_by_onednn
24+ /* !
25+ * This is a workaround checking since oneDNN is not well supported
26+ * matmul+binary_add fusion with all kinds of add input broadcast dims;
27+ * Depending the add input broadcast dims, oneDNN matmul+binary_add will go into
28+ * ref path in some cases; Here we add this function checking to map those
29+ * verified supported cases, and fallback those unsupported cases;
30+ *
31+ * The verified supported cases use following oneDNN non_broadcast_mask:
32+ * 2D: oneDNN non_broadcast_mask = {0, 2, 3}
33+ * 3D: oneDNN non_broadcast_mask = {0, 2, 4, 5, 7}
34+ * 4D: oneDNN non_broadcast_mask = {0, 2, 8, 9, 13, 15}
35+ *
36+ * For example:
37+ * For 4D tensors, left has shape [8, 2, 4, 6] and right has shape [8, 2, 6, 4],
38+ * so matmul shape is [8, 2, 4, 4], and post_add_tensor has shape [8, 1, 1, 4].
39+ * Therefore, the according non_broadcast_mask is 9, which is supported.
40+ *
41+ * \param left: the left operand of matmul
42+ * \param right: the right operand of matmul
43+ * \param post_add_tensor: the post add input tensor
44+ * \return: whether the post add input is supported for broadcast by oneDNN for
45+ * matmul+binary_add fusion
46+ */
47+ bool is_add_broadcast_supported_by_onednn (
48+ const at::Tensor& left,
49+ const at::Tensor& right,
50+ const at::Tensor& post_add_tensor) {
51+ auto non_broadcast_mask = 0 ;
52+ for (int i = 0 ; i < left.dim (); i++) {
53+ if (post_add_tensor.size (i) != 1 ) {
54+ if (i == left.dim () - 1 ) {
55+ non_broadcast_mask +=
56+ post_add_tensor.size (i) == right.size (i) ? 1 << i : 0 ;
57+ } else {
58+ non_broadcast_mask +=
59+ post_add_tensor.size (i) == left.size (i) ? 1 << i : 0 ;
60+ }
61+ }
62+ }
63+ if (left.dim () == 4 ) {
64+ if (non_broadcast_mask == 0 || non_broadcast_mask == 2 ||
65+ non_broadcast_mask == 8 || non_broadcast_mask == 9 ||
66+ non_broadcast_mask == 13 || non_broadcast_mask == 15 ) {
67+ return true ;
68+ }
69+ } else if (left.dim () == 3 ) {
70+ if (non_broadcast_mask == 0 || non_broadcast_mask == 2 ||
71+ non_broadcast_mask == 4 || non_broadcast_mask == 5 ||
72+ non_broadcast_mask == 7 ) {
73+ return true ;
74+ }
75+ } else if (left.dim () == 2 ) {
76+ if (non_broadcast_mask == 0 || non_broadcast_mask == 2 ||
77+ non_broadcast_mask == 3 ) {
78+ return true ;
79+ }
80+ }
81+
82+ return false ;
83+ }
2384// ! function: sumproduct_pair
2485/* !
2586 *
@@ -266,14 +327,19 @@ static Tensor sumproduct_pair(
266327 left = left.permute (lpermutation).reshape (left_shape);
267328 right = right.permute (rpermutation).reshape (right_shape);
268329
269- // Tensor result = at::bmm(left, right);
270- auto _input = arg.is_contiguous () ? arg : arg.contiguous ();
271- ideep::tensor onednn_input = itensor_view_from_dense (_input);
272- auto op_attr = ideep::attr_t::fuse_binary (
273- dnnl::algorithm::binary_add, onednn_input.get_desc ());
274- Tensor result =
275- bmm_impl (left, right, at::Tensor (), op_attr, {onednn_input}, 1 .0f );
276-
330+ // now we do the computation
331+ Tensor result;
332+ if (is_add_broadcast_supported_by_onednn (left, right, arg)) {
333+ auto _input = arg.is_contiguous () ? arg : arg.contiguous ();
334+ ideep::tensor onednn_input = itensor_view_from_dense (_input);
335+ auto op_attr = ideep::attr_t::fuse_binary (
336+ dnnl::algorithm::binary_add, onednn_input.get_desc ());
337+ result = bmm_impl (left, right, at::Tensor (), op_attr, {onednn_input}, 1 .0f );
338+ } else {
339+ result = at::matmul (left, right);
340+ auto f_alpha = alpha.to <float >();
341+ result = result + f_alpha * arg;
342+ }
277343 result = result.view (out_size).permute (opermutation);
278344
279345 // finally squeeze summed dimensions if desired
@@ -666,6 +732,7 @@ at::Tensor einsum_binary(
666732 const c10::List<at::Tensor>& operands,
667733 const at::Tensor& add_arg,
668734 const c10::Scalar& alpha) {
735+ IPEX_RECORD_FUNCTION (" dil_einsum_binary" , c10::ArrayRef<c10::IValue>({}));
669736 auto prepare_res = einsum_prepare (equation, operands);
670737 bool has_zero_size_dim = std::get<0 >(prepare_res);
671738 auto out_size = std::get<1 >(prepare_res);
0 commit comments