Skip to content

Commit 2af4ca8

Browse files
jiayisunxEikanWang
andauthored
align FuseTensorExprs with PyTorch (#883) (#893)
* align FuseTensorExprs with PyTorch * add UT Co-authored-by: Wang Weihan <eikan.wang@intel.com>
1 parent 0d4a314 commit 2af4ca8

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,30 @@ bool isQuantized(const std::shared_ptr<Graph>& graph) {
227227
return checkQuantization(graph->block());
228228
}
229229

230+
FusionBehavior getCurrentBehavior(size_t remaining_depth) {
231+
size_t curr_depth = 0;
232+
FusionStrategy fusion_strategy_ = getFusionStrategy();
233+
for (int i = static_cast<int>(fusion_strategy_.size()) - 1; i >= 0; i--) {
234+
curr_depth += fusion_strategy_[i].second;
235+
if (remaining_depth <= curr_depth) {
236+
return fusion_strategy_[i].first;
237+
}
238+
}
239+
// should never get here
240+
TORCH_WARN("Stratgy changed mid-invocation, NYI");
241+
return FusionBehavior::STATIC;
242+
}
243+
244+
size_t getInstantiatedBailoutDepth() {
245+
// Initialize bailout_depth from command-line flag.
246+
size_t depth = 0;
247+
FusionStrategy fusion_strategy_ = getFusionStrategy();
248+
for (const auto& pair : fusion_strategy_) {
249+
depth += pair.second;
250+
}
251+
return depth;
252+
}
253+
230254
void FusionPass(std::shared_ptr<Graph>& graph) {
231255
GRAPH_DUMP(
232256
"Before RemoveProfileNodesAndSpecializeTypes. Beginning of "
@@ -260,7 +284,15 @@ void FusionPass(std::shared_ptr<Graph>& graph) {
260284
BatchMM(graph);
261285

262286
if (tensorExprFuserEnabled()) {
263-
FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1);
287+
auto min_size = getFusionGroupInlining() ? 2 : 1;
288+
// Here we always get the first valid behavior per the global fusion
289+
// strategies configured by PyTorch (`getInstantiatedBailoutDepth` always
290+
// returns the maximum configured depth). This is because IPEX TE fusion is
291+
// only called the first time of the compilation while the later
292+
// re-compilations are triggered from inside PyTorch.
293+
bool dyn_shapes = getCurrentBehavior(getInstantiatedBailoutDepth()) ==
294+
FusionBehavior::DYNAMIC;
295+
FuseTensorExprs(graph, min_size, /* composed op*/ false, dyn_shapes);
264296
}
265297

266298
// Apply IPEX inplace optimization/replacement

tests/cpu/test_jit.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,13 @@ def __init__(self, equation):
917917
def forward(self, input1, input2, bias):
918918
return bias.add_(torch.einsum(self.equation, input1, input2))
919919

920+
class AddMulDiv(nn.Module):
921+
def __init__(self):
922+
super(AddMulDiv, self).__init__()
923+
924+
def forward(self, input):
925+
return torch.div(torch.mul(input, torch.add(input, 3)), 6)
926+
920927
class Tester(TestCase):
921928
@contextlib.contextmanager
922929
def _texpr_enable(self, strategy):
@@ -3137,6 +3144,18 @@ def forward(self, x):
31373144
kind_not_in_graph="aten::mul",
31383145
prec=0.1)
31393146

3147+
def test_TEfusion_with_dynamic_input(self):
3148+
model = AddMulDiv().eval()
3149+
with torch.no_grad():
3150+
traced_model = torch.jit.trace(model, torch.randn(11, 3, 20, 20)).eval()
3151+
traced_model = torch.jit.freeze(traced_model)
3152+
3153+
for i in range(5):
3154+
input = torch.randn(i, 3, 20, 20)
3155+
tresult = traced_model(input)
3156+
result = model(input)
3157+
self.assertEqual(tresult, result)
3158+
31403159
def test_hardsigmoid_mul(self):
31413160
class HardsigmoidMul(nn.Module):
31423161
def __init__(self) -> None:

0 commit comments

Comments
 (0)