Skip to content

Commit 31875e7

Browse files
authored
Change the OOB blocked format solution (#1628)
* Change the block format policy for our oneDNN op and expand the block semantics with IPEX_XPU_ONEDNN_LAYOUT 1. For training: Only with IPEX_XPU_ONEDNN_LAYOUT=1, the block is chosen, others are all plain * This PR for now has no change to oneDNN backward integration code 2. For inference: Only Conv can trigger the block when IPEX_XPU_ONEDNN_LAYOUT=1 or src is block or doing ATSM inference Matmul chooses block only when IPEX_XPU_ONEDNN_LAYOUT=1 or src is block Others choose block only when src is block * This PR adds some helper function to contain the block suggestion condition Signed-off-by: Chen, Zejun <zejun.chen@intel.com>
1 parent b86fb88 commit 31875e7

File tree

6 files changed

+74
-29
lines changed

6 files changed

+74
-29
lines changed

csrc/aten/operators/OpaqueTensorFactories.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ Tensor empty_opaque_qtensor(
9494
}
9595

9696
inline bool need_to_plain(const Tensor& tensor) {
97-
if (!Settings::I().is_onednn_layout_enabled())
98-
return false;
99-
10097
if (!tensor.defined())
10198
return false;
10299

@@ -108,6 +105,10 @@ inline bool need_to_plain(const Tensor& tensor) {
108105
if (tensor.is_sparse())
109106
return false;
110107

108+
auto tensor_ctx = DPCPPTensorContext::get_tensor_ctx(tensor);
109+
if (tensor_ctx.is_plain())
110+
return false;
111+
111112
return true;
112113
}
113114

@@ -133,9 +134,6 @@ Tensor to_plain_if_needed_(const Tensor& tensor) {
133134
}
134135

135136
std::vector<Tensor> to_plain_if_needed(TensorList tensors) {
136-
if (!Settings::I().is_onednn_layout_enabled())
137-
return tensors.vec();
138-
139137
std::vector<Tensor> _tensors;
140138
for (auto tensor : tensors) {
141139
_tensors.push_back(to_plain_if_needed(tensor));

csrc/oneDNN/Conv.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ static at::Tensor convolution(
217217
padding_back_bottom_right,
218218
stride,
219219
dilation);
220-
if (!Settings::I().is_onednn_layout_enabled() && !dst.defined()) {
220+
auto is_suggested_block = use_blocked_format_for_conv(src);
221+
if (!is_suggested_block && !dst.defined()) {
221222
auto dst_opt = src.options();
222223
if (src.is_quantized()) {
223224
dst_opt = attr.get_dst_dtype();
@@ -288,7 +289,7 @@ static at::Tensor convolution(
288289
: memory::desc();
289290

290291
// block combination
291-
if (Settings::I().is_onednn_layout_enabled()) {
292+
if (is_suggested_block) {
292293
// In blocked format scenario, oneDNN accept the src in plain format
293294
// when src ic = 3
294295
if (ic == 3) {
@@ -396,7 +397,7 @@ static at::Tensor convolution(
396397
? memory::desc(wgh_tz, wei_usr_data_t, fmt_wgh)
397398
: wgh_ctx.meta();
398399

399-
if (!Settings::I().is_onednn_layout_enabled()) {
400+
if (!is_suggested_block) {
400401
src_usr_md = memory::desc(src_tz, src_data_t, fmt_src);
401402
dst_usr_md = memory::desc(dst_tz, dst_data_t, fmt_src);
402403
} else {
@@ -444,7 +445,7 @@ static at::Tensor convolution(
444445

445446
auto weight_cache_optimization = [&]() {
446447
bool onoff = false;
447-
onoff |= Settings::I().is_onednn_layout_enabled();
448+
onoff |= is_suggested_block;
448449
onoff |= onednn_conv_use_channels_last(src, wgh);
449450
onoff &= !at::GradMode::is_enabled();
450451
return onoff;
@@ -494,7 +495,7 @@ static at::Tensor convolution(
494495
auto expected_dst_md = conv_fwd_pd.dst_desc();
495496
auto dst_m = dpcpp_onednn_memory(dst_usr_md, engine, dst.data_ptr());
496497
if (dst_usr_md != expected_dst_md) {
497-
if (Settings::I().is_onednn_layout_enabled() && dst.is_quantized()) {
498+
if (is_suggested_block && dst.is_quantized()) {
498499
auto quantizer = dpcpp_make_per_tensor_affine_quantizer(
499500
(get_onednn_dtype_include_double(dst) == memory::data_type::u8 &&
500501
dst.q_zero_point() == 128)
@@ -580,8 +581,7 @@ static at::Tensor convolution(
580581
{DNNL_ARG_DST, dst_m}});
581582
#endif
582583

583-
if (Settings::I().is_onednn_layout_enabled() &&
584-
dst_.data_ptr() != dst.data_ptr()) {
584+
if (is_suggested_block && dst_.data_ptr() != dst.data_ptr()) {
585585
auto blk_ctx = DPCPPTensorContext::release_tensor_ctx(dst_);
586586
DPCPPTensorContext::set_tensor_ctx(dst, std::move(blk_ctx));
587587
}

csrc/oneDNN/Eltwise.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ static inline void eltwise(
4242
src.is_contiguous(at::MemoryFormat::ChannelsLast3d)));
4343
auto src_md = memory::desc({src_tz}, data_t, format_data);
4444

45+
auto src_ctx = at::AtenIpexTypeXPU::DPCPPTensorContext::get_tensor_ctx(src);
46+
4547
memory src_memory;
46-
if (!Settings::I().is_onednn_layout_enabled() ||
47-
src.is_contiguous(at::MemoryFormat::ChannelsLast) ||
48+
if (src_ctx.is_plain() || src.is_contiguous(at::MemoryFormat::ChannelsLast) ||
4849
src.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
4950
src_memory = dpcpp_onednn_memory(src_md, engine, src.data_ptr());
5051
} else {
51-
auto src_ctx = at::AtenIpexTypeXPU::DPCPPTensorContext::get_tensor_ctx(src);
5252
src_md = src_ctx.is_plain() ? src_md : src_ctx.meta();
5353
src_memory = dpcpp_onednn_memory(src_md, engine, src.data_ptr());
5454
}
@@ -69,7 +69,7 @@ static inline void eltwise(
6969
eltwise_forward::primitive_desc(eltwise_eltwiseFwd_desc, attr, engine);
7070

7171
memory dst_memory;
72-
if (!Settings::I().is_onednn_layout_enabled()) {
72+
if (src_ctx.is_plain()) {
7373
if (!dst.defined()) {
7474
dst = src.is_contiguous(at::MemoryFormat::ChannelsLast)
7575
? at::empty_like(src, at::MemoryFormat::ChannelsLast)

csrc/oneDNN/Matmul.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,14 @@ static inline void matmul(
281281
#endif
282282

283283
auto matmul_desc = matmul::desc(m1_md, m2_md, dst_md);
284+
auto is_suggested_block = use_blocked_format_for_matmul(m1);
284285

285286
if (with_bias && (!m1.is_quantized()) && (!m2.is_quantized())) {
286287
// ensure getting a valid oneDNN bias md here
287288
b_md = memory::desc(
288289
get_onednn_dims(b), get_onednn_dtype(b), get_onednn_strides(b));
289290

290-
if (dims == 2 && Settings::I().is_onednn_layout_enabled()) {
291+
if (dims == 2 && is_suggested_block) {
291292
// attr + blk
292293
#ifdef USE_PRIMITIVE_CACHE
293294
create_key(
@@ -310,7 +311,7 @@ static inline void matmul(
310311
matmul_desc = matmul::desc(m1_md, m2_md, b_md, dst_md);
311312
}
312313
} else {
313-
if (dims == 2 && Settings::I().is_onednn_layout_enabled()) {
314+
if (dims == 2 && is_suggested_block) {
314315
// no attr + blk
315316
#ifdef USE_PRIMITIVE_CACHE
316317
create_key(
@@ -374,7 +375,7 @@ static inline void matmul(
374375

375376
auto weight_cache_optimization = [&]() {
376377
bool onoff = false;
377-
onoff |= Settings::I().is_onednn_layout_enabled();
378+
onoff |= is_suggested_block;
378379
onoff &= c10::InferenceMode::is_enabled();
379380
return onoff;
380381
}();
@@ -441,8 +442,7 @@ static inline void matmul(
441442
});
442443
}
443444

444-
if (Settings::I().is_onednn_layout_enabled() && dst_m != dst_usr_m &&
445-
dims == 2) {
445+
if (is_suggested_block && dst_m != dst_usr_m && dims == 2) {
446446
auto blk_ctx = DPCPPTensorContext::release_tensor_ctx(dst_);
447447
DPCPPTensorContext::set_tensor_ctx(dst, std::move(blk_ctx));
448448
}

csrc/oneDNN/Pooling.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ static at::Tensor pooling(
151151
pooling_forward::primitive_desc(pooling_fwd_desc, engine);
152152

153153
memory src_m, dst_m;
154-
if (!Settings::I().is_onednn_layout_enabled() || is_smf_channels_last(src)) {
154+
if (src_ctx.is_plain()) {
155155
src_m = dpcpp_onednn_memory(src_md, engine, src.data_ptr());
156156
dst_m = dpcpp_onednn_memory(dst_md, engine, dst.data_ptr());
157157
} else {
@@ -310,8 +310,7 @@ static std::tuple<at::Tensor, at::Tensor> pooling(
310310
auto expected_dst_md = pooling_fwd_pd.dst_desc();
311311

312312
memory src_usr_m, dst_usr_m;
313-
if (!Settings::I().is_onednn_layout_enabled() ||
314-
onednn_pool_use_channels_last(src)) {
313+
if (src_ctx.is_plain() || onednn_pool_use_channels_last(src)) {
315314
src_usr_m = dpcpp_onednn_memory(src_md, engine, src.data_ptr());
316315
dst_usr_m = dpcpp_onednn_memory(dst_md, engine, dst.data_ptr());
317316
} else {
@@ -340,8 +339,7 @@ static std::tuple<at::Tensor, at::Tensor> pooling(
340339
if (prop_kind == dnnl::prop_kind::forward_training) {
341340
at::Tensor idx_;
342341
memory idx_m;
343-
if (!Settings::I().is_onednn_layout_enabled() ||
344-
onednn_pool_use_channels_last(src)) {
342+
if (src_ctx.is_plain() || onednn_pool_use_channels_last(src)) {
345343
idx_ = at::empty({dst_tz}, at::TensorOptions(at::kXPU).dtype(at::kInt));
346344
idx_m = dpcpp_onednn_memory(idx_md, engine, idx_.data_ptr());
347345
} else {
@@ -366,8 +364,7 @@ static std::tuple<at::Tensor, at::Tensor> pooling(
366364
{DNNL_ARG_DST, dst_m},
367365
{DNNL_ARG_WORKSPACE, idx_m}});
368366

369-
if (!Settings::I().is_onednn_layout_enabled() ||
370-
onednn_pool_use_channels_last(src)) {
367+
if (src_ctx.is_plain() || onednn_pool_use_channels_last(src)) {
371368
dtype_convert_by_scalar(
372369
idx.data_ptr<int64_t>(), idx_.data_ptr<int32_t>(), idx_.numel());
373370
} else {

csrc/oneDNN/Utils.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <core/MemoryFormat.h>
55
#include <core/detail/TensorInfo.h>
66
#include <oneapi/dnnl/dnnl.hpp>
7+
#include <runtime/Utils.h>
78
#include <tensor/Context.h>
89
#include <utils/Macros.h>
910
#include <utils/Settings.h>
@@ -470,6 +471,55 @@ static inline bool cat_valid(const TensorList& tensors) {
470471
return true;
471472
}
472473

474+
// judge to use block or not for Conv
475+
static inline bool use_blocked_format_for_conv(const at::Tensor& src) {
476+
if (!src.defined() || src.is_sparse()) {
477+
// suggest plain
478+
return false;
479+
}
480+
481+
if (Settings::I().is_onednn_layout_enabled()) {
482+
// suggest block
483+
return true;
484+
}
485+
486+
// inference workloads on ATSM platform, the conv will use blocked format
487+
// used double support to distinguish is atsm or not
488+
auto is_auto_transpose = !dpcppSupportFP64();
489+
auto suggest_weight_block = is_auto_transpose &&
490+
(c10::InferenceMode::is_enabled() || !at::GradMode::is_enabled()) &&
491+
!is_smf_channels_last(src);
492+
if (suggest_weight_block) {
493+
// suggest block
494+
return true;
495+
}
496+
497+
// suggest plain
498+
return false;
499+
}
500+
501+
// judge to use block or not for Matmul
502+
static inline bool use_blocked_format_for_matmul(const at::Tensor& src) {
503+
if (!src.defined() || src.is_sparse()) {
504+
// suggest plain
505+
return false;
506+
}
507+
508+
if (Settings::I().is_onednn_layout_enabled()) {
509+
// suggest block
510+
return true;
511+
}
512+
513+
auto src_ctx = at::AtenIpexTypeXPU::DPCPPTensorContext::get_tensor_ctx(src);
514+
if (!src_ctx.is_plain()) {
515+
// suggest block
516+
return true;
517+
}
518+
519+
// suggest plain
520+
return false;
521+
}
522+
473523
static inline std::vector<int64_t> gen_dummy_input_size_for(
474524
const at::IntArrayRef weight_sizes,
475525
const int64_t groups) {

0 commit comments

Comments
 (0)