Skip to content

Commit 07b9ae7

Browse files
enable quantize AdaptiveAvgPool2d and flatten fusion group in ipex side (#140)
* enable quantize AdaptiveAvgPool2d and flatten fusion group in ipex side * add test case * record flatten data flow for ipex int8 fusion path * change code format * refine code * fix undefined symbol error when debug buiild
1 parent 40a248a commit 07b9ae7

File tree

11 files changed

+269
-44
lines changed

11 files changed

+269
-44
lines changed

intel_pytorch_extension_py/conf.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def save(self, configure_file, default_recipe=True):
4242
def get_default_recipe(self, configures):
4343
elt_wise = ['relu', 'sigmoid', 'gelu']
4444
inplace_ops = ['relu_', 'add_']
45+
shape_ops = ['flatten']
4546
# get default recipe,
4647
# q+dq+conv+q+dq+relu => q+dq+conv+relu
4748
# q+dq+op1+q+dq+q+dq+op2+q+dq => q+dq+op1+q+dq+op2+q+dq
@@ -75,6 +76,19 @@ def get_default_recipe(self, configures):
7576
default_configures[cur_id]['inputs_quantized'][i_num] = False
7677
if cur_op == 'add':
7778
pre_ops[i_num] = pre_op
79+
if cur_op in shape_ops:
80+
# for pooling case, the input and output always has same scale and zero point,
81+
# if the pooling's post ops is flatten, need sync flatten's input and output's
82+
# scale and zero point to pooling.
83+
if pre_op in ['max_pool2d', 'adaptive_avg_pool2d']:
84+
default_configures[cur_id]['input_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num]
85+
default_configures[cur_id]['input_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num]
86+
default_configures[cur_id]['output_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num]
87+
default_configures[cur_id]['output_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num]
88+
if pre_op in shape_ops:
89+
# if pre op is flatten, sync the input's scale and zero point to flatten.
90+
default_configures[cur_id]['input_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num]
91+
default_configures[cur_id]['input_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num]
7892
# conv op conv op
7993
# \ / \ /
8094
# q q \ q
@@ -98,10 +112,17 @@ def get_default_recipe(self, configures):
98112
# post process for add, linear, if cur op hasn't post quantized op, i.e. 'outputs_quantized' is True,
99113
# for good perfromance, the default recipe:
100114
# int8_input -> op -> q -> dq will converted to int8_input -> op.
101-
post_process_ops = ['add', 'linear', 'conv2d']
115+
ops_remove_q_dq_after = ['add', 'linear', 'conv2d']
116+
# post process for flatten, if flatten's pre-pop and post op are fp32 op, don't need add q and dq
117+
# before and after it.
118+
ops_remove_q_dq_before_after = ['flatten']
102119
for cur_id in range(num_ops):
103120
cur_op = default_configures[cur_id]['name']
104-
if cur_op in post_process_ops and default_configures[cur_id]['outputs_quantized'][0]:
121+
if cur_op in ops_remove_q_dq_after and default_configures[cur_id]['outputs_quantized'][0]:
122+
default_configures[cur_id]['outputs_quantized'][0] = False
123+
if cur_op in ops_remove_q_dq_before_after and default_configures[cur_id]['inputs_quantized'][0] \
124+
and default_configures[cur_id]['outputs_quantized'][0]:
125+
default_configures[cur_id]['inputs_quantized'][0] = False
105126
default_configures[cur_id]['outputs_quantized'][0] = False
106127

107128
return default_configures
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import unittest
2+
import itertools
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch.testing import FileCheck
7+
8+
from test_jit_llga_utils import JitLlgaTestCase, run_tests, LLGA_FUSION_GROUP, llga_test_env
9+
10+
import intel_pytorch_extension as ipex
11+
12+
13+
class TestIpexOps(JitLlgaTestCase):
14+
@llga_test_env
15+
def test_adaptive_avg_pool2d(self):
16+
class M(nn.Module):
17+
def __init__(self):
18+
super(M, self).__init__()
19+
self.adaptive_avg_pool2d = nn.AdaptiveAvgPool2d((5,7))
20+
21+
def forward(self, x):
22+
x = self.adaptive_avg_pool2d(x)
23+
return x
24+
25+
m = M()
26+
x = torch.rand(1, 32, 28, 28)
27+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
28+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="adaptive_avg_pool2d", qscheme=qscheme)
29+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
30+
31+
32+
@llga_test_env
33+
def test_flatten_int8(self):
34+
class M(nn.Module):
35+
def __init__(self):
36+
super(M, self).__init__()
37+
self.conv1 = nn.Conv2d(3, 3, 2, padding=1, bias=True)
38+
self.pool = nn.MaxPool2d(2)
39+
self.flatten = nn.Flatten(1)
40+
self.linear = nn.Linear(147, 32)
41+
42+
def forward(self, x):
43+
x = self.conv1(x)
44+
x = self.pool(x)
45+
x = self.flatten(x)
46+
x = self.linear(x)
47+
return x
48+
49+
m = M()
50+
x = torch.rand(1, 3, 14, 14)
51+
patterns = [
52+
["aten::quantize_per_channel", "aten::dequantize", "aten::_convolution"],
53+
["aten::dequantize", "aten::max_pool2d", "aten::quantize_per_tensor"],
54+
["aten::quantize_per_channel", "aten::dequantize", "aten::linear"],
55+
]
56+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
57+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="flatten", qscheme=qscheme)
58+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
59+
self.checkPatterns(graph, patterns)
60+
61+
@llga_test_env
62+
def test_flatten_fp32(self):
63+
class M(nn.Module):
64+
def __init__(self):
65+
super(M, self).__init__()
66+
self.flatten = nn.Flatten(1)
67+
68+
def forward(self, x):
69+
x = self.flatten(x)
70+
return x
71+
72+
m = M()
73+
x = torch.rand(1, 3, 14, 14)
74+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
75+
graph = self.checkQuantizeTrace(m, [x], config_name="flatten", qscheme=qscheme)
76+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
77+
FileCheck().check_not("aten::quantize_per_tensor") \
78+
.check_not("at::dequantize") \
79+
.check("aten::flatten") \
80+
.run(graph)
81+
82+
83+
84+
if __name__ == '__main__':
85+
run_tests()

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import unittest
22
import itertools
3-
from functools import wraps
4-
53
import torch
64
import torch.nn as nn
75
import torch.nn.functional as F
8-
from test_jit_llga_utils import JitLlgaTestCase, run_tests, LLGA_FUSION_GROUP
6+
from test_jit_llga_utils import JitLlgaTestCase, run_tests, LLGA_FUSION_GROUP, llga_test_env
97
from torch.testing._internal.common_utils import TEST_SCIPY
108

119
import intel_pytorch_extension as ipex
@@ -27,21 +25,6 @@ def get_eltwise_fn(name):
2725
else:
2826
raise NameError('Eltwise function %s not found' % name)
2927

30-
# For LLGA UT, disable the PyTorch profiling executor and the IPEX JIT opt
31-
def llga_test_env(func):
32-
@wraps(func)
33-
def wrapTheFunction(*args):
34-
# make sure that the profiling mode is turned on
35-
torch._C._jit_set_profiling_mode(True)
36-
torch._C._jit_set_profiling_executor(True)
37-
38-
ipex.core._jit_set_llga_enabled(True)
39-
ipex.core.disable_jit_opt()
40-
func(*args)
41-
ipex.core.enable_jit_opt()
42-
ipex.core._jit_set_llga_enabled(False)
43-
return wrapTheFunction
44-
4528
class TestOp(JitLlgaTestCase):
4629
@llga_test_env
4730
def test_conv2d_int8_in_f32_out(self):
@@ -162,25 +145,6 @@ def test_max_pool2d(self):
162145
self.assertFused(graph, ['aten::max_pool2d'])
163146
self.checkPatterns(graph, patterns)
164147

165-
@llga_test_env
166-
@unittest.skipIf(True, 'int8 adaptive_avg_pool2d is not supported in the backend')
167-
def test_adaptive_avg_pool2d(self):
168-
m = nn.AdaptiveAvgPool2d((1, 1))
169-
N = torch.randint(3, 10, (1,)).item()
170-
C = torch.randint(3, 10, (1,)).item()
171-
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
172-
173-
patterns = [
174-
["aten::quantize_per_tensor"],
175-
["aten::dequantize", "aten::adaptive_avg_pool2d", "aten::quantize_per_tensor"],
176-
["aten::dequantize"]
177-
]
178-
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
179-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="adaptive_avg_pool2d", qscheme=qscheme)
180-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
181-
self.assertFused(graph, ['aten::adaptive_avg_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
182-
self.checkPatterns(graph, patterns)
183-
184148
class TestFusionPattern(JitLlgaTestCase):
185149
@llga_test_env
186150
def test_conv2d_eltwise(self):
@@ -408,7 +372,7 @@ def forward(self, x):
408372
new_x_shape = x.size()[:-1] + (3, 5)
409373
x = x.view(*new_x_shape)
410374
return x.permute(0, 2, 1, 3)
411-
375+
412376
x = torch.randn(5, 10, 15)
413377
m = M()
414378

@@ -434,7 +398,7 @@ def forward(self, x):
434398
x = self.conv1(x)
435399
x = self.conv2(x).reshape(x.size(0), 4, -1)
436400
return x
437-
401+
438402
x = torch.randn(15, 4, 28, 28)
439403
# change the size of the input, check the fallback
440404
x_var = torch.randn(7, 4, 16, 16)

tests/cpu/test_jit_llga_utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import copy
33
import tempfile
4-
54
import torch
5+
6+
from functools import wraps
67
from torch.testing._internal.jit_utils import JitTestCase, warmup_backward, \
78
get_execution_plan
89
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, \
@@ -14,6 +15,21 @@
1415

1516
LLGA_FUSION_GROUP = 'ipex::LlgaFusionGroup'
1617

18+
# For LLGA UT, disable the PyTorch profiling executor and the IPEX JIT opt
19+
def llga_test_env(func):
20+
@wraps(func)
21+
def wrapTheFunction(*args):
22+
# make sure that the profiling mode is turned on
23+
torch._C._jit_set_profiling_mode(True)
24+
torch._C._jit_set_profiling_executor(True)
25+
26+
ipex.core._jit_set_llga_enabled(True)
27+
ipex.core.disable_jit_opt()
28+
func(*args)
29+
ipex.core.enable_jit_opt()
30+
ipex.core._jit_set_llga_enabled(False)
31+
return wrapTheFunction
32+
1733
def all_backward_graphs(module):
1834
ge_state = module.get_debug_state()
1935
fwd_plan = get_execution_plan(ge_state)
@@ -133,7 +149,7 @@ def prepareModel(self, model, x, folding=False, remove_dropout=False, config_nam
133149

134150
# get the graph at the second run after freezing
135151
graph = model.graph_for(*x)
136-
152+
137153
return graph, model, fp32_model_with_quant_dequant
138154

139155
def checkPatterns(self, graph, patterns):

torch_ipex/csrc/autocast_kernel.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,5 +242,16 @@ lstm_aten(const at::Tensor &_input, at::TensorList hx, at::TensorList _params,
242242
return at::lstm(_input, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first);
243243
}
244244

245+
at::Tensor flatten(const at::Tensor &input, int64_t start_dim,
246+
int64_t end_dim) {
247+
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
248+
auto target_type = get_autocast_dtype();
249+
if (at::ScalarType::Char == target_type) {
250+
return int8::flatten(input, start_dim, end_dim);
251+
}
252+
// Fall Through.
253+
return at::flatten(input, start_dim, end_dim);
254+
}
255+
245256
} // autocast
246257
} // torch_ipex

torch_ipex/csrc/autocast_kernel.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,7 @@ lstm_aten(const at::Tensor &_input, at::TensorList hx, at::TensorList _params,
5454
bool has_biases, int64_t num_layers, double dropout_p, bool train,
5555
bool bidirectional, bool batch_first);
5656

57+
at::Tensor flatten(const at::Tensor &input, int64_t start_dim, int64_t end_dim);
58+
5759
} // autocast
5860
} // torch_ipex

torch_ipex/csrc/autocast_mode.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
772772
TORCH_FN((&torch_ipex::autocast::gelu)));
773773
m.impl(TORCH_SELECTIVE_NAME("aten::lstm.input"),
774774
TORCH_FN((&torch_ipex::autocast::lstm_aten)));
775+
m.impl(TORCH_SELECTIVE_NAME("aten::flatten.using_ints"),
776+
TORCH_FN((&torch_ipex::autocast::flatten)));
775777
}
776778

777779
} // namespace autocast

torch_ipex/csrc/jit/codegen/onednn/interface.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "jit/codegen/onednn/layout_propagation.h"
88
#include "jit/codegen/onednn/prepare_binary.h"
99
#include "jit/codegen/onednn/prepare_dequant.h"
10+
#include "jit/codegen/onednn/quantization_patterns.h"
11+
1012
#include <torch/csrc/jit/jit_log.h>
1113
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
1214
#include <torch/csrc/jit/passes/decompose_ops.h>
@@ -76,8 +78,10 @@ void fuseGraph(std::shared_ptr<Graph> &g) {
7678
g);
7779
RemoveTensorTypeSpecializations(g);
7880
GRAPH_DUMP(
79-
"After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
81+
"After RemoveTensorTypeSpecializations. Before IPEX optimization pass",
8082
g);
83+
IpexQuantFusion(g);
84+
GRAPH_DUMP("After IpexQuantFusion. End of IPEX optimization pass", g);
8185
}
8286
}
8387

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <string>
2+
#include <torch/csrc/jit/ir/ir.h>
3+
#include <torch/csrc/jit/ir/subgraph_matcher.h>
4+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
5+
6+
namespace torch {
7+
namespace jit {
8+
9+
struct FusionInfo {
10+
std::string quantized_op_name;
11+
std::string pattern;
12+
std::string replacement;
13+
std::vector<MatchFilter> filters = {};
14+
};
15+
16+
namespace {
17+
18+
std::string getArgList(std::vector<std::string> extra_args) {
19+
return std::accumulate(
20+
extra_args.begin(), extra_args.end(), std::string(),
21+
[](std::string acc, const std::string &arg) { return acc + ", " + arg; });
22+
}
23+
24+
FusionInfo getIpexFusionInfo(const std::string &fp_op_name,
25+
const std::string &q_op_name,
26+
const std::vector<std::string> &fp_extra_args,
27+
const std::vector<std::string> &q_extra_args) {
28+
const auto &fp_extra_arg_list = getArgList(fp_extra_args);
29+
const auto &q_extra_arg_list = getArgList(q_extra_args);
30+
31+
std::string op_pattern = "graph(%a_quant" + fp_extra_arg_list +
32+
", %r_scale, %r_zero_point, %r_dtype):" + R"(
33+
%a_dequant = aten::dequantize(%a_quant)
34+
%r = )" + fp_op_name +
35+
"(" + "%a_dequant" + fp_extra_arg_list + ")" + R"(
36+
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
37+
return (%r_quant) )";
38+
39+
std::string aten_op_pattern = "graph(%a_quant" + fp_extra_arg_list +
40+
", %r_scale, %r_zero_point, %r_dtype):" + R"(
41+
%r_quant = )" + q_op_name +
42+
"(%a_quant" + q_extra_arg_list + ")" + R"(
43+
return (%r_quant) )";
44+
45+
return {q_op_name, op_pattern, aten_op_pattern};
46+
}
47+
48+
} // namespace
49+
50+
void IpexQuantFusion(std::shared_ptr<Graph> &graph) {
51+
std::vector<FusionInfo> patterns;
52+
auto adaptive_avg_pool2d_patten = getIpexFusionInfo(
53+
"aten::adaptive_avg_pool2d", "aten::adaptive_avg_pool2d",
54+
{"%output_size"}, {"%output_size"});
55+
auto flatten_patten =
56+
getIpexFusionInfo("aten::flatten", "aten::flatten",
57+
{"%start_dim, %end_dim"}, {"%start_dim, %end_dim"});
58+
patterns.emplace_back(adaptive_avg_pool2d_patten);
59+
patterns.emplace_back(flatten_patten);
60+
for (const auto &info : patterns) {
61+
SubgraphRewriter rewriter;
62+
rewriter.RegisterRewritePattern(info.pattern, info.replacement);
63+
rewriter.runOnGraph(graph, info.filters);
64+
}
65+
}
66+
67+
} // namespace jit
68+
} // namespace torch

0 commit comments

Comments
 (0)