22
33#include " WeightPrepack.h"
44#include " mkldnn/MKLDNNCommon.h"
5+ #include " torch_ipex/csrc/rw_lock.h"
56#include " torch_ipex/csrc/utils.h"
67
78namespace torch_ipex {
@@ -11,7 +12,30 @@ namespace {
1112
1213using weakref_type = c10::weak_intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>;
1314using val_blocked = std::tuple<weakref_type, ideep::tensor>;
14- thread_local std::unordered_map<c10::TensorImpl *, val_blocked> cached_weights;
15+ std::unordered_map<c10::TensorImpl *, val_blocked> cached_weights;
16+
17+ using map_iter =
18+ std::unordered_map<c10::TensorImpl *, val_blocked>::const_iterator;
19+
20+ torch_ipex::ReadWriteMutex rwmutex;
21+
22+ ideep::tensor read_cached_weights (const at::Tensor &weight) {
23+ torch_ipex::UniqueReadLock<torch_ipex::ReadWriteMutex> lock (rwmutex);
24+ ideep::tensor cached_weight;
25+ auto it = cached_weights.find (weight.unsafeGetTensorImpl ());
26+ if (it != cached_weights.end ()) {
27+ cached_weight = std::get<1 >(it->second );
28+ }
29+ return cached_weight;
30+ }
31+
32+ void write_cached_weights (const at::Tensor &weight, ideep::tensor &result) {
33+ torch_ipex::UniqueWriteLock<torch_ipex::ReadWriteMutex> lock (rwmutex);
34+ cached_weights.emplace (
35+ weight.unsafeGetTensorImpl (),
36+ val_blocked{weakref_type (weight.getIntrusivePtr ()), result});
37+ }
38+
1539} // namespace
1640
1741ideep::tensor get_conv_prepacked_weight (
@@ -23,10 +47,8 @@ ideep::tensor get_conv_prepacked_weight(
2347 int64_t groups,
2448 const ideep::attr_t & attr,
2549 const at::MemoryFormat& mkldnn_memory_format) {
26- auto it = cached_weights.find (weight.unsafeGetTensorImpl ());
27- if (it != cached_weights.end ()) {
28- return std::get<1 >(it->second );
29- } else {
50+ ideep::tensor cached_weight = read_cached_weights (weight);
51+ if (cached_weight.is_empty ()) {
3052 auto weight_ = weight.contiguous (mkldnn_memory_format);
3153 ideep::tensor w = itensor_view_from_dense (weight_);
3254 // TODO: 3d check
@@ -61,14 +83,11 @@ ideep::tensor get_conv_prepacked_weight(
6183 input.sizes ().vec (),
6284 attr);
6385 }
64- ideep::tensor result;
65- result.init (packed_desc);
66- result.feed_from (w);
67- cached_weights.emplace (
68- weight.unsafeGetTensorImpl (),
69- val_blocked{weakref_type (weight.getIntrusivePtr ()), result});
70- return result;
86+ cached_weight.init (packed_desc);
87+ cached_weight.feed_from (w);
88+ write_cached_weights (weight, cached_weight);
7189 }
90+ return cached_weight;
7291}
7392
7493ideep::tensor get_conv_prepacked_weight (
@@ -274,28 +293,21 @@ ideep::tensor get_linear_prepacked_weight(
274293 const at::Tensor& weight,
275294 const int64_t batch_size,
276295 const at::ScalarType src_dtype) {
277- auto it = cached_weights.find (weight.unsafeGetTensorImpl ());
278- if (it != cached_weights.end ()) {
279- return std::get<1 >(it->second );
280- } else {
296+ ideep::tensor cached_weight = read_cached_weights (weight);
297+ if (cached_weight.is_empty ()) {
281298 auto weight_ = weight.is_contiguous () ? weight : weight.contiguous ();
282299 ideep::tensor w = itensor_view_from_dense (weight_);
283300 auto out_features = weight_.size (0 );
284301 auto in_features = weight_.size (1 );
285302 ideep::dims input_dims ({batch_size, weight.size (1 )});
286303 auto packed_desc = ideep::inner_product_forward::expected_weights_desc (
287- {weight.sizes ().cbegin (), weight.sizes ().cend ()},
288- input_dims,
289- w.get_data_type (),
290- get_mkldnn_dtype (src_dtype));
291- ideep::tensor result;
292- result.init (packed_desc);
293- result.feed_from (w);
294- cached_weights.emplace (
295- weight.unsafeGetTensorImpl (),
296- val_blocked{weakref_type (weight.getIntrusivePtr ()), result});
297- return result;
304+ {weight.sizes ().cbegin (), weight.sizes ().cend ()}, input_dims,
305+ w.get_data_type (), get_mkldnn_dtype (src_dtype));
306+ cached_weight.init (packed_desc);
307+ cached_weight.feed_from (w);
308+ write_cached_weights (weight, cached_weight);
298309 }
310+ return cached_weight;
299311}
300312
301313// Create mkldnn memory view from ATen tensor
@@ -318,8 +330,8 @@ inline ideep::tensor get_mkldnn_tensor_view(
318330}
319331
320332bool is_prepacked (const at::Tensor& weight) {
321- auto it = cached_weights. find (weight. unsafeGetTensorImpl () );
322- return it == cached_weights. end () ? false : true ;
333+ auto cached_weight = read_cached_weights (weight);
334+ return cached_weight. is_empty () ;
323335}
324336
325337std::tuple<ideep::tensor, ideep::tensor> get_lstm_prepacked_weight (
@@ -334,18 +346,14 @@ std::tuple<ideep::tensor, ideep::tensor> get_lstm_prepacked_weight(
334346 const ideep::tensor& src_iter_c,
335347 const ideep::tensor& bias,
336348 const bool reverse) {
337- auto it_i = cached_weights. find (weight_ih. unsafeGetTensorImpl () );
338- auto it_h = cached_weights. find (weight_hh. unsafeGetTensorImpl () );
339-
340- bool all_in_cache = it_i != cached_weights. end () && it_h != cached_weights. end ();
341- bool all_miss = it_i == cached_weights. end () && it_h == cached_weights. end ();
349+ auto cached_weight_ih = read_cached_weights (weight_ih);
350+ auto cached_weight_hh = read_cached_weights (weight_hh);
351+ bool all_in_cache =
352+ !cached_weight_ih. is_empty () && !cached_weight_hh. is_empty ();
353+ bool all_miss = cached_weight_ih. is_empty () && cached_weight_hh. is_empty ();
342354 TORCH_CHECK (all_in_cache || all_miss, " both of the weights of LSTM should be cached or neither should be cached" );
343355
344- if (it_i != cached_weights.end ()) {
345- ideep::tensor w_ih = std::get<1 >(it_i->second );
346- ideep::tensor w_hh = std::get<1 >(it_h->second );
347- return std::make_tuple (w_ih, w_hh);
348- } else {
356+ if (cached_weight_ih.is_empty ()) {
349357 auto w1 = get_mkldnn_tensor_view (weight_ih, {{1 , 1 , input_size, num_gates, hidden_size}, get_mkldnn_dtype (weight_ih.scalar_type ()), ideep::format_tag::ldgoi});
350358 auto w2 = get_mkldnn_tensor_view (weight_hh, {{1 , 1 , hidden_size, num_gates, hidden_size}, get_mkldnn_dtype (weight_hh.scalar_type ()), ideep::format_tag::ldgoi});
351359
@@ -369,23 +377,16 @@ std::tuple<ideep::tensor, ideep::tensor> get_lstm_prepacked_weight(
369377 if (packed_desc_ih.is_rnn_packed () || packed_desc_hh.is_rnn_packed ()) {
370378 return std::make_tuple (w1, w2);
371379 }
380+ cached_weight_ih.init (packed_desc_ih);
381+ cached_weight_hh.init (packed_desc_hh);
372382
373- ideep::tensor packed_ih {packed_desc_ih};
374- ideep::tensor packed_hh {packed_desc_hh};
375-
376- packed_ih.feed_from (w1);
377- packed_hh.feed_from (w2);
378-
379- cached_weights.emplace (
380- weight_ih.unsafeGetTensorImpl (),
381- val_blocked{weakref_type (weight_ih.getIntrusivePtr ()), packed_ih});
382-
383- cached_weights.emplace (
384- weight_hh.unsafeGetTensorImpl (),
385- val_blocked{weakref_type (weight_hh.getIntrusivePtr ()), packed_hh});
383+ cached_weight_ih.feed_from (w1);
384+ cached_weight_hh.feed_from (w2);
386385
387- return std::make_tuple (packed_ih, packed_hh);
386+ write_cached_weights (weight_ih, cached_weight_ih);
387+ write_cached_weights (weight_hh, cached_weight_hh);
388388 }
389+ return std::make_tuple (cached_weight_ih, cached_weight_hh);
389390}
390391
391392at::Tensor linear_weight_prepack (
0 commit comments