Skip to content

Commit 0ba122e

Browse files
cache prepacked weight using global cache (#117)
* cache prepacked weight using global cache * change the code format
1 parent cb4f02e commit 0ba122e

File tree

1 file changed

+54
-53
lines changed

1 file changed

+54
-53
lines changed

torch_ipex/csrc/cpu/WeightPrepack.cpp

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "WeightPrepack.h"
44
#include "mkldnn/MKLDNNCommon.h"
5+
#include "torch_ipex/csrc/rw_lock.h"
56
#include "torch_ipex/csrc/utils.h"
67

78
namespace torch_ipex {
@@ -11,7 +12,30 @@ namespace {
1112

1213
using weakref_type = c10::weak_intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>;
1314
using val_blocked = std::tuple<weakref_type, ideep::tensor>;
14-
thread_local std::unordered_map<c10::TensorImpl *, val_blocked> cached_weights;
15+
std::unordered_map<c10::TensorImpl *, val_blocked> cached_weights;
16+
17+
using map_iter =
18+
std::unordered_map<c10::TensorImpl *, val_blocked>::const_iterator;
19+
20+
torch_ipex::ReadWriteMutex rwmutex;
21+
22+
ideep::tensor read_cached_weights(const at::Tensor &weight) {
23+
torch_ipex::UniqueReadLock<torch_ipex::ReadWriteMutex> lock(rwmutex);
24+
ideep::tensor cached_weight;
25+
auto it = cached_weights.find(weight.unsafeGetTensorImpl());
26+
if (it != cached_weights.end()) {
27+
cached_weight = std::get<1>(it->second);
28+
}
29+
return cached_weight;
30+
}
31+
32+
void write_cached_weights(const at::Tensor &weight, ideep::tensor &result) {
33+
torch_ipex::UniqueWriteLock<torch_ipex::ReadWriteMutex> lock(rwmutex);
34+
cached_weights.emplace(
35+
weight.unsafeGetTensorImpl(),
36+
val_blocked{weakref_type(weight.getIntrusivePtr()), result});
37+
}
38+
1539
} // namespace
1640

1741
ideep::tensor get_conv_prepacked_weight(
@@ -23,10 +47,8 @@ ideep::tensor get_conv_prepacked_weight(
2347
int64_t groups,
2448
const ideep::attr_t& attr,
2549
const at::MemoryFormat& mkldnn_memory_format) {
26-
auto it = cached_weights.find(weight.unsafeGetTensorImpl());
27-
if (it != cached_weights.end()) {
28-
return std::get<1>(it->second);
29-
} else {
50+
ideep::tensor cached_weight = read_cached_weights(weight);
51+
if (cached_weight.is_empty()) {
3052
auto weight_ = weight.contiguous(mkldnn_memory_format);
3153
ideep::tensor w = itensor_view_from_dense(weight_);
3254
// TODO: 3d check
@@ -61,14 +83,11 @@ ideep::tensor get_conv_prepacked_weight(
6183
input.sizes().vec(),
6284
attr);
6385
}
64-
ideep::tensor result;
65-
result.init(packed_desc);
66-
result.feed_from(w);
67-
cached_weights.emplace(
68-
weight.unsafeGetTensorImpl(),
69-
val_blocked{weakref_type(weight.getIntrusivePtr()), result});
70-
return result;
86+
cached_weight.init(packed_desc);
87+
cached_weight.feed_from(w);
88+
write_cached_weights(weight, cached_weight);
7189
}
90+
return cached_weight;
7291
}
7392

7493
ideep::tensor get_conv_prepacked_weight(
@@ -274,28 +293,21 @@ ideep::tensor get_linear_prepacked_weight(
274293
const at::Tensor& weight,
275294
const int64_t batch_size,
276295
const at::ScalarType src_dtype) {
277-
auto it = cached_weights.find(weight.unsafeGetTensorImpl());
278-
if (it != cached_weights.end()) {
279-
return std::get<1>(it->second);
280-
} else {
296+
ideep::tensor cached_weight = read_cached_weights(weight);
297+
if (cached_weight.is_empty()) {
281298
auto weight_ = weight.is_contiguous() ? weight : weight.contiguous();
282299
ideep::tensor w = itensor_view_from_dense(weight_);
283300
auto out_features = weight_.size(0);
284301
auto in_features = weight_.size(1);
285302
ideep::dims input_dims({batch_size, weight.size(1)});
286303
auto packed_desc = ideep::inner_product_forward::expected_weights_desc(
287-
{weight.sizes().cbegin(), weight.sizes().cend()},
288-
input_dims,
289-
w.get_data_type(),
290-
get_mkldnn_dtype(src_dtype));
291-
ideep::tensor result;
292-
result.init(packed_desc);
293-
result.feed_from(w);
294-
cached_weights.emplace(
295-
weight.unsafeGetTensorImpl(),
296-
val_blocked{weakref_type(weight.getIntrusivePtr()), result});
297-
return result;
304+
{weight.sizes().cbegin(), weight.sizes().cend()}, input_dims,
305+
w.get_data_type(), get_mkldnn_dtype(src_dtype));
306+
cached_weight.init(packed_desc);
307+
cached_weight.feed_from(w);
308+
write_cached_weights(weight, cached_weight);
298309
}
310+
return cached_weight;
299311
}
300312

301313
// Create mkldnn memory view from ATen tensor
@@ -318,8 +330,8 @@ inline ideep::tensor get_mkldnn_tensor_view(
318330
}
319331

320332
bool is_prepacked(const at::Tensor& weight) {
321-
auto it = cached_weights.find(weight.unsafeGetTensorImpl());
322-
return it == cached_weights.end() ? false : true;
333+
auto cached_weight = read_cached_weights(weight);
334+
return cached_weight.is_empty();
323335
}
324336

325337
std::tuple<ideep::tensor, ideep::tensor> get_lstm_prepacked_weight(
@@ -334,18 +346,14 @@ std::tuple<ideep::tensor, ideep::tensor> get_lstm_prepacked_weight(
334346
const ideep::tensor& src_iter_c,
335347
const ideep::tensor& bias,
336348
const bool reverse) {
337-
auto it_i = cached_weights.find(weight_ih.unsafeGetTensorImpl());
338-
auto it_h = cached_weights.find(weight_hh.unsafeGetTensorImpl());
339-
340-
bool all_in_cache = it_i != cached_weights.end() && it_h != cached_weights.end();
341-
bool all_miss = it_i == cached_weights.end() && it_h == cached_weights.end();
349+
auto cached_weight_ih = read_cached_weights(weight_ih);
350+
auto cached_weight_hh = read_cached_weights(weight_hh);
351+
bool all_in_cache =
352+
!cached_weight_ih.is_empty() && !cached_weight_hh.is_empty();
353+
bool all_miss = cached_weight_ih.is_empty() && cached_weight_hh.is_empty();
342354
TORCH_CHECK(all_in_cache || all_miss, "both of the weights of LSTM should be cached or neither should be cached");
343355

344-
if (it_i != cached_weights.end()) {
345-
ideep::tensor w_ih = std::get<1>(it_i->second);
346-
ideep::tensor w_hh = std::get<1>(it_h->second);
347-
return std::make_tuple(w_ih, w_hh);
348-
} else {
356+
if (cached_weight_ih.is_empty()) {
349357
auto w1 = get_mkldnn_tensor_view(weight_ih, {{1, 1, input_size, num_gates, hidden_size}, get_mkldnn_dtype(weight_ih.scalar_type()), ideep::format_tag::ldgoi});
350358
auto w2 = get_mkldnn_tensor_view(weight_hh, {{1, 1, hidden_size, num_gates, hidden_size}, get_mkldnn_dtype(weight_hh.scalar_type()), ideep::format_tag::ldgoi});
351359

@@ -369,23 +377,16 @@ std::tuple<ideep::tensor, ideep::tensor> get_lstm_prepacked_weight(
369377
if (packed_desc_ih.is_rnn_packed() || packed_desc_hh.is_rnn_packed()) {
370378
return std::make_tuple(w1, w2);
371379
}
380+
cached_weight_ih.init(packed_desc_ih);
381+
cached_weight_hh.init(packed_desc_hh);
372382

373-
ideep::tensor packed_ih {packed_desc_ih};
374-
ideep::tensor packed_hh {packed_desc_hh};
375-
376-
packed_ih.feed_from(w1);
377-
packed_hh.feed_from(w2);
378-
379-
cached_weights.emplace(
380-
weight_ih.unsafeGetTensorImpl(),
381-
val_blocked{weakref_type(weight_ih.getIntrusivePtr()), packed_ih});
382-
383-
cached_weights.emplace(
384-
weight_hh.unsafeGetTensorImpl(),
385-
val_blocked{weakref_type(weight_hh.getIntrusivePtr()), packed_hh});
383+
cached_weight_ih.feed_from(w1);
384+
cached_weight_hh.feed_from(w2);
386385

387-
return std::make_tuple(packed_ih, packed_hh);
386+
write_cached_weights(weight_ih, cached_weight_ih);
387+
write_cached_weights(weight_hh, cached_weight_hh);
388388
}
389+
return std::make_tuple(cached_weight_ih, cached_weight_hh);
389390
}
390391

391392
at::Tensor linear_weight_prepack(

0 commit comments

Comments
 (0)