Skip to content

Commit ad0b753

Browse files
opaque context: avoid recursive invoke in context conversion routine by bypassing ATen dispatch. (#2566) (#2573)
* Check opaque u8 in to_plain * Add comment for case analysis Co-authored-by: Zhiwei <zhiwei.yan@intel.com>
1 parent de83c49 commit ad0b753

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

csrc/gpu/aten/tensor/Context.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@ at::Tensor DPCPPTensorConvertor::to_plain(const at::Tensor& from_original) {
3232
// dtype in stored context, so the reconstruction is needed for reorder's
3333
// correctness.
3434
auto is_equal = check_equality_for_meta_dtype_and_ctx_dtype(from_original);
35-
if (!is_equal) {
35+
auto is_opaque_u8_qtensor = is_opaque_u8(from_original);
36+
// Case 1:
37+
// Tensor ctx has real dtype(like f32,s8) but meta has byte dtype(u8), this
38+
// is for pickling tensor. Run following if statement
39+
// Case 2:
40+
// Opaque u8 qtensor has QUInt8 meta but s8 ctx. No need for pickiling it,
41+
// bypass following if statement.
42+
if (!is_equal && !is_opaque_u8_qtensor) {
3643
// Here use opaqueTypeToScalarType to deduce the meta dtype
3744
// [from] the context dtype, then reconstruct the tensor [from]
3845
from = at::empty_like(
@@ -59,7 +66,7 @@ at::Tensor DPCPPTensorConvertor::to_plain(const at::Tensor& from_original) {
5966
to, to_meta.sizes_, to_meta.strides_, c10::nullopt);
6067
xpu::oneDNN::reorder(from, to_);
6168

62-
if (!is_equal) {
69+
if (!is_equal && !is_opaque_u8_qtensor) {
6370
// reconstruct the [to] tensor with the original tensor meta
6471
to = at::empty_like(from_original);
6572

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import TestCase
4+
import intel_extension_for_pytorch # noqa
5+
6+
from torch.quantization.quantize_jit import (
7+
convert_jit,
8+
prepare_jit,
9+
)
10+
import pytest
11+
import time
12+
13+
def trace_int8_model(model, device, test_input):
14+
model = model.to(device)
15+
modelJit = torch.jit.trace(model, test_input.to(device))
16+
modelJit.eval()
17+
modelJit.to(device)
18+
print(modelJit)
19+
print("finish jit tracing...")
20+
21+
print("start ", device, " calibration ...")
22+
qconfig_u8 = torch.quantization.QConfig(
23+
activation=torch.quantization.observer.MinMaxObserver.with_args(
24+
qscheme=torch.per_tensor_symmetric,
25+
reduce_range=False,
26+
dtype=torch.quint8
27+
),
28+
weight=torch.quantization.default_weight_observer
29+
)
30+
31+
modelJit = prepare_jit(modelJit, {'': qconfig_u8}, True)
32+
33+
# do calibration
34+
test_input = test_input.to(device)
35+
with torch.no_grad():
36+
for i in range(1):
37+
calib_input = test_input
38+
modelJit(calib_input)
39+
print("start ", device, " convert...")
40+
modelJit = convert_jit(modelJit, True)
41+
# inference
42+
print("start ", device, " inference ...")
43+
with torch.no_grad():
44+
for i in range(1):
45+
start = time.time()
46+
output_cpu = modelJit(test_input)
47+
end = time.time()
48+
print("iter.{} ... {time:.3f}ms".format(i, time=(end - start) * 1000))
49+
print("print ", device, " jit graph ....")
50+
print(modelJit.graph_for(test_input))
51+
52+
print("get ", device, " test input result....")
53+
output = modelJit(test_input)
54+
print("finish ", device, " testing.......")
55+
return output
56+
57+
class SimpleModule(nn.Module):
58+
def __init__(self):
59+
super().__init__()
60+
self.conv = nn.Conv2d(3, 6, 3, 1, 1)
61+
self.instance_norm = nn.InstanceNorm2d(6, **{'eps': 1e-5, 'affine': True, 'momentum': 0.1})
62+
63+
def forward(self, x):
64+
x = self.conv(x)
65+
x = self.instance_norm(x)
66+
return x
67+
68+
69+
class TestQTensortoPlain(TestCase):
70+
def test_q_to_plain(self):
71+
mod = SimpleModule()
72+
test_input = torch.randn(3, 3, 16, 16)
73+
with torch.no_grad():
74+
with torch.xpu.onednn_layout():
75+
trace_int8_model(mod, "xpu", test_input)
76+

0 commit comments

Comments
 (0)