11#include < ATen/ATen.h>
22#include < ATen/Context.h>
33#include < ATen/SparseTensorUtils.h>
4+ #include < ATen/autocast_mode.h>
45#include < ATen/native/BinaryOps.h>
56#include < ATen/native/TensorIterator.h>
67#include < ATen/record_function.h>
@@ -23,6 +24,9 @@ using namespace at::sparse;
2324
2425namespace at {
2526namespace AtenIpexTypeXPU {
27+ using autocast::cached_cast;
28+ using autocast::get_lower_precision_fp_from_device_type;
29+ using autocast::promote_type;
2630
2731std::tuple<Tensor, Tensor> sort (
2832 const Tensor& self,
@@ -38,7 +42,8 @@ static void mul_add_kernel_dpcpp(TensorIterator& iter, Scalar alpha_scalar) {
3842 iter.dtype (),
3943 " mul_add" ,
4044 [&]() {
41- auto alpha = alpha_scalar.to <scalar_t >();
45+ using accscalar_t = acc_type<scalar_t >;
46+ auto alpha = alpha_scalar.to <accscalar_t >();
4247 dpcpp_kernel_for_tensor_iter (
4348 iter, [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
4449 return a * b + alpha * c;
@@ -141,8 +146,9 @@ Tensor mul_scalar_add_scalar(
141146 iter.dtype (),
142147 " mul_scalar_add_scalar" ,
143148 [&]() {
144- auto add_scalar = alpha.to <scalar_t >() * accumu.to <scalar_t >();
145- auto other_scalar = other.to <scalar_t >();
149+ using accscalar_t = acc_type<scalar_t >;
150+ auto add_scalar = alpha.to <accscalar_t >() * accumu.to <accscalar_t >();
151+ auto other_scalar = other.to <accscalar_t >();
146152 dpcpp_kernel_for_tensor_iter (iter, [=](scalar_t a) -> scalar_t {
147153 return a * other_scalar + add_scalar;
148154 });
@@ -151,6 +157,15 @@ Tensor mul_scalar_add_scalar(
151157 return result;
152158}
153159
160+ Tensor mul_scalar_add_scalar_autocast (
161+ const Tensor& self,
162+ Scalar other,
163+ Scalar accumu,
164+ Scalar alpha) {
165+ c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::AutocastXPU);
166+ return mul_scalar_add_scalar (self, other, accumu, alpha);
167+ }
168+
154169Tensor mul_add_scalar (
155170 const Tensor& self,
156171 const Tensor& other,
@@ -174,7 +189,8 @@ Tensor mul_add_scalar(
174189 iter.dtype (),
175190 " mul_scalar_add_scalar" ,
176191 [&]() {
177- auto add_scalar = alpha.to <scalar_t >() * accumu.to <scalar_t >();
192+ using accscalar_t = acc_type<scalar_t >;
193+ auto add_scalar = alpha.to <accscalar_t >() * accumu.to <accscalar_t >();
178194 dpcpp_kernel_for_tensor_iter (
179195 iter, [=](scalar_t a, scalar_t b) -> scalar_t {
180196 return a * b + add_scalar;
@@ -184,6 +200,24 @@ Tensor mul_add_scalar(
184200 return result;
185201}
186202
203+ Tensor mul_add_scalar_autocast (
204+ const Tensor& self,
205+ const Tensor& other,
206+ Scalar accumu,
207+ Scalar alpha) {
208+ c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::AutocastXPU);
209+ auto to_type = promote_type (
210+ get_lower_precision_fp_from_device_type (c10::DeviceType::XPU),
211+ c10::DeviceType::XPU,
212+ self,
213+ other);
214+ return mul_add_scalar (
215+ cached_cast (to_type, self, c10::DeviceType::XPU),
216+ cached_cast (to_type, other, c10::DeviceType::XPU),
217+ accumu,
218+ alpha);
219+ }
220+
187221Tensor mul_scalar_add (
188222 const Tensor& self,
189223 Scalar other,
@@ -207,8 +241,9 @@ Tensor mul_scalar_add(
207241 iter.dtype (),
208242 " mul_scalar_add_scalar" ,
209243 [&]() {
210- auto alpha_scalar = alpha.to <scalar_t >();
211- auto other_scalar = other.to <scalar_t >();
244+ using accscalar_t = acc_type<scalar_t >;
245+ auto alpha_scalar = alpha.to <accscalar_t >();
246+ auto other_scalar = other.to <accscalar_t >();
212247 dpcpp_kernel_for_tensor_iter (
213248 iter, [=](scalar_t a, scalar_t b) -> scalar_t {
214249 return a * other_scalar + b * alpha_scalar;
@@ -218,6 +253,24 @@ Tensor mul_scalar_add(
218253 return result;
219254}
220255
256+ Tensor mul_scalar_add_autocast (
257+ const Tensor& self,
258+ Scalar other,
259+ const Tensor& accumu,
260+ Scalar alpha) {
261+ c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::AutocastXPU);
262+ auto to_type = promote_type (
263+ get_lower_precision_fp_from_device_type (c10::DeviceType::XPU),
264+ c10::DeviceType::XPU,
265+ self,
266+ accumu);
267+ return mul_scalar_add (
268+ cached_cast (to_type, self, c10::DeviceType::XPU),
269+ other,
270+ cached_cast (to_type, accumu, c10::DeviceType::XPU),
271+ alpha);
272+ }
273+
221274Tensor mul_add (
222275 const Tensor& self,
223276 const Tensor& other,
@@ -242,6 +295,25 @@ Tensor mul_add(
242295 return result;
243296}
244297
298+ Tensor mul_add_autocast (
299+ const Tensor& self,
300+ const Tensor& other,
301+ const Tensor& accumu,
302+ Scalar alpha) {
303+ c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::AutocastXPU);
304+ auto to_type = promote_type (
305+ get_lower_precision_fp_from_device_type (c10::DeviceType::XPU),
306+ c10::DeviceType::XPU,
307+ self,
308+ other,
309+ accumu);
310+ return mul_add (
311+ cached_cast (to_type, self, c10::DeviceType::XPU),
312+ cached_cast (to_type, other, c10::DeviceType::XPU),
313+ cached_cast (to_type, accumu, c10::DeviceType::XPU),
314+ alpha);
315+ }
316+
245317template <typename scalar_t >
246318static inline void packed_add_kernel (
247319 unsigned short * __restrict__ w_MSB,
@@ -428,10 +500,27 @@ Tensor packed_add(
428500
429501namespace {
430502IPEX_LIBRARY_FRAGMENT () {
431- IPEX_OP_REGISTER (" mul_add" , mul_add);
432- IPEX_OP_REGISTER (" mul_add.Scalar_Tensor" , mul_add_scalar);
433- IPEX_OP_REGISTER (" mul_add.Tensor_Scalar" , mul_scalar_add);
434- IPEX_OP_REGISTER (" mul_add.Scalar_Scalar" , mul_scalar_add_scalar);
503+ IPEX_OP_REGISTER_DISPATCH (" mul_add" , mul_add, c10::DispatchKey::XPU);
504+ IPEX_OP_REGISTER_DISPATCH (
505+ " mul_add" , mul_add_autocast, c10::DispatchKey::AutocastXPU);
506+ IPEX_OP_REGISTER_DISPATCH (
507+ " mul_add.Scalar_Tensor" , mul_scalar_add, c10::DispatchKey::XPU);
508+ IPEX_OP_REGISTER_DISPATCH (
509+ " mul_add.Scalar_Tensor" ,
510+ mul_scalar_add_autocast,
511+ c10::DispatchKey::AutocastXPU);
512+ IPEX_OP_REGISTER_DISPATCH (
513+ " mul_add.Tensor_Scalar" , mul_add_scalar, c10::DispatchKey::XPU);
514+ IPEX_OP_REGISTER_DISPATCH (
515+ " mul_add.Tensor_Scalar" ,
516+ mul_add_scalar_autocast,
517+ c10::DispatchKey::AutocastXPU);
518+ IPEX_OP_REGISTER_DISPATCH (
519+ " mul_add.Scalar_Scalar" , mul_scalar_add_scalar, c10::DispatchKey::XPU);
520+ IPEX_OP_REGISTER_DISPATCH (
521+ " mul_add.Scalar_Scalar" ,
522+ mul_scalar_add_scalar_autocast,
523+ c10::DispatchKey::AutocastXPU);
435524 IPEX_OP_REGISTER_DISPATCH (
436525 " packed_add" , at::AtenIpexTypeXPU::packed_add, c10::DispatchKey::XPU);
437526 IPEX_OP_REGISTER_DISPATCH (
0 commit comments