@@ -873,6 +873,88 @@ torch::jit::RegisterOperators op({
873873 },
874874 aliasAnalysisFromSchema ()),
875875
876+ Operator (
877+ " ipex::matmul_mul(Tensor left, Tensor right, Tensor(a!) out_opt, Tensor "
878+ " mul_input) -> Tensor(a!)" ,
879+ [](const Node* node) -> Operation {
880+ return [](Stack* stack) {
881+ auto mul_tensor = std::move (peek (stack, 3 , 4 ).toTensor ());
882+ auto mul_input_data = mul_tensor.item ();
883+ // divide mul_input to reuse dil_matmul_div function
884+ auto div_input_data = 1 .0f / mul_input_data.to <float >();
885+ auto result = dil_matmul_div (
886+ (std::move (peek (stack, 0 , 4 ))).toTensor (),
887+ (std::move (peek (stack, 1 , 4 ))).toTensor (),
888+ toOptionalTensor (std::move (peek (stack, 2 , 4 ))),
889+ div_input_data);
890+ drop (stack, 4 );
891+ torch::jit::pack (stack, std::move (result));
892+ return 0 ;
893+ };
894+ },
895+ aliasAnalysisFromSchema ()),
896+
897+ Operator (
898+ " ipex::matmul_mul(Tensor left, Tensor right, Tensor(a!) out_opt, Scalar "
899+ " mul_input) -> Tensor(a!)" ,
900+ [](const Node* node) -> Operation {
901+ return [](Stack* stack) {
902+ // divide mul_input to reuse dil_matmul_div function
903+ auto div_input_data =
904+ 1 .0f / (std::move (peek (stack, 3 , 4 ))).toScalar ().to <float >();
905+ auto result = dil_matmul_div (
906+ (std::move (peek (stack, 0 , 4 ))).toTensor (),
907+ (std::move (peek (stack, 1 , 4 ))).toTensor (),
908+ toOptionalTensor (std::move (peek (stack, 2 , 4 ))),
909+ div_input_data);
910+ drop (stack, 4 );
911+ torch::jit::pack (stack, std::move (result));
912+ return 0 ;
913+ };
914+ },
915+ aliasAnalysisFromSchema ()),
916+
917+ Operator (
918+ " ipex::matmul_mul(Tensor left, Tensor right, Tensor mul_input) -> "
919+ " Tensor" ,
920+ [](const Node* node) -> Operation {
921+ return [](Stack* stack) {
922+ auto mul_tensor = (std::move (peek (stack, 2 , 3 ))).toTensor ();
923+ auto mul_input_data = mul_tensor.item ();
924+ // divide mul_input to reuse dil_matmul_div function
925+ auto div_input_data = 1 .0f / mul_input_data.to <float >();
926+ auto result = dil_matmul_div (
927+ (std::move (peek (stack, 0 , 3 ))).toTensor (),
928+ (std::move (peek (stack, 1 , 3 ))).toTensor (),
929+ at::Tensor (),
930+ div_input_data);
931+ drop (stack, 3 );
932+ torch::jit::pack (stack, std::move (result));
933+ return 0 ;
934+ };
935+ },
936+ aliasAnalysisFromSchema ()),
937+
938+ Operator (
939+ " ipex::matmul_mul(Tensor left, Tensor right, Scalar mul_input) -> "
940+ " Tensor" ,
941+ [](const Node* node) -> Operation {
942+ return [](Stack* stack) {
943+ // divide mul_input to reuse dil_matmul_div function
944+ auto div_input_data =
945+ 1 .0f / (std::move (peek (stack, 2 , 3 ))).toScalar ().to <float >();
946+ auto result = dil_matmul_div (
947+ (std::move (peek (stack, 0 , 3 ))).toTensor (),
948+ (std::move (peek (stack, 1 , 3 ))).toTensor (),
949+ at::Tensor (),
950+ div_input_data);
951+ drop (stack, 3 );
952+ torch::jit::pack (stack, std::move (result));
953+ return 0 ;
954+ };
955+ },
956+ aliasAnalysisFromSchema ()),
957+
876958 Operator (
877959 " ipex::bmm_add(Tensor input, Tensor batch1, Tensor batch2, Scalar alpha) -> "
878960 " Tensor" ,
@@ -944,7 +1026,7 @@ torch::jit::RegisterOperators op({
9441026 auto scale_tensor = std::move (peek (stack, 4 , 7 ).toTensor ());
9451027 auto scale_data = scale_tensor.item ();
9461028 // divide scale to reuse dil_mha_scores_calc function
947- auto div_scale_data = 1 / scale_data.to <float >();
1029+ auto div_scale_data = 1 . 0f / scale_data.to <float >();
9481030 auto result = dil_mha_scores_calc (
9491031 peek (stack, 0 , 7 ).toTensor (),
9501032 peek (stack, 1 , 7 ).toTensor (),
@@ -964,7 +1046,7 @@ torch::jit::RegisterOperators op({
9641046 " Scalar scale, int softmax_dim, ScalarType ? dtype) -> Tensor" ,
9651047 [](Stack& stack) {
9661048 // divide scale to reuse dil_mha_scores_calc function
967- auto div_scale_data = 1 / peek (stack, 4 , 7 ).toScalar ().to <float >();
1049+ auto div_scale_data = 1 . 0f / peek (stack, 4 , 7 ).toScalar ().to <float >();
9681050 auto result = dil_mha_scores_calc (
9691051 peek (stack, 0 , 7 ).toTensor (),
9701052 peek (stack, 1 , 7 ).toTensor (),
0 commit comments