|
2 | 2 | #include <ATen/Config.h> |
3 | 3 | #include <ATen/NativeFunctions.h> |
4 | 4 |
|
| 5 | +#include <ATen/record_function.h> |
5 | 6 | #include <oneDNN/oneDNN.h> |
| 7 | +#include <torch/autograd.h> |
| 8 | +#include <torch/custom_class.h> |
| 9 | +#include <utils/SimpleTrace.h> |
6 | 10 | #include "Norm.h" |
| 11 | +#include "comm/ATDispatch.h" |
7 | 12 | #include "comm/RegistrationDeclarations.h" |
8 | 13 | #include "utils/CustomOperatorRegistration.h" |
9 | 14 |
|
10 | 15 | using namespace torch_ipex::xpu::dpcpp; |
| 16 | +using namespace torch::autograd; |
11 | 17 | using namespace at::AtenIpexTypeXPU::normalization; |
12 | 18 |
|
13 | 19 | namespace at { |
14 | 20 | namespace AtenIpexTypeXPU { |
15 | 21 |
|
| 22 | +std::tuple<Tensor, Tensor> rms_norm_fw( |
| 23 | + const Tensor& input, |
| 24 | + at::IntArrayRef normalized_shape, |
| 25 | + const Tensor& weight, |
| 26 | + double epsilon); |
| 27 | + |
| 28 | +std::tuple<Tensor, Tensor> rms_norm_bw( |
| 29 | + const Tensor& grad_output, |
| 30 | + const Tensor& input, |
| 31 | + at::IntArrayRef normalized_shape, |
| 32 | + const Tensor& rstd, |
| 33 | + const Tensor& weight, |
| 34 | + std::array<bool, 2> grad_input_mask); |
| 35 | + |
16 | 36 | template <typename scalar_t, typename mean_t, typename weight_t> |
17 | 37 | class RMSNormForward : public NormForward<scalar_t, mean_t, weight_t, true> { |
18 | 38 | public: |
@@ -337,12 +357,13 @@ void RMSNormKernelImpl( |
337 | 357 | X.scalar_type(), |
338 | 358 | "RMSNormKernelImpl", |
339 | 359 | [&]() { |
340 | | - rstd = at::empty({M}, X.options().dtype(kFloat)); |
341 | 360 | if (gamma.scalar_type() == kFloat) { |
| 361 | + rstd = at::empty({M}, X.options().dtype(kFloat)); |
342 | 362 | RMSNormKernelImplInternal<scalar_t, float, float>( |
343 | 363 | X, gamma, M, N, static_cast<acc_type<scalar_t>>(eps), Y, rstd); |
344 | 364 | } else { |
345 | | - RMSNormKernelImplInternal<scalar_t, float, scalar_t>( |
| 365 | + rstd = at::empty({M}, X.options()); |
| 366 | + RMSNormKernelImplInternal<scalar_t, scalar_t, scalar_t>( |
346 | 367 | X, gamma, M, N, static_cast<acc_type<scalar_t>>(eps), Y, rstd); |
347 | 368 | } |
348 | 369 | }); |
@@ -374,11 +395,251 @@ std::tuple<Tensor, Tensor> rms_norm_fw( |
374 | 395 | return std::make_tuple(output.reshape(input.sizes()), rstd); |
375 | 396 | } |
376 | 397 |
|
| 398 | +template <typename scalar_t, typename mean_t, typename weight_t> |
| 399 | +void RmsNormBackwardKernelImplInternal( |
| 400 | + const Tensor& dY, |
| 401 | + const Tensor& X, |
| 402 | + const Tensor& rstd, |
| 403 | + const Tensor& gamma, |
| 404 | + int64_t M, |
| 405 | + int64_t N, |
| 406 | + Tensor& dX, |
| 407 | + Tensor& dgamma, |
| 408 | + const Tensor& output, |
| 409 | + std::array<bool, 2> grad_input_mask) { |
| 410 | + TORCH_CHECK(dY.numel() == M * N); |
| 411 | + TORCH_CHECK(rstd.numel() == M); |
| 412 | + |
| 413 | + using accscalar_t = acc_type<scalar_t>; |
| 414 | + mean_t* var_data = rstd.data_ptr<mean_t>(); |
| 415 | + weight_t* gamma_data = gamma.defined() ? gamma.data_ptr<weight_t>() : nullptr; |
| 416 | + |
| 417 | + if (grad_input_mask[0]) { |
| 418 | + // backward data |
| 419 | + scalar_t* X_data = X.data_ptr<scalar_t>(); |
| 420 | + scalar_t* dY_data = dY.data_ptr<scalar_t>(); |
| 421 | + scalar_t* dX_data = dX.data_ptr<scalar_t>(); |
| 422 | + |
| 423 | + auto config = NormConfig(M, N, 1, sizeof(scalar_t)); |
| 424 | + bool can_use_32bit_index = canUse32BitIndexMath(X) && |
| 425 | + canUse32BitIndexMath(dY) && canUse32BitIndexMath(dX); |
| 426 | + |
| 427 | + // TODO: force it to use fused_norm_kernel |
| 428 | + config.workgroup_num_foreach = 1; |
| 429 | + config.WGPlane = config.Plane; |
| 430 | + |
| 431 | + if (config.workgroup_num_foreach == 1) { |
| 432 | + RMSNormBackward<scalar_t, mean_t, weight_t> rms_norm_backward( |
| 433 | + X_data, dY_data, dX_data, var_data, gamma_data, M, N); |
| 434 | + launch_vectorized_fused_norm_kernel< |
| 435 | + scalar_t, |
| 436 | + mean_t, |
| 437 | + weight_t, |
| 438 | + RMSNormBackward, |
| 439 | + true>(rms_norm_backward, config, can_use_32bit_index); |
| 440 | + } else { |
| 441 | + const auto kAccType = |
| 442 | + (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16) |
| 443 | + ? kFloat |
| 444 | + : X.scalar_type(); |
| 445 | + Tensor a = at::empty({M}, X.options().dtype(kAccType)); |
| 446 | + accscalar_t* a_data = a.data_ptr<accscalar_t>(); |
| 447 | + |
| 448 | + RMSNormBackward<scalar_t, mean_t, weight_t> rms_norm_backward( |
| 449 | + X_data, dY_data, dX_data, var_data, gamma_data, a_data, M, N); |
| 450 | + Tensor semaphores, scratchpad; |
| 451 | + config.template init_global_reduce<accscalar_t>( |
| 452 | + X, semaphores, scratchpad); |
| 453 | + RowwiseMomentsDPCPPKernelImpl< |
| 454 | + scalar_t, |
| 455 | + mean_t, |
| 456 | + weight_t, |
| 457 | + RMSNormBackward, |
| 458 | + true>(rms_norm_backward, config, can_use_32bit_index); |
| 459 | + NormUpdateKernelImpl<scalar_t, mean_t, weight_t, RMSNormBackward, true>( |
| 460 | + rms_norm_backward, config, can_use_32bit_index); |
| 461 | + } |
| 462 | + } |
| 463 | + |
| 464 | + if (grad_input_mask[1]) { |
| 465 | + // backward weight |
| 466 | + Tensor sum_tmp = at::mul(output, dY); |
| 467 | + at::sum_out(dgamma, sum_tmp, at::IntArrayRef{0, 1}); |
| 468 | + } |
| 469 | +} |
| 470 | + |
| 471 | +void RmsNormBackwardKernelImpl( |
| 472 | + const Tensor& dY, |
| 473 | + const Tensor& X, |
| 474 | + const Tensor& rstd, |
| 475 | + const Tensor& gamma, |
| 476 | + int64_t M, |
| 477 | + int64_t N, |
| 478 | + Tensor& dX, |
| 479 | + Tensor& dgamma, |
| 480 | + const Tensor& output, |
| 481 | + std::array<bool, 2> grad_input_mask) { |
| 482 | + IPEX_DISPATCH_FLOATING_TYPES_AND2( |
| 483 | + at::ScalarType::Half, |
| 484 | + at::ScalarType::BFloat16, |
| 485 | + X.scalar_type(), |
| 486 | + "RmsNormBackwardKernelImpl", |
| 487 | + [&]() { |
| 488 | + using accscalar_t = acc_type<scalar_t>; |
| 489 | + if (gamma.scalar_type() == kFloat) { |
| 490 | + RmsNormBackwardKernelImplInternal<scalar_t, float, float>( |
| 491 | + dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask); |
| 492 | + } else { |
| 493 | + RmsNormBackwardKernelImplInternal<scalar_t, scalar_t, scalar_t>( |
| 494 | + dY, X, rstd, gamma, M, N, dX, dgamma, output, grad_input_mask); |
| 495 | + } |
| 496 | + }); |
| 497 | +} |
| 498 | + |
| 499 | +std::tuple<Tensor, Tensor> rms_norm_bw( |
| 500 | + const Tensor& grad_output, |
| 501 | + const Tensor& input, |
| 502 | + at::IntArrayRef normalized_shape, |
| 503 | + const Tensor& rstd, |
| 504 | + const Tensor& weight, |
| 505 | + const Tensor& output, |
| 506 | + std::array<bool, 2> grad_input_mask) { |
| 507 | + RECORD_FUNCTION("ipex::rms_norm_bw", std::vector<c10::IValue>({grad_output})); |
| 508 | + auto M_N = |
| 509 | + _check_layer_norm_inputs(input, normalized_shape, weight, Tensor()); |
| 510 | + auto M = M_N.first; |
| 511 | + auto N = M_N.second; |
| 512 | + |
| 513 | + Tensor grad_input; |
| 514 | + Tensor grad_weight; |
| 515 | + |
| 516 | + if (grad_input_mask[0]) { |
| 517 | + grad_input = at::native::empty_like( |
| 518 | + input, |
| 519 | + c10::nullopt /* dtype */, |
| 520 | + c10::nullopt /* layout */, |
| 521 | + c10::nullopt /* device */, |
| 522 | + c10::nullopt /* pin_memory */, |
| 523 | + LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| 524 | + } |
| 525 | + |
| 526 | + if (grad_input_mask[1]) { |
| 527 | + grad_weight = M > 0 ? at::native::empty_like( |
| 528 | + weight, |
| 529 | + c10::nullopt /* dtype */, |
| 530 | + c10::nullopt /* layout */, |
| 531 | + c10::nullopt /* device */, |
| 532 | + c10::nullopt /* pin_memory */, |
| 533 | + LEGACY_CONTIGUOUS_MEMORY_FORMAT) |
| 534 | + : at::native::zeros_like( |
| 535 | + weight, |
| 536 | + c10::nullopt /* dtype */, |
| 537 | + c10::nullopt /* layout */, |
| 538 | + c10::nullopt /* device */, |
| 539 | + c10::nullopt /* pin_memory */, |
| 540 | + LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| 541 | + } |
| 542 | + |
| 543 | + if (input.numel() != 0 && grad_output.numel() != 0) { |
| 544 | + Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; |
| 545 | + Tensor grad_output_ = |
| 546 | + (grad_output.dim() == 1) ? grad_output.reshape({M, N}) : grad_output; |
| 547 | + Tensor weight_ = |
| 548 | + (weight.defined() && weight.dim() == 1) ? weight.reshape({N}) : weight; |
| 549 | + Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output; |
| 550 | + |
| 551 | + input_ = input_.contiguous(); |
| 552 | + grad_output_ = grad_output_.contiguous(); |
| 553 | + output_ = output_.contiguous(); |
| 554 | + weight_ = weight_.defined() ? weight_.contiguous() : weight_; |
| 555 | + |
| 556 | + RmsNormBackwardKernelImpl( |
| 557 | + grad_output_, |
| 558 | + input_, |
| 559 | + rstd, |
| 560 | + weight_, |
| 561 | + M, |
| 562 | + N, |
| 563 | + grad_input, |
| 564 | + grad_weight, |
| 565 | + output_, |
| 566 | + grad_input_mask); |
| 567 | + } |
| 568 | + return std::make_tuple( |
| 569 | + grad_input_mask[0] ? grad_input.reshape(input.sizes()) : grad_input, |
| 570 | + grad_input_mask[1] ? grad_weight.reshape(weight.sizes()) : grad_weight); |
| 571 | +} |
| 572 | + |
| 573 | +class IPEXRmsNormOp : public Function<IPEXRmsNormOp> { |
| 574 | + public: |
| 575 | + static variable_list forward( |
| 576 | + AutogradContext* ctx, |
| 577 | + const Tensor& input, |
| 578 | + at::IntArrayRef normalized_shape, |
| 579 | + const Tensor& weight, |
| 580 | + double epsilon) { |
| 581 | +#ifdef BUILD_SIMPLE_TRACE |
| 582 | + SimpleTrace trace( |
| 583 | + "IPEXRmsNormOp forward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::forward"); |
| 584 | +#endif |
| 585 | + ctx->saved_data["input_requires_grad"] = input.requires_grad(); |
| 586 | + ctx->saved_data["weight_requires_grad"] = weight.requires_grad(); |
| 587 | + ctx->saved_data["normalized_shape"] = normalized_shape; |
| 588 | + auto outputs = rms_norm_fw(input, normalized_shape, weight, epsilon); |
| 589 | + |
| 590 | + ctx->save_for_backward( |
| 591 | + {input, weight, std::get<0>(outputs), std::get<1>(outputs)}); |
| 592 | + variable_list result = {std::get<0>(outputs), std::get<1>(outputs)}; |
| 593 | + return result; |
| 594 | + } |
| 595 | + |
| 596 | + static variable_list backward( |
| 597 | + AutogradContext* ctx, |
| 598 | + variable_list grad_outputs) { |
| 599 | +#ifdef BUILD_SIMPLE_TRACE |
| 600 | + SimpleTrace trace( |
| 601 | + "IPEXRmsNormOp backward -> at::AtenIpexTypeXPU::IPEXRmsNormOp::backward"); |
| 602 | +#endif |
| 603 | + auto weight_requires_grad = |
| 604 | + ctx->saved_data["weight_requires_grad"].toBool(); |
| 605 | + auto input_requires_grad = ctx->saved_data["input_requires_grad"].toBool(); |
| 606 | + auto saved = ctx->get_saved_variables(); |
| 607 | + Tensor input = saved[0]; |
| 608 | + Tensor weight = saved[1]; |
| 609 | + Tensor output = saved[2]; |
| 610 | + Tensor rstd = saved[3]; |
| 611 | + auto normalized_shape = weight.sizes(); |
| 612 | + |
| 613 | + auto grad_inputs = rms_norm_bw( |
| 614 | + grad_outputs[0], |
| 615 | + input, |
| 616 | + normalized_shape, |
| 617 | + rstd, |
| 618 | + weight, |
| 619 | + output, |
| 620 | + {input_requires_grad, weight_requires_grad}); |
| 621 | + return { |
| 622 | + std::get<0>(grad_inputs), Tensor(), std::get<1>(grad_inputs), Tensor()}; |
| 623 | + } |
| 624 | +}; |
| 625 | + |
| 626 | +Tensor rms_norm_impl( |
| 627 | + const Tensor& input, |
| 628 | + at::IntArrayRef normalized_shape, |
| 629 | + const Tensor& weight, |
| 630 | + double epsilon) { |
| 631 | + auto output = IPEXRmsNormOp::apply(input, normalized_shape, weight, epsilon); |
| 632 | + return output[0]; |
| 633 | +} |
377 | 634 | } // namespace AtenIpexTypeXPU |
378 | 635 | } // namespace at |
379 | 636 |
|
380 | 637 | namespace { |
381 | 638 | IPEX_LIBRARY_FRAGMENT() { |
| 639 | + IPEX_OP_REGISTER_DISPATCH( |
| 640 | + "rms_norm_impl", |
| 641 | + at::AtenIpexTypeXPU::rms_norm_impl, |
| 642 | + c10::DispatchKey::AutogradXPU); |
382 | 643 | IPEX_OP_REGISTER("rms_norm.xpu", at::AtenIpexTypeXPU::rms_norm_fw); |
383 | 644 | } |
384 | 645 | } // namespace |
0 commit comments