Skip to content

Commit 88e2f09

Browse files
authored
jit: support yolo mish, yolo mish add on release branch (#1669)
* jit: support yolo mish, yolo mish add on release branch
1 parent 5a3823c commit 88e2f09

File tree

9 files changed

+359
-13
lines changed

9 files changed

+359
-13
lines changed

csrc/aten/operators/Conv.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,81 @@ Tensor convolution_silu(
856856
attr);
857857
}
858858

859+
Tensor convolution_mish(
860+
const Tensor& input_r,
861+
const Tensor& weight_r,
862+
const Tensor& bias_r,
863+
IntArrayRef stride_,
864+
IntArrayRef padding_,
865+
IntArrayRef dilation_,
866+
bool transposed_,
867+
IntArrayRef output_padding_,
868+
int64_t groups_,
869+
Scalar scale,
870+
Scalar alpha,
871+
Scalar beta) {
872+
// only support scale = 1.0f in oneDNN for non-quantized case.
873+
TORCH_CHECK(
874+
scale.to<float>() == 1.f && alpha.to<float>() == 1.f,
875+
"only support convolution silu fusion with silu scale equals to 1, alpha equal to 1");
876+
Attr attr;
877+
attr.append_post_eltwise(
878+
/* relu_scale */ 1.0,
879+
/* alpha */ 1.f,
880+
/* beta */ 0.f,
881+
attr.kind_with_mish);
882+
return _convolution(
883+
input_r,
884+
weight_r,
885+
bias_r,
886+
stride_,
887+
padding_,
888+
dilation_,
889+
transposed_,
890+
output_padding_,
891+
groups_,
892+
attr);
893+
}
894+
895+
Tensor convolution_mish_add(
896+
const Tensor& input_r,
897+
const Tensor& weight_r,
898+
const Tensor& bias_r,
899+
IntArrayRef stride_,
900+
IntArrayRef padding_,
901+
IntArrayRef dilation_,
902+
bool transposed_,
903+
IntArrayRef output_padding_,
904+
int64_t groups_,
905+
Tensor& accumu,
906+
Scalar scale,
907+
Scalar alpha,
908+
Scalar beta) {
909+
// only support scale = 1.0f in oneDNN for non-quantized case.
910+
TORCH_CHECK(
911+
scale.to<float>() == 1.f && alpha.to<float>() == 1.f,
912+
"only support convolution silu fusion with silu scale equals to 1, alpha equal to 1");
913+
Attr attr;
914+
attr.append_post_eltwise(
915+
/* relu_scale */ 1.0,
916+
/* alpha */ 1.f,
917+
/* beta */ 0.f,
918+
attr.kind_with_mish)
919+
.append_post_sum(/* sum_scale */ scale.to<float>()); // append post op sum
920+
return _convolution_out(
921+
accumu,
922+
input_r,
923+
weight_r,
924+
bias_r,
925+
stride_,
926+
padding_,
927+
dilation_,
928+
transposed_,
929+
output_padding_,
930+
groups_,
931+
attr);
932+
}
933+
859934
Tensor convolution_sigmoid(
860935
const Tensor& input_r,
861936
const Tensor& weight_r,

csrc/intrinsic/intrinsic.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,35 @@ at::Tensor dequantize_tensor_per_channel_affine(
346346
const at::Tensor& zero_points,
347347
int64_t axis);
348348

349+
Tensor convolution_mish(
350+
const Tensor& input_r,
351+
const Tensor& weight_r,
352+
const Tensor& bias_r,
353+
IntArrayRef stride_,
354+
IntArrayRef padding_,
355+
IntArrayRef dilation_,
356+
bool transposed_,
357+
IntArrayRef output_padding_,
358+
int64_t groups_,
359+
Scalar scale,
360+
Scalar alpha,
361+
Scalar beta);
362+
363+
Tensor convolution_mish_add(
364+
const Tensor& input_r,
365+
const Tensor& weight_r,
366+
const Tensor& bias_r,
367+
IntArrayRef stride_,
368+
IntArrayRef padding_,
369+
IntArrayRef dilation_,
370+
bool transposed_,
371+
IntArrayRef output_padding_,
372+
int64_t groups_,
373+
Tensor& accumu,
374+
Scalar scale,
375+
Scalar alpha,
376+
Scalar beta);
377+
349378
} // namespace AtenIpexTypeXPU
350379
} // namespace at
351380

csrc/jit/accelerated_ops.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ static auto permute_contiguous_sym =
7575
Symbol::fromQualString("xpu::permute_contiguous");
7676
static auto convolution_silu_sym =
7777
Symbol::fromQualString("xpu::_convolution_silu");
78-
78+
static auto _convolution_mish_sym =
79+
Symbol::fromQualString("xpu::_convolution_mish");
80+
static auto _convolution_mish_add_sym =
81+
Symbol::fromQualString("xpu::_convolution_mish_add");
7982
// Fold weights of batch_norm with conv2d's
8083
static auto fold_weight_sym = Symbol::fromQualString("xpu::fold_weight");
8184
static auto fold_bias_sym = Symbol::fromQualString("xpu::fold_bias");

csrc/jit/dpcpp_ops.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,77 @@ at::Tensor _convolution_silu(
11261126
0.0);
11271127
}
11281128

1129+
at::Tensor _convolution_mish(
1130+
const at::Tensor& input,
1131+
const at::Tensor& weight,
1132+
const at::Tensor& bias,
1133+
at::IntArrayRef stride_,
1134+
at::IntArrayRef padding_,
1135+
at::IntArrayRef dilation_,
1136+
bool transposed_,
1137+
at::IntArrayRef output_padding_,
1138+
int64_t groups_,
1139+
bool benchmark,
1140+
bool deterministic,
1141+
bool cudnn_enabled,
1142+
bool allow_tf32,
1143+
Scalar beta,
1144+
Scalar threshold) {
1145+
RECORD_FUNCTION(
1146+
"_convolution_mish", std::vector<c10::IValue>({input, weight, bias}));
1147+
const OptionalDeviceGuard device_guard(device_of(input));
1148+
return at::AtenIpexTypeXPU::convolution_mish(
1149+
input,
1150+
weight,
1151+
bias,
1152+
stride_,
1153+
padding_,
1154+
dilation_,
1155+
transposed_,
1156+
output_padding_,
1157+
groups_,
1158+
1.0,
1159+
1.0,
1160+
0.0);
1161+
}
1162+
1163+
at::Tensor _convolution_mish_add(
1164+
const at::Tensor& input,
1165+
const at::Tensor& weight,
1166+
const at::Tensor& bias,
1167+
at::IntArrayRef stride_,
1168+
at::IntArrayRef padding_,
1169+
at::IntArrayRef dilation_,
1170+
bool transposed_,
1171+
at::IntArrayRef output_padding_,
1172+
int64_t groups_,
1173+
bool benchmark,
1174+
bool deterministic,
1175+
bool cudnn_enabled,
1176+
bool allow_tf32,
1177+
Scalar beta,
1178+
Scalar threshold,
1179+
Tensor accumu,
1180+
Scalar scale) {
1181+
RECORD_FUNCTION(
1182+
"_convolution_silu", std::vector<c10::IValue>({input, weight, bias}));
1183+
const OptionalDeviceGuard device_guard(device_of(input));
1184+
return at::AtenIpexTypeXPU::convolution_mish_add(
1185+
input,
1186+
weight,
1187+
bias,
1188+
stride_,
1189+
padding_,
1190+
dilation_,
1191+
transposed_,
1192+
output_padding_,
1193+
groups_,
1194+
accumu,
1195+
scale,
1196+
1.0,
1197+
0.0);
1198+
}
1199+
11291200
} // namespace xpu
11301201
} // namespace jit
11311202
} // namespace torch

csrc/jit/dpcpp_ops.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,42 @@ at::Tensor _convolution_silu(
324324
bool cudnn_enabled,
325325
bool allow_tf32);
326326

327+
at::Tensor _convolution_mish(
328+
const at::Tensor& input_r,
329+
const at::Tensor& weight_r,
330+
const at::Tensor& bias_r,
331+
at::IntArrayRef stride_,
332+
at::IntArrayRef padding_,
333+
at::IntArrayRef dilation_,
334+
bool transposed_,
335+
at::IntArrayRef output_padding_,
336+
int64_t groups_,
337+
bool benchmark,
338+
bool deterministic,
339+
bool cudnn_enabled,
340+
bool allow_tf32,
341+
Scalar beta,
342+
Scalar threshold);
343+
344+
at::Tensor _convolution_mish_add(
345+
const at::Tensor& input_r,
346+
const at::Tensor& weight_r,
347+
const at::Tensor& bias_r,
348+
at::IntArrayRef stride_,
349+
at::IntArrayRef padding_,
350+
at::IntArrayRef dilation_,
351+
bool transposed_,
352+
at::IntArrayRef output_padding_,
353+
int64_t groups_,
354+
bool benchmark,
355+
bool deterministic,
356+
bool cudnn_enabled,
357+
bool allow_tf32,
358+
Scalar beta,
359+
Scalar threshold,
360+
Tensor accumu,
361+
Scalar alpha);
362+
327363
} // namespace xpu
328364
} // namespace jit
329365
} // namespace torch

csrc/jit/fusion_pass.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,10 @@ OpFuser::RuleTab OpFuser::dnnlRules = {
506506
xpu::_convolution_sum_relu_sym},
507507
{{Symbol::fromQualString("aten::_convolution"),
508508
Symbol::fromQualString("aten::silu_")},
509-
xpu::convolution_silu_sym}};
509+
xpu::convolution_silu_sym},
510+
{{Symbol::fromQualString("aten::_convolution"), xpu::softplus_tanh_mul_sym},
511+
xpu::_convolution_mish_sym},
512+
{{xpu::_convolution_mish_sym, aten::add_}, xpu::_convolution_mish_add_sym}};
510513

511514
void FusionPass(std::shared_ptr<Graph>& graph) {
512515
// Pattern based fusion was lack of alias analysis
@@ -531,4 +534,4 @@ static RegisterPreFusionPass pass_3([](std::shared_ptr<Graph>& g) {
531534
});
532535

533536
} // namespace jit
534-
} // namespace torch
537+
} // namespace torch

csrc/jit/register_dnnl_jit_ops.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,62 @@ RegisterOperators op(
831831
},
832832
aliasAnalysisFromSchema()),
833833

834+
Operator(
835+
"xpu::_convolution_mish(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, Scalar beta, Scalar threshold) -> Tensor",
836+
[](const Node* node) -> Operation {
837+
return [](Stack& stack) {
838+
at::Tensor input = std::move(peek(stack, 0, 15)).toTensor();
839+
auto result = torch::jit::xpu::_convolution_mish(
840+
input,
841+
(std::move(peek(stack, 1, 15))).toTensor(),
842+
toOptionalTensor(std::move(peek(stack, 2, 15))),
843+
(std::move(peek(stack, 3, 15))).toIntVector(),
844+
(std::move(peek(stack, 4, 15))).toIntVector(),
845+
(std::move(peek(stack, 5, 15))).toIntVector(),
846+
(std::move(peek(stack, 6, 15))).toBool(),
847+
(std::move(peek(stack, 7, 15))).toIntVector(),
848+
(std::move(peek(stack, 8, 15))).toInt(),
849+
(std::move(peek(stack, 9, 15))).toBool(),
850+
(std::move(peek(stack, 10, 15))).toBool(),
851+
(std::move(peek(stack, 11, 15))).toBool(),
852+
(std::move(peek(stack, 12, 15))).toBool(),
853+
(std::move(peek(stack, 13, 15))).toScalar(),
854+
(std::move(peek(stack, 14, 15))).toScalar());
855+
drop(stack, 15);
856+
pack(stack, std::move(result));
857+
};
858+
},
859+
aliasAnalysisFromSchema()),
860+
861+
Operator(
862+
"xpu::_convolution_mish_add(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, Scalar beta, Scalar threshold, Tensor(a!) accumu, *, Scalar alpha) -> Tensor",
863+
[](const Node* node) -> Operation {
864+
return [](Stack& stack) {
865+
at::Tensor input = std::move(peek(stack, 0, 17)).toTensor();
866+
auto result = torch::jit::xpu::_convolution_mish_add(
867+
input,
868+
(std::move(peek(stack, 1, 17))).toTensor(),
869+
toOptionalTensor(std::move(peek(stack, 2, 17))),
870+
(std::move(peek(stack, 3, 17))).toIntVector(),
871+
(std::move(peek(stack, 4, 17))).toIntVector(),
872+
(std::move(peek(stack, 5, 17))).toIntVector(),
873+
(std::move(peek(stack, 6, 17))).toBool(),
874+
(std::move(peek(stack, 7, 17))).toIntVector(),
875+
(std::move(peek(stack, 8, 17))).toInt(),
876+
(std::move(peek(stack, 9, 17))).toBool(),
877+
(std::move(peek(stack, 10, 17))).toBool(),
878+
(std::move(peek(stack, 11, 17))).toBool(),
879+
(std::move(peek(stack, 12, 17))).toBool(),
880+
(std::move(peek(stack, 13, 17))).toScalar(),
881+
(std::move(peek(stack, 14, 17))).toScalar(),
882+
(std::move(peek(stack, 15, 17))).toTensor(),
883+
(std::move(peek(stack, 16, 17))).toScalar());
884+
drop(stack, 17);
885+
pack(stack, std::move(result));
886+
};
887+
},
888+
aliasAnalysisFromSchema()),
889+
834890
});
835891
} // namespace jit
836892
} // namespace torch

csrc/oneDNN/Conv.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -537,16 +537,16 @@ static at::Tensor convolution(
537537
bia_m = dpcpp_onednn_memory(bia_md, engine, bia_.data_ptr());
538538
xpu::oneDNN::reorder(bia, bia_, reorder_attr);
539539

540-
// Following is for saving bias correctly.
541-
// TODO: Need a general solution for bias caching
542-
#ifndef BUILD_JIT_QUANTIZATION_SAVE
543-
if (weight_cache_optimization) {
544-
strm.wait();
545-
// FIXME: thread safty
546-
auto bia_opt_ctx = DPCPPTensorContext::release_tensor_ctx(bia_);
547-
DPCPPTensorContext::set_tensor_ctx(bia, std::move(bia_opt_ctx));
548-
}
549-
#endif
540+
// Following is for saving bias correctly.
541+
// TODO: Need a general solution for bias caching
542+
#ifndef BUILD_JIT_QUANTIZATION_SAVE
543+
if (weight_cache_optimization) {
544+
strm.wait();
545+
// FIXME: thread safty
546+
auto bia_opt_ctx = DPCPPTensorContext::release_tensor_ctx(bia_);
547+
DPCPPTensorContext::set_tensor_ctx(bia, std::move(bia_opt_ctx));
548+
}
549+
#endif
550550
}
551551
}
552552

0 commit comments

Comments
 (0)