Skip to content

Commit 9cffeb9

Browse files
authored
add the implementation for rmsnorm bwd (#4531) (#4679)
1 parent ecffd3d commit 9cffeb9

File tree

3 files changed

+313
-2
lines changed

3 files changed

+313
-2
lines changed

csrc/gpu/aten/operators/RMSNorm.cpp

Lines changed: 263 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,37 @@
22
#include <ATen/Config.h>
33
#include <ATen/NativeFunctions.h>
44

5+
#include <ATen/record_function.h>
56
#include <oneDNN/oneDNN.h>
7+
#include <torch/autograd.h>
8+
#include <torch/custom_class.h>
9+
#include <utils/SimpleTrace.h>
610
#include "Norm.h"
11+
#include "comm/ATDispatch.h"
712
#include "comm/RegistrationDeclarations.h"
813
#include "utils/CustomOperatorRegistration.h"
914

1015
using namespace torch_ipex::xpu::dpcpp;
16+
using namespace torch::autograd;
1117
using namespace at::AtenIpexTypeXPU::normalization;
1218

1319
namespace at {
1420
namespace AtenIpexTypeXPU {
1521

22+
std::tuple<Tensor, Tensor> rms_norm_fw(
23+
const Tensor& input,
24+
at::IntArrayRef normalized_shape,
25+
const Tensor& weight,
26+
double epsilon);
27+
28+
std::tuple<Tensor, Tensor> rms_norm_bw(
29+
const Tensor& grad_output,
30+
const Tensor& input,
31+
at::IntArrayRef normalized_shape,
32+
const Tensor& rstd,
33+
const Tensor& weight,
34+
std::array<bool, 2> grad_input_mask);
35+
1636
template <typename scalar_t, typename mean_t, typename weight_t>
1737
class RMSNormForward : public NormForward<scalar_t, mean_t, weight_t, true> {
1838
public:
@@ -337,12 +357,13 @@ void RMSNormKernelImpl(
337357
X.scalar_type(),
338358
"RMSNormKernelImpl",
339359
[&]() {
340-
rstd = at::empty({M}, X.options().dtype(kFloat));
341360
if (gamma.scalar_type() == kFloat) {
361+
rstd = at::empty({M}, X.options().dtype(kFloat));
342362
RMSNormKernelImplInternal<scalar_t, float, float>(
343363
X, gamma, M, N, static_cast<acc_type<scalar_t>>(eps), Y, rstd);
344364
} else {
345-
RMSNormKernelImplInternal<scalar_t, float, scalar_t>(
365+
rstd = at::empty({M}, X.options());
366+
RMSNormKernelImplInternal<scalar_t, scalar_t, scalar_t>(
346367
X, gamma, M, N, static_cast<acc_type<scalar_t>>(eps), Y, rstd);
347368
}
348369
});
@@ -374,11 +395,251 @@ std::tuple<Tensor, Tensor> rms_norm_fw(
374395
return std::make_tuple(output.reshape(input.sizes()), rstd);
375396
}
376397

398+
template <typename scalar_t, typename mean_t, typename weight_t>
399+
void RmsNormBackwardKernelImplInternal(
400+
const Tensor& dY,
401+
const Tensor& X,
402+
const Tensor& rstd,
403+
const Tensor& gamma,
404+
int64_t M,
405+
int64_t N,
406+
Tensor& dX,
407+
Tensor& dgamma,
408+
const Tensor& output,
409+
std::array<bool, 2> grad_input_mask) {
410+
TORCH_CHECK(dY.numel() == M * N);
411+
TORCH_CHECK(rstd.numel() == M);
412+
413+
using accscalar_t = acc_type<scalar_t>;
414+
mean_t* var_data = rstd.data_ptr<mean_t>();
415+
weight_t* gamma_data = gamma.defined() ? gamma.data_ptr<weight_t>() : nullptr;
416+
417+
if (grad_input_mask[0]) {
418+
// backward data
419+
scalar_t* X_data = X.data_ptr<scalar_t>();
420+
scalar_t* dY_data = dY.data_ptr<scalar_t>();
421+
scalar_t* dX_data = dX.data_ptr<scalar_t>();
422+
423+
auto config = NormConfig(M, N, 1, sizeof(scalar_t));
424+
bool can_use_32bit_index = canUse32BitIndexMath(X) &&
425+
canUse32BitIndexMath(dY) && canUse32BitIndexMath(dX);
426+
427+
// TODO: force it to use fused_norm_kernel
428+
config.workgroup_num_foreach = 1;
429+
config.WGPlane = config.Plane;
430+
431+
if (config.workgroup_num_foreach == 1) {
432+
RMSNormBackward<scalar_t, mean_t, weight_t> rms_norm_backward(
433+
X_data, dY_data, dX_data, var_data, gamma_data, M, N);
434+
launch_vectorized_fused_norm_kernel<
435+
scalar_t,
436+
mean_t,
437+
weight_t,
438+
RMSNormBackward,
439+
true>(rms_norm_backward, config, can_use_32bit_index);
440+
} else {
441+
const auto kAccType =
442+
(X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
443+
? kFloat
444+
: X.scalar_type();
445+
Tensor a = at::empty({M}, X.options().dtype(kAccType));
446+
accscalar_t* a_data = a.data_ptr<accscalar_t>();
447+
448+
RMSNormBackward<scalar_t, mean_t, weight_t> rms_norm_backward(
449+
X_data, dY_data, dX_data, var_data, gamma_data, a_data, M, N);
450+
Tensor semaphores, scratchpad;
451+
config.template init_global_reduce<accscalar_t>(
452+
X, semaphores, scratchpad);
453+
RowwiseMomentsDPCPPKernelImpl<
454+
scalar_t,
455+
mean_t,
456+
weight_t,
457+
RMSNormBackward,
458+
true>(rms_norm_backward, config, can_use_32bit_index);
459+
NormUpdateKernelImpl<scalar_t, mean_t, weight_t, RMSNormBackward, true>(
460+
rms_norm_backward, config, can_use_32bit_index);
461+
}
462+
}
463+
464+
if (grad_input_mask[1]) {
465+
// backward weight
466+
Tensor sum_tmp = at::mul(output, dY);
467+
at::sum_out(dgamma, sum_tmp, at::IntArrayRef{0, 1});
468+
}
469+
}
470+
471+
void RmsNormBackwardKernelImpl(
472+
const Tensor& dY,
473+
const Tensor& X,
474+
const Tensor& rstd,
475+
const Tensor& gamma,
476+
int64_t M,
477+
int64_t N,
478+
Tensor& dX,
479+
Tensor& dgamma,
480+
const Tensor& output,
481+
std::array<bool, 2> grad_input_mask) {
482+
IPEX_DISPATCH_FLOATING_TYPES_AND2(
483+
at::ScalarType::Half,
484+
at::ScalarType::BFloat16,
485+
X.scalar_type(),
486+
"RmsNormBackwardKernelImpl",
487+
[&]() {
488+
using accscalar_t = acc_type<scalar_t>;
489+
if (gamma.scalar_type() == kFloat) {
490+
RmsNormBackwardKernelImplInternal<scalar_t, float, float>(
491+
dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask);
492+
} else {
493+
RmsNormBackwardKernelImplInternal<scalar_t, scalar_t, scalar_t>(
494+
dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask);
495+
}
496+
});
497+
}
498+
499+
std::tuple<Tensor, Tensor> rms_norm_bw(
500+
const Tensor& grad_output,
501+
const Tensor& input,
502+
at::IntArrayRef normalized_shape,
503+
const Tensor& rstd,
504+
const Tensor& weight,
505+
const Tensor& output,
506+
std::array<bool, 2> grad_input_mask) {
507+
RECORD_FUNCTION("ipex::rms_norm_bw", std::vector<c10::IValue>({grad_output}));
508+
auto M_N =
509+
_check_layer_norm_inputs(input, normalized_shape, weight, Tensor());
510+
auto M = M_N.first;
511+
auto N = M_N.second;
512+
513+
Tensor grad_input;
514+
Tensor grad_weight;
515+
516+
if (grad_input_mask[0]) {
517+
grad_input = at::native::empty_like(
518+
input,
519+
c10::nullopt /* dtype */,
520+
c10::nullopt /* layout */,
521+
c10::nullopt /* device */,
522+
c10::nullopt /* pin_memory */,
523+
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
524+
}
525+
526+
if (grad_input_mask[1]) {
527+
grad_weight = M > 0 ? at::native::empty_like(
528+
weight,
529+
c10::nullopt /* dtype */,
530+
c10::nullopt /* layout */,
531+
c10::nullopt /* device */,
532+
c10::nullopt /* pin_memory */,
533+
LEGACY_CONTIGUOUS_MEMORY_FORMAT)
534+
: at::native::zeros_like(
535+
weight,
536+
c10::nullopt /* dtype */,
537+
c10::nullopt /* layout */,
538+
c10::nullopt /* device */,
539+
c10::nullopt /* pin_memory */,
540+
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
541+
}
542+
543+
if (input.numel() != 0 && grad_output.numel() != 0) {
544+
Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input;
545+
Tensor grad_output_ =
546+
(grad_output.dim() == 1) ? grad_output.reshape({M, N}) : grad_output;
547+
Tensor weight_ =
548+
(weight.defined() && weight.dim() == 1) ? weight.reshape({N}) : weight;
549+
Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output;
550+
551+
input_ = input_.contiguous();
552+
grad_output_ = grad_output_.contiguous();
553+
output_ = output_.contiguous();
554+
weight_ = weight_.defined() ? weight_.contiguous() : weight_;
555+
556+
RmsNormBackwardKernelImpl(
557+
grad_output_,
558+
input_,
559+
rstd,
560+
weight_,
561+
M,
562+
N,
563+
grad_input,
564+
grad_weight,
565+
output_,
566+
grad_input_mask);
567+
}
568+
return std::make_tuple(
569+
grad_input_mask[0] ? grad_input.reshape(input.sizes()) : grad_input,
570+
grad_input_mask[1] ? grad_weight.reshape(weight.sizes()) : grad_weight);
571+
}
572+
573+
class IPEXRmsNormOp : public Function<IPEXRmsNormOp> {
574+
public:
575+
static variable_list forward(
576+
AutogradContext* ctx,
577+
const Tensor& input,
578+
at::IntArrayRef normalized_shape,
579+
const Tensor& weight,
580+
double epsilon) {
581+
#ifdef BUILD_SIMPLE_TRACE
582+
SimpleTrace trace(
583+
"IPEXRmsNormOp forward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::forward");
584+
#endif
585+
ctx->saved_data["input_requires_grad"] = input.requires_grad();
586+
ctx->saved_data["weight_requires_grad"] = weight.requires_grad();
587+
ctx->saved_data["normalized_shape"] = normalized_shape;
588+
auto outputs = rms_norm_fw(input, normalized_shape, weight, epsilon);
589+
590+
ctx->save_for_backward(
591+
{input, weight, std::get<0>(outputs), std::get<1>(outputs)});
592+
variable_list result = {std::get<0>(outputs), std::get<1>(outputs)};
593+
return result;
594+
}
595+
596+
static variable_list backward(
597+
AutogradContext* ctx,
598+
variable_list grad_outputs) {
599+
#ifdef BUILD_SIMPLE_TRACE
600+
SimpleTrace trace(
601+
"IPEXRmsNormOp backward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::backward");
602+
#endif
603+
auto weight_requires_grad =
604+
ctx->saved_data["weight_requires_grad"].toBool();
605+
auto input_requires_grad = ctx->saved_data["input_requires_grad"].toBool();
606+
auto saved = ctx->get_saved_variables();
607+
Tensor input = saved[0];
608+
Tensor weight = saved[1];
609+
Tensor output = saved[2];
610+
Tensor rstd = saved[3];
611+
auto normalized_shape = weight.sizes();
612+
613+
auto grad_inputs = rms_norm_bw(
614+
grad_outputs[0],
615+
input,
616+
normalized_shape,
617+
rstd,
618+
weight,
619+
output,
620+
{input_requires_grad, weight_requires_grad});
621+
return {
622+
std::get<0>(grad_inputs), Tensor(), std::get<1>(grad_inputs), Tensor()};
623+
}
624+
};
625+
626+
Tensor rms_norm_impl(
627+
const Tensor& input,
628+
at::IntArrayRef normalized_shape,
629+
const Tensor& weight,
630+
double epsilon) {
631+
auto output = IPEXRmsNormOp::apply(input, normalized_shape, weight, epsilon);
632+
return output[0];
633+
}
377634
} // namespace AtenIpexTypeXPU
378635
} // namespace at
379636

380637
namespace {
381638
IPEX_LIBRARY_FRAGMENT() {
639+
IPEX_OP_REGISTER_DISPATCH(
640+
"rms_norm_impl",
641+
at::AtenIpexTypeXPU::rms_norm_impl,
642+
c10::DispatchKey::AutogradXPU);
382643
IPEX_OP_REGISTER("rms_norm.xpu", at::AtenIpexTypeXPU::rms_norm_fw);
383644
}
384645
} // namespace

intel_extension_for_pytorch/xpu/intrinsic/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"copy_blocks",
2626
"swap_blocks",
2727
"IpexPaged_attention",
28+
"IpexRmsNorm",
2829
]
2930

3031

@@ -164,6 +165,10 @@ def IpexSDP_dropout(
164165
)
165166

166167

168+
def IpexRmsNorm(input, normalized_shape, weight, epsilon) -> Tensor:
169+
return torch.ops.torch_ipex.rms_norm_impl(input, normalized_shape, weight, epsilon)
170+
171+
167172
def varlen_fwd(
168173
query, # [total_q, num_head, head_size]
169174
key, # [total_k, num_head_k, head_size]

tests/gpu/examples/test_rms_norm.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,56 @@ def test_rms_norm_fw_xpu(dtype):
4040
w = model.weight.xpu()
4141
output = torch.ops.torch_ipex.rms_norm(input_case, [hsz], w, 1e-5)
4242
output1 = ipex.llm.functional.rms_norm(input_case, w, 1e-5)
43+
output2 = torch.xpu.IpexRmsNorm(input_case, [hsz], w, 1e-5)
4344
# diff = (output.cpu() - output_ref).abs().max().item()
4445
# print('diff', diff)
4546
# assert diff < 1e-2
4647
self.assertEqual(output[0].cpu(), output_ref, atol=1e-2, rtol=1e-2)
4748
self.assertEqual(output1.cpu(), output_ref, atol=1e-2, rtol=1e-2)
49+
self.assertEqual(output2.cpu(), output_ref, atol=1e-2, rtol=1e-2)
4850

4951
test_rms_norm_fw_xpu(torch.float)
5052
test_rms_norm_fw_xpu(torch.bfloat16)
53+
54+
def test_rms_norm_bw(self):
55+
def test_rms_norm_fwd_bwd(dtype):
56+
print("test_rms_norm_fw_bw", dtype)
57+
torch.manual_seed(13)
58+
modelb = RMSNormRef(64)
59+
model0 = RMSNormRef(768)
60+
model1 = RMSNormRef(2048)
61+
model2 = RMSNormRef(4096)
62+
model3 = RMSNormRef(16384)
63+
model4 = RMSNormRef(16384 * 4 + 123)
64+
hszs = [64, 768, 2048, 4096, 16384, 16384 * 4 + 123]
65+
ls = [modelb, model0, model1, model2, model3, model4]
66+
for i, model in enumerate(ls):
67+
model = model.to(dtype)
68+
hsz = hszs[i]
69+
input_case = torch.rand(4, 1024, hsz).to(dtype)
70+
input_case.requires_grad_(True)
71+
grad = torch.rand(4, 1024, hsz).to(dtype)
72+
output_ref = model(input_case)
73+
output_ref.backward(grad)
74+
grad_wei = model.weight.grad.clone()
75+
input_grad_cpu = input_case.grad.clone()
76+
w = model.weight.clone()
77+
78+
input_case_xpu = input_case.clone().xpu()
79+
input_case_xpu.retain_grad()
80+
input_case_xpu.requires_grad_(True)
81+
grad_xpu = grad.xpu()
82+
w = w.xpu()
83+
w.retain_grad()
84+
w.requires_grad_(True)
85+
output1 = torch.xpu.IpexRmsNorm(input_case_xpu, [hsz], w, 1e-5)
86+
output1.backward(grad_xpu)
87+
grad_wei_xpu = w.grad
88+
89+
self.assertEqual(grad_wei_xpu.cpu(), grad_wei, atol=10e-2, rtol=10e-2)
90+
self.assertEqual(
91+
input_case_xpu.grad.cpu(), input_grad_cpu, atol=10e-2, rtol=10e-2
92+
)
93+
94+
test_rms_norm_fwd_bwd(torch.bfloat16)
95+
test_rms_norm_fwd_bwd(torch.float)

0 commit comments

Comments
 (0)