Skip to content

Commit c2794b1

Browse files
modify Batch NMS's output to make it jit traceable (#123)
* enable jit trace of NMS for different BS * change the UT: test_batch_nms_result's output to be compatiable with new NMS output * add NMS jit trace test case * clean format
1 parent 0ba122e commit c2794b1

File tree

3 files changed

+98
-32
lines changed

3 files changed

+98
-32
lines changed

tests/cpu/test_nms.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,63 @@ def test_batch_nms_result(self):
137137
bbox = bbox.squeeze(0)
138138
prob = prob.squeeze(0)
139139
output.append(self.decode_single(bbox, prob, criteria, max_output))
140-
output2 = batch_score_nms(bboxes_clone, probs_clone, criteria, max_output)
140+
output2_raw = batch_score_nms(bboxes_clone, probs_clone, criteria, max_output)
141+
142+
# Re-assembly the result
143+
output2 = []
144+
idx = 0
145+
for i in range(output2_raw[3].size(0)):
146+
output2.append((output2_raw[0][idx:idx+output2_raw[3][i]],
147+
output2_raw[1][idx:idx+output2_raw[3][i]],
148+
output2_raw[2][idx:idx+output2_raw[3][i]]))
149+
idx += output2_raw[3][i]
150+
151+
for i in range(batch_size):
152+
loc, label, prob = [r for r in output[i]]
153+
loc2, label2, prob2 = [r for r in output2[i]]
154+
self.assertTrue(torch.allclose(loc, loc2, rtol=1e-4, atol=1e-4))
155+
self.assertEqual(label, label2)
156+
self.assertTrue(torch.allclose(prob, prob2, rtol=1e-4, atol=1e-4))
157+
158+
def test_jit_trace_batch_nms(self):
159+
class Batch_NMS(nn.Module):
160+
def __init__(self, criteria, max_output):
161+
super(Batch_NMS, self).__init__()
162+
self.criteria = criteria
163+
self.max_output = max_output
164+
def forward(self, bboxes_clone, probs_clone):
165+
return batch_score_nms(bboxes_clone, probs_clone, self.criteria, self.max_output)
166+
batch_size = 1
167+
number_boxes = 15130
168+
scale_xy = 0.1
169+
scale_wh = 0.2
170+
criteria = 0.50
171+
max_output = 200
172+
predicted_loc = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")) # sizes: [1, 15130, 4]
173+
predicted_score = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")) # sizes: [1, 15130, 81]
174+
dboxes_xywh = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt"))
175+
bboxes, probs = parallel_scale_back_batch(predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh)
176+
bboxes_clone = bboxes.clone()
177+
probs_clone = probs.clone()
178+
179+
output = []
180+
for bbox, prob in zip(bboxes.split(1, 0), probs.split(1, 0)):
181+
bbox = bbox.squeeze(0)
182+
prob = prob.squeeze(0)
183+
output.append(self.decode_single(bbox, prob, criteria, max_output))
184+
185+
batch_score_nms_module = Batch_NMS(criteria, max_output)
186+
model_decode = torch.jit.trace(batch_score_nms_module, (bboxes_clone, probs_clone))
187+
output2_raw = model_decode(bboxes_clone, probs_clone)
188+
189+
# Re-assembly the result
190+
output2 = []
191+
idx = 0
192+
for i in range(output2_raw[3].size(0)):
193+
output2.append((output2_raw[0][idx:idx+output2_raw[3][i]],
194+
output2_raw[1][idx:idx+output2_raw[3][i]],
195+
output2_raw[2][idx:idx+output2_raw[3][i]]))
196+
idx += output2_raw[3][i]
141197

142198
for i in range(batch_size):
143199
loc, label, prob = [r for r in output[i]]

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,25 @@ class AtenIpexTypeExt {
6262
/// \brief Perform batch non-maximum suppression.
6363
///
6464
/// C++ version of Encoder::decode_single.
65-
/// Refer to https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py.
65+
/// Refer to
66+
/// https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py.
6667
///
67-
/// \param dets: predicted loc in ltrb format, size [BS, number_boxes, 4], for example: [1, 15130, 4].
68-
/// \param scores: predicted score, size [BS, number_boxes, class_number], for example: [1, 15130, 81].
69-
/// \param threshold: IOU threshold(scalar) to suppress bboxs which has the IOU val larger than the threshold.
70-
/// \param max_output: the max number of output bbox.
68+
/// \param dets: predicted loc in ltrb format, size [BS, number_boxes, 4], for
69+
/// example: [1, 15130, 4]. \param scores: predicted score, size [BS,
70+
/// number_boxes, class_number], for example: [1, 15130, 81]. \param
71+
/// threshold: IOU threshold(scalar) to suppress bboxs which has the IOU val
72+
/// larger than the threshold. \param max_output: the max number of output
73+
/// bbox.
7174
///
72-
/// \return result is a list of tuple. In each tuple, there are 3 tensors:
75+
/// \return result is a list of tensors, each 4 continuous tensors
76+
/// corresponding the decode results of one image
7377
/// bboxes_out_: the selected out bboxes coordinate, size [max_output, 4].
7478
/// labels_out_: the label of each selected out bboxes, size [max_output].
7579
/// scores_out_: the score of each selected out bboxes, size [max_output].
76-
static std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms(const at::Tensor& dets,
77-
const at::Tensor& scores,
78-
const double threshold,
79-
const int64_t max_output);
80+
/// length_out_: the number of detection bboxs [1].
81+
static std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
82+
batch_score_nms(const at::Tensor &dets, const at::Tensor &scores,
83+
const double threshold, const int64_t max_output);
8084

8185
/// \brief Perform batch non-maximum suppression (NMS) for MaskRCNN RPN part.
8286
///

torch_ipex/csrc/cpu/nms.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,10 @@ std::vector<at::Tensor> remove_empty(std::vector<at::Tensor>& candidate, int64_t
292292
}
293293

294294
template <typename scalar_t>
295-
std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_kernel(const at::Tensor& batch_dets,
296-
const at::Tensor& batch_scores,
297-
const float threshold, const int max_output=200) {
295+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
296+
batch_score_nms_kernel(const at::Tensor &batch_dets,
297+
const at::Tensor &batch_scores, const float threshold,
298+
const int max_output = 200) {
298299
// Reference to: https://github.com/mlcommons/inference/blob/0f096a18083c3fd529c1fbf97ebda7bc3f1fda70/others/cloud/single_stage_detector/pytorch/utils.py#L163
299300
// batch_dets: (batchsize, num_bbox, 4) For example: batch_dets: (1, 15130, 4)
300301
// batch_scores: (batchsize, num_bbox, label_num) For example: batch_scores: (1, 15130, 81)
@@ -351,7 +352,10 @@ std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_kern
351352
labels_out[index] = at::empty({keep.sizes()}).fill_(i);
352353
}
353354

354-
std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> output(nbatch);
355+
std::vector<at::Tensor> output_bboxes_(nbatch);
356+
std::vector<at::Tensor> output_labels_(nbatch);
357+
std::vector<at::Tensor> output_scores_(nbatch);
358+
std::vector<at::Tensor> output_length_(nbatch);
355359
#ifdef _OPENMP
356360
#if (_OPENMP >= 201307)
357361
# pragma omp parallel for simd schedule(static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
@@ -372,11 +376,14 @@ std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_kern
372376
std::tuple<at::Tensor, at::Tensor> sort_result = scores_out_.sort(0);
373377
at::Tensor max_ids = std::get<1>(sort_result);
374378
max_ids = max_ids.slice(/*dim*/0, /*start*/std::max(max_ids.size(0) - max_output, static_cast<int64_t>(0)), /*end*/max_ids.size(0));
375-
output[bs] = std::tuple<at::Tensor, at::Tensor, at::Tensor>(bboxes_out_.index_select(/*dim*/0, /*index*/max_ids),
376-
labels_out_.index_select(/*dim*/0, /*index*/max_ids),
377-
scores_out_.index_select(/*dim*/0, /*index*/max_ids));
379+
output_bboxes_[bs] = bboxes_out_.index_select(/*dim*/ 0, /*index*/ max_ids);
380+
output_labels_[bs] = labels_out_.index_select(/*dim*/ 0, /*index*/ max_ids);
381+
output_scores_[bs] = scores_out_.index_select(/*dim*/ 0, /*index*/ max_ids);
382+
output_length_[bs] = torch::tensor(max_ids.size(0), {torch::kInt32});
378383
}
379-
return output;
384+
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
385+
at::cat(output_bboxes_), at::cat(output_labels_), at::cat(output_scores_),
386+
at::stack(output_length_));
380387
}
381388

382389
template <typename scalar_t>
@@ -526,11 +533,10 @@ at::Tensor nms_cpu(const at::Tensor& dets,
526533
return result;
527534
}
528535

529-
std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms_cpu(const at::Tensor& dets,
530-
const at::Tensor& scores,
531-
const float threshold,
532-
const int max_output) {
533-
std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> result;
536+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
537+
batch_score_nms_cpu(const at::Tensor &dets, const at::Tensor &scores,
538+
const float threshold, const int max_output) {
539+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> result;
534540
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "batch_score_nms", [&] {
535541
result = batch_score_nms_kernel<scalar_t>(dets, scores, threshold, max_output);
536542
});
@@ -581,10 +587,11 @@ at::Tensor AtenIpexTypeExt::nms(const at::Tensor& dets,
581587
return result;
582588
}
583589

584-
std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> AtenIpexTypeExt::batch_score_nms(const at::Tensor& dets,
585-
const at::Tensor& scores,
586-
const double threshold,
587-
const int64_t max_output) {
590+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
591+
AtenIpexTypeExt::batch_score_nms(const at::Tensor &dets,
592+
const at::Tensor &scores,
593+
const double threshold,
594+
const int64_t max_output) {
588595
#if defined(IPEX_DISP_OP)
589596
printf("IpexExternal::batch_score_nms\n");
590597
#endif
@@ -758,10 +765,9 @@ at::Tensor nms(const at::Tensor& dets,
758765
return op.call(cpu_cached_cast(at::kFloat, dets), cpu_cached_cast(at::kFloat, scores), threshold, sorted);
759766
}
760767

761-
std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms(const at::Tensor& dets,
762-
const at::Tensor& scores,
763-
const double threshold,
764-
const int64_t max_output) {
768+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
769+
batch_score_nms(const at::Tensor &dets, const at::Tensor &scores,
770+
const double threshold, const int64_t max_output) {
765771
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
766772
static auto op = torch::Dispatcher::singleton()
767773
.findSchemaOrThrow("torch_ipex::batch_score_nms", "")

0 commit comments

Comments
 (0)