Skip to content

Commit 75eb59d

Browse files
authored
[1.12] bridge: only lift up viewOp if used by one single Op (#909)
* bridge: lift up viewOp if not used by Ops requiring type promotion * only rewrite if viewOp is used by one single Op
1 parent e7b925a commit 75eb59d

File tree

2 files changed

+123
-67
lines changed

2 files changed

+123
-67
lines changed

intel_extension_for_pytorch/csrc/jit/codegen/onednn/lift_up_quant.cpp

Lines changed: 79 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,91 +7,103 @@ namespace jit {
77
namespace fuser {
88
namespace onednn {
99

10+
bool usedBySingleOp(Value* v) {
11+
return v->uses().size() == 1;
12+
}
13+
1014
class QuantLifter {
1115
private:
1216
std::shared_ptr<Graph> graph_;
1317

1418
public:
1519
QuantLifter(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
1620

17-
bool analyzeNode(Node* node) {
18-
if (node->kind() != Symbol::aten("quantize_per_tensor") &&
19-
node->kind() != aten::to) {
20-
return false;
21-
}
22-
23-
// TODO: only supported nb_uses to be 1 for now
24-
auto* output_value = node->output(0);
25-
auto& uses = output_value->uses();
26-
if (uses.size() != 1) {
27-
return false;
28-
}
21+
bool analyze(Block* block) {
22+
bool changed = false;
23+
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
24+
auto node = *it;
2925

30-
auto user = uses[0].user;
31-
auto target = node;
32-
33-
auto prev_node = node->input(0)->node();
34-
35-
bool could_lift_up = true;
36-
while (could_lift_up) {
37-
if (utils::isViewOp(target->input(0)->node())) {
38-
target = target->input(0)->node();
39-
40-
// After lifting up, need to fix the output type
41-
auto prev_target_type = target->output(0)->type()->expect<TensorType>();
42-
auto new_scalar_type =
43-
node->output(0)->type()->expect<TensorType>()->scalarType();
44-
auto new_target_type =
45-
prev_target_type->withScalarType(new_scalar_type);
46-
target->output(0)->setType(new_target_type);
47-
} else {
48-
could_lift_up = false;
26+
if (node->kind() != Symbol::aten("quantize_per_tensor") &&
27+
node->kind() != aten::to) {
28+
continue;
4929
}
50-
}
51-
52-
// No possible lift up, return directly
53-
if (target == node) {
54-
return false;
55-
}
5630

57-
// From:
58-
// linear -> view (target) -> permute -> transpose -> to (node) -> quant
59-
// To:
60-
// linear -> to -> view (target) -> permute -> transpose -> quant
61-
// Finally:
62-
// linear -> to -> quant -> view -> permute -> transpose
63-
WithInsertPoint guard(target);
64-
auto g = target->owningGraph();
65-
66-
// Construct lifted up node
67-
std::vector<Value*> input_values;
68-
input_values.push_back(target->input(0));
69-
for (size_t i = 1; i < node->inputs().size(); i++) {
70-
input_values.push_back(node->input(i));
71-
}
72-
auto new_node = g->create(node->kind(), input_values)->insertBefore(target);
31+
// TODO: only supported nb_uses to be 1 for now
32+
auto* output_value = node->output(0);
33+
auto& uses = output_value->uses();
34+
if (uses.size() != 1) {
35+
continue;
36+
}
7337

74-
// Fix type of the output of lifted up node
75-
auto insert_point_output_type =
76-
target->input(0)->type()->expect<TensorType>();
77-
auto old_node_type = node->input(0)->type()->expect<TensorType>();
78-
auto new_node_type =
79-
insert_point_output_type->withScalarType(old_node_type->scalarType());
80-
new_node->output(0)->setType(new_node_type);
38+
auto user = uses[0].user;
39+
auto target = node;
40+
41+
auto prev_node = node->input(0)->node();
42+
43+
bool could_lift_up = true;
44+
while (could_lift_up) {
45+
auto* target_value = target->input(0);
46+
if (utils::isViewOp(target_value->node()) &&
47+
(usedBySingleOp(target_value))) {
48+
target = target_value->node();
49+
50+
// After lifting up, need to fix the output type
51+
auto prev_target_type =
52+
target->output(0)->type()->expect<TensorType>();
53+
auto new_scalar_type =
54+
node->output(0)->type()->expect<TensorType>()->scalarType();
55+
auto new_target_type =
56+
prev_target_type->withScalarType(new_scalar_type);
57+
target->output(0)->setType(new_target_type);
58+
} else {
59+
could_lift_up = false;
60+
}
61+
}
8162

82-
target->replaceInputWith(target->input(0), new_node->output(0));
83-
user->replaceInputWith(node->output(0), prev_node->output(0));
63+
// No possible lift up, return directly
64+
if (target == node) {
65+
continue;
66+
}
8467

85-
return true;
68+
// From:
69+
// linear -> view (target) -> permute -> transpose -> to (node) -> quant
70+
// To:
71+
// linear -> to -> view (target) -> permute -> transpose -> quant
72+
// Finally:
73+
// linear -> to -> quant -> view -> permute -> transpose
74+
WithInsertPoint guard(target);
75+
auto g = target->owningGraph();
76+
77+
// Construct lifted up node
78+
std::vector<Value*> input_values;
79+
input_values.push_back(target->input(0));
80+
for (size_t i = 1; i < node->inputs().size(); i++) {
81+
input_values.push_back(node->input(i));
82+
}
83+
auto new_node =
84+
g->create(node->kind(), input_values)->insertBefore(target);
85+
86+
// Fix type of the output of lifted up node
87+
auto insert_point_output_type =
88+
target->input(0)->type()->expect<TensorType>();
89+
auto old_node_type = node->input(0)->type()->expect<TensorType>();
90+
auto new_node_type =
91+
insert_point_output_type->withScalarType(old_node_type->scalarType());
92+
new_node->output(0)->setType(new_node_type);
93+
94+
target->replaceInputWith(target->input(0), new_node->output(0));
95+
user->replaceInputWith(node->output(0), prev_node->output(0));
96+
97+
it.destroyCurrent();
98+
changed = true;
99+
}
100+
return changed;
86101
}
87102

88103
void run() {
89104
bool changed = true;
90105
while (changed) {
91-
changed = false;
92-
for (Node* node : graph_->block()->nodes()) {
93-
changed |= analyzeNode(node);
94-
}
106+
changed = analyze(graph_->block());
95107
}
96108
}
97109
};

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,50 @@ def forward(self, x, y):
10191019
self.assertFused(graph, ['aten::dequantize', 'aten::linear', 'aten::matmul'])
10201020
self.checkPatterns(graph, patterns)
10211021

1022+
def test_lift_up_quant_unsupported(self):
1023+
# Original graph:
1024+
# |
1025+
# view
1026+
# / (f32)\ /(f32)
1027+
# quant add
1028+
# |
1029+
1030+
# Lifting up in this case will raise:
1031+
# promoteTypes with quantized numbers is not handled in aten::add;
1032+
# |
1033+
# quant
1034+
# |
1035+
# view
1036+
# (int8)\ /(f32)
1037+
# add
1038+
class M(nn.Module):
1039+
def __init__(self):
1040+
super(M, self).__init__()
1041+
self.conv1 = nn.Conv2d(3, 8, 1)
1042+
self.conv2 = nn.Conv2d(8, 8, 1)
1043+
1044+
def forward(self, x, y):
1045+
x = self.conv1(x)
1046+
z1 = x.permute(0, 3, 1, 2)
1047+
z2 = self.conv2(z1)
1048+
z = z1 + y
1049+
output = z2 + z
1050+
return output
1051+
1052+
x = torch.randn(1, 3, 8, 8)
1053+
y = torch.randn(1, 8, 8, 8)
1054+
m = M()
1055+
1056+
patterns = [
1057+
["aten::dequantize", "aten::_convolution"],
1058+
["aten::dequantize", "aten::_convolution", "aten::add"],
1059+
]
1060+
1061+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1062+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1063+
self.assertFused(graph, ['aten::_convolution', 'aten::dequantize'])
1064+
self.checkPatterns(graph, patterns)
1065+
10221066
def test_wildcard(self):
10231067
class M(nn.Module):
10241068
def __init__(self):

0 commit comments

Comments
 (0)