Skip to content

Commit 4f0598b

Browse files
ganyi1996ppo1pikachuweishi-denggujinghui
authored
autocast: register mul_add and roialign to autocast on low precision policy (#2511) (#2540)
* add mul add to autocast support * register mul_add to autocast and run it with low precision * add roialign to autocast * using promote policy on mul_add in autocast scenario * using acc type as compute type in mul_add * using retrive device as autocast device * add bf16 and fp16 test on strict tolerance * add roialign into torchvision's autocast op --------- Co-authored-by: Du, Jun <jun.du@intel.com> Co-authored-by: Deng, Weishi <weishi.deng@intel.com> Co-authored-by: Jinghui <jinghui.gu@intel.com>
1 parent 1c1f13e commit 4f0598b

File tree

4 files changed

+170
-16
lines changed

4 files changed

+170
-16
lines changed

csrc/gpu/aten/operators/ROIAlign.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <runtime/Utils.h>
99
#include <utils/DPCPP.h>
1010

11+
#include <ATen/autocast_mode.h>
1112
#include "RandomEngine.h"
1213
#include "comm/ATDispatch.h"
1314
#include "comm/AccumulateType.h"
@@ -495,24 +496,57 @@ at::Tensor roi_align_backward_kernel(
495496
return grad_input;
496497
}
497498

499+
at::Tensor roi_align_forward_autocast(
500+
const at::Tensor& input,
501+
const at::Tensor& rois,
502+
double spatial_scale,
503+
int64_t pooled_height,
504+
int64_t pooled_width,
505+
int64_t sampling_ratio,
506+
bool aligned) {
507+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastXPU);
508+
return roi_align_forward_kernel(
509+
at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::XPU),
510+
at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::XPU),
511+
spatial_scale,
512+
pooled_height,
513+
pooled_width,
514+
sampling_ratio,
515+
aligned)
516+
.to(input.scalar_type());
517+
}
518+
498519
} // namespace AtenIpexTypeXPU
499520
} // namespace at
500521

501522
namespace {
502523
IPEX_LIBRARY_FRAGMENT() {
503-
IPEX_OP_REGISTER(
504-
"roi_align.xpu", at::AtenIpexTypeXPU::roi_align_forward_kernel);
505-
IPEX_OP_REGISTER(
524+
IPEX_OP_REGISTER_DISPATCH(
525+
"roi_align.xpu",
526+
at::AtenIpexTypeXPU::roi_align_forward_kernel,
527+
c10::DispatchKey::XPU);
528+
IPEX_OP_REGISTER_DISPATCH(
506529
"_roi_align_backward.xpu",
507-
at::AtenIpexTypeXPU::roi_align_backward_kernel);
530+
at::AtenIpexTypeXPU::roi_align_backward_kernel,
531+
c10::DispatchKey::XPU);
532+
IPEX_OP_REGISTER_DISPATCH(
533+
"roi_align.xpu",
534+
at::AtenIpexTypeXPU::roi_align_forward_autocast,
535+
c10::DispatchKey::AutocastXPU);
508536
}
509537

510-
IPEX_TORCH_LIBRARY_IMPL(torchvision, XPU, m) {
538+
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
511539
m.impl(
512540
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
541+
c10::DispatchKey::XPU,
513542
TORCH_FN((&at::AtenIpexTypeXPU::roi_align_forward_kernel)));
514543
m.impl(
515544
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
545+
c10::DispatchKey::XPU,
516546
TORCH_FN((&at::AtenIpexTypeXPU::roi_align_backward_kernel)));
547+
m.impl(
548+
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
549+
c10::DispatchKey::AutocastXPU,
550+
TORCH_FN((&at::AtenIpexTypeXPU::roi_align_forward_autocast)));
517551
}
518552
} // namespace

csrc/gpu/aten/operators/TripleOps.cpp

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/ATen.h>
22
#include <ATen/Context.h>
33
#include <ATen/SparseTensorUtils.h>
4+
#include <ATen/autocast_mode.h>
45
#include <ATen/native/BinaryOps.h>
56
#include <ATen/native/TensorIterator.h>
67
#include <ATen/record_function.h>
@@ -23,6 +24,9 @@ using namespace at::sparse;
2324

2425
namespace at {
2526
namespace AtenIpexTypeXPU {
27+
using autocast::cached_cast;
28+
using autocast::get_lower_precision_fp_from_device_type;
29+
using autocast::promote_type;
2630

2731
std::tuple<Tensor, Tensor> sort(
2832
const Tensor& self,
@@ -38,7 +42,8 @@ static void mul_add_kernel_dpcpp(TensorIterator& iter, Scalar alpha_scalar) {
3842
iter.dtype(),
3943
"mul_add",
4044
[&]() {
41-
auto alpha = alpha_scalar.to<scalar_t>();
45+
using accscalar_t = acc_type<scalar_t>;
46+
auto alpha = alpha_scalar.to<accscalar_t>();
4247
dpcpp_kernel_for_tensor_iter(
4348
iter, [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
4449
return a * b + alpha * c;
@@ -141,8 +146,9 @@ Tensor mul_scalar_add_scalar(
141146
iter.dtype(),
142147
"mul_scalar_add_scalar",
143148
[&]() {
144-
auto add_scalar = alpha.to<scalar_t>() * accumu.to<scalar_t>();
145-
auto other_scalar = other.to<scalar_t>();
149+
using accscalar_t = acc_type<scalar_t>;
150+
auto add_scalar = alpha.to<accscalar_t>() * accumu.to<accscalar_t>();
151+
auto other_scalar = other.to<accscalar_t>();
146152
dpcpp_kernel_for_tensor_iter(iter, [=](scalar_t a) -> scalar_t {
147153
return a * other_scalar + add_scalar;
148154
});
@@ -151,6 +157,15 @@ Tensor mul_scalar_add_scalar(
151157
return result;
152158
}
153159

160+
Tensor mul_scalar_add_scalar_autocast(
161+
const Tensor& self,
162+
Scalar other,
163+
Scalar accumu,
164+
Scalar alpha) {
165+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastXPU);
166+
return mul_scalar_add_scalar(self, other, accumu, alpha);
167+
}
168+
154169
Tensor mul_add_scalar(
155170
const Tensor& self,
156171
const Tensor& other,
@@ -174,7 +189,8 @@ Tensor mul_add_scalar(
174189
iter.dtype(),
175190
"mul_scalar_add_scalar",
176191
[&]() {
177-
auto add_scalar = alpha.to<scalar_t>() * accumu.to<scalar_t>();
192+
using accscalar_t = acc_type<scalar_t>;
193+
auto add_scalar = alpha.to<accscalar_t>() * accumu.to<accscalar_t>();
178194
dpcpp_kernel_for_tensor_iter(
179195
iter, [=](scalar_t a, scalar_t b) -> scalar_t {
180196
return a * b + add_scalar;
@@ -184,6 +200,24 @@ Tensor mul_add_scalar(
184200
return result;
185201
}
186202

203+
Tensor mul_add_scalar_autocast(
204+
const Tensor& self,
205+
const Tensor& other,
206+
Scalar accumu,
207+
Scalar alpha) {
208+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastXPU);
209+
auto to_type = promote_type(
210+
get_lower_precision_fp_from_device_type(c10::DeviceType::XPU),
211+
c10::DeviceType::XPU,
212+
self,
213+
other);
214+
return mul_add_scalar(
215+
cached_cast(to_type, self, c10::DeviceType::XPU),
216+
cached_cast(to_type, other, c10::DeviceType::XPU),
217+
accumu,
218+
alpha);
219+
}
220+
187221
Tensor mul_scalar_add(
188222
const Tensor& self,
189223
Scalar other,
@@ -207,8 +241,9 @@ Tensor mul_scalar_add(
207241
iter.dtype(),
208242
"mul_scalar_add_scalar",
209243
[&]() {
210-
auto alpha_scalar = alpha.to<scalar_t>();
211-
auto other_scalar = other.to<scalar_t>();
244+
using accscalar_t = acc_type<scalar_t>;
245+
auto alpha_scalar = alpha.to<accscalar_t>();
246+
auto other_scalar = other.to<accscalar_t>();
212247
dpcpp_kernel_for_tensor_iter(
213248
iter, [=](scalar_t a, scalar_t b) -> scalar_t {
214249
return a * other_scalar + b * alpha_scalar;
@@ -218,6 +253,24 @@ Tensor mul_scalar_add(
218253
return result;
219254
}
220255

256+
Tensor mul_scalar_add_autocast(
257+
const Tensor& self,
258+
Scalar other,
259+
const Tensor& accumu,
260+
Scalar alpha) {
261+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastXPU);
262+
auto to_type = promote_type(
263+
get_lower_precision_fp_from_device_type(c10::DeviceType::XPU),
264+
c10::DeviceType::XPU,
265+
self,
266+
accumu);
267+
return mul_scalar_add(
268+
cached_cast(to_type, self, c10::DeviceType::XPU),
269+
other,
270+
cached_cast(to_type, accumu, c10::DeviceType::XPU),
271+
alpha);
272+
}
273+
221274
Tensor mul_add(
222275
const Tensor& self,
223276
const Tensor& other,
@@ -242,6 +295,25 @@ Tensor mul_add(
242295
return result;
243296
}
244297

298+
Tensor mul_add_autocast(
299+
const Tensor& self,
300+
const Tensor& other,
301+
const Tensor& accumu,
302+
Scalar alpha) {
303+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastXPU);
304+
auto to_type = promote_type(
305+
get_lower_precision_fp_from_device_type(c10::DeviceType::XPU),
306+
c10::DeviceType::XPU,
307+
self,
308+
other,
309+
accumu);
310+
return mul_add(
311+
cached_cast(to_type, self, c10::DeviceType::XPU),
312+
cached_cast(to_type, other, c10::DeviceType::XPU),
313+
cached_cast(to_type, accumu, c10::DeviceType::XPU),
314+
alpha);
315+
}
316+
245317
template <typename scalar_t>
246318
static inline void packed_add_kernel(
247319
unsigned short* __restrict__ w_MSB,
@@ -428,10 +500,27 @@ Tensor packed_add(
428500

429501
namespace {
430502
IPEX_LIBRARY_FRAGMENT() {
431-
IPEX_OP_REGISTER("mul_add", mul_add);
432-
IPEX_OP_REGISTER("mul_add.Scalar_Tensor", mul_add_scalar);
433-
IPEX_OP_REGISTER("mul_add.Tensor_Scalar", mul_scalar_add);
434-
IPEX_OP_REGISTER("mul_add.Scalar_Scalar", mul_scalar_add_scalar);
503+
IPEX_OP_REGISTER_DISPATCH("mul_add", mul_add, c10::DispatchKey::XPU);
504+
IPEX_OP_REGISTER_DISPATCH(
505+
"mul_add", mul_add_autocast, c10::DispatchKey::AutocastXPU);
506+
IPEX_OP_REGISTER_DISPATCH(
507+
"mul_add.Scalar_Tensor", mul_scalar_add, c10::DispatchKey::XPU);
508+
IPEX_OP_REGISTER_DISPATCH(
509+
"mul_add.Scalar_Tensor",
510+
mul_scalar_add_autocast,
511+
c10::DispatchKey::AutocastXPU);
512+
IPEX_OP_REGISTER_DISPATCH(
513+
"mul_add.Tensor_Scalar", mul_add_scalar, c10::DispatchKey::XPU);
514+
IPEX_OP_REGISTER_DISPATCH(
515+
"mul_add.Tensor_Scalar",
516+
mul_add_scalar_autocast,
517+
c10::DispatchKey::AutocastXPU);
518+
IPEX_OP_REGISTER_DISPATCH(
519+
"mul_add.Scalar_Scalar", mul_scalar_add_scalar, c10::DispatchKey::XPU);
520+
IPEX_OP_REGISTER_DISPATCH(
521+
"mul_add.Scalar_Scalar",
522+
mul_scalar_add_scalar_autocast,
523+
c10::DispatchKey::AutocastXPU);
435524
IPEX_OP_REGISTER_DISPATCH(
436525
"packed_add", at::AtenIpexTypeXPU::packed_add, c10::DispatchKey::XPU);
437526
IPEX_OP_REGISTER_DISPATCH(

tests/gpu/examples/test_fusion.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,15 @@ def model_check(model):
15241524
modelJit(m1_dpcpp, m2_dpcpp, add1_dpcpp)
15251525
print(modelJit.graph_for(m1_dpcpp, m2_dpcpp, add1_dpcpp))
15261526
real = modelJit(m1_dpcpp, m2_dpcpp, add2_dpcpp)
1527-
self.assertEqual(raw, real.to(cpu_device))
1527+
self.assertEqual(raw, real.to(cpu_device))
1528+
1529+
with torch.xpu.amp.autocast(enabled=True, dtype=torch.float16):
1530+
autocast_arg1 = modelJit(m1_dpcpp, m2_dpcpp, add1_dpcpp)
1531+
self.assertEqual(raw, autocast_arg1.to(device=cpu_device, dtype=torch.float), atol=1e-5, rtol=1e-5)
1532+
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
1533+
autocast_arg1 = modelJit(m1_dpcpp, m2_dpcpp, add1_dpcpp)
1534+
self.assertEqual(raw, autocast_arg1.to(device=cpu_device, dtype=torch.float), atol=1e-5, rtol=1e-5)
1535+
15281536
del modelJit
15291537
model_check(MulAdd())
15301538
model_check(MulAddScalar())

tests/gpu/examples/test_roi_align.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,30 @@ def roi_align_forward_(self, dtype_):
9797
tol = 1e-2 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
9898
torch.testing.assert_close(gt_y.cpu(), y.cpu(), rtol=tol, atol=tol)
9999

100+
def roi_align_autocast_forward_(self, dtype_):
101+
device = torch.device('xpu')
102+
pool_size = 5
103+
n_channels = 2 * (pool_size**2)
104+
x = torch.rand(2, n_channels, 10, 10, dtype=dtype_, device=device)
105+
rois = torch.tensor(
106+
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
107+
dtype=torch.float,
108+
device=device,
109+
)
110+
pool_h, pool_w = pool_size, pool_size
111+
112+
with torch.xpu.amp.autocast(enabled=True, dtype=dtype_):
113+
y = torch.xpu.roi_align(x, rois, [pool_h, pool_w], spatial_scale=1, sampling_ratio=-1)
114+
gt_y = expected_fn(
115+
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=torch.float
116+
)
117+
tol = 1e-2 if dtype_ is torch.float16 else 1e-1
118+
torch.testing.assert_close(gt_y.cpu(), y.to(torch.float).cpu(), rtol=tol, atol=tol)
119+
100120
def test_roi_align_forward(self):
101121
for dtype in [torch.float, torch.half]:
102122
print('testing dtype:', dtype)
103123
self.roi_align_forward_(dtype)
124+
for dtype in [torch.float16, torch.bfloat16]:
125+
print('testing dtype in autocast: ', dtype)
126+
self.roi_align_autocast_forward_(dtype)

0 commit comments

Comments
 (0)