Skip to content

Commit 023c104

Browse files
authored
[release/2.0] Fix rnn_packed format check (#1592)
* add UT for lstm weight reorder * use is_opaque instead of is_rnn_packed * update ideep commit to include is_opaque API
1 parent 6beb3d4 commit 023c104

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

csrc/cpu/aten/WeightPack.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ bool is_packed(const at::Tensor& weight) {
6161
std::tuple<ideep::tensor, ideep::tensor> CommonLstmWeightDesc::
6262
get_and_save_lstm_packed_weight() {
6363
ideep::tensor cached_weight_ih, cached_weight_hh;
64-
// Don't pack when the weight is of rnn_packed format
64+
// Don't pack when the weight is of opaque format (rnn_packed format).
6565
// When the weight is of rnn_packed format, if the seq_lens of
6666
// the input changes, the format of weight also changes.
6767
// oneDNN does not support reorder from rnn_packed back to public
6868
// format. LSTM based on BRGEMM kernel (on AVX512 and newest ISAs) will
6969
// use blocked format for weight of LSTM, which won't change when the
7070
// input seq_lens changes.
71-
if (packed_desc_ih_.is_rnn_packed() || packed_desc_hh_.is_rnn_packed()) {
71+
if (packed_desc_ih_.is_opaque() || packed_desc_hh_.is_opaque()) {
7272
return std::make_tuple(w1_src_, w2_src_);
7373
}
7474

tests/cpu/test_weight_prepack.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,35 @@ def forward(self, x, h=None):
12951295
os.remove('origin_checkpoint.pth')
12961296
os.remove('ipex_checkpoint.pth')
12971297

1298+
def test_lstm_weight_reorder(self):
1299+
class Lstm(torch.nn.Module):
1300+
def __init__(self, input_size, hidden_size, num_layers, bidirectional, bias, dropout, batch_first):
1301+
super(Lstm, self).__init__()
1302+
self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first)
1303+
1304+
def forward(self, x, h=None):
1305+
x, h = self.lstm(x, h)
1306+
return x, h
1307+
1308+
test_dtypes = []
1309+
if core.onednn_has_bf16_support():
1310+
test_dtypes.append(torch.bfloat16)
1311+
for dtype in test_dtypes:
1312+
m = Lstm(2, 3, 1, False, False, 0, False)
1313+
x = torch.randn(2, 1, 2)
1314+
x_var = torch.randn(5, 1, 2)
1315+
1316+
origin_model = copy.deepcopy(m).eval()
1317+
ipex_model = ipex.optimize(origin_model, dtype=dtype)
1318+
1319+
with torch.cpu.amp.autocast(enabled=True, dtype=dtype):
1320+
# run with 2 different shapes to verify weight prepack works when weight format changes
1321+
y = ipex_model(x)
1322+
y_var = ipex_model(x_var)
1323+
1324+
y_ref = origin_model(x_var)
1325+
self.assertEqual(y_var, y_ref)
1326+
12981327
if __name__ == '__main__':
12991328
torch.manual_seed(2020)
13001329
test = unittest.main()

third_party/ideep

0 commit comments

Comments
 (0)