Skip to content

Commit 3223ca7

Browse files
authored
complete the pattern matching for matmul + div & modified the out parameter's semantic for this pattern (#862) (#903)
* complete the pattern matching for matmul + div & modified the out parameter's semantic for this pattern * clang-format * clang-format * remove OpFuser * Explicitly mark the out parameter and return value with alias anotation in fused kernel's signature; Add UT to clarify when the div is an outplace version and the out parameter may have side effect, its will not be replaced by our fused pattern. * remove the ill illustrated UT * complete UT coverage * add two UT to demonstrate we are free of side effect when we replace div with div_ in this pattern
1 parent b7b9359 commit 3223ca7

File tree

6 files changed

+174
-278
lines changed

6 files changed

+174
-278
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ at::Tensor dil_matmul_div(
8484

8585
if (out.defined()) {
8686
at::matmul_out(out, tensor1, tensor2);
87-
return out.div(div_input);
87+
return out.div_(div_input);
8888
}
8989
auto output = at::matmul(tensor1, tensor2);
90-
return output.div(div_input);
90+
return output.div_(div_input);
9191
}
9292

9393
/**

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,43 @@ void FuseAddLayerNorm(std::shared_ptr<Graph>& graph) {
218218
rewriter_aten.runOnGraph(graph);
219219
}
220220

221+
void FuseMatmulDiv(std::shared_ptr<Graph>& graph) {
222+
const std::string div_str = R"(div)";
223+
const std::string div_inplace_str = R"(div_)";
224+
std::vector<std::string> div_ops = {div_str, div_inplace_str};
225+
226+
auto aten_pattern = at::jit::CodeTemplate(R"(
227+
graph(%x, %y, %z):
228+
%mm_res = aten::matmul(%x, %y)
229+
%div_res = aten::${div_op}(%mm_res, %z)
230+
return (%div_res) )");
231+
232+
auto aten_pattern_with_out = at::jit::CodeTemplate(R"(
233+
graph(%x, %y, %z, %out):
234+
%mm_res = aten::matmul(%x, %y, %out)
235+
%div_res = aten::${div_op}(%mm_res, %z)
236+
return (%div_res) )");
237+
238+
std::string fused_matmul_div = R"(
239+
graph(%x, %y, %z):
240+
%r = ipex::matmul_div(%x, %y, %z)
241+
return (%r) )";
242+
std::string fused_matmul_div_with_out = R"(
243+
graph(%x, %y, %z, %out):
244+
%r = ipex::matmul_div(%x, %y, %out, %z)
245+
return (%r) )";
246+
for (auto const& it : div_ops) {
247+
at::jit::TemplateEnv env;
248+
env.s("div_op", it);
249+
250+
SubgraphRewriter rewriter;
251+
rewriter.RegisterRewritePattern(aten_pattern.format(env), fused_matmul_div);
252+
rewriter.RegisterRewritePattern(
253+
aten_pattern_with_out.format(env), fused_matmul_div_with_out);
254+
rewriter.runOnGraph(graph);
255+
}
256+
}
257+
221258
// MHA fusion covers aten::softmax, ipex::softmax and ipex::softmax_:
222259
// (1) MHA obviously shows better performance than aten div/matmul/add/softmax.
223260
// (2) MHA also shows better performance than aten add + matmul_div fusion

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph);
5656
void fuseLinearAddRelu(std::shared_ptr<Graph>& graph);
5757

5858
void FuseAddLayerNorm(std::shared_ptr<Graph>& graph);
59+
void FuseMatmulDiv(std::shared_ptr<Graph>& graph);
5960
void FuseConcatBnRelu(std::shared_ptr<Graph>& graph);
6061

6162
void insertPrePackedConvTransposeOp(std::shared_ptr<Graph>& graph);

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,8 @@ RegisterOperators op({
729729
},
730730
aliasAnalysisFromSchema()),
731731
Operator(
732-
"ipex::matmul_div(Tensor left, Tensor right, Tensor? out_opt, Tensor "
733-
"div_input) -> Tensor",
732+
"ipex::matmul_div(Tensor left, Tensor right, Tensor(a!) out_opt, Tensor "
733+
"div_input) -> Tensor(a!)",
734734
[](const Node* node) -> Operation {
735735
return [](Stack* stack) {
736736
auto result = dil_matmul_div(
@@ -746,8 +746,8 @@ RegisterOperators op({
746746
aliasAnalysisFromSchema()),
747747

748748
Operator(
749-
"ipex::matmul_div(Tensor left, Tensor right, Tensor? out_opt, Scalar "
750-
"div_input) -> Tensor",
749+
"ipex::matmul_div(Tensor left, Tensor right, Tensor(a!) out_opt, Scalar "
750+
"div_input) -> Tensor(a!)",
751751
[](const Node* node) -> Operation {
752752
return [](Stack* stack) {
753753
auto result = dil_matmul_div(

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 1 addition & 262 deletions
Original file line numberDiff line numberDiff line change
@@ -40,263 +40,6 @@ struct hash<std::pair<Symbol, Symbol>> {
4040

4141
namespace torch {
4242
namespace jit {
43-
//
44-
// The main goal of MKL-DNN fusion is to limit bandwidth wasting.
45-
// MKL-DNN provided post ops to fuse ops in its output stage
46-
// What we could do is listed inside RuleTab.
47-
//
48-
class OpFuser {
49-
Block* block_;
50-
std::unique_ptr<AliasDb> aliasDb_;
51-
std::shared_ptr<Graph> graph_;
52-
using Symbols = std::vector<Symbol>;
53-
using RuleTab = std::unordered_map<::std::pair<Symbol, Symbol>, Symbol>;
54-
using Rule = RuleTab::iterator;
55-
static RuleTab dnnlRules;
56-
57-
public:
58-
OpFuser(Block* block, std::shared_ptr<Graph> graph)
59-
: block_(block), graph_(std::move(graph)) {}
60-
61-
void run() {
62-
bool any_changed = true;
63-
while (any_changed) {
64-
any_changed = false;
65-
refreshAliasDb();
66-
for (auto it = block_->nodes().begin(); it != block_->nodes().end();) {
67-
bool changed;
68-
std::tie(it, changed) = processNode(*it);
69-
any_changed |= changed;
70-
}
71-
}
72-
73-
refreshAliasDb();
74-
75-
for (Node* node : block_->nodes()) {
76-
for (Block* sub : node->blocks()) {
77-
OpFuser(sub, graph_).run();
78-
}
79-
}
80-
}
81-
82-
c10::optional<Rule> isFusable(Node* curr, Node* prev) const {
83-
// Is it happening in our case ???
84-
if (curr->owningBlock() != block_)
85-
return c10::nullopt;
86-
87-
auto choice = dnnlRules.find({prev->kind(), curr->kind()});
88-
if (choice != dnnlRules.end())
89-
return choice;
90-
91-
return c10::nullopt;
92-
}
93-
94-
void refreshAliasDb() {
95-
aliasDb_ = std::make_unique<AliasDb>(graph_);
96-
}
97-
98-
Node* fuseOpsWithNewKind(Node* curr, Value* v, Graph* g, NodeKind kind) {
99-
auto newNode = g->create(kind);
100-
auto prev = v->node();
101-
newNode->insertBefore(prev);
102-
newNode->setScope(prev->scope());
103-
newNode->copyAttributes(*prev);
104-
105-
for (auto input : prev->inputs()) {
106-
newNode->addInput(input);
107-
}
108-
109-
for (auto input : curr->inputs()) {
110-
if (input != v) {
111-
newNode->addInput(input);
112-
}
113-
}
114-
115-
// Copy curr or prev?
116-
newNode->output()->copyMetadata(prev->output());
117-
newNode->output()->setType(prev->output()->type());
118-
119-
v->replaceAllUsesWith(newNode->output());
120-
curr->replaceAllUsesWith(newNode);
121-
122-
prev->destroy();
123-
curr->destroy();
124-
125-
return newNode;
126-
}
127-
128-
Node* fuseNodes(Node* curr, Value* path, Rule rule) {
129-
return fuseOpsWithNewKind(curr, path, curr->owningGraph(), rule->second);
130-
}
131-
132-
bool aliasIsSafeForSquashingValue(Node* node, Value* v) {
133-
bool safe = false;
134-
auto prev = v->node();
135-
if (aliasDb_->moveAfterTopologicallyValid(node, prev)) {
136-
if (v->uses().size() == 1 ||
137-
aliasDb_->mayAlias /* mustAlias */ (v, node->output())) {
138-
safe = true;
139-
}
140-
}
141-
return safe;
142-
}
143-
144-
//
145-
// Check whether we could change specific input to be inplace with output
146-
// Any use topologically after node will fail it.
147-
// XXX: haven't considered loop
148-
//
149-
bool aliasIsSafeForInplaceValue(Node* node, Value* v) {
150-
for (auto use : v->uses())
151-
if (use.user->isAfter(node))
152-
return false;
153-
154-
return true;
155-
}
156-
157-
const FunctionSchema& matchSchemaForFusion(
158-
c10::Symbol symbol,
159-
Node* prev,
160-
Node* node) {
161-
auto ops = getAllOperatorsFor(symbol);
162-
163-
for (auto& op : ops) {
164-
auto& schema = op->schema();
165-
if (schema.arguments().size() ==
166-
prev->inputs().size() + node->inputs().size() - 1 &&
167-
schema.returns().size() == node->outputs().size())
168-
return schema;
169-
}
170-
171-
// throw
172-
auto er = ErrorReport(node->sourceRange());
173-
er << "Schema not found for fusion process. \n";
174-
er << "Prev: " << *prev << "\n";
175-
er << "Node: " << *node << "\n";
176-
177-
if (ops.size() > 0) {
178-
er << "\ncandidates were:\n";
179-
for (auto& op : ops)
180-
er << " " << op->schema() << "\n";
181-
} else {
182-
er << "\nno candidates found\n";
183-
}
184-
er << "within the graph:\n";
185-
er << *node->owningGraph() << "\n";
186-
throw er;
187-
}
188-
189-
bool aliasIsSafeForFusion(Node* node, Value* v, c10::optional<Rule> r) {
190-
bool safe = false;
191-
// Returns false if the two nodes to be fused do not have the same owning
192-
// block
193-
if (node->owningBlock() != v->node()->owningBlock()) {
194-
return safe;
195-
}
196-
// TODO: it might be flawed because we don't have 'alias must' information
197-
//
198-
// Simple fusion, unary ops:
199-
// Example: conv2d -> relu to conv2d_relu
200-
//
201-
// To maintain equivalence before and after fusion, we have some rules:
202-
// 1. Op could be moved safely right after the op it fuse to.
203-
// 2. If one of node's input and output are alias must (relu_?), we could
204-
// replace all uses of input to use output, which remove the use that might
205-
// clogging the fuse path which is to be squashed.
206-
// 3. If there is no alias between input and output, we can only fuse the
207-
// case when there is only use.
208-
//
209-
// Y-merge (conv-sum-relu?)
210-
// 4. We aquire alias info from resulted op schema, check whether the fusion
211-
// is not breaking any computational semantics.
212-
//
213-
// A Y-merge fusion, like:
214-
// conv2d_inputs | or | conv2d_inputs
215-
// / | | \
216-
// x conv2d | | conv2d x
217-
// \ / | | \ /
218-
// add | | add
219-
// | | | |
220-
// y | | y
221-
//
222-
// both to:
223-
//
224-
// conv2d_inputs x(a!)
225-
// \ /
226-
// conv2d_sum
227-
// |
228-
// y(a!)
229-
//
230-
// Which y is alias to x, we check whether later is equivalent to formal.
231-
// The params convention when we do Y-merge: arguments from both ops comes
232-
// to new op in topological order. So in the exmaple conv2d's inputs comes
233-
// first then sum's inputs (without the input which is squashed).
234-
//
235-
safe = aliasIsSafeForSquashingValue(node, v);
236-
237-
//
238-
// Y-merge like case
239-
//
240-
if (safe && node->inputs().size() > 1) {
241-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(r);
242-
auto rule = *r.value();
243-
auto& schema = matchSchemaForFusion(rule.second, v->node(), node);
244-
auto o_schema = node->schema();
245-
246-
auto pos = v->node()->inputs().size();
247-
248-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
249-
schema.arguments().size() == pos + node->inputs().size() - 1);
250-
251-
for (int i = 0; i < node->inputs().size(); ++i) {
252-
if (node->input(i) != v) { /* avoid squashing path */
253-
auto aliasInfo = schema.arguments()[pos++].alias_info();
254-
if (!aliasInfo)
255-
continue;
256-
257-
// Introdued new alias write to
258-
if (aliasInfo->isWrite()) {
259-
auto old_info = o_schema.arguments()[i].alias_info();
260-
if (!old_info || !old_info->isWrite()) {
261-
// Introduced new written to alias
262-
safe = safe && aliasIsSafeForInplaceValue(node, node->input(i));
263-
}
264-
}
265-
}
266-
}
267-
268-
// XXX: Do we have to handle output alias change case?
269-
}
270-
return safe;
271-
}
272-
273-
std::pair<graph_node_list::iterator, bool> processNode(Node* node) {
274-
Node* pos = node;
275-
bool changed = false;
276-
277-
//
278-
// Check whether we could fuse to one certain value path
279-
//
280-
for (auto* v : node->inputs()) {
281-
auto prev = v->node();
282-
auto fuseRule = isFusable(node, prev);
283-
284-
// We can fuse only one path
285-
if (fuseRule && aliasIsSafeForFusion(node, v, fuseRule)) {
286-
pos = fuseNodes(node, v, fuseRule.value());
287-
changed = true;
288-
break;
289-
}
290-
}
291-
return std::make_pair(++pos->iterator(), changed);
292-
}
293-
};
294-
295-
// TODO: These rules should be more scalable
296-
OpFuser::RuleTab OpFuser::dnnlRules = {
297-
{{aten::matmul, aten::div}, ipex::matmul_div},
298-
};
299-
30043
// Including in-place optimizations that try to (conditionally)
30144
// replace the origin op with in-place opted one for better performance.
30245
// This in-place optimized ops may come from either oneDNN or aten
@@ -440,11 +183,7 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
440183

441184
// Fuse operators as shuffle
442185
graph_rewrite::FuseShuffle(graph);
443-
444-
// Pattern based fusion was lack of alias analysis
445-
// ??? It may either be too conservative or too aggressive ???
446-
// getSubgraphRewriter().runOnGraph(graph);
447-
OpFuser(graph->block(), graph).run();
186+
graph_rewrite::FuseMatmulDiv(graph);
448187

449188
// replace aten max_pool2d with ipex max_pool2d
450189
graph_rewrite::replaceAtenMaxPool2dWithIpexMaxPool2d(graph);

0 commit comments

Comments
 (0)