Skip to content

Commit 15678cc

Browse files
quantization: introduce inplace parameter for convert function (#855)
1 parent 7768124 commit 15678cc

File tree

5 files changed

+123
-48
lines changed

5 files changed

+123
-48
lines changed

intel_extension_for_pytorch/ao/quantization/_quantize.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
1-
import torch
21
import copy
32
from typing import Tuple, Any
3+
import warnings
44

5-
import intel_extension_for_pytorch._C as core
5+
import torch
6+
from torch.ao.quantization import PlaceholderObserver
67
import torch.fx.experimental.optimization as optimization
7-
from ._quantize_utils import auto_prepare, auto_convert
8-
import warnings
8+
9+
import intel_extension_for_pytorch._C as core
10+
from ._quantize_utils import auto_prepare, auto_convert, copy_prepared_model
911
from ... import nn
1012

1113
def prepare(
1214
model,
1315
configure,
1416
example_inputs,
15-
inplace=True):
17+
inplace=False):
1618
r"""
1719
Prepare an FP32 torch.nn.Module model to do calibration or to convert to quantized model.
1820
Args:
1921
model (torch.nn.Module): The FP32 model to be prepared.
2022
configure (torch.quantization.qconfig.QConfig): The observer settings about activation and weight.
2123
example_inputs (tuple or torch.Tensor): A tuple of example inputs that
22-
will be passed to the function while running to init quantizaiton state.
23-
inplace: (bool): It will do overide the original model.
24+
will be passed to the function while running to init quantization state.
25+
inplace: (bool): It will change the given model in-place if True. The default value is ``False``.
2426
Returns:
2527
torch.nn.Module
2628
"""
@@ -43,20 +45,47 @@ def prepare(
4345
example_inputs = tuple(example_inputs)
4446
return auto_prepare(prepare_model, configure, example_inputs)
4547

46-
def convert(model):
48+
def convert(
49+
model,
50+
inplace=False):
4751
r"""
4852
Convert an FP32 prepared model to a model which will automatically insert fake quant
4953
before a quantizable module or operator.
5054
Args:
5155
model (torch.nn.Module): The FP32 model to be convert.
56+
inplace: (bool): It will change the given model in-place if True. The default value is ``False``.
5257
Returns:
5358
torch.torch.nn.Module
5459
"""
55-
5660
assert isinstance(model, torch.nn.Module), "Only support nn.Module convert for quantization path"
57-
# Vonvert linear and weight's dtype when use autocast, which will reduce the dtype conversion.
61+
assert hasattr(model, 'q_config'), "Please do prepare the model before doing convert"
62+
63+
if inplace:
64+
convert_model = model
65+
else:
66+
try:
67+
convert_model = copy_prepared_model(model)
68+
except:
69+
assert False, "The model's copy is failed, please try set inplace to True to do the convert"
70+
71+
# If the module's activation's qconfig is PlaceholderObserver,
72+
# we can say that the module want to run dynamic quantization path.
73+
if isinstance(convert_model.q_config.activation(), PlaceholderObserver):
74+
qconfig_spec = {
75+
torch.nn.Linear : convert_model.q_config,
76+
torch.nn.LSTM : convert_model.q_config,
77+
torch.nn.GRU : convert_model.q_config,
78+
torch.nn.LSTMCell : convert_model.q_config,
79+
torch.nn.RNNCell : convert_model.q_config,
80+
torch.nn.GRUCell : convert_model.q_config,
81+
}
82+
return torch.quantization.quantize_dynamic(convert_model, qconfig_spec=qconfig_spec, inplace=True)
83+
84+
# Convert linear, conv, and Embedding's weight dtype when use autocast,
85+
# which will reduce the dtype conversion.
5886
# TODO: check whether can be removed or not?
5987
if torch.is_autocast_cpu_enabled() and core.get_autocast_dtype() == torch.bfloat16:
60-
model = nn.utils._model_convert.convert_module_data_type(model, torch.bfloat16)
61-
convert_model = auto_convert(model)
88+
convert_model = nn.utils._model_convert.convert_module_data_type(convert_model, torch.bfloat16)
89+
90+
convert_model = auto_convert(convert_model)
6291
return convert_model

intel_extension_for_pytorch/ao/quantization/_quantize_utils.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import copy
23
from typing import List, Dict, Tuple, Any, Optional
34
import torch
45
import torch.nn.functional as F
@@ -8,7 +9,8 @@
89

910
from ._utils import get_torch_function_hook_type, HookType, get_module_hook_type, OpQuantizeabilityType, \
1011
attach_op_convert_info_to_model, save_quant_state, attach_scale_zp_values_to_model, convert_quant_state_map_to_nodes, \
11-
sync_pool_and_lstm_input_output_scale_zp, module_call_to_function_call, quantized_modules_has_weights, load_qconf_summary_to_model
12+
sync_pool_and_lstm_input_output_scale_zp, module_call_to_function_call, quantized_modules_has_weights, \
13+
load_qconf_summary_to_model, get_fqn_valid_for_module_dict_key
1214
from ._quantization_state import AutoQuantizationState, AutoQuantizationStateModuleDict, init_model_quant_state
1315
from ._recipe import get_default_recipe
1416
from ._module_swap_utils import swap_child_modules
@@ -343,7 +345,26 @@ def load_qconf_summary(self, qconf_summary):
343345
model(*example_inputs)
344346
return model
345347

346-
def auto_convert(module : torch.nn.Module) -> torch.nn.Module:
348+
def copy_prepared_model(model):
349+
copied_model = copy.deepcopy(model)
350+
copied_model.q_config = model.q_config
351+
if isinstance(copied_model.q_config.activation(), PlaceholderObserver):
352+
return copied_model
353+
copied_model._fqn_to_auto_quant_state_map = copy.deepcopy(model._fqn_to_auto_quant_state_map)
354+
named_modules = list(copied_model.named_modules())
355+
for fqn, v in named_modules:
356+
fqn_to_use_for_key = get_fqn_valid_for_module_dict_key(fqn)
357+
if fqn_to_use_for_key in copied_model._fqn_to_auto_quant_state_map:
358+
auto_quant_state = copied_model._fqn_to_auto_quant_state_map[fqn_to_use_for_key]
359+
object.__setattr__(v, '_auto_quant_state', auto_quant_state)
360+
if hasattr(model, '_qconf_summary'):
361+
copied_model._qconf_summary = copy.deepcopy(model._qconf_summary)
362+
copied_model.__class__ = model.__class__
363+
return copied_model
364+
365+
def auto_convert(
366+
module : torch.nn.Module,
367+
) -> torch.nn.Module:
347368
def convert_to_dispatch_proxy(x):
348369
if isinstance(x, torch.Tensor):
349370
return x.as_subclass(QuantizationConvertTensorProxy) # type: ignore[arg-type]
@@ -528,19 +549,7 @@ def unwrap_proxy(a):
528549
finally:
529550
torch.nn.Module.__call__ = orig_module_call
530551
torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment]
531-
532-
# If the module's activation's qconfig is PlaceholderObserver, we can say that the module want to run dynamic quantization path.
533-
if isinstance(module.q_config.activation(), PlaceholderObserver):
534-
qconfig_spec = {
535-
torch.nn.Linear : module.q_config,
536-
torch.nn.LSTM : module.q_config,
537-
torch.nn.GRU : module.q_config,
538-
torch.nn.LSTMCell : module.q_config,
539-
torch.nn.RNNCell : module.q_config,
540-
torch.nn.GRUCell : module.q_config,
541-
}
542-
return torch.quantization.quantize_dynamic(module, qconfig_spec=qconfig_spec)
543-
552+
544553
# If module doesn't have a configure_file attr, we can say that user has run save_qconf_summary method which have
545554
# computed the scales and zp, or use the user's setting from a given json file(load_qconf_summary), we need to compute
546555
# the scale and zp here.

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,25 @@ def _lstm_params_list():
247247
self.assertGraphContainsExactly(graph, 'ipex::quantized_lstm', 1)
248248

249249
class TestIpexQuantizationConvertAPI(JitLlgaTestCase):
250+
def test_inplace_preapre(self):
251+
class M(nn.Module):
252+
def __init__(self):
253+
super(M, self).__init__()
254+
self.linear = nn.Linear(128,1)
255+
256+
def forward(self, x):
257+
x = self.linear(x)
258+
return x
259+
260+
x = torch.rand(1,128)
261+
for inplace in [False, True]:
262+
m = M()
263+
prepared_model = ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x, inplace=inplace)
264+
if inplace:
265+
self.assertEqual(m.linear.weight.data_ptr(), prepared_model.linear.weight.data_ptr())
266+
else:
267+
self.assertNotEqual(m.linear.weight.data_ptr(), prepared_model.linear.weight.data_ptr())
268+
250269
def test_inplace_convert(self):
251270
class M(nn.Module):
252271
def __init__(self):
@@ -264,7 +283,8 @@ def forward(self, x):
264283
for inplace in [False, True]:
265284
orgin_model_weight_dtype = m_.linear.weight.dtype
266285
orgin_model_bias_dtype = m_.linear.bias.dtype
267-
_, _, ori_model = self.prepareModel(m_, x, qconfig=static_qconfig[1], int8_bf16=int8_bf16, inplace=inplace)
286+
_, _, ori_model = self.prepareModel(m_, x, qconfig=static_qconfig[1], int8_bf16=int8_bf16,
287+
prepare_inplace=True, convert_inplace=inplace)
268288
if inplace and int8_bf16:
269289
if m_.linear.weight.dtype == orgin_model_weight_dtype or m_.linear.bias.dtype == orgin_model_bias_dtype:
270290
print("model should have changed")
@@ -291,20 +311,20 @@ def forward(self, x):
291311
prepared_model = ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x, inplace=False)
292312
prepared_model(x)
293313
with tempfile.TemporaryDirectory() as tmp:
294-
path = os.path.join(tmp, "configure.json")
295-
prepared_model.save_qconf_summary(path)
296-
convert_model = ipex.quantization.convert(prepared_model)
297-
traced_model = torch.jit.trace(convert_model, x).eval()
298-
traced_model = torch.jit.freeze(traced_model)
299-
y_before = traced_model(x)
300-
# load the saved qconf
301-
prepared_model = ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x, inplace=False)
302-
prepared_model.load_qconf_summary(path)
303-
convert_model = ipex.quantization.convert(prepared_model)
304-
traced_model = torch.jit.trace(convert_model, x).eval()
305-
traced_model = torch.jit.freeze(traced_model)
306-
y_after = traced_model(x)
307-
self.assertEqual(y_before, y_after)
314+
path = os.path.join(tmp, "configure.json")
315+
prepared_model.save_qconf_summary(path)
316+
convert_model = ipex.quantization.convert(prepared_model)
317+
traced_model = torch.jit.trace(convert_model, x).eval()
318+
traced_model = torch.jit.freeze(traced_model)
319+
y_before = traced_model(x)
320+
# load the saved qconf
321+
prepared_model = ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x, inplace=False)
322+
prepared_model.load_qconf_summary(path)
323+
convert_model = ipex.quantization.convert(prepared_model)
324+
traced_model = torch.jit.trace(convert_model, x).eval()
325+
traced_model = torch.jit.freeze(traced_model)
326+
y_after = traced_model(x)
327+
self.assertEqual(y_before, y_after)
308328

309329
class TestRemoveMutate(JitLlgaTestCase):
310330
def test_mutated_value_alive_after_inplace_op(self):
@@ -373,6 +393,21 @@ def forward(self, x):
373393
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
374394
FileCheck().check_not("aten:linear").check("quantized::linear_dynamic").run(graph)
375395

396+
def test_linear_dynamic_bf16(self):
397+
class M(nn.Module):
398+
def __init__(self):
399+
super(M, self).__init__()
400+
self.linear = nn.Linear(3, 3)
401+
402+
def forward(self, x):
403+
x = self.linear(x)
404+
return x
405+
406+
x = torch.randn(3, 3)
407+
m = M().eval()
408+
graph, _, _ = self.prepareModel(m, [x], qconfig=dynamic_qconfig[0], int8_bf16=True)
409+
FileCheck().check_not("aten:linear").check("quantized::linear_dynamic").run(graph)
410+
376411
def test_lstm_dynamic(self):
377412
class M(nn.Module):
378413
def __init__(self):

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,14 @@ def forward(self, x):
181181

182182
for bias in [True]: # TODO:[True, False] when supported in backend
183183
x = torch.randn(2, 15)
184-
m = M(bias)
185184

186185
patterns = [
187186
["aten::to", "aten::quantize_per_tensor"],
188187
["aten::dequantize", "aten::to", "aten::linear"],
189188
]
190189

191190
for qconfig in static_qconfig:
191+
m = M(bias)
192192
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig, int8_bf16=True)
193193
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
194194
# single aten::to won't be rewritten by llga backend

tests/cpu/test_ao_jit_llga_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def assertFused(self, graph, fused_patterns):
102102
for pat in fused_patterns:
103103
self.assertGraphContainsExactly(graph, pat, 0)
104104

105-
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, remove_dropout=False, x_var=None, qconfig=default_static_qconfig, int8_bf16=False):
105+
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, remove_dropout=False, x_var=None,
106+
qconfig=default_static_qconfig, int8_bf16=False):
106107
graph, traced_model, fp32_model = self.prepareModel(model, x, remove_dropout, qconfig, int8_bf16)
107108
with torch.no_grad():
108109
y = fp32_model(*x)
@@ -119,23 +120,24 @@ def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, remove_dropout=Fals
119120

120121
return graph
121122

122-
def prepareModel(self, model, x, remove_dropout=False, qconfig=default_static_qconfig, int8_bf16=False, inplace=False):
123+
def prepareModel(self, model, x, remove_dropout=False, qconfig=default_static_qconfig,
124+
int8_bf16=False, prepare_inplace=True, convert_inplace=True,):
123125
model.eval()
124126
fp32_model = copy.deepcopy(model)
125127
with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
126128
# fold conv bn
127129
if remove_dropout:
128130
ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
129-
model = ipex.quantization.prepare(model, qconfig, x, inplace=inplace)
131+
model = ipex.quantization.prepare(model, qconfig, x, inplace=prepare_inplace)
130132
# do calibration
131133
y = model(*x)
132134
# jit trace to insert quant/dequant
133135
if int8_bf16:
134136
with torch.cpu.amp.autocast():
135-
convert_model = ipex.quantization.convert(model)
137+
convert_model = ipex.quantization.convert(model, inplace=convert_inplace)
136138
traced_model = torch.jit.trace(convert_model, x)
137139
else:
138-
convert_model = ipex.quantization.convert(model)
140+
convert_model = ipex.quantization.convert(model, inplace=convert_inplace)
139141
traced_model = torch.jit.trace(convert_model, x)
140142
traced_model = torch.jit.freeze(traced_model)
141143

0 commit comments

Comments
 (0)