@@ -292,9 +292,10 @@ std::vector<at::Tensor> remove_empty(std::vector<at::Tensor>& candidate, int64_t
292292}
293293
294294template <typename scalar_t >
295- std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_kernel (const at::Tensor& batch_dets,
296- const at::Tensor& batch_scores,
297- const float threshold, const int max_output=200 ) {
295+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
296+ batch_score_nms_kernel (const at::Tensor &batch_dets,
297+ const at::Tensor &batch_scores, const float threshold,
298+ const int max_output = 200 ) {
298299 // Reference to: https://github.com/mlcommons/inference/blob/0f096a18083c3fd529c1fbf97ebda7bc3f1fda70/others/cloud/single_stage_detector/pytorch/utils.py#L163
299300 // batch_dets: (batchsize, num_bbox, 4) For example: batch_dets: (1, 15130, 4)
300301 // batch_scores: (batchsize, num_bbox, label_num) For example: batch_scores: (1, 15130, 81)
@@ -351,7 +352,10 @@ std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_kern
351352 labels_out[index] = at::empty ({keep.sizes ()}).fill_ (i);
352353 }
353354
354- std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> output (nbatch);
355+ std::vector<at::Tensor> output_bboxes_ (nbatch);
356+ std::vector<at::Tensor> output_labels_ (nbatch);
357+ std::vector<at::Tensor> output_scores_ (nbatch);
358+ std::vector<at::Tensor> output_length_ (nbatch);
355359#ifdef _OPENMP
356360#if (_OPENMP >= 201307)
357361# pragma omp parallel for simd schedule(static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
@@ -372,11 +376,14 @@ std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_kern
372376 std::tuple<at::Tensor, at::Tensor> sort_result = scores_out_.sort (0 );
373377 at::Tensor max_ids = std::get<1 >(sort_result);
374378 max_ids = max_ids.slice (/* dim*/ 0 , /* start*/ std::max (max_ids.size (0 ) - max_output, static_cast <int64_t >(0 )), /* end*/ max_ids.size (0 ));
375- output[bs] = std::tuple<at::Tensor, at::Tensor, at::Tensor>(bboxes_out_.index_select (/* dim*/ 0 , /* index*/ max_ids),
376- labels_out_.index_select (/* dim*/ 0 , /* index*/ max_ids),
377- scores_out_.index_select (/* dim*/ 0 , /* index*/ max_ids));
379+ output_bboxes_[bs] = bboxes_out_.index_select (/* dim*/ 0 , /* index*/ max_ids);
380+ output_labels_[bs] = labels_out_.index_select (/* dim*/ 0 , /* index*/ max_ids);
381+ output_scores_[bs] = scores_out_.index_select (/* dim*/ 0 , /* index*/ max_ids);
382+ output_length_[bs] = torch::tensor (max_ids.size (0 ), {torch::kInt32 });
378383 }
379- return output;
384+ return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
385+ at::cat (output_bboxes_), at::cat (output_labels_), at::cat (output_scores_),
386+ at::stack (output_length_));
380387}
381388
382389template <typename scalar_t >
@@ -526,11 +533,10 @@ at::Tensor nms_cpu(const at::Tensor& dets,
526533 return result;
527534}
528535
529- std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_cpu (const at::Tensor& dets,
530- const at::Tensor& scores,
531- const float threshold,
532- const int max_output) {
533- std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> result;
536+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
537+ batch_score_nms_cpu (const at::Tensor &dets, const at::Tensor &scores,
538+ const float threshold, const int max_output) {
539+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> result;
534540 AT_DISPATCH_FLOATING_TYPES (dets.scalar_type (), " batch_score_nms" , [&] {
535541 result = batch_score_nms_kernel<scalar_t >(dets, scores, threshold, max_output);
536542 });
@@ -581,10 +587,11 @@ at::Tensor AtenIpexTypeExt::nms(const at::Tensor& dets,
581587 return result;
582588}
583589
584- std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> AtenIpexTypeExt::batch_score_nms (const at::Tensor& dets,
585- const at::Tensor& scores,
586- const double threshold,
587- const int64_t max_output) {
590+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
591+ AtenIpexTypeExt::batch_score_nms (const at::Tensor &dets,
592+ const at::Tensor &scores,
593+ const double threshold,
594+ const int64_t max_output) {
588595#if defined(IPEX_DISP_OP)
589596 printf (" IpexExternal::batch_score_nms\n " );
590597#endif
@@ -758,10 +765,9 @@ at::Tensor nms(const at::Tensor& dets,
758765 return op.call (cpu_cached_cast (at::kFloat , dets), cpu_cached_cast (at::kFloat , scores), threshold, sorted);
759766}
760767
761- std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms (const at::Tensor& dets,
762- const at::Tensor& scores,
763- const double threshold,
764- const int64_t max_output) {
768+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
769+ batch_score_nms (const at::Tensor &dets, const at::Tensor &scores,
770+ const double threshold, const int64_t max_output) {
765771 c10::impl::ExcludeDispatchKeyGuard no_autocastCPU (DispatchKey::AutocastCPU);
766772 static auto op = torch::Dispatcher::singleton ()
767773 .findSchemaOrThrow (" torch_ipex::batch_score_nms" , " " )
0 commit comments