Skip to content

Commit a89ab64

Browse files
quantization: fix issue of scales/zps are not updated after quantization model load qconf_summary and re-do calibration (#1245)
1 parent 2e3618d commit a89ab64

File tree

3 files changed

+100
-13
lines changed

3 files changed

+100
-13
lines changed

intel_extension_for_pytorch/quantization/_quantize_utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._utils import get_torch_function_hook_type, HookType, get_module_hook_type, OpQuantizeabilityType, \
1111
attach_op_convert_info_to_model, save_quant_state, attach_scale_zp_values_to_model, convert_quant_state_map_to_nodes, \
1212
sync_pool_and_lstm_input_output_scale_zp, module_call_to_function_call, quantized_modules_has_weights, \
13-
load_qconf_summary_to_model, get_fqn_valid_for_module_dict_key
13+
load_qconf_summary_to_model, get_fqn_valid_for_module_dict_key, check_model_obsever_has_run
1414
from ._quantization_state import AutoQuantizationState, AutoQuantizationStateModuleDict, init_model_quant_state
1515
from ._recipe import get_default_recipe
1616
from ._module_swap_utils import swap_child_modules
@@ -322,6 +322,13 @@ def save_qconf_summary(self, qconf_summary):
322322
# pooling and lstm's input and output should have same scale_zp.
323323
sync_pool_and_lstm_input_output_scale_zp(quant_state_map, nodes)
324324
get_default_recipe(nodes)
325+
else:
326+
if check_model_obsever_has_run(model):
327+
# re-compute the scales and zp if user load a json file and re-do the calibration step.
328+
attach_scale_zp_values_to_model(model)
329+
else:
330+
# do nothing if user just loaded a json file and not re-do the calibration step
331+
pass
325332
# Setting model qconf_summary attr which can be easily to check the whether the scale/zp has been computed.
326333
self._qconf_summary = qconf_summary
327334
save_quant_state(quant_state_map, qconf_summary)
@@ -550,8 +557,8 @@ def unwrap_proxy(a):
550557
torch.nn.Module.__call__ = orig_module_call
551558
torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment]
552559

553-
# If module doesn't have a configure_file attr, we can say that user has run save_qconf_summary method which have
554-
# computed the scales and zp, or use the user's setting from a given json file(load_qconf_summary), we need to compute
560+
# If module doesn't have a configure_file attr, we can say that user didn't run save_qconf_summary method which have
561+
# computed the scales and zp, or didn't use the user's setting from a given json file(load_qconf_summary), we need to compute
555562
# the scale and zp here.
556563
if not hasattr(module, '_qconf_summary'):
557564
quant_state_map = module._fqn_to_auto_quant_state_map
@@ -562,10 +569,16 @@ def unwrap_proxy(a):
562569
sync_pool_and_lstm_input_output_scale_zp(quant_state_map, nodes)
563570
get_default_recipe(nodes)
564571
else:
565-
# Clear observer if module have, this will works when the user's json setting is loaded.
566-
for _, v in module._fqn_to_auto_quant_state_map.items():
567-
v.tensor_id_to_observer.clear()
568-
v.weight_tensor_id_to_observer.clear()
572+
if check_model_obsever_has_run(module):
573+
# re-compute the scales and zp if user load a json file and re-do the calibration step.
574+
attach_scale_zp_values_to_model(module)
575+
else:
576+
# clear observer if module have, this will works when the user's json setting is loaded
577+
# and not re-do the calibration step.
578+
for _, v in module._fqn_to_auto_quant_state_map.items():
579+
v.tensor_id_to_observer.clear()
580+
v.weight_tensor_id_to_observer.clear()
581+
569582
# Attach quant_info to parent each module
570583
attach_op_convert_info_to_model(module)
571584
swap_child_modules(module)

intel_extension_for_pytorch/quantization/_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,57 @@ def attach_scale_zp_values_to_model(
191191
if observer.dtype in quantized_dtype:
192192
scale, zp = observer.calculate_qparams()
193193
qstate.tensor_id_to_scale_zp[int(tensor_id)] = (scale, zp)
194+
else:
195+
assert False, "The observer's dtype only can be torch.quint8 or torch.qint8"
194196
for tensor_id, observer in qstate.weight_tensor_id_to_observer.items():
195197
if observer.dtype in quantized_dtype:
196198
scale, zp = observer.calculate_qparams()
197199
qstate.weight_tensor_id_to_scale_zp[tensor_id] = (scale, zp)
200+
else:
201+
assert False, "The observer's dtype only can be torch.quint8 or torch.qint8"
198202
qstate.tensor_id_to_observer.clear()
199203
qstate.weight_tensor_id_to_observer.clear()
200204

201205
for _, child in module.named_children():
202206
attach_scale_zp_values_to_model(child)
203207

208+
209+
def _check_observer_has_run(observer):
210+
if observer.min_val.numel() == 0 or observer.max_val.numel() == 0:
211+
return False
212+
if (observer.min_val.dim() == 0 or observer.max_val.dim() == 0) and \
213+
observer.min_val == float("inf") and observer.max_val == float("-inf"):
214+
return False
215+
return True
216+
217+
218+
def check_model_obsever_has_run(
219+
module: torch.nn.Module,
220+
) -> None:
221+
"""
222+
This function is about check whether the module's observer has been run by checking the
223+
observer's min_value and max_max_value is the init value or not.
224+
"""
225+
if hasattr(module, '_auto_quant_state'):
226+
qstate: AutoQuantizationState = module._auto_quant_state # type: ignore[assignment]
227+
quantized_dtype = [torch.quint8, torch.qint8]
228+
for tensor_id, observer in qstate.tensor_id_to_observer.items():
229+
if observer.dtype in quantized_dtype:
230+
return _check_observer_has_run(observer)
231+
else:
232+
assert False, "The observer's dtype only can be torch.quint8 or torch.qint8"
233+
for tensor_id, observer in qstate.weight_tensor_id_to_observer.items():
234+
if observer.dtype in quantized_dtype:
235+
return _check_observer_has_run(observer)
236+
else:
237+
assert False, "The observer's dtype only can be torch.quint8 or torch.qint8"
238+
239+
for _, child in module.named_children():
240+
check_model_obsever_has_run(child)
241+
242+
return True
243+
244+
204245
def attach_op_convert_info_to_model(
205246
module: torch.nn.Module,
206247
) -> None:

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,19 +307,21 @@ def forward(self, x):
307307
prepared_model = ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x, inplace=False)
308308
prepared_model(x)
309309
with tempfile.TemporaryDirectory() as tmp:
310+
# case1: save qconf and load qconf.
310311
path = os.path.join(tmp, "configure.json")
311312
prepared_model.save_qconf_summary(path)
312313
convert_model = ipex.quantization.convert(prepared_model)
313-
traced_model = torch.jit.trace(convert_model, x).eval()
314-
traced_model = torch.jit.freeze(traced_model)
315-
y_before = traced_model(x)
314+
traced_model_ref = torch.jit.trace(convert_model, x).eval()
315+
traced_model_ref = torch.jit.freeze(traced_model_ref)
316316
# load the saved qconf
317317
prepared_model = ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x, inplace=False)
318318
prepared_model.load_qconf_summary(path)
319319
convert_model = ipex.quantization.convert(prepared_model)
320320
traced_model = torch.jit.trace(convert_model, x).eval()
321321
traced_model = torch.jit.freeze(traced_model)
322-
y_after = traced_model(x)
322+
for i in range(2):
323+
y_before = traced_model_ref(x)
324+
y_after = traced_model(x)
323325
self.assertEqual(y_before, y_after)
324326
# save and load qconf again to make sure we didn't lost something
325327
path2 = os.path.join(tmp, "configure_new.json")
@@ -329,14 +331,45 @@ def forward(self, x):
329331
convert_model = ipex.quantization.convert(prepared_model)
330332
traced_model = torch.jit.trace(convert_model, x).eval()
331333
traced_model = torch.jit.freeze(traced_model)
332-
y_after = traced_model(x)
334+
for i in range(2):
335+
y_after = traced_model(x)
333336
self.assertEqual(y_before, y_after)
334337
# make sure the new saved json is same as old one.
335338
with open(path, 'r') as f:
336339
old_json = json.load(f)
337340
with open(path2, 'r') as f:
338341
new_json = json.load(f)
339-
self.assertTrue(old_json == new_json)
342+
self.assertTrue(old_json == new_json)
343+
344+
# case2: load qconf and re-do calibration, make sure the scales/zps is updated.
345+
x_new = torch.rand(1, 3, 2, 2) * 10
346+
# do ref quantization
347+
prepared_model= ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x_new, inplace=False)
348+
prepared_model(x_new)
349+
ref_path = os.path.join(tmp, "configure_ref.json")
350+
prepared_model.save_qconf_summary(ref_path)
351+
convert_model = ipex.quantization.convert(prepared_model)
352+
traced_model_ref = torch.jit.trace(convert_model, x_new).eval()
353+
traced_model_ref = torch.jit.freeze(traced_model_ref)
354+
# load qconf, and re-do calibration
355+
prepared_model = ipex.quantization.prepare(m, static_qconfig[0], example_inputs=x_new, inplace=False)
356+
prepared_model.load_qconf_summary(path2)
357+
prepared_model(x_new)
358+
new_path = os.path.join(tmp, "configure_new.json")
359+
prepared_model.save_qconf_summary(new_path)
360+
traced_model_new = torch.jit.trace(convert_model, x_new).eval()
361+
traced_model_new = torch.jit.freeze(traced_model_new)
362+
for i in range(2):
363+
y_ref = traced_model_ref(x_new)
364+
y_new = traced_model_new(x_new)
365+
self.assertEqual(y_ref, y_new)
366+
# make sure the new saved json is same as ref one.
367+
with open(ref_path, 'r') as f:
368+
old_json = json.load(f)
369+
with open(new_path, 'r') as f:
370+
new_json = json.load(f)
371+
self.assertTrue(old_json == new_json)
372+
340373

341374
class TestRemoveMutate(JitLlgaTestCase):
342375
def test_mutated_value_alive_after_inplace_op(self):

0 commit comments

Comments
 (0)