@@ -80,10 +80,10 @@ namespace {
8080
8181template <typename T>
8282struct PreCalc {
83- int pos1;
84- int pos2;
85- int pos3;
86- int pos4;
83+ int64_t pos1;
84+ int64_t pos2;
85+ int64_t pos3;
86+ int64_t pos4;
8787 T w1;
8888 T w2;
8989 T w3;
@@ -94,27 +94,27 @@ template <typename T, typename ACC_T>
9494inline void roi_align_single_framework_forward (
9595 const T* input,
9696 const ACC_T count,
97- int channels,
98- int height,
99- int width,
100- int pooled_height,
101- int pooled_width,
102- int roi_bin_grid_h,
103- int roi_bin_grid_w,
97+ int64_t channels,
98+ int64_t height,
99+ int64_t width,
100+ int64_t pooled_height,
101+ int64_t pooled_width,
102+ int64_t roi_bin_grid_h,
103+ int64_t roi_bin_grid_w,
104104 const std::vector<PreCalc<ACC_T>>& pre_calc,
105105 T* output);
106106
107107template <typename T, typename ACC_T>
108108inline void roi_align_single_framework_channels_last_forward (
109109 const T* input,
110110 const ACC_T count,
111- int channels,
112- int height,
113- int width,
114- int pooled_height,
115- int pooled_width,
116- int roi_bin_grid_h,
117- int roi_bin_grid_w,
111+ int64_t channels,
112+ int64_t height,
113+ int64_t width,
114+ int64_t pooled_height,
115+ int64_t pooled_width,
116+ int64_t roi_bin_grid_h,
117+ int64_t roi_bin_grid_w,
118118 const std::vector<PreCalc<ACC_T>>& pre_calc,
119119 T* output);
120120
@@ -124,13 +124,13 @@ inline void roi_align_single_framework_channels_last_forward<
124124 float >(
125125 const at::BFloat16* input,
126126 const float count,
127- int channels,
128- int height,
129- int width,
130- int pooled_height,
131- int pooled_width,
132- int roi_bin_grid_h,
133- int roi_bin_grid_w,
127+ int64_t channels,
128+ int64_t height,
129+ int64_t width,
130+ int64_t pooled_height,
131+ int64_t pooled_width,
132+ int64_t roi_bin_grid_h,
133+ int64_t roi_bin_grid_w,
134134 const std::vector<PreCalc<float >>& pre_calc,
135135 at::BFloat16* output);
136136
@@ -141,57 +141,57 @@ template <typename T, typename ACC_T>
141141inline void roi_align_single_framework_backward (
142142 const T* grad_output,
143143 const ACC_T count,
144- int channels,
145- int height,
146- int width,
147- int pooled_height,
148- int pooled_width,
149- int roi_bin_grid_h,
150- int roi_bin_grid_w,
144+ int64_t channels,
145+ int64_t height,
146+ int64_t width,
147+ int64_t pooled_height,
148+ int64_t pooled_width,
149+ int64_t roi_bin_grid_h,
150+ int64_t roi_bin_grid_w,
151151 const std::vector<PreCalc<ACC_T>>& pre_calc,
152152 T* grad_input);
153153
154154template <typename T, typename ACC_T>
155155inline void roi_align_single_framework_channels_last_backward (
156156 const T* grad_output,
157157 const ACC_T count,
158- int channels,
159- int height,
160- int width,
161- int pooled_height,
162- int pooled_width,
163- int roi_bin_grid_h,
164- int roi_bin_grid_w,
158+ int64_t channels,
159+ int64_t height,
160+ int64_t width,
161+ int64_t pooled_height,
162+ int64_t pooled_width,
163+ int64_t roi_bin_grid_h,
164+ int64_t roi_bin_grid_w,
165165 const std::vector<PreCalc<ACC_T>>& pre_calc,
166166 T* grad_input);
167167
168168template <typename T, typename ACC_T>
169169void roi_align_forward_kernel_body (
170- int n_rois,
170+ int64_t n_rois,
171171 const T* input,
172172 const ACC_T& spatial_scale,
173- int channels,
174- int height,
175- int width,
176- int pooled_height,
177- int pooled_width,
178- int sampling_ratio,
173+ int64_t channels,
174+ int64_t height,
175+ int64_t width,
176+ int64_t pooled_height,
177+ int64_t pooled_width,
178+ int64_t sampling_ratio,
179179 bool aligned,
180180 const ACC_T* rois,
181181 T* output,
182182 bool is_channels_last);
183183
184184template <typename T, typename ACC_T>
185185void roi_align_backward_kernel_body (
186- int n_rois,
186+ int64_t n_rois,
187187 const T* grad_output,
188188 const ACC_T& spatial_scale,
189- int channels,
190- int height,
191- int width,
192- int pooled_height,
193- int pooled_width,
194- int sampling_ratio,
189+ int64_t channels,
190+ int64_t height,
191+ int64_t width,
192+ int64_t pooled_height,
193+ int64_t pooled_width,
194+ int64_t sampling_ratio,
195195 bool aligned,
196196 T* grad_input,
197197 const ACC_T* rois,
0 commit comments