@@ -13,15 +13,19 @@ namespace cpu {
1313DEFINE_DISPATCH (roi_align_forward_kernel_stub);
1414DEFINE_DISPATCH (roi_align_backward_kernel_stub);
1515
16- at::Tensor IPEXROIAlignOp::_forward (
16+ at::Tensor ROIAlign_forward_impl (
1717 const at::Tensor& input,
1818 const at::Tensor& rois,
1919 double spatial_scale,
2020 int64_t pooled_height,
2121 int64_t pooled_width,
2222 int64_t sampling_ratio,
2323 bool aligned) {
24- RECORD_FUNCTION (" IPEXROIAlignOp::_forward" , c10::ArrayRef<c10::IValue>({}));
24+ #if defined(IPEX_DISP_OP)
25+ printf (" torch_ipex::ROIAlign_forward\n " );
26+ #endif
27+ RECORD_FUNCTION (
28+ " torch_ipex::ROIAlign_forward" , c10::ArrayRef<c10::IValue>({}));
2529
2630 return roi_align_forward_kernel_stub (
2731 kCPU ,
@@ -34,6 +38,66 @@ at::Tensor IPEXROIAlignOp::_forward(
3438 aligned);
3539}
3640
41+ at::Tensor ROIAlign_backward (
42+ const at::Tensor& grad,
43+ const at::Tensor& rois,
44+ double spatial_scale,
45+ int64_t pooled_height,
46+ int64_t pooled_width,
47+ int64_t batch_size,
48+ int64_t channels,
49+ int64_t height,
50+ int64_t width,
51+ int64_t sampling_ratio,
52+ bool aligned,
53+ bool is_channels_last) {
54+ #if defined(IPEX_DISP_OP)
55+ printf (" torch_ipex::ROIAlign_backward\n " );
56+ #endif
57+ RECORD_FUNCTION (
58+ " torch_ipex::ROIAlign_backward" , c10::ArrayRef<c10::IValue>({}));
59+
60+ return roi_align_backward_kernel_stub (
61+ kCPU ,
62+ grad,
63+ rois,
64+ spatial_scale,
65+ pooled_height,
66+ pooled_width,
67+ batch_size,
68+ channels,
69+ height,
70+ width,
71+ sampling_ratio,
72+ aligned,
73+ is_channels_last);
74+ }
75+
76+ at::Tensor IPEXROIAlignOp::_forward (
77+ const at::Tensor& input,
78+ const at::Tensor& rois,
79+ double spatial_scale,
80+ int64_t pooled_height,
81+ int64_t pooled_width,
82+ int64_t sampling_ratio,
83+ bool aligned) {
84+ at::AutoDispatchBelowADInplaceOrView g;
85+ RECORD_FUNCTION (" IPEXROIAlignOp::_forward" , c10::ArrayRef<c10::IValue>({}));
86+
87+ static auto op = torch::Dispatcher::singleton ()
88+ .findSchemaOrThrow (" torch_ipex::ROIAlign_forward" , " " )
89+ .typed <decltype (ROIAlign_forward)>();
90+
91+ return op.call (
92+ input,
93+ rois,
94+ spatial_scale,
95+ pooled_height,
96+ pooled_width,
97+ sampling_ratio,
98+ aligned);
99+ }
100+
37101at::Tensor IPEXROIAlignOp::forward (
38102 torch::autograd::AutogradContext* ctx,
39103 const at::Tensor& input,
@@ -45,7 +109,7 @@ at::Tensor IPEXROIAlignOp::forward(
45109 bool aligned) {
46110 RECORD_FUNCTION (" IPEXROIAlignOp::forward" , c10::ArrayRef<c10::IValue>({}));
47111
48- ctx->saved_data [" input_shape" ] = input.sizes ();
112+ ctx->saved_data [" input_shape" ] = input.sym_sizes ();
49113 ctx->saved_data [" spatial_scale" ] = spatial_scale;
50114 ctx->saved_data [" pooled_height" ] = pooled_height;
51115 ctx->saved_data [" pooled_width" ] = pooled_width;
@@ -55,8 +119,7 @@ at::Tensor IPEXROIAlignOp::forward(
55119 input.is_contiguous (at::MemoryFormat::ChannelsLast);
56120 ctx->save_for_backward ({rois});
57121
58- return roi_align_forward_kernel_stub (
59- kCPU ,
122+ return _forward (
60123 input,
61124 rois,
62125 spatial_scale,
@@ -81,8 +144,11 @@ torch::autograd::variable_list IPEXROIAlignOp::backward(
81144 auto saved = ctx->get_saved_variables ();
82145 at::Tensor rois = saved[0 ];
83146
84- at::Tensor grad_input = roi_align_backward_kernel_stub (
85- kCPU ,
147+ static auto op = torch::Dispatcher::singleton ()
148+ .findSchemaOrThrow (" torch_ipex::ROIAlign_backward" , " " )
149+ .typed <decltype (ROIAlign_backward)>();
150+
151+ auto grad_input = op.call (
86152 grad_outputs[0 ],
87153 rois,
88154 spatial_scale,
@@ -134,45 +200,26 @@ at::Tensor ROIAlign_forward(
134200 aligned);
135201}
136202
137- } // namespace cpu
138- } // namespace torch_ipex
139-
140- namespace torch_ipex {
141- namespace autocast {
142-
143- at::Tensor roi_align_autocast (
203+ at::Tensor ROIAlign_forward_meta (
144204 const at::Tensor& input,
145205 const at::Tensor& rois,
146206 double spatial_scale,
147207 int64_t pooled_height,
148208 int64_t pooled_width,
149209 int64_t sampling_ratio,
150210 bool aligned) {
151- c10::impl::ExcludeDispatchKeyGuard no_autocastCPU (DispatchKey::AutocastCPU);
152- static auto op = torch::Dispatcher::singleton ()
153- .findSchemaOrThrow (" torchvision::roi_align" , " " )
154- .typed <decltype (torch_ipex::cpu::ROIAlign_forward)>();
155- if (input.scalar_type () == at::ScalarType::BFloat16) {
156- return op.call (
157- input,
158- cpu_cached_cast (at::kFloat , rois),
159- spatial_scale,
160- pooled_height,
161- pooled_width,
162- sampling_ratio,
163- aligned);
164- } else {
165- return op.call (
166- input,
167- cpu_cached_cast (input.scalar_type (), rois),
168- spatial_scale,
169- pooled_height,
170- pooled_width,
171- sampling_ratio,
172- aligned);
173- }
211+ auto num_rois = rois.sym_size (0 );
212+ auto channels = input.sym_size (1 );
213+ return at::empty_symint (
214+ {num_rois, channels, pooled_height, pooled_width}, input.options ());
174215}
175216
217+ } // namespace cpu
218+ } // namespace torch_ipex
219+
220+ namespace torch_ipex {
221+ namespace autocast {
222+
176223at::Tensor ROIAlign_forward (
177224 const at::Tensor& input,
178225 const at::Tensor& rois,
@@ -222,6 +269,21 @@ IPEX_TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
222269 " ROIAlign_forward" ,
223270 c10::DispatchKey::AutocastCPU,
224271 torch_ipex::autocast::ROIAlign_forward);
272+ m.impl (
273+ " ROIAlign_forward" ,
274+ c10::DispatchKey::CPU,
275+ torch_ipex::cpu::ROIAlign_forward_impl);
276+ m.impl (
277+ " ROIAlign_forward" ,
278+ c10::DispatchKey::Meta,
279+ torch_ipex::cpu::ROIAlign_forward_meta);
280+ // bw
281+ m.def (
282+ " ROIAlign_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned, bool is_channels_last) -> Tensor" );
283+ m.impl (
284+ " ROIAlign_backward" ,
285+ c10::DispatchKey::CPU,
286+ torch_ipex::cpu::ROIAlign_backward);
225287}
226288
227289IPEX_TORCH_LIBRARY_FRAGMENT (torchvision, m) {
@@ -232,7 +294,7 @@ IPEX_TORCH_LIBRARY_FRAGMENT(torchvision, m) {
232294 m.impl (
233295 " roi_align" ,
234296 c10::DispatchKey::AutocastCPU,
235- torch_ipex::autocast::roi_align_autocast );
297+ torch_ipex::autocast::ROIAlign_forward );
236298}
237299
238300} // namespace
0 commit comments