1+ #include " jit/codegen/onednn/guard_shape.h"
2+ #include " jit/codegen/onednn/fusion_group_name.h"
3+
4+ #include < torch/csrc/jit/jit_log.h>
5+ #include < torch/csrc/jit/passes/tensorexpr_fuser.h>
6+ #include < torch/csrc/jit/runtime/graph_executor.h>
7+
8+ namespace torch {
9+ namespace jit {
10+ namespace fuser {
11+ namespace onednn {
12+
13+ using tensor_type_converter_t =
14+ c10::function_ref<TensorTypePtr(const TensorTypePtr &t)>;
15+
16+ void insertTypeGuardForFusionGroup (Node *guarded_node,
17+ tensor_type_converter_t type_converter,
18+ Symbol kind) {
19+ GRAPH_DEBUG (" Inserting a typecheck guard for a node" , *guarded_node);
20+ auto subgraph = guarded_node->g (attr::Subgraph);
21+
22+ // Fixup types of the subgraph inputs
23+ std::vector<Value *> inputs_to_check;
24+ std::vector<TypePtr> guard_types;
25+ for (Value *input : guarded_node->inputs ()) {
26+ // We only check inputs of the guarded nodes and expect user to infer
27+ // intermediates and outputs shapes
28+ if (!input->type ()->cast <TensorType>()) {
29+ continue ;
30+ }
31+
32+ // fusion outputs are already guarded
33+ if (input->node ()->kind () == prim::Constant ||
34+ input->node ()->kind () ==
35+ Symbol::fromQualString (LlgaFusionGroupName ())) {
36+ continue ;
37+ }
38+ inputs_to_check.push_back (input);
39+ guard_types.push_back (type_converter (input->type ()->expect <TensorType>()));
40+ }
41+ if (!inputs_to_check.size ()) {
42+ return ;
43+ }
44+
45+ // Add ipex::LlgaFusionGuard node
46+ //
47+ // ipex::LlgaFusionGuard nodes look like the following:
48+ // %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool =
49+ // ipex::LlgaFusionGuard(%inp1 : Tensor, %inp2 : Tensor)
50+ //
51+ // They have N inputs whose types we are going to check and N+1 outputs. The
52+ // first N outputs specify expected types and N+1-th output holds the result
53+ // of the check (bool).
54+ Node *typecheck_node =
55+ guarded_node->owningGraph ()
56+ ->create (kind, inputs_to_check, inputs_to_check.size () + 1 )
57+ ->insertBefore (guarded_node);
58+ typecheck_node->tys_ (attr::types, guard_types);
59+ Value *typecheck_result = typecheck_node->output (inputs_to_check.size ());
60+
61+ std::unordered_map<Value *, Value *> typechecked_inputs;
62+ for (size_t i = 0 ; i < typecheck_node->inputs ().size (); ++i) {
63+ typechecked_inputs[typecheck_node->input (i)] = typecheck_node->output (i);
64+ }
65+
66+ // Fixup types of the typecheck node outputs, which are used by the op in
67+ // execution
68+ typecheck_node->output (inputs_to_check.size ())->setType (BoolType::get ());
69+ for (size_t i = 0 ; i < typecheck_node->inputs ().size (); ++i) {
70+ typecheck_node->output (i)->setType (typecheck_node->input (i)->type ());
71+ }
72+
73+ // Insert if
74+ auto versioning_if =
75+ guarded_node->owningGraph ()
76+ ->create (prim::If, {typecheck_result}, guarded_node->outputs ().size ())
77+ ->insertAfter (typecheck_node);
78+ for (size_t idx = 0 ; idx < guarded_node->outputs ().size (); ++idx) {
79+ versioning_if->output (idx)->setType (guarded_node->output (idx)->type ());
80+ guarded_node->output (idx)->replaceAllUsesWith (versioning_if->output (idx));
81+ }
82+ auto true_block = versioning_if->addBlock ();
83+ auto false_block = versioning_if->addBlock ();
84+
85+ // Fill in the false block. It should contain the unoptimized
86+ // copy of the fused subgraph.
87+ WithInsertPoint guard (false_block->return_node ());
88+ const auto subgraph_outputs = insertGraph (*guarded_node->owningGraph (),
89+ *subgraph, guarded_node->inputs ());
90+ for (Value *output : subgraph_outputs) {
91+ false_block->registerOutput (output);
92+ }
93+
94+ // types get copied to the fallback graph, so remove specializations before
95+ // replacing
96+ removeTensorTypeSpecializations (false_block);
97+ replaceBlockWithFallbackGraph (false_block, guarded_node->inputs ());
98+
99+ // Fill in the true block. It has all inputs type-checked and its
100+ // body should be the fusion group node.
101+ guarded_node->moveBefore (true_block->return_node ());
102+ for (size_t idx = 0 ; idx < guarded_node->inputs ().size (); ++idx) {
103+ if (typechecked_inputs.count (guarded_node->input (idx))) {
104+ guarded_node->replaceInput (
105+ idx, typechecked_inputs.at (guarded_node->input (idx)));
106+ }
107+ }
108+ for (Value *output : guarded_node->outputs ()) {
109+ true_block->registerOutput (output);
110+ }
111+ }
112+
113+ // ! [ Note -- prepareFusionGroupAndGuardOutputs implementation ]
114+ // ! shamelessly copying code from NNC (tensorexpr_fuser) with very little
115+ // ! modification, original code at:
116+ // ! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs`
117+ // !
118+ // ! We have the assumption that LLGA does not have operators
119+ // ! depending on the content of the tensor.
120+ void prepareFusionGroupAndGuardOutputs (Block *block) {
121+ std::vector<Node *> fusion_groups;
122+ for (Node *n : block->nodes ()) {
123+ for (Block *b : n->blocks ()) {
124+ prepareFusionGroupAndGuardOutputs (b);
125+ }
126+ if (n->kind () == Symbol::fromQualString (LlgaFusionGroupName ())) {
127+ fusion_groups.push_back (n);
128+ }
129+ }
130+ for (Node *fusion_group : fusion_groups) {
131+ // TODO: add further optimization pass to removeOutputsUsedOnlyInSize,
132+ // refer to
133+ // `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize`
134+ // removeOutputsUsedOnlyInSize(fusion_group);
135+ insertTypeGuardForFusionGroup (
136+ fusion_group, [](const TensorTypePtr &t) { return t; },
137+ Symbol::fromQualString (fuser::onednn::LlgaGuardName ()));
138+ }
139+ }
140+
141+ } // namespace onednn
142+ } // namespace fuser
143+ } // namespace jit
144+ } // namespace torch
0 commit comments