@@ -831,6 +831,62 @@ RegisterOperators op(
831831 },
832832 aliasAnalysisFromSchema ()),
833833
834+ Operator (
835+ " xpu::_convolution_mish(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, Scalar beta, Scalar threshold) -> Tensor" ,
836+ [](const Node* node) -> Operation {
837+ return [](Stack& stack) {
838+ at::Tensor input = std::move (peek (stack, 0 , 15 )).toTensor ();
839+ auto result = torch::jit::xpu::_convolution_mish (
840+ input,
841+ (std::move (peek (stack, 1 , 15 ))).toTensor (),
842+ toOptionalTensor (std::move (peek (stack, 2 , 15 ))),
843+ (std::move (peek (stack, 3 , 15 ))).toIntVector (),
844+ (std::move (peek (stack, 4 , 15 ))).toIntVector (),
845+ (std::move (peek (stack, 5 , 15 ))).toIntVector (),
846+ (std::move (peek (stack, 6 , 15 ))).toBool (),
847+ (std::move (peek (stack, 7 , 15 ))).toIntVector (),
848+ (std::move (peek (stack, 8 , 15 ))).toInt (),
849+ (std::move (peek (stack, 9 , 15 ))).toBool (),
850+ (std::move (peek (stack, 10 , 15 ))).toBool (),
851+ (std::move (peek (stack, 11 , 15 ))).toBool (),
852+ (std::move (peek (stack, 12 , 15 ))).toBool (),
853+ (std::move (peek (stack, 13 , 15 ))).toScalar (),
854+ (std::move (peek (stack, 14 , 15 ))).toScalar ());
855+ drop (stack, 15 );
856+ pack (stack, std::move (result));
857+ };
858+ },
859+ aliasAnalysisFromSchema ()),
860+
861+ Operator (
862+ " xpu::_convolution_mish_add(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, Scalar beta, Scalar threshold, Tensor(a!) accumu, *, Scalar alpha) -> Tensor" ,
863+ [](const Node* node) -> Operation {
864+ return [](Stack& stack) {
865+ at::Tensor input = std::move (peek (stack, 0 , 17 )).toTensor ();
866+ auto result = torch::jit::xpu::_convolution_mish_add (
867+ input,
868+ (std::move (peek (stack, 1 , 17 ))).toTensor (),
869+ toOptionalTensor (std::move (peek (stack, 2 , 17 ))),
870+ (std::move (peek (stack, 3 , 17 ))).toIntVector (),
871+ (std::move (peek (stack, 4 , 17 ))).toIntVector (),
872+ (std::move (peek (stack, 5 , 17 ))).toIntVector (),
873+ (std::move (peek (stack, 6 , 17 ))).toBool (),
874+ (std::move (peek (stack, 7 , 17 ))).toIntVector (),
875+ (std::move (peek (stack, 8 , 17 ))).toInt (),
876+ (std::move (peek (stack, 9 , 17 ))).toBool (),
877+ (std::move (peek (stack, 10 , 17 ))).toBool (),
878+ (std::move (peek (stack, 11 , 17 ))).toBool (),
879+ (std::move (peek (stack, 12 , 17 ))).toBool (),
880+ (std::move (peek (stack, 13 , 17 ))).toScalar (),
881+ (std::move (peek (stack, 14 , 17 ))).toScalar (),
882+ (std::move (peek (stack, 15 , 17 ))).toTensor (),
883+ (std::move (peek (stack, 16 , 17 ))).toScalar ());
884+ drop (stack, 17 );
885+ pack (stack, std::move (result));
886+ };
887+ },
888+ aliasAnalysisFromSchema ()),
889+
834890 });
835891} // namespace jit
836892} // namespace torch
0 commit comments