Skip to content

Commit 5b44996

Browse files
quantization: make lstm quantizable when has state input (#1473)
1 parent 2624a27 commit 5b44996

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

intel_extension_for_pytorch/quantization/_recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _default_recipe_init(nodes):
6666

6767
# For LSTM, if it's input is a PackedSequence, we don't support ot now.
6868
# TODO: support PackedSequence input for quantization LSTM.
69-
if node.type in rnn_ops and len(node.input_tensor_infos) > 2:
69+
if node.type in rnn_ops and len(node.input_tensor_infos) > 2 and node.input_tensor_infos[1].orig_dtype == torch.int64:
7070
for idx, tensor_info in enumerate(node.input_tensor_infos):
7171
if tensor_info is not None:
7272
tensor_info.inf_dtype = tensor_info.orig_dtype

intel_extension_for_pytorch/quantization/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def _reset_post_node_input_infos(node):
435435
_reset_post_node_input_infos(node)
436436
else:
437437
# TODO: enable PackedSequence input for LSTM.
438-
if not (node.type in [nn.LSTM] and len(node.input_tensor_infos) > 2):
438+
if not (node.type in [nn.LSTM] and len(node.input_tensor_infos) > 2 and node.input_tensor_infos[1].orig_dtype == torch.int64):
439439
if node.input_tensor_force_inf_dtype[0] in [torch.qint8, torch.quint8] and not post_node_are_quantized:
440440
node.output_tensor_infos[0].inf_dtype = node.input_tensor_force_inf_dtype[0]
441441
node.insert_fake_quant_after_outputs[0] = True

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,6 @@ def _lstm_params_list():
257257
x = torch.randn(batch_size, seq_len, input_size)
258258
else:
259259
x = torch.randn(seq_len, batch_size, input_size)
260-
h = torch.randn(num_layers * num_directions, batch_size, hidden_size)
261-
c = torch.randn(num_layers * num_directions, batch_size, hidden_size)
262260
m = M(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first)
263261
graph = self.checkQuantizeTrace(m, [x], atol=3e-2, rtol=1e-1)
264262
self.assertGraphContainsExactly(graph, 'ipex::quantized_lstm', 1)
@@ -303,8 +301,10 @@ def forward(self, input, hid=None):
303301

304302
model = M().eval()
305303
seq = torch.randn(24, 1, 512)
304+
h0 = torch.zeros((2, 1, 256), dtype=seq.dtype)
305+
hid = (h0, h0)
306306

307-
graph = self.checkQuantizeTrace(model, [seq], atol=3e-2, rtol=1e-1)
307+
graph = self.checkQuantizeTrace(model, [seq, hid], atol=3e-2, rtol=1e-1)
308308
self.assertGraphContainsExactly(graph, 'ipex::quantized_lstm', 1)
309309
self.assertGraphContainsExactly(graph, 'aten::lstm', 0)
310310

0 commit comments

Comments
 (0)