66from torch .fx .node import map_aggregate
77from torch .ao .quantization import PlaceholderObserver
88from torch .quantization .qconfig import QConfig
9+ from torch .nn .utils .rnn import PackedSequence
910
1011from ._utils import get_torch_function_hook_type , HookType , get_module_hook_type , OpQuantizeabilityType , \
1112 attach_op_convert_info_to_model , save_quant_state , attach_scale_zp_values_to_model , convert_quant_state_map_to_nodes , \
@@ -36,6 +37,25 @@ def _check_add_has_scalar_input(args):
3637 return True
3738 return False
3839
40+ def _convert_PackedSequence_to_tuple_lstm (args ):
41+ if isinstance (args , tuple ) and len (args ) == 2 : # (PackedSequence, hx)
42+ input , batch_sizes , sorted_indices , unsorted_indices = args [0 ]
43+ args = (input , batch_sizes , sorted_indices , unsorted_indices , args [- 1 ])
44+ elif isinstance (args , tuple ) and len (args ) == 1 : # (PackedSequence, )
45+ input , batch_sizes , sorted_indices , unsorted_indices = args [0 ]
46+ args = (input , batch_sizes , sorted_indices , unsorted_indices )
47+ else :
48+ assert False , "_convert_PackedSequence_to_tuple args should be a tuple with size 2 or PackedSequence"
49+ return args
50+
51+ def _convert_tuple_to_PackedSequence_lstm (args ):
52+ assert isinstance (args , tuple ) and len (args ) >= 4 and len (args ) <= 5 , "_convert_tuple_to_PackedSequence input should be a tuple(5=<size >=4)"
53+ if len (args ) == 4 :
54+ return (PackedSequence (* args ),)
55+ else :
56+ return (PackedSequence (* args [:- 1 ]), args [- 1 ])
57+
58+
3959def auto_prepare (
4060 model : torch .nn .Module ,
4161 configure : QConfig ,
@@ -212,7 +232,9 @@ def _patched_module_call(self, *args, **kwargs):
212232 old_global_disable_torch_function_override = \
213233 global_disable_torch_function_override
214234 global_disable_torch_function_override = True
215-
235+ is_lstm_packed_input = isinstance (cur_module , torch .nn .LSTM ) and isinstance (args [0 ], PackedSequence )
236+ if is_lstm_packed_input :
237+ args = _convert_PackedSequence_to_tuple_lstm (args )
216238 if first_call :
217239 # mypy ignore is used instead of assert because this
218240 # runs on every forward and assert has a performance cost
@@ -226,19 +248,28 @@ def _patched_module_call(self, *args, **kwargs):
226248 args , kwargs = parent_qstate .op_prepare_before_hook (
227249 cur_module , args , kwargs ) # type: ignore[arg-type]
228250
251+ if is_lstm_packed_input :
252+ args = _convert_tuple_to_PackedSequence_lstm (args )
253+
229254 # original forward
230255 output = orig_module_call (self , * args , ** kwargs )
231256 # Re-enable the overrides.
232257 global_disable_torch_function_override = \
233258 old_global_disable_torch_function_override
234259
235260 # after hooks
261+ if is_lstm_packed_input :
262+ output = _convert_PackedSequence_to_tuple_lstm (output )
236263 if first_call :
237264 output = parent_qstate .first_call_op_prepare_after_hook (
238265 cur_module , output , args , qtensor_id , OpQuantizeabilityType .QUANTIZEABLE )
239266 else :
240267 output = parent_qstate .op_prepare_after_hook (
241268 cur_module , output , args , global_op_idx )
269+
270+ if is_lstm_packed_input :
271+ output = _convert_tuple_to_PackedSequence_lstm (output )
272+
242273 parent_qstate .mark_cur_op_complete (cur_module )
243274 elif hook_type is HookType .MODULE_IO_HOOKS :
244275 cur_qstate = cur_module ._auto_quant_state
@@ -500,17 +531,25 @@ def _patched_module_call(self, *args, **kwargs):
500531 old_global_disable_torch_function_override = \
501532 global_disable_torch_function_override
502533 global_disable_torch_function_override = True
534+ is_lstm_packed_input = isinstance (cur_module , torch .nn .LSTM ) and isinstance (args [0 ], PackedSequence )
535+ if is_lstm_packed_input :
536+ args = _convert_PackedSequence_to_tuple_lstm (args )
503537 _ , args , kwargs = qstate .op_convert_before_hook (
504538 cur_module , args , kwargs , cur_module )
539+ if is_lstm_packed_input :
540+ args = _convert_tuple_to_PackedSequence_lstm (args )
505541 if type (cur_module ) in quantized_modules_has_weights :
506542 weights = qstate .op_weight_convert_before_hook (cur_module )
507543 output = module_call_to_function_call (self , args , weights )
508544 else :
509545 output = orig_module_call (self , * args , ** kwargs )
510546 # after hooks
547+ if is_lstm_packed_input :
548+ output = _convert_PackedSequence_to_tuple_lstm (output )
511549 output = qstate .op_convert_after_hook (
512550 cur_module , output )
513-
551+ if is_lstm_packed_input :
552+ output = _convert_tuple_to_PackedSequence_lstm (output )
514553 # Re-enable the override.
515554 global_disable_torch_function_override = \
516555 old_global_disable_torch_function_override
0 commit comments