@@ -7,91 +7,103 @@ namespace jit {
77namespace fuser {
88namespace onednn {
99
10+ bool usedBySingleOp (Value* v) {
11+ return v->uses ().size () == 1 ;
12+ }
13+
1014class QuantLifter {
1115 private:
1216 std::shared_ptr<Graph> graph_;
1317
1418 public:
1519 QuantLifter (std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
1620
17- bool analyzeNode (Node* node) {
18- if (node->kind () != Symbol::aten (" quantize_per_tensor" ) &&
19- node->kind () != aten::to) {
20- return false ;
21- }
22-
23- // TODO: only supported nb_uses to be 1 for now
24- auto * output_value = node->output (0 );
25- auto & uses = output_value->uses ();
26- if (uses.size () != 1 ) {
27- return false ;
28- }
21+ bool analyze (Block* block) {
22+ bool changed = false ;
23+ for (auto it = block->nodes ().begin (); it != block->nodes ().end (); ++it) {
24+ auto node = *it;
2925
30- auto user = uses[0 ].user ;
31- auto target = node;
32-
33- auto prev_node = node->input (0 )->node ();
34-
35- bool could_lift_up = true ;
36- while (could_lift_up) {
37- if (utils::isViewOp (target->input (0 )->node ())) {
38- target = target->input (0 )->node ();
39-
40- // After lifting up, need to fix the output type
41- auto prev_target_type = target->output (0 )->type ()->expect <TensorType>();
42- auto new_scalar_type =
43- node->output (0 )->type ()->expect <TensorType>()->scalarType ();
44- auto new_target_type =
45- prev_target_type->withScalarType (new_scalar_type);
46- target->output (0 )->setType (new_target_type);
47- } else {
48- could_lift_up = false ;
26+ if (node->kind () != Symbol::aten (" quantize_per_tensor" ) &&
27+ node->kind () != aten::to) {
28+ continue ;
4929 }
50- }
51-
52- // No possible lift up, return directly
53- if (target == node) {
54- return false ;
55- }
5630
57- // From:
58- // linear -> view (target) -> permute -> transpose -> to (node) -> quant
59- // To:
60- // linear -> to -> view (target) -> permute -> transpose -> quant
61- // Finally:
62- // linear -> to -> quant -> view -> permute -> transpose
63- WithInsertPoint guard (target);
64- auto g = target->owningGraph ();
65-
66- // Construct lifted up node
67- std::vector<Value*> input_values;
68- input_values.push_back (target->input (0 ));
69- for (size_t i = 1 ; i < node->inputs ().size (); i++) {
70- input_values.push_back (node->input (i));
71- }
72- auto new_node = g->create (node->kind (), input_values)->insertBefore (target);
31+ // TODO: only supported nb_uses to be 1 for now
32+ auto * output_value = node->output (0 );
33+ auto & uses = output_value->uses ();
34+ if (uses.size () != 1 ) {
35+ continue ;
36+ }
7337
74- // Fix type of the output of lifted up node
75- auto insert_point_output_type =
76- target->input (0 )->type ()->expect <TensorType>();
77- auto old_node_type = node->input (0 )->type ()->expect <TensorType>();
78- auto new_node_type =
79- insert_point_output_type->withScalarType (old_node_type->scalarType ());
80- new_node->output (0 )->setType (new_node_type);
38+ auto user = uses[0 ].user ;
39+ auto target = node;
40+
41+ auto prev_node = node->input (0 )->node ();
42+
43+ bool could_lift_up = true ;
44+ while (could_lift_up) {
45+ auto * target_value = target->input (0 );
46+ if (utils::isViewOp (target_value->node ()) &&
47+ (usedBySingleOp (target_value))) {
48+ target = target_value->node ();
49+
50+ // After lifting up, need to fix the output type
51+ auto prev_target_type =
52+ target->output (0 )->type ()->expect <TensorType>();
53+ auto new_scalar_type =
54+ node->output (0 )->type ()->expect <TensorType>()->scalarType ();
55+ auto new_target_type =
56+ prev_target_type->withScalarType (new_scalar_type);
57+ target->output (0 )->setType (new_target_type);
58+ } else {
59+ could_lift_up = false ;
60+ }
61+ }
8162
82- target->replaceInputWith (target->input (0 ), new_node->output (0 ));
83- user->replaceInputWith (node->output (0 ), prev_node->output (0 ));
63+ // No possible lift up, return directly
64+ if (target == node) {
65+ continue ;
66+ }
8467
85- return true ;
68+ // From:
69+ // linear -> view (target) -> permute -> transpose -> to (node) -> quant
70+ // To:
71+ // linear -> to -> view (target) -> permute -> transpose -> quant
72+ // Finally:
73+ // linear -> to -> quant -> view -> permute -> transpose
74+ WithInsertPoint guard (target);
75+ auto g = target->owningGraph ();
76+
77+ // Construct lifted up node
78+ std::vector<Value*> input_values;
79+ input_values.push_back (target->input (0 ));
80+ for (size_t i = 1 ; i < node->inputs ().size (); i++) {
81+ input_values.push_back (node->input (i));
82+ }
83+ auto new_node =
84+ g->create (node->kind (), input_values)->insertBefore (target);
85+
86+ // Fix type of the output of lifted up node
87+ auto insert_point_output_type =
88+ target->input (0 )->type ()->expect <TensorType>();
89+ auto old_node_type = node->input (0 )->type ()->expect <TensorType>();
90+ auto new_node_type =
91+ insert_point_output_type->withScalarType (old_node_type->scalarType ());
92+ new_node->output (0 )->setType (new_node_type);
93+
94+ target->replaceInputWith (target->input (0 ), new_node->output (0 ));
95+ user->replaceInputWith (node->output (0 ), prev_node->output (0 ));
96+
97+ it.destroyCurrent ();
98+ changed = true ;
99+ }
100+ return changed;
86101 }
87102
88103 void run () {
89104 bool changed = true ;
90105 while (changed) {
91- changed = false ;
92- for (Node* node : graph_->block ()->nodes ()) {
93- changed |= analyzeNode (node);
94- }
106+ changed = analyze (graph_->block ());
95107 }
96108 }
97109};
0 commit comments