@@ -40,263 +40,6 @@ struct hash<std::pair<Symbol, Symbol>> {
4040
4141namespace torch {
4242namespace jit {
43- //
44- // The main goal of MKL-DNN fusion is to limit bandwidth wasting.
45- // MKL-DNN provided post ops to fuse ops in its output stage
46- // What we could do is listed inside RuleTab.
47- //
48- class OpFuser {
49- Block* block_;
50- std::unique_ptr<AliasDb> aliasDb_;
51- std::shared_ptr<Graph> graph_;
52- using Symbols = std::vector<Symbol>;
53- using RuleTab = std::unordered_map<::std::pair<Symbol, Symbol>, Symbol>;
54- using Rule = RuleTab::iterator;
55- static RuleTab dnnlRules;
56-
57- public:
58- OpFuser (Block* block, std::shared_ptr<Graph> graph)
59- : block_(block), graph_(std::move(graph)) {}
60-
61- void run () {
62- bool any_changed = true ;
63- while (any_changed) {
64- any_changed = false ;
65- refreshAliasDb ();
66- for (auto it = block_->nodes ().begin (); it != block_->nodes ().end ();) {
67- bool changed;
68- std::tie (it, changed) = processNode (*it);
69- any_changed |= changed;
70- }
71- }
72-
73- refreshAliasDb ();
74-
75- for (Node* node : block_->nodes ()) {
76- for (Block* sub : node->blocks ()) {
77- OpFuser (sub, graph_).run ();
78- }
79- }
80- }
81-
82- c10::optional<Rule> isFusable (Node* curr, Node* prev) const {
83- // Is it happening in our case ???
84- if (curr->owningBlock () != block_)
85- return c10::nullopt ;
86-
87- auto choice = dnnlRules.find ({prev->kind (), curr->kind ()});
88- if (choice != dnnlRules.end ())
89- return choice;
90-
91- return c10::nullopt ;
92- }
93-
94- void refreshAliasDb () {
95- aliasDb_ = std::make_unique<AliasDb>(graph_);
96- }
97-
98- Node* fuseOpsWithNewKind (Node* curr, Value* v, Graph* g, NodeKind kind) {
99- auto newNode = g->create (kind);
100- auto prev = v->node ();
101- newNode->insertBefore (prev);
102- newNode->setScope (prev->scope ());
103- newNode->copyAttributes (*prev);
104-
105- for (auto input : prev->inputs ()) {
106- newNode->addInput (input);
107- }
108-
109- for (auto input : curr->inputs ()) {
110- if (input != v) {
111- newNode->addInput (input);
112- }
113- }
114-
115- // Copy curr or prev?
116- newNode->output ()->copyMetadata (prev->output ());
117- newNode->output ()->setType (prev->output ()->type ());
118-
119- v->replaceAllUsesWith (newNode->output ());
120- curr->replaceAllUsesWith (newNode);
121-
122- prev->destroy ();
123- curr->destroy ();
124-
125- return newNode;
126- }
127-
128- Node* fuseNodes (Node* curr, Value* path, Rule rule) {
129- return fuseOpsWithNewKind (curr, path, curr->owningGraph (), rule->second );
130- }
131-
132- bool aliasIsSafeForSquashingValue (Node* node, Value* v) {
133- bool safe = false ;
134- auto prev = v->node ();
135- if (aliasDb_->moveAfterTopologicallyValid (node, prev)) {
136- if (v->uses ().size () == 1 ||
137- aliasDb_->mayAlias /* mustAlias */ (v, node->output ())) {
138- safe = true ;
139- }
140- }
141- return safe;
142- }
143-
144- //
145- // Check whether we could change specific input to be inplace with output
146- // Any use topologically after node will fail it.
147- // XXX: haven't considered loop
148- //
149- bool aliasIsSafeForInplaceValue (Node* node, Value* v) {
150- for (auto use : v->uses ())
151- if (use.user ->isAfter (node))
152- return false ;
153-
154- return true ;
155- }
156-
157- const FunctionSchema& matchSchemaForFusion (
158- c10::Symbol symbol,
159- Node* prev,
160- Node* node) {
161- auto ops = getAllOperatorsFor (symbol);
162-
163- for (auto & op : ops) {
164- auto & schema = op->schema ();
165- if (schema.arguments ().size () ==
166- prev->inputs ().size () + node->inputs ().size () - 1 &&
167- schema.returns ().size () == node->outputs ().size ())
168- return schema;
169- }
170-
171- // throw
172- auto er = ErrorReport (node->sourceRange ());
173- er << " Schema not found for fusion process. \n " ;
174- er << " Prev: " << *prev << " \n " ;
175- er << " Node: " << *node << " \n " ;
176-
177- if (ops.size () > 0 ) {
178- er << " \n candidates were:\n " ;
179- for (auto & op : ops)
180- er << " " << op->schema () << " \n " ;
181- } else {
182- er << " \n no candidates found\n " ;
183- }
184- er << " within the graph:\n " ;
185- er << *node->owningGraph () << " \n " ;
186- throw er;
187- }
188-
189- bool aliasIsSafeForFusion (Node* node, Value* v, c10::optional<Rule> r) {
190- bool safe = false ;
191- // Returns false if the two nodes to be fused do not have the same owning
192- // block
193- if (node->owningBlock () != v->node ()->owningBlock ()) {
194- return safe;
195- }
196- // TODO: it might be flawed because we don't have 'alias must' information
197- //
198- // Simple fusion, unary ops:
199- // Example: conv2d -> relu to conv2d_relu
200- //
201- // To maintain equivalence before and after fusion, we have some rules:
202- // 1. Op could be moved safely right after the op it fuse to.
203- // 2. If one of node's input and output are alias must (relu_?), we could
204- // replace all uses of input to use output, which remove the use that might
205- // clogging the fuse path which is to be squashed.
206- // 3. If there is no alias between input and output, we can only fuse the
207- // case when there is only use.
208- //
209- // Y-merge (conv-sum-relu?)
210- // 4. We aquire alias info from resulted op schema, check whether the fusion
211- // is not breaking any computational semantics.
212- //
213- // A Y-merge fusion, like:
214- // conv2d_inputs | or | conv2d_inputs
215- // / | | \
216- // x conv2d | | conv2d x
217- // \ / | | \ /
218- // add | | add
219- // | | | |
220- // y | | y
221- //
222- // both to:
223- //
224- // conv2d_inputs x(a!)
225- // \ /
226- // conv2d_sum
227- // |
228- // y(a!)
229- //
230- // Which y is alias to x, we check whether later is equivalent to formal.
231- // The params convention when we do Y-merge: arguments from both ops comes
232- // to new op in topological order. So in the exmaple conv2d's inputs comes
233- // first then sum's inputs (without the input which is squashed).
234- //
235- safe = aliasIsSafeForSquashingValue (node, v);
236-
237- //
238- // Y-merge like case
239- //
240- if (safe && node->inputs ().size () > 1 ) {
241- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (r);
242- auto rule = *r.value ();
243- auto & schema = matchSchemaForFusion (rule.second , v->node (), node);
244- auto o_schema = node->schema ();
245-
246- auto pos = v->node ()->inputs ().size ();
247-
248- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (
249- schema.arguments ().size () == pos + node->inputs ().size () - 1 );
250-
251- for (int i = 0 ; i < node->inputs ().size (); ++i) {
252- if (node->input (i) != v) { /* avoid squashing path */
253- auto aliasInfo = schema.arguments ()[pos++].alias_info ();
254- if (!aliasInfo)
255- continue ;
256-
257- // Introdued new alias write to
258- if (aliasInfo->isWrite ()) {
259- auto old_info = o_schema.arguments ()[i].alias_info ();
260- if (!old_info || !old_info->isWrite ()) {
261- // Introduced new written to alias
262- safe = safe && aliasIsSafeForInplaceValue (node, node->input (i));
263- }
264- }
265- }
266- }
267-
268- // XXX: Do we have to handle output alias change case?
269- }
270- return safe;
271- }
272-
273- std::pair<graph_node_list::iterator, bool > processNode (Node* node) {
274- Node* pos = node;
275- bool changed = false ;
276-
277- //
278- // Check whether we could fuse to one certain value path
279- //
280- for (auto * v : node->inputs ()) {
281- auto prev = v->node ();
282- auto fuseRule = isFusable (node, prev);
283-
284- // We can fuse only one path
285- if (fuseRule && aliasIsSafeForFusion (node, v, fuseRule)) {
286- pos = fuseNodes (node, v, fuseRule.value ());
287- changed = true ;
288- break ;
289- }
290- }
291- return std::make_pair (++pos->iterator (), changed);
292- }
293- };
294-
295- // TODO: These rules should be more scalable
296- OpFuser::RuleTab OpFuser::dnnlRules = {
297- {{aten::matmul, aten::div}, ipex::matmul_div},
298- };
299-
30043// Including in-place optimizations that try to (conditionally)
30144// replace the origin op with in-place opted one for better performance.
30245// This in-place optimized ops may come from either oneDNN or aten
@@ -440,11 +183,7 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
440183
441184 // Fuse operators as shuffle
442185 graph_rewrite::FuseShuffle (graph);
443-
444- // Pattern based fusion was lack of alias analysis
445- // ??? It may either be too conservative or too aggressive ???
446- // getSubgraphRewriter().runOnGraph(graph);
447- OpFuser (graph->block (), graph).run ();
186+ graph_rewrite::FuseMatmulDiv (graph);
448187
449188 // replace aten max_pool2d with ipex max_pool2d
450189 graph_rewrite::replaceAtenMaxPool2dWithIpexMaxPool2d (graph);
0 commit comments