Skip to content

Commit 56428f8

Browse files
[block][storage save] fix storage save issue for block format (#2399) (#2539)
* Fix save acc drop error for block format * Add include file for Context method * Use at::AtenIpexTypeXPU::DPCPPTensorContext in Utils.h * change the dtype mapping: u8 <---> QUInt8 s8 <---> QInt8 * Change the deduce dtype mapping for quantized tensor --------- Signed-off-by: Chen, Zejun <zejun.chen@intel.com> Co-authored-by: Jinghui <jinghui.gu@intel.com>
1 parent a7c4a16 commit 56428f8

File tree

5 files changed

+195
-73
lines changed

5 files changed

+195
-73
lines changed

csrc/gpu/aten/tensor/Context.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,39 @@ using namespace xpu::oneDNN;
88
namespace at {
99
namespace AtenIpexTypeXPU {
1010

11-
at::Tensor DPCPPTensorConvertor::to_plain(const at::Tensor& from) {
12-
if (!is_opaque_tensor(from))
13-
return from;
11+
at::Tensor DPCPPTensorConvertor::to_plain(const at::Tensor& from_original) {
12+
if (!is_opaque_tensor(from_original))
13+
return from_original;
14+
15+
auto from = from_original;
16+
17+
// [watch out] here the dtype in tensor from_original's meta may not be equal
18+
// to the dtype stored in its context. When doing save, the storage is
19+
// regarded as pure u8 type then to be pickled down(here is the reorder happen
20+
// when pickling:
21+
// https://github.com/pytorch/pytorch/blob/0f4652f4989a2d196f36fe75e5c73cb88dc0800d/torch/serialization.py#L667)
22+
23+
// Before saving, the block tensor should be converted to plain to make sure
24+
// the correctness, so the tensor MUST be reconstructed on the given storage
25+
// with the CORRECT meta dtype.
26+
// If saving block tensor, here tensor [from_original] is a pure u8 one-dim
27+
// tensor. Thus, the [from] tensor should be reconstructed with the correct
28+
// meta dtype associated with the dtype stored in the [from_original]'s
29+
// context. After reordering, the plain u8 tensor should be recovered.
30+
31+
// Here is_equal(false) means the dtype in tensor meta is not equal to the
32+
// dtype in stored context, so the reconstruction is needed for reorder's
33+
// correctness.
34+
auto is_equal = check_equality_for_meta_dtype_and_ctx_dtype(from_original);
35+
if (!is_equal) {
36+
// Here use opaqueTypeToScalarType to deduce the meta dtype
37+
// [from] the context dtype, then reconstruct the tensor [from]
38+
from = at::empty_like(
39+
from_original,
40+
from_original.options().dtype(opaqueTypeToScalarType(from_original)));
41+
42+
unsafe_get_and_set_data_ptr(from_original, from);
43+
}
1444

1545
// use native API to break recursive call resulted by opaque guard in aten itf
1646
auto to = from.is_quantized()
@@ -29,6 +59,22 @@ at::Tensor DPCPPTensorConvertor::to_plain(const at::Tensor& from) {
2959
to, to_meta.sizes_, to_meta.strides_, c10::nullopt);
3060
xpu::oneDNN::reorder(from, to_);
3161

62+
if (!is_equal) {
63+
// reconstruct the [to] tensor with the original tensor meta
64+
to = at::empty_like(from_original);
65+
66+
// release [to_] context and set it to [to], now the tensor [to] is plain
67+
// and it has same meta with the tensor [from_original]
68+
unsafe_release_and_set_data_ptr(to_, to);
69+
70+
// manually free [from] context
71+
from.unsafeGetTensorImpl()
72+
->storage()
73+
.unsafeGetStorageImpl()
74+
->data_ptr()
75+
.release_context();
76+
}
77+
3278
return to;
3379
}
3480

csrc/gpu/aten/tensor/Tensor.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,44 @@ Tensor share_storage_and_set_strided_as(
123123
return result;
124124
}
125125

126+
// get ctx from src and set(share) it to dst
127+
// [watch out] When using this function, the src and the dst will share the same
128+
// memory raw data with associated different storage and context, thus the src's
129+
// context MUST be explicitly and manually released after this calling to avoid
130+
// double free when the src and the dst end their life cycle.
131+
void unsafe_get_and_set_data_ptr(const Tensor& src, const Tensor& dst) {
132+
auto src_cptr = (DPCPPTensorContext*)src.unsafeGetTensorImpl()
133+
->storage()
134+
.unsafeGetStorageImpl()
135+
->data_ptr()
136+
.get_context();
137+
at::DataPtr src_dptr(
138+
src_cptr->data(),
139+
src_cptr,
140+
getDeviceAllocator()->raw_deleter(),
141+
dst.device());
142+
dst.unsafeGetTensorImpl()->storage().unsafeGetStorageImpl()->set_data_ptr(
143+
std::move(src_dptr));
144+
}
145+
146+
// release ctx from src and set it to dst
147+
// [watch out] When using this function, the src's control to its memory raw
148+
// data will be transfered to dst. After this calling, the src's context is
149+
// released and cannot be used any more.
150+
void unsafe_release_and_set_data_ptr(const Tensor& src, const Tensor& dst) {
151+
auto src_cptr = (DPCPPTensorContext*)src.unsafeGetTensorImpl()
152+
->storage()
153+
.unsafeGetStorageImpl()
154+
->data_ptr()
155+
.release_context();
156+
at::DataPtr dptr(
157+
src_cptr->data(),
158+
src_cptr,
159+
getDeviceAllocator()->raw_deleter(),
160+
src.device());
161+
dst.unsafeGetTensorImpl()->storage().unsafeGetStorageImpl()->set_data_ptr(
162+
std::move(dptr));
163+
}
164+
126165
} // namespace AtenIpexTypeXPU
127166
} // namespace at

csrc/gpu/aten/tensor/Tensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@ Tensor share_storage_and_set_strided_as(
2020
IntArrayRef stride,
2121
c10::optional<int64_t> storage_offset_);
2222

23+
void unsafe_get_and_set_data_ptr(const Tensor& src, const Tensor& dst);
24+
25+
void unsafe_release_and_set_data_ptr(const Tensor& src, const Tensor& dst);
2326
} // namespace AtenIpexTypeXPU
2427
} // namespace at

csrc/gpu/oneDNN/Utils.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,38 @@ static bool is_supported_onednn_dtype(const at::Tensor& tensor) {
105105
: true;
106106
}
107107

108+
// Here this function is used to deduce the torch tensor meta dtype from the
109+
// kept oqaque tensor context in case of saving tensor
110+
static inline c10::ScalarType opaqueTypeToScalarType(const at::Tensor& tensor) {
111+
auto is_quantized = tensor.is_quantized();
112+
auto ctx = *(static_cast<at::AtenIpexTypeXPU::DPCPPTensorContext*>(
113+
tensor.unsafeGetTensorImpl()->storage().data_ptr().get_context()));
114+
switch (ctx.dtype()) {
115+
case dnnl::memory::data_type::u8:
116+
// For quantized tensor, the meta dtype is QUInt8
117+
return (is_quantized) ? at::ScalarType::QUInt8 : at::ScalarType::Byte;
118+
case dnnl::memory::data_type::s8:
119+
// For quantized tensor, the meta dtype is QInt8
120+
return (is_quantized) ? at::ScalarType::QInt8 : at::ScalarType::Char;
121+
case dnnl::memory::data_type::f16:
122+
return at::ScalarType::Half;
123+
case dnnl::memory::data_type::f32:
124+
return at::ScalarType::Float;
125+
case dnnl::memory::data_type::bf16:
126+
return at::ScalarType::BFloat16;
127+
case dnnl::memory::data_type::f64:
128+
return at::ScalarType::Double;
129+
default:
130+
TORCH_CHECK(false, "Cannot be translated to torch dtype");
131+
};
132+
}
133+
134+
static inline bool check_equality_for_meta_dtype_and_ctx_dtype(
135+
const at::Tensor& tensor) {
136+
auto ctx_dtype = opaqueTypeToScalarType(tensor);
137+
return bool(ctx_dtype == tensor.scalar_type());
138+
}
139+
108140
static inline fpmath_mode get_onednn_fpmath_mode() {
109141
auto math_mode = Settings::I().get_fp32_math_mode();
110142
switch (math_mode) {

tests/gpu/examples/test_save_load.py

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,20 @@
55
import torch.nn as nn
66
from torch.testing._internal.common_utils import TestCase
77
import intel_extension_for_pytorch # noqa
8+
import torchvision.models as models
89
import pytest
910
import os
1011

1112
cpu_device = torch.device("cpu")
1213
xpu_device = torch.device("xpu")
1314

1415
batch_size = 128
15-
class_num = 1000
16-
input_channel = 512
17-
hidden_channel = 2048
18-
num_iter = 10
16+
input_channel = 3
17+
train_num_iter = 5
18+
eval_num_iter = 3
1919
lr = 0.01
2020
checkpoint_path_str = './_checkpoint.test.case.test_xpu_checkpoint_save_load_integrity_and_accuracy.pth.tar'
2121

22-
class TrainingModel(nn.Module):
23-
def __init__(self):
24-
super(TrainingModel, self).__init__()
25-
self.m = nn.Sequential(
26-
nn.Conv2d(input_channel, hidden_channel, kernel_size=(1, 1), stride=(1, 1), bias=False),
27-
nn.BatchNorm2d(hidden_channel, eps=1e-05, momentum=0.1),
28-
nn.ReLU(inplace=True),
29-
nn.AvgPool2d(kernel_size=7, stride=1, padding=0),
30-
)
31-
self.fc = nn.Linear(in_features=hidden_channel, out_features=class_num, bias=True)
32-
33-
def forward(self, x, indentity_for_mul, indentity_for_add):
34-
x = self.m(x)
35-
x = x * indentity_for_mul
36-
x = x.view(x.size(0), -1)
37-
x = self.fc(x)
38-
x = x + indentity_for_add
39-
return x
40-
4122
class TestTorchMethod(TestCase):
4223
@pytest.mark.skipif(not torch.xpu.utils.has_fp64_dtype(), reason="fp64 not support by this device")
4324
def test_save_load(self):
@@ -65,31 +46,17 @@ def test_serialization_multi_map_location(self):
6546
self.assertEqual(b.device.__str__(), 'xpu:1')
6647

6748
@pytest.mark.skipif(not torch.xpu.utils.has_fp64_dtype(), reason="fp64 not support by this device")
68-
def test_xpu_checkpoint_save_load_integrity_and_accuracy(self, dtype=torch.bfloat16):
69-
# create model
49+
def test_xpu_checkpoint_save_load_integrity_and_accuracy(self):
7050
device = 'xpu'
71-
model_xpu = TrainingModel()
72-
model_xpu = model_xpu.to(device=device).train()
73-
optimizer_xpu = torch.optim.SGD(model_xpu.parameters(), lr=lr)
74-
criterion = nn.CrossEntropyLoss()
75-
76-
if os.path.exists(checkpoint_path_str):
77-
os.remove(checkpoint_path_str)
78-
79-
# process torch.xpu.optimize
80-
model_xpu, optimizer_xpu = torch.xpu.optimize(model=model_xpu, dtype=dtype, optimizer=optimizer_xpu)
81-
82-
def training_step(model_xpu, optimizer_xpu, criterion):
83-
input = torch.randn(batch_size, input_channel, 7, 7)
84-
target = torch.empty(batch_size, dtype=torch.long).random_(class_num)
51+
def training_step(model_xpu, optimizer_xpu, criterion, dtype):
52+
input = torch.randn(batch_size, input_channel, 224, 224)
53+
target = torch.empty(batch_size, dtype=torch.long).random_(1000)
8554
input_xpu = input.clone().to(device=device).requires_grad_()
8655
target_xpu = target.to(device)
87-
indentity_for_mul = torch.randn(batch_size, hidden_channel, 1, 1).to(device=device)
88-
indentity_for_add = torch.randn(batch_size, class_num).to(device=device)
8956

9057
# forward
9158
with torch.xpu.amp.autocast(enabled=True, dtype=dtype):
92-
output_xpu = model_xpu(input_xpu, indentity_for_mul, indentity_for_add)
59+
output_xpu = model_xpu(input_xpu)
9360
loss_xpu = criterion(output_xpu, target_xpu)
9461

9562
# optimizer
@@ -103,35 +70,70 @@ def training_step(model_xpu, optimizer_xpu, criterion):
10370
loss_xpu = loss_xpu.cpu()
10471
output_xpu = output_xpu.cpu()
10572

106-
def save_checkpoint(state, filename=checkpoint_path_str):
107-
torch.save(state, filename)
73+
def eval_step(model_xpu, dtype):
74+
input = torch.randn(batch_size, input_channel, 224, 224)
75+
target = torch.empty(batch_size, dtype=torch.long).random_(1000)
76+
input_xpu = input.clone().to(device=device).requires_grad_()
77+
target_xpu = target.to(device)
10878

109-
for _ in range(num_iter):
110-
training_step(model_xpu, optimizer_xpu, criterion)
79+
# forward
80+
with torch.xpu.amp.autocast(enabled=True, dtype=dtype):
81+
output_xpu = model_xpu(input_xpu)
82+
loss_xpu = criterion(output_xpu, target_xpu)
83+
84+
loss_xpu = loss_xpu.cpu()
85+
output_xpu = output_xpu.cpu()
11186

112-
save_checkpoint({'model_state_dict': model_xpu.state_dict(), 'optimizer_state_dict': optimizer_xpu.state_dict()})
113-
if os.path.isfile(checkpoint_path_str):
114-
# load checkpoint
115-
checkpoint = torch.load(checkpoint_path_str, map_location='xpu')
116-
print('load checkpoint')
87+
def save_checkpoint(state, filename=checkpoint_path_str):
88+
torch.save(state, filename)
11789

90+
for dtype in [torch.float32, torch.bfloat16]:
91+
print('dtype = ', dtype)
11892
# create model
119-
new_model = TrainingModel()
120-
new_model = new_model.to(device=device).train()
121-
print('create model')
122-
123-
# create optimizer
124-
new_optimizer = torch.optim.SGD(new_model.parameters(), lr=lr)
125-
print('create model')
126-
127-
# load state dict
128-
new_model.load_state_dict(checkpoint['model_state_dict'])
129-
new_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
130-
print('load state dict')
131-
132-
# check
133-
print('checking...')
134-
self.assertEqual(model_xpu.state_dict(), new_model.state_dict(), atol=1e-6, rtol=1e-6)
135-
self.assertEqual(optimizer_xpu.state_dict(), new_optimizer.state_dict(), atol=1e-6, rtol=1e-6)
136-
else:
137-
assert False, "save checkpoint failed for xpu model" # noqa B011
93+
model_xpu = models.__dict__['resnet18'](pretrained=True).to(device=device).train()
94+
optimizer_xpu = torch.optim.SGD(model_xpu.parameters(), lr=lr)
95+
criterion = nn.CrossEntropyLoss()
96+
97+
if os.path.exists(checkpoint_path_str):
98+
os.remove(checkpoint_path_str)
99+
100+
# process torch.xpu.optimize
101+
model_xpu, optimizer_xpu = torch.xpu.optimize(model=model_xpu, dtype=dtype, optimizer=optimizer_xpu)
102+
103+
# mimic model train, then eval
104+
for _ in range(train_num_iter):
105+
training_step(model_xpu, optimizer_xpu, criterion, dtype)
106+
model_xpu.eval()
107+
for _ in range(eval_num_iter):
108+
eval_step(model_xpu, dtype)
109+
torch.xpu.synchronize()
110+
111+
save_checkpoint({'model_state_dict': model_xpu.state_dict(), 'optimizer_state_dict': optimizer_xpu.state_dict()})
112+
if os.path.isfile(checkpoint_path_str):
113+
# load checkpoint
114+
checkpoint = torch.load(checkpoint_path_str, map_location=device)
115+
print('load checkpoint')
116+
117+
# create model
118+
new_model = models.__dict__['resnet18'](pretrained=False).to(device=device).train()
119+
print('create model')
120+
121+
# create optimizer
122+
new_optimizer = torch.optim.SGD(new_model.parameters(), lr=lr)
123+
print('create model')
124+
125+
# optimize
126+
new_model, new_optimizer = torch.xpu.optimize(model=new_model, dtype=dtype, optimizer=new_optimizer)
127+
128+
# load state dict
129+
new_model.load_state_dict(checkpoint['model_state_dict'])
130+
new_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
131+
print('load state dict')
132+
133+
# check
134+
print('checking...')
135+
self.assertEqual(model_xpu.state_dict(), new_model.state_dict(), atol=1e-6, rtol=1e-6)
136+
self.assertEqual(optimizer_xpu.state_dict(), new_optimizer.state_dict(), atol=1e-6, rtol=1e-6)
137+
os.remove(checkpoint_path_str)
138+
else:
139+
assert False, "save checkpoint failed for xpu model" # noqa B011

0 commit comments

Comments
 (0)