Skip to content

Commit 5d14f80

Browse files
authored
Replace at::layer_norm with ipex::layernorm (#129)
* Replace at::layer_norm with ipex::layernorm * Add torch_ipex/csrc/cpu/LayerNorm.h torch_ipex/csrc/cpu/LayerNorm.cpp * Add comment * Fix conflict build issue * Fix clang-format issue * Update CustomOPs.cpp Add the reason for layer_norm performance regression and condition to remove workaround for layer_norm. * Update CustomOPs.cpp * Fix clang issue
1 parent c4bc4f4 commit 5d14f80

File tree

9 files changed

+238
-4
lines changed

9 files changed

+238
-4
lines changed

tests/cpu/test_jit.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,14 @@ def __init__(self, dim=-1):
436436
def forward(self, x):
437437
return self.softmax(x)
438438

439+
class IPEXLayerNorm(torch.nn.Module):
440+
def __init__(self):
441+
super(IPEXLayerNorm, self).__init__()
442+
self.layernorm = torch.nn.LayerNorm(4)
443+
def forward(self, x):
444+
return self.layernorm(x)
445+
446+
439447

440448
class Tester(TestCase):
441449

@@ -947,6 +955,17 @@ def test_ipex_softmax(self):
947955
torch.rand(3, 4, 4, dtype=torch.bfloat16),
948956
kind_in_graph="ipex::softmax",
949957
prec=5e-3)
958+
def test_ipex_layernorm(self):
959+
self._test_output(
960+
IPEXLayerNorm(),
961+
torch.rand(8, 3, 4),
962+
kind_in_graph="ipex::layernorm")
963+
self._test_output_bf16(
964+
IPEXLayerNorm(),
965+
torch.rand(8, 3, 4, dtype=torch.bfloat16),
966+
kind_in_graph="ipex::layernorm",
967+
prec=5e-2)
968+
950969

951970
if __name__ == '__main__':
952971
torch.manual_seed(2020)

torch_ipex/csrc/cpu/CustomOPs.cpp

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#include "torch_ipex/csrc/cpu/CustomOPs.h"
2-
#include "torch_ipex/csrc/utils.h"
32
#include "Conv.h"
3+
#include "LayerNorm.h"
44
#include "Linear.h"
5-
#include "Pooling.h"
65
#include "Matmul.h"
6+
#include "Pooling.h"
77
#include "Softmax.h"
8+
#include "torch_ipex/csrc/utils.h"
89

910
#include <ATen/Context.h>
1011
#include <ATen/InferSize.h>
@@ -357,5 +358,115 @@ at::Tensor AtenIpexJITDev::dil_softmax(
357358
return softmax_impl(input, dim);
358359
}
359360

361+
/**
362+
*prepare inputs for dil_layernorm
363+
*
364+
*@param input: the source tensor to layernorm
365+
*@param normalized_shape: input shape from an expected input of size
366+
*@param weight: scale tensor for layernorm
367+
*@param bias: shift tensor for layernorm
368+
*
369+
*@return inputs for dil_layernorm.
370+
**/
371+
std::tuple<at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t>
372+
_prepare_layer_norm_inputs(const at::Tensor &input,
373+
at::IntArrayRef normalized_shape,
374+
const at::Tensor &weight /* optional */,
375+
const at::Tensor &bias /* optional */) {
376+
377+
const int normalized_ndim = normalized_shape.size();
378+
TORCH_CHECK(normalized_ndim >= 1,
379+
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
380+
"containing at least one element, but got normalized_shape = ",
381+
normalized_shape);
382+
TORCH_CHECK(
383+
!weight.defined() || weight.sizes().equals(normalized_shape),
384+
"Expected weight to be of same shape as normalized_shape, but got ",
385+
"weight of shape ", weight.sizes(),
386+
" and normalized_shape = ", normalized_shape);
387+
TORCH_CHECK(!bias.defined() || bias.sizes().equals(normalized_shape),
388+
"Expected bias to be of same shape as normalized_shape, but got ",
389+
"bias of shape ", bias.sizes(),
390+
" and normalized_shape = ", normalized_shape);
391+
392+
const auto input_shape = input.sizes();
393+
const auto input_ndim = input.dim();
394+
395+
if (input_ndim < normalized_ndim ||
396+
!input_shape.slice(input_ndim - normalized_ndim)
397+
.equals(normalized_shape)) {
398+
std::stringstream ss;
399+
ss << "Given normalized_shape=" << normalized_shape
400+
<< ", expected input with shape [*";
401+
for (auto size : normalized_shape) {
402+
ss << ", " << size;
403+
}
404+
ss << "], but got input of size" << input_shape;
405+
AT_ERROR(ss.str());
406+
}
407+
408+
const int axis = input_ndim - normalized_ndim;
409+
const int64_t M =
410+
std::accumulate(input_shape.cbegin(), input_shape.cbegin() + axis,
411+
static_cast<int64_t>(1), std::multiplies<int64_t>());
412+
const int64_t N =
413+
std::accumulate(input_shape.cbegin() + axis, input_shape.cend(),
414+
static_cast<int64_t>(1), std::multiplies<int64_t>());
415+
;
416+
417+
const auto &X = input.is_contiguous() ? input : input.contiguous();
418+
const auto &gamma = weight.is_contiguous() ? weight : weight.contiguous();
419+
const auto &beta = bias.is_contiguous() ? bias : bias.contiguous();
420+
return std::make_tuple(X, gamma, beta, M, N);
421+
}
422+
423+
/**
424+
* at::layer_norm performance drop due to
425+
* #PR https://github.com/pytorch/pytorch/pull/59987
426+
* This is a workaround for layernorm regression.
427+
* Replace at::layer_norm with ipex::layernorm in jit pass for inference.
428+
* Now, we only use oneDNN kernel when both weight and bias are provided.
429+
* ToDo: more scenarios to use oneDNN or remvoe this pass
430+
* when at::layer_norm performance is back compared to w/o
431+
* mergeing https://github.com/pytorch/pytorch/pull/59987
432+
*
433+
* @param input: the source tensor to layernorm
434+
* @param normalized_shape: input shape from an expected input of size
435+
* @param weight_opt: scale tensor for layernorm
436+
* @param bias_opt: shift tensor for layernorm
437+
* @param bias: a value added to the denominator for numerical stability.
438+
* Default: 1e-5
439+
*
440+
* return: output for layernorm
441+
*/
442+
at::Tensor AtenIpexJITDev::dil_layernorm(
443+
const at::Tensor &input, at::IntArrayRef normalized_shape,
444+
const c10::optional<at::Tensor> &weight_opt,
445+
const c10::optional<at::Tensor> &bias_opt, float eps, bool cudnn_enable) {
446+
447+
if (weight_opt.has_value() && bias_opt.has_value()) {
448+
#if defined(IPEX_PROFILE_OP)
449+
RECORD_FUNCTION("AtenIpexJITDev::dil_layernorm",
450+
std::vector<c10::IValue>({}));
451+
#endif
452+
auto inputs = _prepare_layer_norm_inputs(
453+
input, normalized_shape, weight_opt.value(), bias_opt.value());
454+
auto X = std::get<0>(inputs);
455+
auto gamma = std::get<1>(inputs);
456+
auto beta = std::get<2>(inputs);
457+
auto M = std::get<3>(inputs);
458+
auto N = std::get<4>(inputs);
459+
return std::get<0>(dil_native_layer_norm_impl(X, gamma, beta, M, N, eps));
460+
}
461+
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
462+
at::borrow_from_optional_tensor(weight_opt);
463+
const at::Tensor &weight = *weight_maybe_owned;
464+
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
465+
at::borrow_from_optional_tensor(bias_opt);
466+
const at::Tensor &bias = *bias_maybe_owned;
467+
return std::get<0>(
468+
at::native_layer_norm(input, normalized_shape, weight, bias, eps));
469+
}
470+
360471
} // namespace cpu
361472
} // namespace torch_ipex

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace ipex {
3131

3232
static auto max_pool2d = Symbol::fromQualString("ipex::max_pool2d");
3333
static auto softmax = Symbol::fromQualString("ipex::softmax");
34+
static auto layernorm = Symbol::fromQualString("ipex::layernorm");
3435

3536
// n-dims tensor op.
3637
static auto convolution_nd_weight_base =
@@ -186,6 +187,11 @@ class AtenIpexJITDev {
186187
at::IntArrayRef kernel_size, int64_t groups, int64_t output_channel,
187188
bool weight_channels_last, bool weight_prepacked, at::Tensor &accumu,
188189
at::Scalar alpha);
190+
static at::Tensor dil_layernorm(const at::Tensor &input,
191+
at::IntArrayRef normalized_shape,
192+
const c10::optional<at::Tensor> &weight_opt,
193+
const c10::optional<at::Tensor> &bias_opt,
194+
float eps, bool cudnn_enable);
189195
};
190196

191197
} // namespace cpu

torch_ipex/csrc/cpu/LayerNorm.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "LayerNorm.h"
2+
#include "mkldnn/MKLDNNCommon.h"
3+
4+
namespace torch_ipex {
5+
namespace cpu {
6+
7+
/**layer_norm kernel for inference mode with oneDNN implementation
8+
*
9+
* @param X: input tensor for layernorm
10+
* @param gamma: scale for layernorm
11+
* @param beta: shift for layernorm
12+
* @param M
13+
* @param N
14+
* @param eps
15+
**/
16+
std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_native_layer_norm_impl(
17+
const at::Tensor &X, const at::Tensor &gamma /* optional */,
18+
const at::Tensor &beta /* optional */, int64_t M, int64_t N, double eps) {
19+
ideep::tensor x = itensor_view_from_dense(X);
20+
auto gamma_fp32 = gamma.to(at::kFloat);
21+
auto beta_fp32 = beta.to(at::kFloat);
22+
const ideep::tensor scale = itensor_view_from_dense(gamma_fp32);
23+
const ideep::tensor shift = itensor_view_from_dense(beta_fp32);
24+
int64_t i = 0;
25+
auto dim = at::maybe_wrap_dim(0, X.dim(), false);
26+
auto j = X.sizes()[dim];
27+
std::vector<int64_t> input_size;
28+
while (j <= M) {
29+
dim = at::maybe_wrap_dim(i++, X.dim(), false);
30+
input_size.push_back(X.sizes()[dim]);
31+
dim = at::maybe_wrap_dim(i, X.dim(), false);
32+
j *= X.sizes()[dim];
33+
}
34+
input_size.push_back(N);
35+
auto src = x.reshape(input_size);
36+
at::Tensor Y = at::native::empty_like(X);
37+
at::Tensor mean = at::empty({M}, X.options());
38+
at::Tensor variance = at::empty({M}, X.options());
39+
auto onednn_Y = itensor_view_from_dense(Y);
40+
auto onednn_mean = itensor_view_from_dense(mean);
41+
auto onednn_variance = itensor_view_from_dense(variance);
42+
ideep::layer_normalization_forward::compute(
43+
src, scale, shift, onednn_Y, onednn_mean, onednn_variance, eps);
44+
return std::make_tuple(Y, mean, variance);
45+
}
46+
47+
} // namespace cpu
48+
} // namespace torch_ipex

torch_ipex/csrc/cpu/LayerNorm.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
5+
#include "ideep/ideep.hpp"
6+
7+
namespace torch_ipex {
8+
namespace cpu {
9+
10+
std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_native_layer_norm_impl(
11+
const at::Tensor &X, const at::Tensor &gamma /* optional */,
12+
const at::Tensor &beta /* optional */, int64_t M, int64_t N, double eps);
13+
} // namespace cpu
14+
} // namespace torch_ipex

torch_ipex/csrc/jit/fusion_pass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ void FusionPass(std::shared_ptr<Graph> &graph) {
345345

346346
// replace aten::softmax with ipex::softmax
347347
graph_rewrite::replaceAtenLinearWithIpexSoftmax(graph);
348-
348+
graph_rewrite::replaceAtenLayerNormWithIpexLayerNorm(graph);
349349
// TODO: Some post processing?? ECS/EDC/Peephole???
350350
ConstantPropagation(graph);
351351
}

torch_ipex/csrc/jit/graph_rewrite.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,22 @@ void replaceAtenLinearWithIpexSoftmax(std::shared_ptr<Graph>& graph) {
551551
rewriter_aten.runOnGraph(graph);
552552

553553
}
554+
// replace aten::layer_norm with ipex::layer_norm during jit pass
555+
// this is a just workaround for layernorm performance reggression
556+
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph> &graph) {
557+
std::string aten_layernorm = R"(
558+
graph(%a, %shape:int[], %w, %b, %eps:float, %cudnn_enable:bool):
559+
%r = aten::layer_norm(%a, %shape, %w, %b, %eps, %cudnn_enable)
560+
return (%r) )";
561+
std::string ipex_layernorm = R"(
562+
graph(%a, %shape:int[], %w, %b, %eps:float, %cudnn_enable:bool):
563+
%r = ipex::layernorm(%a, %shape, %w, %b, %eps, %cudnn_enable)
564+
return (%r) )";
565+
SubgraphRewriter rewriter_aten;
566+
rewriter_aten.RegisterRewritePattern(aten_layernorm, ipex_layernorm);
567+
rewriter_aten.runOnGraph(graph);
568+
}
569+
554570
} // namespace graph_rewrite
555571
} // namespace jit
556572
} // namespace torch

torch_ipex/csrc/jit/graph_rewrite.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void FuseShuffle(std::shared_ptr<Graph>& graph);
2727
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph);
2828
void replaceAtenLinearWithIpexLinear(std::shared_ptr<Graph>& graph);
2929
void replaceAtenLinearWithIpexSoftmax(std::shared_ptr<Graph>& graph);
30+
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph> &graph);
3031
} // namespace graph_rewrite_helper
3132
} // namespace jit
3233
} // namespace torch

torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,28 @@ RegisterOperators op(
400400
return 0;
401401
};
402402
},
403+
aliasAnalysisFromSchema()),
404+
405+
Operator(
406+
"ipex::layernorm(Tensor a, int[] normalized_shape, Tensor ? "
407+
"weight_opt, Tensor ? bias_opt, float eps, bool cudnn_enable) -> "
408+
"Tensor",
409+
[](const Node *node) -> Operation {
410+
return [](Stack *stack) {
411+
auto result = AtenIpexJITDev::dil_layernorm(
412+
(std::move(peek(stack, 0, 6))).toTensor(),
413+
(std::move(peek(stack, 1, 6))).toIntVector(),
414+
toOptionalTensor(std::move(peek(stack, 2, 6))),
415+
toOptionalTensor(std::move(peek(stack, 3, 6))),
416+
(std::move(peek(stack, 4, 6))).toDouble(),
417+
(std::move(peek(stack, 5, 6))).toBool());
418+
drop(stack, 6);
419+
pack(stack, std::move(result));
420+
return 0;
421+
};
422+
},
403423
aliasAnalysisFromSchema())
404424

405425
});
406-
407426
} // namespace jit
408427
} // namespace torch

0 commit comments

Comments
 (0)