Skip to content

Commit e7b925a

Browse files
fix TE issue which report UNSUPPORTED DTYPE error if calling to bfloat16 (#910)
1 parent fb66cfa commit e7b925a

File tree

5 files changed

+163
-0
lines changed

5 files changed

+163
-0
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ void insertPrePackedConvTransposeOp(std::shared_ptr<Graph>& graph);
7171
void fuseConvTransposeWithEltwise(std::shared_ptr<Graph>& graph);
7272

7373
void FusedEinsumPost(std::shared_ptr<Graph>& graph);
74+
75+
// This code will be removed after the official PyTorch NNC fully support
76+
// BFloat16.
77+
void replaceAtenToWithIPEXTo(std::shared_ptr<Graph>& graph);
78+
void replaceIPEXToWithAtenTo(std::shared_ptr<Graph>& graph);
79+
7480
} // namespace graph_rewrite
7581
} // namespace jit
7682
} // namespace torch
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <ATen/code_template.h>
2+
#include "graph_rewrite.h"
3+
4+
namespace torch {
5+
namespace jit {
6+
namespace graph_rewrite {
7+
8+
using namespace at::jit;
9+
10+
// This code will be removed after the official PyTorch NNC fully support
11+
// BFloat16.
12+
13+
void replaceAtenToWithIPEXTo(Block* b) {
14+
for (Node* n : b->nodes()) {
15+
for (Block* block : n->blocks()) {
16+
replaceAtenToWithIPEXTo(block);
17+
}
18+
if (n->kind() == aten::to) {
19+
// skip aten::to.other
20+
if (n->inputs().at(1)->type()->kind() == TypeKind::TensorType) {
21+
continue;
22+
}
23+
if (n->inputs().size() == 5 || n->inputs().size() == 4) {
24+
auto const& input_dtype =
25+
n->inputs().at(0)->type()->cast<TensorType>()->scalarType();
26+
auto const& output_dtype =
27+
n->outputs().at(0)->type()->cast<TensorType>()->scalarType();
28+
if (!input_dtype || !output_dtype) {
29+
continue;
30+
}
31+
if (!(*input_dtype == c10::ScalarType::Float &&
32+
*output_dtype == c10::ScalarType::BFloat16)) {
33+
continue;
34+
}
35+
// device check?
36+
WithInsertPoint guard(n);
37+
auto graph = n->owningGraph();
38+
Node* ipex_to_node =
39+
graph->create(Symbol::fromQualString("ipex::to_dtype"));
40+
for (auto i = 0; i < n->inputs().size(); ++i) {
41+
Value* v = n->inputs().at(i);
42+
ipex_to_node->addInput(v);
43+
}
44+
graph->insertNode(ipex_to_node);
45+
n->output()->replaceAllUsesWith(ipex_to_node->output());
46+
} else {
47+
continue;
48+
}
49+
}
50+
}
51+
EliminateDeadCode(b);
52+
}
53+
54+
void replaceIPEXToWithAtenTo(Block* b) {
55+
for (Node* n : b->nodes()) {
56+
for (Block* block : n->blocks()) {
57+
replaceIPEXToWithAtenTo(block);
58+
}
59+
if (n->kind() == Symbol::fromQualString("ipex::to_dtype")) {
60+
WithInsertPoint guard(n);
61+
auto graph = n->owningGraph();
62+
Node* aten_to_node = graph->create(aten::to);
63+
for (auto i = 0; i < n->inputs().size(); ++i) {
64+
Value* v = n->inputs().at(i);
65+
aten_to_node->addInput(v);
66+
}
67+
graph->insertNode(aten_to_node);
68+
n->output()->replaceAllUsesWith(aten_to_node->output());
69+
}
70+
}
71+
EliminateDeadCode(b);
72+
}
73+
74+
void replaceAtenToWithIPEXTo(std::shared_ptr<Graph>& graph) {
75+
replaceAtenToWithIPEXTo(graph->block());
76+
EliminateDeadCode(graph);
77+
}
78+
79+
void replaceIPEXToWithAtenTo(std::shared_ptr<Graph>& graph) {
80+
replaceIPEXToWithAtenTo(graph->block());
81+
EliminateDeadCode(graph);
82+
}
83+
84+
} // namespace graph_rewrite
85+
} // namespace jit
86+
} // namespace torch

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,46 @@ RegisterOperators op({
11541154
};
11551155
},
11561156
aliasAnalysisFromSchema()),
1157+
Operator(
1158+
"ipex::to_dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)",
1159+
[](const Node* node) -> Operation {
1160+
return [](Stack* stack) {
1161+
auto result = at::native::to(
1162+
(std::move(peek(stack, 0, 5))).toTensor(),
1163+
(std::move(peek(stack, 1, 5))).toScalarType(),
1164+
(std::move(peek(stack, 2, 5))).toBool(),
1165+
(std::move(peek(stack, 3, 5))).toBool(),
1166+
(std::move(peek(stack, 4, 5))).toOptional<at::MemoryFormat>());
1167+
drop(stack, 5);
1168+
pack(stack, std::move(result));
1169+
return 0;
1170+
};
1171+
},
1172+
aliasAnalysisFromSchema()),
1173+
Operator(
1174+
"ipex::to_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
1175+
[](const Node* node) -> Operation {
1176+
return [](Stack* stack) {
1177+
const auto& input = (std::move(peek(stack, 0, 4))).toTensor();
1178+
const auto dtype =
1179+
(std::move(peek(stack, 1, 4))).toOptional<at::ScalarType>();
1180+
const auto copy = (std::move(peek(stack, 3, 4))).toBool();
1181+
at::Tensor result;
1182+
if (!dtype && !copy) {
1183+
result = input;
1184+
} else {
1185+
TORCH_CHECK(
1186+
dtype,
1187+
"dtype cannot be None when copy is True for ipex::to.prim_dtype");
1188+
result = at::native::to(
1189+
input, *dtype, (std::move(peek(stack, 2, 4))).toBool(), copy);
1190+
}
1191+
drop(stack, 4);
1192+
pack(stack, std::move(result));
1193+
return 0;
1194+
};
1195+
},
1196+
aliasAnalysisFromSchema()),
11571197
});
11581198
} // namespace jit
11591199
} // namespace torch

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ void FusionPass(std::shared_ptr<Graph>& graph) {
284284
BatchMM(graph);
285285

286286
if (tensorExprFuserEnabled()) {
287+
graph_rewrite::replaceAtenToWithIPEXTo(graph);
287288
auto min_size = getFusionGroupInlining() ? 2 : 1;
288289
// Here we always get the first valid behavior per the global fusion
289290
// strategies configured by PyTorch (`getInstantiatedBailoutDepth` always
@@ -293,6 +294,7 @@ void FusionPass(std::shared_ptr<Graph>& graph) {
293294
bool dyn_shapes = getCurrentBehavior(getInstantiatedBailoutDepth()) ==
294295
FusionBehavior::DYNAMIC;
295296
FuseTensorExprs(graph, min_size, /* composed op*/ false, dyn_shapes);
297+
graph_rewrite::replaceIPEXToWithAtenTo(graph);
296298
}
297299

298300
// Apply IPEX inplace optimization/replacement

tests/cpu/test_jit.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868
import torch.nn as nn
6969
import torch.nn.functional as F
70+
from torch.testing import FileCheck
7071

7172
from common_utils import TestCase
7273

@@ -3172,6 +3173,34 @@ def forward(self, x):
31723173
kind_in_graph="ipex::hardsigmoid",
31733174
kind_not_in_graph="aten::hardsigmoid")
31743175

3176+
# This test case will be removed after offical PyTorch NNC support bfloat16.
3177+
def test_TEfusion_with_to_dtype(self):
3178+
class TestTo(torch.nn.Module):
3179+
def __init__(self, dtype):
3180+
super(TestTo, self).__init__()
3181+
self.dtype = dtype
3182+
3183+
def forward(self, x):
3184+
return (x + 1).to(self.dtype)
3185+
3186+
X = torch.randn((5, 5))
3187+
with torch.no_grad():
3188+
# to(torch.bfloat16)
3189+
m = TestTo(torch.bfloat16).eval()
3190+
m = torch.jit.trace(m, X)
3191+
torch.jit.freeze(m)
3192+
out = m(X)
3193+
graph = m.graph_for(X)
3194+
FileCheck().check_not("prim::TensorExprGroup").run(graph)
3195+
# to(torch.long)
3196+
m = TestTo(torch.long).eval()
3197+
m = torch.jit.trace(m, X)
3198+
torch.jit.freeze(m)
3199+
out = m(X)
3200+
graph = m.graph_for(X)
3201+
FileCheck().check("prim::TensorExprGroup").run(graph)
3202+
3203+
31753204
if __name__ == '__main__':
31763205
torch.manual_seed(2020)
31773206
test = unittest.main()

0 commit comments

Comments
 (0)