diff --git a/main.py b/main.py index 1dee322..3bed4a8 100644 --- a/main.py +++ b/main.py @@ -50,6 +50,9 @@ def post_processing(self, pred_boxes, scores, pred_classes, pred_masks, im_hw, p heights = pred_boxes[:, 3] - pred_boxes[:, 1] keep = (widths > threshold) & (heights > threshold) + condation = np.where(scores > self.confThreshold, True, False) + keep &= condation + pred_boxes = pred_boxes[keep] scores = scores[keep] pred_classes = pred_classes[keep]