Skip to content

Commit 951444b

Browse files
authored
Optimize PyTorch-LLGA integration overhead (#132)
* [LLGA] register additional profiling node * [LLGA] handle profiling node for fusion pattern * [LLGA] add type guard * [LLGA] turn on the profiling mode for UTs * [LLGA] add custom guard node * [LLGA] add rw_mutex and refactor kernel class private members * [LLGA] only guard shape when profiling mode is on * [LLGA] cache inputSpecs_ * [LLGA] TypeCheck: remove check on grad due to benchmark throughput issue * [LLGA] add test for ThroughputBenchmark with llga * [LLGA] use rwlock in IPEX * [LLGA] only support profiling mode * fix clang format * [LLGA] add note on TypeCheck rule * [LLGA] TypeCheck: push false if input is not a tensor * [LLGA] fix rwlock * [LLGA] add no_grad to throughput benchmark UT * [LLGA] use matchTensor after fixing throughput benchmark GradMode * [LLGA] use symbol instead of schema to register op * [LLGA] add no_grad in UT to ensure pass GradMode check * [LLGA] remove hard-coded name string
1 parent 0447135 commit 951444b

File tree

12 files changed

+596
-182
lines changed

12 files changed

+596
-182
lines changed

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ def get_eltwise_fn(name):
3131
def llga_test_env(func):
3232
@wraps(func)
3333
def wrapTheFunction(*args):
34-
torch._C._jit_set_profiling_mode(False)
35-
torch._C._jit_set_profiling_executor(False)
34+
# make sure that the profiling mode is turned on
35+
torch._C._jit_set_profiling_mode(True)
36+
torch._C._jit_set_profiling_executor(True)
37+
38+
ipex.core._jit_set_llga_enabled(True)
3639
ipex.core.disable_jit_opt()
3740
func(*args)
3841
ipex.core.enable_jit_opt()
39-
torch._C._jit_set_profiling_mode(True)
40-
torch._C._jit_set_profiling_executor(True)
42+
ipex.core._jit_set_llga_enabled(False)
4143
return wrapTheFunction
4244

4345
class TestOp(JitLlgaTestCase):
@@ -79,7 +81,7 @@ def test_conv2d_int8_in_f32_out(self):
7981
]
8082
#TODO: enable torch.per_tensor_symmetric case.
8183
for qscheme in [torch.per_tensor_affine]:
82-
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d", qscheme=qscheme)
84+
graph = self.checkQuantizeTrace(m, [x], x_var=[torch.rand(5, in_channels * g, spatial, spatial, requires_grad=False)], atol=2e-1, config_name="conv2d", qscheme=qscheme)
8385
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
8486
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel'])
8587
self.checkPatterns(graph, patterns)
@@ -308,7 +310,7 @@ def forward(self, x):
308310
["aten::dequantize"]
309311
]
310312
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
311-
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear_eltwise", qscheme=qscheme)
313+
graph = self.checkQuantizeTrace(m, [x], x_var=[torch.rand(2, 28, requires_grad=False)], atol=1e-1, config_name="linear_eltwise", qscheme=qscheme)
312314
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
313315
self.assertFused(graph, ['aten::' + eltwise])
314316
self.checkPatterns(graph, patterns)
@@ -423,6 +425,55 @@ def forward(self, x):
423425
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
424426
self.checkPatterns(graph, patterns)
425427

428+
class TestShapeFallback(JitLlgaTestCase):
429+
@unittest.skipIf(True, 'Size peephole optimization not enabled yet')
430+
@llga_test_env
431+
def test_view_permute(self):
432+
class M(nn.Module):
433+
def __init__(self):
434+
super(M, self).__init__()
435+
436+
def forward(self, x):
437+
new_x_shape = x.size()[:-1] + (3, 5)
438+
x = x.view(*new_x_shape)
439+
return x.permute(0, 2, 1, 3)
440+
441+
x = torch.randn(5, 10, 15)
442+
m = M()
443+
444+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
445+
graph = self.checkQuantizeTrace(m, [x], config_name="view_permute", qscheme=qscheme)
446+
self.assertGraphContainsExactly(graph, "aten::size", 0)
447+
self.assertGraphContainsExactly(graph, "prim::ListConstruct", 0)
448+
449+
# change the size of the input
450+
x2 = torch.randn(6, 4, 15)
451+
# Bailout get triggered here
452+
y2 = m(x2)
453+
454+
@llga_test_env
455+
def test_conv_reshape(self):
456+
class M(nn.Module):
457+
def __init__(self):
458+
super(M, self).__init__()
459+
self.conv1 = nn.Conv2d(4, 4, 3, padding=1, bias=True)
460+
self.conv2 = nn.Conv2d(4, 32, 3, padding=1, bias=True)
461+
462+
def forward(self, x):
463+
x = self.conv1(x)
464+
x = self.conv2(x).reshape(x.size(0), 4, -1)
465+
return x
466+
467+
x = torch.randn(15, 4, 28, 28)
468+
# change the size of the input, check the fallback
469+
x_var = torch.randn(7, 4, 16, 16)
470+
m = M()
471+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
472+
graph = self.checkQuantizeTrace(m, [x], x_var = [x_var], atol=2e-1, config_name="conv_reshape", qscheme=qscheme)
473+
474+
# TODO: enable this check when size peephole optimization is enabled
475+
# self.assertGraphContainsExactly(graph, "aten::size", 0)
476+
426477
class TestModel(JitLlgaTestCase):
427478
@skipIfNoTorchVision
428479
@llga_test_env
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from functools import wraps
2+
3+
import torch
4+
from torch.utils import ThroughputBenchmark
5+
from torch.testing import assert_allclose
6+
from torch.testing._internal.common_utils import run_tests, TestCase
7+
8+
import intel_pytorch_extension as ipex
9+
from test_jit_llga_utils import JitLlgaTestCase, run_tests, LLGA_FUSION_GROUP
10+
from test_jit_llga_quantization_fuser import llga_test_env
11+
12+
class LinearEltwise(torch.nn.Module):
13+
def __init__(self, D_in, H, D_out):
14+
super(LinearEltwise, self).__init__()
15+
self.linear1 = torch.nn.Linear(D_in, H)
16+
self.eltwise = torch.nn.ReLU()
17+
self.linear2 = torch.nn.Linear(H, D_out)
18+
19+
def forward(self, x):
20+
x = self.linear1(x)
21+
x = self.eltwise(x)
22+
x = self.linear2(x)
23+
return x
24+
25+
def freeze(model):
26+
return torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True))
27+
28+
class TestThroughputBenchmark(JitLlgaTestCase):
29+
@llga_test_env
30+
def test_linear_eltwise(self):
31+
with torch.no_grad():
32+
D_in = 10
33+
H = 5
34+
D_out = 15
35+
B = 8
36+
37+
m = LinearEltwise(D_in, H, D_out)
38+
x = torch.randn(B, D_in)
39+
40+
graph, m_llga, m_cpu = self.prepareModel(m, [x])
41+
42+
ipex.core._jit_set_llga_enabled(False)
43+
module_result = m_cpu(x)
44+
ipex.core._jit_set_llga_enabled(True)
45+
46+
bench = ThroughputBenchmark(m_llga)
47+
bench.add_input(x)
48+
bench_result = bench.run_once(x)
49+
50+
assert_allclose(bench_result, module_result, atol=1e-1, rtol=1e-2)
51+
52+
stats = bench.benchmark(
53+
num_calling_threads=4,
54+
num_warmup_iters=100,
55+
num_iters=1000
56+
)
57+
58+
print(stats)
59+
60+
if __name__ == '__main__':
61+
run_tests()

tests/cpu/test_jit_llga_utils.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,30 @@ def assertFused(self, graph, fused_patterns):
7373
for pat in fused_patterns:
7474
self.assertGraphContainsExactly(graph, pat, 0)
7575

76-
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, folding=False, remove_dropout=False, config_name="", qscheme=torch.per_tensor_affine):
76+
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, folding=False, remove_dropout=False, config_name="", x_var=None, qscheme=torch.per_tensor_affine):
77+
graph, model, fp32_model_with_quant_dequant = self.prepareModel(model, x, folding, remove_dropout, config_name, qscheme)
78+
with torch.no_grad():
79+
# calculate after getting the graph
80+
y_llga = model(*x)
81+
82+
# disable llga for fp32 path
83+
ipex.core._jit_set_llga_enabled(False)
84+
y = fp32_model_with_quant_dequant(*x)
85+
# test Fallback when input shape changes:
86+
if x_var:
87+
y_var = fp32_model_with_quant_dequant(*x_var)
88+
ipex.core._jit_set_llga_enabled(True)
89+
90+
self.assertEqual(y, y_llga, atol=atol, rtol=rtol)
91+
92+
# test Fallback when input shape changes:
93+
if x_var:
94+
y_var_llga = model(*x_var)
95+
self.assertEqual(y_var, y_var_llga, atol=atol, rtol=rtol)
96+
97+
return graph
98+
99+
def prepareModel(self, model, x, folding=False, remove_dropout=False, config_name="", qscheme=torch.per_tensor_affine):
77100
model.eval()
78101
with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
79102
# fold conv bn
@@ -105,14 +128,13 @@ def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, folding=False, remo
105128
# freeze the module
106129
model = freeze(model)
107130

108-
# apply llga optimization pass
109-
ipex.core._jit_llga_fuser(model.graph)
110-
111-
y = fp32_model_with_quant_dequant(*x)
112-
y_llga = model(*x)
131+
# warm up run
132+
y0 = model(*x)
113133

114-
self.assertEqual(y, y_llga, atol=atol, rtol=rtol)
115-
return model.graph
134+
# get the graph at the second run after freezing
135+
graph = model.graph_for(*x)
136+
137+
return graph, model, fp32_model_with_quant_dequant
116138

117139
def checkPatterns(self, graph, patterns):
118140
fusion_groups = findFusionGroups(graph)

torch_ipex/csrc/jit/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ LIST(APPEND DPCPP_JIT_SRCS
1313
${DPCPP_ROOT}/jit/fusion_pass.cpp
1414
${DPCPP_ROOT}/jit/register_dnnl_jit_ops.cpp
1515
${DPCPP_ROOT}/jit/graph_rewrite.cpp
16+
${DPCPP_ROOT}/jit/codegen/onednn/register_interface.cpp
17+
${DPCPP_ROOT}/jit/codegen/onednn/guard_shape.cpp
1618
)
1719

1820
# Pass to parent

torch_ipex/csrc/jit/codegen/onednn/fusion_group_name.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@ namespace jit {
55
namespace fuser {
66
namespace onednn {
77

8-
const std::string& LlgaFusionGroupName() {
8+
const std::string &LlgaFusionGroupName() {
99
static const std::string _LlgaFusionGroupName = "ipex::LlgaFusionGroup";
1010
return _LlgaFusionGroupName;
1111
}
1212

13+
const std::string &LlgaGuardName() {
14+
static const std::string LlgaGuardName = "ipex::LlgaFusionGuard";
15+
return LlgaGuardName;
16+
}
17+
1318
} // namespace onednn
1419
} // namespace fuser
1520
} // namespace jit

torch_ipex/csrc/jit/codegen/onednn/fusion_group_name.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ namespace fuser {
88
namespace onednn {
99

1010
// Workaround here. Once the PR of PyTorch LLGA bridge code has been landed
11-
// into the stock PyTorch, we could directly use the Symbol: prim::LlgaFusionGroup
12-
// instead of Symbol::fromQualString(LlgaFusionGroupName())
13-
extern const std::string& LlgaFusionGroupName();
11+
// into the stock PyTorch, we could directly use the Symbol:
12+
// prim::LlgaFusionGroup and prim::LlgaFusionGuard instead of
13+
// Symbol::fromQualString(LlgaFusionGroupName()) and
14+
// Symbol::fromQualString(LlgaGuardName())
15+
extern const std::string &LlgaFusionGroupName();
16+
extern const std::string &LlgaGuardName();
1417

1518
} // namespace onednn
1619
} // namespace fuser
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#include "jit/codegen/onednn/guard_shape.h"
2+
#include "jit/codegen/onednn/fusion_group_name.h"
3+
4+
#include <torch/csrc/jit/jit_log.h>
5+
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
6+
#include <torch/csrc/jit/runtime/graph_executor.h>
7+
8+
namespace torch {
9+
namespace jit {
10+
namespace fuser {
11+
namespace onednn {
12+
13+
using tensor_type_converter_t =
14+
c10::function_ref<TensorTypePtr(const TensorTypePtr &t)>;
15+
16+
void insertTypeGuardForFusionGroup(Node *guarded_node,
17+
tensor_type_converter_t type_converter,
18+
Symbol kind) {
19+
GRAPH_DEBUG("Inserting a typecheck guard for a node", *guarded_node);
20+
auto subgraph = guarded_node->g(attr::Subgraph);
21+
22+
// Fixup types of the subgraph inputs
23+
std::vector<Value *> inputs_to_check;
24+
std::vector<TypePtr> guard_types;
25+
for (Value *input : guarded_node->inputs()) {
26+
// We only check inputs of the guarded nodes and expect user to infer
27+
// intermediates and outputs shapes
28+
if (!input->type()->cast<TensorType>()) {
29+
continue;
30+
}
31+
32+
// fusion outputs are already guarded
33+
if (input->node()->kind() == prim::Constant ||
34+
input->node()->kind() ==
35+
Symbol::fromQualString(LlgaFusionGroupName())) {
36+
continue;
37+
}
38+
inputs_to_check.push_back(input);
39+
guard_types.push_back(type_converter(input->type()->expect<TensorType>()));
40+
}
41+
if (!inputs_to_check.size()) {
42+
return;
43+
}
44+
45+
// Add ipex::LlgaFusionGuard node
46+
//
47+
// ipex::LlgaFusionGuard nodes look like the following:
48+
// %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool =
49+
// ipex::LlgaFusionGuard(%inp1 : Tensor, %inp2 : Tensor)
50+
//
51+
// They have N inputs whose types we are going to check and N+1 outputs. The
52+
// first N outputs specify expected types and N+1-th output holds the result
53+
// of the check (bool).
54+
Node *typecheck_node =
55+
guarded_node->owningGraph()
56+
->create(kind, inputs_to_check, inputs_to_check.size() + 1)
57+
->insertBefore(guarded_node);
58+
typecheck_node->tys_(attr::types, guard_types);
59+
Value *typecheck_result = typecheck_node->output(inputs_to_check.size());
60+
61+
std::unordered_map<Value *, Value *> typechecked_inputs;
62+
for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
63+
typechecked_inputs[typecheck_node->input(i)] = typecheck_node->output(i);
64+
}
65+
66+
// Fixup types of the typecheck node outputs, which are used by the op in
67+
// execution
68+
typecheck_node->output(inputs_to_check.size())->setType(BoolType::get());
69+
for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
70+
typecheck_node->output(i)->setType(typecheck_node->input(i)->type());
71+
}
72+
73+
// Insert if
74+
auto versioning_if =
75+
guarded_node->owningGraph()
76+
->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
77+
->insertAfter(typecheck_node);
78+
for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
79+
versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
80+
guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
81+
}
82+
auto true_block = versioning_if->addBlock();
83+
auto false_block = versioning_if->addBlock();
84+
85+
// Fill in the false block. It should contain the unoptimized
86+
// copy of the fused subgraph.
87+
WithInsertPoint guard(false_block->return_node());
88+
const auto subgraph_outputs = insertGraph(*guarded_node->owningGraph(),
89+
*subgraph, guarded_node->inputs());
90+
for (Value *output : subgraph_outputs) {
91+
false_block->registerOutput(output);
92+
}
93+
94+
// types get copied to the fallback graph, so remove specializations before
95+
// replacing
96+
removeTensorTypeSpecializations(false_block);
97+
replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
98+
99+
// Fill in the true block. It has all inputs type-checked and its
100+
// body should be the fusion group node.
101+
guarded_node->moveBefore(true_block->return_node());
102+
for (size_t idx = 0; idx < guarded_node->inputs().size(); ++idx) {
103+
if (typechecked_inputs.count(guarded_node->input(idx))) {
104+
guarded_node->replaceInput(
105+
idx, typechecked_inputs.at(guarded_node->input(idx)));
106+
}
107+
}
108+
for (Value *output : guarded_node->outputs()) {
109+
true_block->registerOutput(output);
110+
}
111+
}
112+
113+
//! [ Note -- prepareFusionGroupAndGuardOutputs implementation ]
114+
//! shamelessly copying code from NNC (tensorexpr_fuser) with very little
115+
//! modification, original code at:
116+
//! `torch/csrc/jit/passes/tensorexpr_fuser.cpp:prepareFusionGroupAndGuardOutputs`
117+
//!
118+
//! We have the assumption that LLGA does not have operators
119+
//! depending on the content of the tensor.
120+
void prepareFusionGroupAndGuardOutputs(Block *block) {
121+
std::vector<Node *> fusion_groups;
122+
for (Node *n : block->nodes()) {
123+
for (Block *b : n->blocks()) {
124+
prepareFusionGroupAndGuardOutputs(b);
125+
}
126+
if (n->kind() == Symbol::fromQualString(LlgaFusionGroupName())) {
127+
fusion_groups.push_back(n);
128+
}
129+
}
130+
for (Node *fusion_group : fusion_groups) {
131+
// TODO: add further optimization pass to removeOutputsUsedOnlyInSize,
132+
// refer to
133+
// `torch/csrc/jit/passes/tensorexpr_fuser.cpp:removeOutputsUsedOnlyInSize`
134+
// removeOutputsUsedOnlyInSize(fusion_group);
135+
insertTypeGuardForFusionGroup(
136+
fusion_group, [](const TensorTypePtr &t) { return t; },
137+
Symbol::fromQualString(fuser::onednn::LlgaGuardName()));
138+
}
139+
}
140+
141+
} // namespace onednn
142+
} // namespace fuser
143+
} // namespace jit
144+
} // namespace torch

0 commit comments

Comments
 (0)