Skip to content

Commit b8848b5

Browse files
authored
fix CPU UTs (#2531)
1 parent 2fcd381 commit b8848b5

File tree

5 files changed

+102
-24
lines changed

5 files changed

+102
-24
lines changed

intel_extension_for_pytorch/quantization/_quantization_state_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
F.conv3d,
2525
torch.conv2d,
2626
torch.conv3d,
27-
#F.conv_transpose2d, #TODO
28-
#F.conv_transpose3d, #TODO
29-
#torch.conv_transpose2d, #TODO
30-
#torch.conv_transpose3d, #TODO
27+
F.conv_transpose2d,
28+
F.conv_transpose3d,
29+
torch.conv_transpose2d,
30+
torch.conv_transpose3d,
3131
torch.relu,
3232
F.relu,
3333
#torch.sigmoid, # TODO
@@ -50,8 +50,8 @@
5050
module_types_supported_by_quantization = set([
5151
torch.nn.Conv2d,
5252
torch.nn.Conv3d,
53-
#torch.nn.ConvTranspose2d,
54-
#torch.nn.ConvTranspose3d,
53+
torch.nn.ConvTranspose2d,
54+
torch.nn.ConvTranspose3d,
5555
torch.nn.Linear,
5656
torch.nn.MaxPool2d,
5757
torch.nn.MaxPool3d,
@@ -90,10 +90,10 @@
9090
str(F.conv3d),
9191
str(torch.conv2d),
9292
str(torch.conv3d),
93-
#str(F.conv_transpose2d),
94-
#str(F.conv_transpose3d),
95-
#str(torch.conv_transpose2d),
96-
#str(torch.conv_transpose3d),
93+
str(F.conv_transpose2d),
94+
str(F.conv_transpose3d),
95+
str(torch.conv_transpose2d),
96+
str(torch.conv_transpose3d),
9797
str(F.linear),
9898
str(torch._C._nn.linear),
9999
]
@@ -102,8 +102,8 @@
102102
#str(torch.nn.Conv1d) # it will be enabled at next step.
103103
str(torch.nn.Conv2d),
104104
str(torch.nn.Conv3d),
105-
#str(torch.nn.ConvTranspose2d),
106-
#str(torch.nn.ConvTranspose3d),
105+
str(torch.nn.ConvTranspose2d),
106+
str(torch.nn.ConvTranspose3d),
107107
str(torch.nn.Linear),
108108
]
109109

intel_extension_for_pytorch/quantization/_quantize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def prepare(
3030
torch.nn.Module
3131
"""
3232
assert isinstance(model, torch.nn.Module), "Only support nn.Module prepare for quantization path"
33+
# auto model channels_last memory format conversion
34+
from ..frontend import auto_channels_last, _convert_convNd_weight_memory_format
35+
if auto_channels_last:
36+
_convert_convNd_weight_memory_format(model)
3337
try:
3438
prepare_model = optimization.fuse(model, inplace=inplace)
3539
prepare_model = linear_bn_fuse(prepare_model, inplace=inplace)

intel_extension_for_pytorch/quantization/_quantize_utils.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.fx.node import map_aggregate
77
from torch.ao.quantization import PlaceholderObserver
88
from torch.quantization.qconfig import QConfig
9+
from torch.nn.utils.rnn import PackedSequence
910

1011
from ._utils import get_torch_function_hook_type, HookType, get_module_hook_type, OpQuantizeabilityType, \
1112
attach_op_convert_info_to_model, save_quant_state, attach_scale_zp_values_to_model, convert_quant_state_map_to_nodes, \
@@ -36,6 +37,25 @@ def _check_add_has_scalar_input(args):
3637
return True
3738
return False
3839

40+
def _convert_PackedSequence_to_tuple_lstm(args):
41+
if isinstance(args, tuple) and len(args) == 2: # (PackedSequence, hx)
42+
input, batch_sizes, sorted_indices, unsorted_indices = args[0]
43+
args = (input, batch_sizes, sorted_indices, unsorted_indices, args[-1])
44+
elif isinstance(args, tuple) and len(args) == 1: # (PackedSequence, )
45+
input, batch_sizes, sorted_indices, unsorted_indices = args[0]
46+
args = (input, batch_sizes, sorted_indices, unsorted_indices)
47+
else:
48+
assert False, "_convert_PackedSequence_to_tuple args should be a tuple with size 2 or PackedSequence"
49+
return args
50+
51+
def _convert_tuple_to_PackedSequence_lstm(args):
52+
assert isinstance(args, tuple) and len(args) >= 4 and len(args) <=5, "_convert_tuple_to_PackedSequence input should be a tuple(5=<size >=4)"
53+
if len(args) == 4:
54+
return (PackedSequence(*args),)
55+
else:
56+
return (PackedSequence(*args[:-1]), args[-1])
57+
58+
3959
def auto_prepare(
4060
model : torch.nn.Module,
4161
configure: QConfig,
@@ -212,7 +232,9 @@ def _patched_module_call(self, *args, **kwargs):
212232
old_global_disable_torch_function_override = \
213233
global_disable_torch_function_override
214234
global_disable_torch_function_override = True
215-
235+
is_lstm_packed_input = isinstance(cur_module, torch.nn.LSTM) and isinstance(args[0], PackedSequence)
236+
if is_lstm_packed_input:
237+
args = _convert_PackedSequence_to_tuple_lstm(args)
216238
if first_call:
217239
# mypy ignore is used instead of assert because this
218240
# runs on every forward and assert has a performance cost
@@ -226,19 +248,28 @@ def _patched_module_call(self, *args, **kwargs):
226248
args, kwargs = parent_qstate.op_prepare_before_hook(
227249
cur_module, args, kwargs) # type: ignore[arg-type]
228250

251+
if is_lstm_packed_input:
252+
args = _convert_tuple_to_PackedSequence_lstm(args)
253+
229254
# original forward
230255
output = orig_module_call(self, *args, **kwargs)
231256
# Re-enable the overrides.
232257
global_disable_torch_function_override = \
233258
old_global_disable_torch_function_override
234259

235260
# after hooks
261+
if is_lstm_packed_input:
262+
output = _convert_PackedSequence_to_tuple_lstm(output)
236263
if first_call:
237264
output = parent_qstate.first_call_op_prepare_after_hook(
238265
cur_module, output, args, qtensor_id, OpQuantizeabilityType.QUANTIZEABLE)
239266
else:
240267
output = parent_qstate.op_prepare_after_hook(
241268
cur_module, output, args, global_op_idx)
269+
270+
if is_lstm_packed_input:
271+
output = _convert_tuple_to_PackedSequence_lstm(output)
272+
242273
parent_qstate.mark_cur_op_complete(cur_module)
243274
elif hook_type is HookType.MODULE_IO_HOOKS:
244275
cur_qstate = cur_module._auto_quant_state
@@ -500,17 +531,25 @@ def _patched_module_call(self, *args, **kwargs):
500531
old_global_disable_torch_function_override = \
501532
global_disable_torch_function_override
502533
global_disable_torch_function_override = True
534+
is_lstm_packed_input = isinstance(cur_module, torch.nn.LSTM) and isinstance(args[0], PackedSequence)
535+
if is_lstm_packed_input:
536+
args = _convert_PackedSequence_to_tuple_lstm(args)
503537
_, args, kwargs = qstate.op_convert_before_hook(
504538
cur_module, args, kwargs, cur_module)
539+
if is_lstm_packed_input:
540+
args = _convert_tuple_to_PackedSequence_lstm(args)
505541
if type(cur_module) in quantized_modules_has_weights:
506542
weights = qstate.op_weight_convert_before_hook(cur_module)
507543
output = module_call_to_function_call(self, args, weights)
508544
else:
509545
output = orig_module_call(self, *args, **kwargs)
510546
# after hooks
547+
if is_lstm_packed_input:
548+
output = _convert_PackedSequence_to_tuple_lstm(output)
511549
output = qstate.op_convert_after_hook(
512550
cur_module, output)
513-
551+
if is_lstm_packed_input:
552+
output = _convert_tuple_to_PackedSequence_lstm(output)
514553
# Re-enable the override.
515554
global_disable_torch_function_override = \
516555
old_global_disable_torch_function_override

intel_extension_for_pytorch/quantization/_recipe.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
conv_gemm_ops = [str(F.conv2d), str(nn.Conv2d), str(F.conv3d), str(nn.Conv3d), str(torch.conv2d), str(torch.conv3d), \
1717
str(F.conv_transpose2d), str(torch.nn.ConvTranspose2d), str(F.conv_transpose3d), str(torch.nn.ConvTranspose3d),
1818
str(torch.conv_transpose2d), str(torch.conv_transpose2d), str(F.linear), str(nn.Linear), str(torch.matmul), str(torch.Tensor.matmul)]
19+
conv_ops = [str(F.conv2d), str(nn.Conv2d), str(F.conv3d), str(nn.Conv3d), str(torch.conv2d), str(torch.conv3d), \
20+
str(F.conv_transpose2d), str(torch.nn.ConvTranspose2d), str(F.conv_transpose3d), str(torch.nn.ConvTranspose3d),
21+
str(torch.conv_transpose2d), str(torch.conv_transpose2d)]
1922
rnn_ops = [str(torch.nn.LSTM)]
2023

2124
# Those ops only support s8->s8 path, and also require the qscheme is per_tensor_symmetric.
@@ -60,6 +63,17 @@ def _default_recipe_init(nodes):
6063
tensor_info.inf_dtype = tensor_info.orig_dtype
6164
node.input_tensor_force_inf_dtype[idx] = tensor_info.inf_dtype
6265

66+
# For LSTM, if it's input is a PackedSequence, we don't support ot now.
67+
# TODO: support PackedSequence input for quantization LSTM.
68+
if node.type in rnn_ops and len(node.input_tensor_infos) > 2:
69+
for idx, tensor_info in enumerate(node.input_tensor_infos):
70+
if tensor_info is not None:
71+
tensor_info.inf_dtype = tensor_info.orig_dtype
72+
node.input_tensor_force_inf_dtype[idx] = tensor_info.inf_dtype
73+
for idx, tensor_info in enumerate(node.weight_tensor_infos):
74+
if tensor_info is not None:
75+
tensor_info.inf_dtype = tensor_info.orig_dtype
76+
6377
#TODO: making fusion pattern check more general.
6478
def _find_fused_node_with_cur_elt_wise(node, ops):
6579
r"""
@@ -198,6 +212,20 @@ def _check_has_quantizable_node_before_node(node):
198212
# for none ipex customer op, if have a qconfig, we can say it is a quantizable op.
199213
return True
200214

215+
def _check_has_quantizable_node_after_node(node):
216+
r"""
217+
This function is about check whether all quantizable nodes after the given node,
218+
which is used to check whether insert fake quant before one quantizable node or not.
219+
"""
220+
if len(node.post_nodes) > 0:
221+
output = True
222+
for i in range(len(node.post_nodes)):
223+
if node.post_nodes[i].qconfig is None:
224+
output = False
225+
return output
226+
else:
227+
return False
228+
201229
def _add_recipe(node):
202230
'''
203231
Case1: add has pre gemm node.
@@ -233,6 +261,7 @@ def reset_input_inf_dtype_to_orig_dtype(node, input_idx):
233261
node.input_tensor_force_inf_dtype[input_idx] = node.input_tensor_infos[input_idx].inf_dtype
234262

235263
conv_gemm_node = _find_fused_node_with_cur_add(node, conv_gemm_ops)
264+
conv_node = _find_fused_node_with_cur_add(node, conv_ops)
236265
if conv_gemm_node is None:
237266
# If pre_nodes don't have gemm node, need to check whether have quantizable node before it,
238267
# if does't have quantizable node before it, we will not insert fake quant before add.
@@ -255,13 +284,17 @@ def reset_input_inf_dtype_to_orig_dtype(node, input_idx):
255284
if node.input_tensor_infos[0] is not None and node.input_tensor_infos[0] in conv_gemm_node.output_tensor_infos:
256285
node.input_tensor_infos[0].inf_dtype = node.input_tensor_infos[0].orig_dtype
257286
node.input_tensor_force_inf_dtype[0] = node.input_tensor_infos[0].inf_dtype
258-
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
259-
reset_input_inf_dtype_to_orig_dtype(node, 1)
287+
# TODO: set another input's dtype for conv nodes when oneDNN is ready.
288+
if conv_node is None or not _check_has_quantizable_node_after_node(node):
289+
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
290+
reset_input_inf_dtype_to_orig_dtype(node, 1)
260291
elif node.input_tensor_infos[1] is not None and node.input_tensor_infos[1] in conv_gemm_node.output_tensor_infos:
261292
node.input_tensor_infos[1].inf_dtype = node.input_tensor_infos[1].orig_dtype
262293
node.input_tensor_force_inf_dtype[1] = node.input_tensor_infos[1].inf_dtype
263-
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
264-
reset_input_inf_dtype_to_orig_dtype(node, 0)
294+
# TODO: set another input's dtype for conv nodes when oneDNN is ready.
295+
if conv_node is None or not _check_has_quantizable_node_after_node(node):
296+
# set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
297+
reset_input_inf_dtype_to_orig_dtype(node, 0)
265298

266299
# get a default recipe
267300
def get_default_recipe(nodes):

intel_extension_for_pytorch/quantization/_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def set_node_output_quantized(nodes):
403403
# output's infe dtype is not int8, set it and also set insert_fake_quant_after_output to True.
404404
"""
405405
def _reset_post_node_input_infos(node):
406-
# make sure the post node will node insert fake quant if we add fake quant by cur node' output
406+
# make sure the post node will insert fake quant if we add fake quant by cur node' output
407407
if len(node.post_nodes) > 0:
408408
for post_node in node.post_nodes:
409409
if post_node.qconfig is not None:
@@ -434,10 +434,12 @@ def _reset_post_node_input_infos(node):
434434
node.insert_fake_quant_after_outputs[0] = True
435435
_reset_post_node_input_infos(node)
436436
else:
437-
if node.input_tensor_force_inf_dtype[0] in [torch.qint8, torch.quint8] and not post_node_are_quantized:
438-
node.output_tensor_infos[0].inf_dtype = node.input_tensor_force_inf_dtype[0]
439-
node.insert_fake_quant_after_outputs[0] = True
440-
_reset_post_node_input_infos(node)
437+
# TODO: enable PackedSequence input for LSTM.
438+
if not (node.type in [nn.LSTM] and len(node.input_tensor_infos) > 2):
439+
if node.input_tensor_force_inf_dtype[0] in [torch.qint8, torch.quint8] and not post_node_are_quantized:
440+
node.output_tensor_infos[0].inf_dtype = node.input_tensor_force_inf_dtype[0]
441+
node.insert_fake_quant_after_outputs[0] = True
442+
_reset_post_node_input_infos(node)
441443

442444
qscheme_dict = {
443445
str(torch.per_tensor_affine): torch.per_tensor_affine,
@@ -794,7 +796,7 @@ def module_call_to_function_call(module, args, weights):
794796
output = F.embedding_bag(args[0], weights[0], args[1], module.max_norm, \
795797
module.norm_type, module.scale_grad_by_freq, module.mode, module.sparse,
796798
args[2] if len(args) == 3 else None, module.include_last_offset, module.padding_idx)
797-
elif isinstance(module, torch.nn.ConvTranspose2d) or isinstance(module, torch.nn.ConvTranspose2d):
799+
elif isinstance(module, torch.nn.ConvTranspose2d) or isinstance(module, torch.nn.ConvTranspose3d):
798800
if module.padding_mode != 'zeros':
799801
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
800802
assert isinstance(module.padding, tuple)

0 commit comments

Comments
 (0)