Skip to content

Commit 0d4a314

Browse files
Hardsigmoid (#904)
* This graph pass is to replace at::hardsigmoid with IPEX hardsigmoid. Because NNC pulls aten::hardsigmoidn into its fusion group while its performance might not be good enough if the most outer loop is small. Besides that, IPEX will use oneDNN post-op to use hard sigmoid. Hence, this graph pass is a workaround for this release and will be removed in the next major release. * Remove hardsigmoid from ut as it has been replaced with ipex hardsigmoid Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
1 parent 3223ca7 commit 0d4a314

File tree

6 files changed

+109
-38
lines changed

6 files changed

+109
-38
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
namespace torch_ipex {
6+
namespace cpu {
7+
8+
inline at::Tensor dil_hardsigmoid(const at::Tensor& self) {
9+
return at::hardsigmoid(self);
10+
}
11+
12+
} // namespace cpu
13+
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,11 +610,11 @@ void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph) {
610610
// %y, %hy, %cy = aten::lstm(%ret, ...)
611611
void preprocessSizeForQLstm(std::shared_ptr<Graph>& graph) {
612612
const static std::string op_list_construct_same_states = R"(
613-
%hx.1 = aten::zeros(%sizes, %scalar_type, %layout, %device, %pin_memory)
613+
%hx.1 = aten::zeros(%sizes, %scalar_type, %layout, %device, %pin_memory)
614614
%state : Tensor[] = prim::ListConstruct(%hx.1, %hx.1) )";
615615

616616
const static std::string op_list_construct_diff_states = R"(
617-
%hx.1 = aten::zeros(%sizes, %scalar_type, %layout, %device, %pin_memory)
617+
%hx.1 = aten::zeros(%sizes, %scalar_type, %layout, %device, %pin_memory)
618618
%hx = aten::zeros(%sizes, %scalar_type, %layout, %device, %pin_memory)
619619
%state : Tensor[] = prim::ListConstruct(%hx.1, %hx) )";
620620

@@ -705,7 +705,7 @@ void replaceLstmWithQLstm(std::shared_ptr<Graph>& graph) {
705705
std::string QLstmPattern = complete_header + R"(
706706
%input : Tensor = aten::dequantize(%quantized_input) )" +
707707
weight_pattern + complete_LC + R"(
708-
%output, %hy, %cy = aten::lstm(%input, %h, %weights, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist)
708+
%output, %hy, %cy = aten::lstm(%input, %h, %weights, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist)
709709
%quantized_output = aten::quantize_per_tensor(%output, %scale, %zp, %dtype)
710710
return (%quantized_output, %hy, %cy) )";
711711

@@ -875,6 +875,23 @@ void FuseLinearSwishCustomized(std::shared_ptr<Graph>& graph) {
875875
ls_fusion.runOnGraph(graph);
876876
}
877877

878+
void ReplaceHardsigmoidWithIPEX(std::shared_ptr<Graph>& graph) {
879+
std::string aten_hardsigmoid = R"(
880+
graph(%x):
881+
%res = aten::hardsigmoid(%x)
882+
return (%res) )";
883+
884+
std::string ipex_hardsigmoid = R"(
885+
graph(%x):
886+
%res = ipex::hardsigmoid(%x)
887+
return (%res) )";
888+
889+
SubgraphRewriter hardsigmoid_replacement;
890+
hardsigmoid_replacement.RegisterRewritePattern(
891+
aten_hardsigmoid, ipex_hardsigmoid);
892+
hardsigmoid_replacement.runOnGraph(graph);
893+
}
894+
878895
} // namespace graph_rewrite
879896
} // namespace jit
880897
} // namespace torch

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ void fuseConvWithEltwise(std::shared_ptr<Graph>& graph);
4646
void fuseConvAddRelu(std::shared_ptr<Graph>& graph);
4747
void fuseBottleneck(std::shared_ptr<Graph>& graph);
4848

49+
// This graph pass is to replace at::hardsigmoid with IPEX hardsigmoid.
50+
// Because NNC pulls aten::hardsigmoidn into its fusion group while its
51+
// performance might not be good engouh if the mout outer loop is small. Besides
52+
// that, IPEX will use oneDNN post-op to fuse hard sigmoid. Hence, this graph
53+
// pass is a workaround for this release and will be removed in the next major
54+
// release.
55+
void ReplaceHardsigmoidWithIPEX(std::shared_ptr<Graph>& graph);
56+
4957
void RecordAtenLinearNodes(
5058
std::shared_ptr<Graph>& graph,
5159
std::unordered_set<Node*>& aten_linear);

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "csrc/jit/cpu/kernels/ConvTransposePacked.h"
1010
#include "csrc/jit/cpu/kernels/Einsum.h"
1111
#include "csrc/jit/cpu/kernels/Embeddingbag.h"
12+
#include "csrc/jit/cpu/kernels/Hardsigmoid.h"
1213
#include "csrc/jit/cpu/kernels/Interaction.h"
1314
#include "csrc/jit/cpu/kernels/LinearPacked.h"
1415
#include "csrc/jit/cpu/kernels/LinearSwishCustomized.h"
@@ -1140,7 +1141,19 @@ RegisterOperators op({
11401141
};
11411142
},
11421143
aliasAnalysisFromSchema()),
1144+
Operator(
1145+
"ipex::hardsigmoid(Tensor input) -> Tensor",
1146+
[](const Node* node) -> Operation {
1147+
return [](Stack* stack) {
1148+
auto result =
1149+
dil_hardsigmoid((std::move(peek(stack, 0, 1))).toTensor());
11431150

1151+
drop(stack, 1);
1152+
pack(stack, std::move(result));
1153+
return 0;
1154+
};
1155+
},
1156+
aliasAnalysisFromSchema()),
11441157
});
11451158
} // namespace jit
11461159
} // namespace torch

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
162162
graph_rewrite::fuseLinearAddRelu(graph);
163163
GRAPH_DUMP("After fuseLinearAddRelu.", graph);
164164

165+
GRAPH_DEBUG("Before replacing hardsigmoid", graph);
166+
graph_rewrite::ReplaceHardsigmoidWithIPEX(graph);
167+
GRAPH_DEBUG("After replacing hardsigmoid", graph);
168+
165169
graph_rewrite::FuseLinearSwishCustomized(graph);
166170
// fuse add+layernorm
167171
graph_rewrite::FuseAddLayerNorm(graph);

tests/cpu/test_jit.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def forward(self, x):
589589
a1 = self.conv_transpose(x)
590590
b1 = torch.sigmoid(a1)
591591
c1 = self.mul_op(a1, b1)
592-
return c1
592+
return c1
593593

594594
class ChannelShuffle_with_Static_Shape(nn.Module):
595595
def __init__(self, batchsize, num_channels, height, width, groups):
@@ -992,10 +992,10 @@ def _test_onednn_fp32(self, model, input, kind_in_graph=None, kind_not_in_graph=
992992
trace_graph = tr_model.graph_for(input)
993993
res_jit = tr_model(input)
994994
self.assertEqual(res_ref, res_jit)
995-
995+
996996
if kind_in_graph is not None:
997997
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
998-
998+
999999
if kind_not_in_graph is not None:
10001000
self.assertTrue(all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
10011001

@@ -1054,8 +1054,8 @@ def _test_fusion_unsupported_case(self, m, x, auto_kernel_selection=False, kind_
10541054
traced_model = torch.jit.trace(model, x).eval()
10551055
traced_model = torch.jit.freeze(traced_model)
10561056
tresult = traced_model(x)
1057-
trace_graph = traced_model.graph_for(x)
1058-
1057+
trace_graph = traced_model.graph_for(x)
1058+
10591059
if kind_in_graph is not None:
10601060
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
10611061

@@ -1088,7 +1088,7 @@ def test_jit_freeze(self):
10881088
self.assertTrue(all(n.kind() != pack_node for n in freeze_graph.nodes()))
10891089
# for non-freeze model, since op-ctx dose not have value, cannot re-pack for this path
10901090
self.assertTrue(any(n.kind() == imperative_node for n in trace_graph.nodes()))
1091-
1091+
10921092

10931093
def test_concat_linear(self):
10941094
def check_op_count(graph_str, op_names=[]):
@@ -1459,7 +1459,7 @@ def _test_conv_unary_fusion(self, op_list, seed=None):
14591459
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
14601460
torch.manual_seed(rand_seed)
14611461
else:
1462-
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, seed))
1462+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, seed))
14631463
torch.manual_seed(seed)
14641464

14651465
for dim in [2, 3]:
@@ -1502,7 +1502,7 @@ def _test_conv_transpose_unary_fusion(self, op_list, seed=None):
15021502
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
15031503
torch.manual_seed(rand_seed)
15041504
else:
1505-
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, seed))
1505+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, seed))
15061506
torch.manual_seed(seed)
15071507

15081508
for dim in [2, 3]:
@@ -1538,7 +1538,7 @@ def test_conv_unary_fusion(self):
15381538
self._test_conv_unary_fusion(unary_PyTorch_op_to_IPEX_op_map)
15391539
self._test_conv_unary_fusion(PyTorch_op_to_IPEX_op_fixed_seed_map, 1654064339261196288)
15401540

1541-
def test_conv_non_unary_fusion(self):
1541+
def test_conv_non_unary_fusion(self):
15421542
self._test_conv_unary_fusion(non_unary_PyTorch_op_to_IPEX_op_map)
15431543

15441544
def test_conv_fusion_unsupported_case(self):
@@ -1548,7 +1548,7 @@ def test_conv_fusion_unsupported_case(self):
15481548
out_channels = 16
15491549
in_channels = 3
15501550
kernel_size = 3
1551-
image_size = 16
1551+
image_size = 16
15521552
for eltwise in unsupported_PyTorch_op_to_IPEX_op_map:
15531553
input_size = [batch_size, in_channels, image_size, image_size]
15541554

@@ -1560,7 +1560,7 @@ def test_conv_fusion_unsupported_case(self):
15601560

15611561
x = torch.randn(input_size)
15621562
m = ConvEltwise(eltwise, dim, in_channels, out_channels, kernel_size, image_size, **op_input_list)
1563-
1563+
15641564
self._test_fusion_unsupported_case(
15651565
m,
15661566
x,
@@ -1640,7 +1640,7 @@ def test_output_frozen_conv_bn(self):
16401640
if use_channels_last:
16411641
x = x.to(memory_format=torch.channels_last)
16421642
model = model.to(memory_format=torch.channels_last)
1643-
1643+
16441644
model = ipex.optimize(model, dtype=dtype, conv_bn_folding=False)
16451645

16461646
with torch.cpu.amp.autocast(enabled=True, dtype=dtype), torch.no_grad():
@@ -2345,7 +2345,7 @@ def test_conv_transpose_unary_fusion(self):
23452345
self._test_conv_transpose_unary_fusion(unary_PyTorch_op_to_IPEX_op_map)
23462346
self._test_conv_transpose_unary_fusion(PyTorch_op_to_IPEX_op_fixed_seed_map, 1654583254233936896)
23472347

2348-
def test_conv_transpose_non_unary_fusion(self):
2348+
def test_conv_transpose_non_unary_fusion(self):
23492349
self._test_conv_transpose_unary_fusion(non_unary_PyTorch_op_to_IPEX_op_map)
23502350

23512351
def test_conv_transpose_fusion_unsupported_case(self):
@@ -2356,7 +2356,7 @@ def test_conv_transpose_fusion_unsupported_case(self):
23562356
in_channels = 3
23572357
kernel_size = 3
23582358
image_size = 8
2359-
2359+
23602360
for eltwise in unsupported_PyTorch_op_to_IPEX_op_map:
23612361
input_size = [batch_size, in_channels, image_size, image_size]
23622362

@@ -2405,7 +2405,7 @@ def test_conv_transpose_sigmoid_mul(self):
24052405
# x,
24062406
# kind_in_graph="ipex_prepack::conv_transpose_%s_run" % ipex_eltwise_op,
24072407
# kind_not_in_graph="ipex_prepack::conv_transpose_prepack",
2408-
# prec=prec)
2408+
# prec=prec)
24092409

24102410
def test_linear_auto_kernel_selection_fp32(self):
24112411
x = torch.rand(32, 3)
@@ -2559,22 +2559,22 @@ def _test_linear_unary_fusion(self, op_list, seed=None):
25592559
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
25602560
torch.manual_seed(rand_seed)
25612561
else:
2562-
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, seed))
2563-
torch.manual_seed(seed)
2562+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, seed))
2563+
torch.manual_seed(seed)
25642564

25652565
for bias in [True, False]:
25662566
for eltwise in op_list:
25672567
input_size = [batch_size, in_channels]
2568-
2568+
25692569
unary_fusion_op = op_list[eltwise]
25702570
ipex_eltwise_op = unary_fusion_op.ipex_eltwise_op
25712571
bf16_supported = unary_fusion_op.bf16_supported
25722572
prec = unary_fusion_op.prec
25732573
op_input_list = unary_fusion_op.op_input_list
2574-
2574+
25752575
x = torch.randn(input_size)
25762576
m = LinearEltwise(eltwise, in_channels, out_channels, bias, **op_input_list)
2577-
2577+
25782578
self._test_output(
25792579
m,
25802580
x,
@@ -2583,7 +2583,7 @@ def _test_linear_unary_fusion(self, op_list, seed=None):
25832583
m,
25842584
x,
25852585
kind_in_graph="ipex_prepack::linear_%s_run" % ipex_eltwise_op,
2586-
kind_not_in_graph="ipex_prepack::linear_prepack")
2586+
kind_not_in_graph="ipex_prepack::linear_prepack")
25872587
if bf16_supported:
25882588
self._test_output_bf16(
25892589
m,
@@ -2603,17 +2603,17 @@ def test_linear_fusion_unsupported_case(self):
26032603
batch_size = 3
26042604
out_channels = 32
26052605
in_channels = 3
2606-
bias = False
2606+
bias = False
26072607

26082608
for eltwise in unsupported_PyTorch_op_to_IPEX_op_map:
26092609
input_size = [batch_size, in_channels]
2610-
2610+
26112611
unary_fusion_op = unsupported_PyTorch_op_to_IPEX_op_map[eltwise]
26122612
ipex_eltwise_op = unary_fusion_op.ipex_eltwise_op
26132613
bf16_supported = unary_fusion_op.bf16_supported
26142614
prec = unary_fusion_op.prec
26152615
op_input_list = unary_fusion_op.op_input_list
2616-
2616+
26172617
x = torch.randn(input_size)
26182618
m = LinearEltwise(eltwise, in_channels, out_channels, bias, **op_input_list)
26192619

@@ -2839,7 +2839,7 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
28392839
input2 = torch.randn(768, 2304)
28402840
model_v1 = EinsumAdd('bsh,ho->bso')
28412841
_test_fp32(model_v1, input1, input2, bias)
2842-
2842+
28432843
bias = torch.randn(1, 1, 1, 4)
28442844
input1 = torch.randn(12, 1, 4, 16)
28452845
input2 = torch.randn(12, 4, 4, 16)
@@ -2851,7 +2851,7 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
28512851
input2 = torch.randn(768, 2304)
28522852
model_v1 = EinsumAddInplace('bsh,ho->bso')
28532853
_test_fp32(model_v1, input1, input2, bias)
2854-
2854+
28552855
input1 = torch.randn(8, 3, 768)
28562856
input2 = torch.randn(768, 2304)
28572857
model = EinsumAddScalar('bsh,ho->bso').eval()
@@ -2876,7 +2876,7 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
28762876
input4 = torch.randn(2, 4, 128, 768)
28772877
model_v2 = EinsumAdd("bnqd,bnkd->bnqk")
28782878
_test_fp32(model_v2, input3, input4, bias1)
2879-
2879+
28802880
bias1 = torch.randn(8, 1, 1, 128)
28812881
input3 = torch.randn(8, 4, 128, 768)
28822882
input4 = torch.randn(8, 4, 128, 768)
@@ -2900,34 +2900,34 @@ def _test_fp32(model_test, input1, input2, bias=None, kind_in_graph='ipex::einsu
29002900
input2 = torch.randn(1024, 768)
29012901
model_v2 = EinsumAdd("mc,cn->mn")
29022902
_test_fp32(model_v2, input1, input2, bias1)
2903-
2903+
29042904
bias1 = torch.randn(1024)
29052905
input1 = torch.randn(1024, 1024)
29062906
input2 = torch.randn(1024, 1024)
29072907
model_v2 = EinsumAdd("mc,cn->nm")
29082908
_test_fp32(model_v2, input1, input2, bias1)
2909-
2909+
29102910
bias1 = torch.randn(768)
29112911
input1 = torch.randn(2, 128, 1024)
29122912
input2 = torch.randn(1024, 23, 768)
29132913
model_v2 = EinsumAdd("bqc,chv->bqhv")
29142914
_test_fp32(model_v2, input1, input2, bias1)
2915-
2915+
29162916
bias = torch.randn(768)
29172917
input1 = torch.randn(2, 128, 16, 64)
29182918
input2 = torch.randn(16,64, 768)
29192919
model = EinsumAdd("bqhc,hco->bqo")
29202920
_test_fp32(model, input1, input2, bias)
2921-
2921+
29222922
bias = torch.randn(8)
29232923
input1 = torch.randn(8)
29242924
input2 = torch.randn(8)
29252925
model = EinsumAdd("i,i->")
29262926
_test_fp32(model, input1, input2, bias)
2927-
2928-
#the output of torch.einsum("ij,j") is tensor([])
2927+
2928+
#the output of torch.einsum("ij,j") is tensor([])
29292929
bias = torch.randn(1)
2930-
input1 = torch.randn(0, 3)
2930+
input1 = torch.randn(0, 3)
29312931
input2 = torch.randn(3)
29322932
model = EinsumAdd(("ij,j"))
29332933
_test_fp32(model, input1, input2, bias)
@@ -3042,7 +3042,7 @@ def forward(self, x):
30423042
x1 = self.eltwise(x1, **self.params_dict)
30433043
return x1
30443044

3045-
for eltwise in ['sigmoid', 'tanh', 'celu', 'elu', 'hardsigmoid', 'hardswish', 'hardtanh', 'leaky_relu', 'relu6', 'relu', 'rrelu', 'selu', 'silu']:
3045+
for eltwise in ['sigmoid', 'tanh', 'celu', 'elu', 'hardswish', 'hardtanh', 'leaky_relu', 'relu6', 'relu', 'rrelu', 'selu', 'silu']:
30463046
eltwise_fn_name = eltwise + '_'
30473047
if eltwise in ['sigmoid', 'tanh', 'celu', 'relu', 'rrelu', 'selu']:
30483048
#use torch.sigmoid_(x)
@@ -3137,6 +3137,22 @@ def forward(self, x):
31373137
kind_not_in_graph="aten::mul",
31383138
prec=0.1)
31393139

3140+
def test_hardsigmoid_mul(self):
3141+
class HardsigmoidMul(nn.Module):
3142+
def __init__(self) -> None:
3143+
super(HardsigmoidMul, self).__init__()
3144+
self.hard_sigmoid = nn.Hardsigmoid()
3145+
3146+
def forward(self, x):
3147+
return self.hard_sigmoid(x) * x
3148+
3149+
model = HardsigmoidMul().eval()
3150+
self._test_output(
3151+
model,
3152+
torch.randn(2, 3, 4, 5),
3153+
kind_in_graph="ipex::hardsigmoid",
3154+
kind_not_in_graph="aten::hardsigmoid")
3155+
31403156
if __name__ == '__main__':
31413157
torch.manual_seed(2020)
31423158
test = unittest.main()

0 commit comments

Comments
 (0)