|
1 | 1 | import os |
2 | 2 | import time |
3 | 3 | from time import perf_counter |
4 | | -from typing import Any, List, Tuple, Union |
| 4 | +from typing import Any, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | import cv2 |
7 | 7 | import numpy as np |
8 | 8 | import onnxruntime |
9 | 9 | from PIL import Image |
10 | 10 |
|
11 | 11 | from inference.core.entities.requests.inference import InferenceRequestImage |
| 12 | +from inference.core.entities.responses.inference import InferenceResponseImage |
12 | 13 | from inference.core.env import ( |
13 | 14 | DISABLE_PREPROC_AUTO_ORIENT, |
14 | 15 | FIX_BATCH_SIZE, |
|
26 | 27 | ) |
27 | 28 | from inference.core.logger import logger |
28 | 29 | from inference.core.models.defaults import DEFAULT_CONFIDENCE, DEFAUlT_MAX_DETECTIONS |
| 30 | +from inference.core.models.instance_segmentation_base import ( |
| 31 | + InstanceSegmentationBaseOnnxRoboflowInferenceModel, |
| 32 | + InstanceSegmentationInferenceResponse, |
| 33 | + InstanceSegmentationPrediction, |
| 34 | + Point, |
| 35 | +) |
29 | 36 | from inference.core.models.object_detection_base import ( |
30 | 37 | ObjectDetectionBaseOnnxRoboflowInferenceModel, |
31 | 38 | ObjectDetectionInferenceResponse, |
|
38 | 45 | get_onnxruntime_execution_providers, |
39 | 46 | run_session_via_iobinding, |
40 | 47 | ) |
| 48 | +from inference.core.utils.postprocess import mask2poly |
41 | 49 | from inference.core.utils.preprocess import letterbox_image |
42 | 50 |
|
43 | 51 | if USE_PYTORCH_FOR_PREPROCESSING: |
@@ -536,3 +544,222 @@ def initialize_model(self, **kwargs) -> None: |
536 | 544 |
|
537 | 545 | def validate_model_classes(self) -> None: |
538 | 546 | pass |
| 547 | + |
| 548 | + |
| 549 | +class RFDETRInstanceSegmentation( |
| 550 | + RFDETRObjectDetection, InstanceSegmentationBaseOnnxRoboflowInferenceModel |
| 551 | +): |
| 552 | + def initialize_model(self, **kwargs) -> None: |
| 553 | + super().initialize_model(**kwargs) |
| 554 | + mask_shape = self.onnx_session.get_outputs()[2].shape |
| 555 | + self.mask_shape = mask_shape[2:] |
| 556 | + |
| 557 | + def predict(self, img_in: ImageMetaType, **kwargs) -> Tuple[np.ndarray]: |
| 558 | + """Performs object detection on the given image using the ONNX session with the RFDETR model. |
| 559 | +
|
| 560 | + Args: |
| 561 | + img_in (np.ndarray): Input image as a NumPy array. |
| 562 | +
|
| 563 | + Returns: |
| 564 | + Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class IDs. |
| 565 | + """ |
| 566 | + with self._session_lock: |
| 567 | + predictions = run_session_via_iobinding( |
| 568 | + self.onnx_session, self.input_name, img_in |
| 569 | + ) |
| 570 | + bboxes = predictions[0] |
| 571 | + logits = predictions[1] |
| 572 | + masks = predictions[2] |
| 573 | + |
| 574 | + return (bboxes, logits, masks) |
| 575 | + |
| 576 | + def postprocess( |
| 577 | + self, |
| 578 | + predictions: Tuple[np.ndarray, ...], |
| 579 | + preproc_return_metadata: PreprocessReturnMetadata, |
| 580 | + confidence: float = DEFAULT_CONFIDENCE, |
| 581 | + max_detections: int = DEFAUlT_MAX_DETECTIONS, |
| 582 | + **kwargs, |
| 583 | + ) -> List[InstanceSegmentationInferenceResponse]: |
| 584 | + bboxes, logits, masks = predictions |
| 585 | + bboxes = bboxes.astype(np.float32) |
| 586 | + logits = logits.astype(np.float32) |
| 587 | + |
| 588 | + batch_size, num_queries, num_classes = logits.shape |
| 589 | + logits_sigmoid = self.sigmoid_stable(logits) |
| 590 | + |
| 591 | + img_dims = preproc_return_metadata["img_dims"] |
| 592 | + |
| 593 | + processed_predictions = [] |
| 594 | + |
| 595 | + for batch_idx in range(batch_size): |
| 596 | + orig_h, orig_w = img_dims[batch_idx] |
| 597 | + |
| 598 | + logits_flat = logits_sigmoid[batch_idx].reshape(-1) |
| 599 | + |
| 600 | + # Use argpartition for better performance when max_detections is smaller than logits_flat |
| 601 | + partition_indices = np.argpartition(-logits_flat, max_detections)[ |
| 602 | + :max_detections |
| 603 | + ] |
| 604 | + sorted_indices = partition_indices[ |
| 605 | + np.argsort(-logits_flat[partition_indices]) |
| 606 | + ] |
| 607 | + topk_scores = logits_flat[sorted_indices] |
| 608 | + |
| 609 | + conf_mask = topk_scores > confidence |
| 610 | + sorted_indices = sorted_indices[conf_mask] |
| 611 | + topk_scores = topk_scores[conf_mask] |
| 612 | + |
| 613 | + topk_boxes = sorted_indices // num_classes |
| 614 | + topk_labels = sorted_indices % num_classes |
| 615 | + |
| 616 | + if self.is_one_indexed: |
| 617 | + class_filter_mask = topk_labels != self.background_class_index |
| 618 | + |
| 619 | + topk_labels[topk_labels > self.background_class_index] -= 1 |
| 620 | + topk_scores = topk_scores[class_filter_mask] |
| 621 | + topk_labels = topk_labels[class_filter_mask] |
| 622 | + topk_boxes = topk_boxes[class_filter_mask] |
| 623 | + |
| 624 | + selected_boxes = bboxes[batch_idx, topk_boxes] |
| 625 | + selected_masks = masks[batch_idx, topk_boxes] |
| 626 | + selected_masks = selected_masks > 0 |
| 627 | + |
| 628 | + cxcy = selected_boxes[:, :2] |
| 629 | + wh = selected_boxes[:, 2:] |
| 630 | + xy_min = cxcy - 0.5 * wh |
| 631 | + xy_max = cxcy + 0.5 * wh |
| 632 | + boxes_xyxy = np.concatenate([xy_min, xy_max], axis=1) |
| 633 | + |
| 634 | + if self.resize_method == "Stretch to": |
| 635 | + scale_fct = np.array([orig_w, orig_h, orig_w, orig_h], dtype=np.float32) |
| 636 | + boxes_xyxy *= scale_fct |
| 637 | + else: |
| 638 | + input_h, input_w = self.img_size_h, self.img_size_w |
| 639 | + |
| 640 | + scale = min(input_w / orig_w, input_h / orig_h) |
| 641 | + scaled_w = int(orig_w * scale) |
| 642 | + scaled_h = int(orig_h * scale) |
| 643 | + |
| 644 | + pad_x = (input_w - scaled_w) / 2 |
| 645 | + pad_y = (input_h - scaled_h) / 2 |
| 646 | + |
| 647 | + boxes_input = boxes_xyxy * np.array( |
| 648 | + [input_w, input_h, input_w, input_h], dtype=np.float32 |
| 649 | + ) |
| 650 | + |
| 651 | + boxes_input[:, 0] -= pad_x |
| 652 | + boxes_input[:, 1] -= pad_y |
| 653 | + boxes_input[:, 2] -= pad_x |
| 654 | + boxes_input[:, 3] -= pad_y |
| 655 | + |
| 656 | + boxes_xyxy = boxes_input / scale |
| 657 | + |
| 658 | + np.clip( |
| 659 | + boxes_xyxy, |
| 660 | + [0, 0, 0, 0], |
| 661 | + [orig_w, orig_h, orig_w, orig_h], |
| 662 | + out=boxes_xyxy, |
| 663 | + ) |
| 664 | + |
| 665 | + batch_predictions = np.column_stack( |
| 666 | + ( |
| 667 | + boxes_xyxy, |
| 668 | + topk_scores, |
| 669 | + np.zeros((len(topk_scores), 1), dtype=np.float32), |
| 670 | + topk_labels, |
| 671 | + ) |
| 672 | + ) |
| 673 | + batch_predictions = batch_predictions[ |
| 674 | + batch_predictions[:, 6] < len(self.class_names) |
| 675 | + ] |
| 676 | + selected_masks = selected_masks[ |
| 677 | + batch_predictions[:, 6] < len(self.class_names) |
| 678 | + ] |
| 679 | + |
| 680 | + outputs = [] |
| 681 | + for pred, mask in zip(batch_predictions, selected_masks): |
| 682 | + outputs.append(list(pred) + [mask]) |
| 683 | + |
| 684 | + processed_predictions.append(outputs) |
| 685 | + |
| 686 | + res = self.make_response(processed_predictions, img_dims, **kwargs) |
| 687 | + return res |
| 688 | + |
| 689 | + def make_response( |
| 690 | + self, |
| 691 | + predictions: List[List[float]], |
| 692 | + img_dims: List[Tuple[int, int]], |
| 693 | + class_filter: Optional[List[str]] = None, |
| 694 | + *args, |
| 695 | + **kwargs, |
| 696 | + ) -> List[ObjectDetectionInferenceResponse]: |
| 697 | + """Constructs object detection response objects based on predictions. |
| 698 | +
|
| 699 | + Args: |
| 700 | + predictions (List[List[float]]): The list of predictions. |
| 701 | + img_dims (List[Tuple[int, int]]): Dimensions of the images. |
| 702 | + class_filter (Optional[List[str]]): A list of class names to filter, if provided. |
| 703 | +
|
| 704 | + Returns: |
| 705 | + List[ObjectDetectionInferenceResponse]: A list of response objects containing object detection predictions. |
| 706 | + """ |
| 707 | + |
| 708 | + if isinstance(img_dims, dict) and "img_dims" in img_dims: |
| 709 | + img_dims = img_dims["img_dims"] |
| 710 | + |
| 711 | + predictions = predictions[ |
| 712 | + : len(img_dims) |
| 713 | + ] # If the batch size was fixed we have empty preds at the end |
| 714 | + |
| 715 | + batch_mask_preds = [] |
| 716 | + for image_ind in range(len(img_dims)): |
| 717 | + masks = [pred[7] for pred in predictions[image_ind]] |
| 718 | + orig_h, orig_w = img_dims[image_ind] |
| 719 | + prediction_h, prediction_w = self.mask_shape[0], self.mask_shape[1] |
| 720 | + |
| 721 | + mask_preds = [] |
| 722 | + for mask in masks: |
| 723 | + points = mask2poly(mask.astype(np.uint8)) |
| 724 | + new_points = [] |
| 725 | + for point in points: |
| 726 | + if self.resize_method == "Stretch to": |
| 727 | + new_x = point[0] * (orig_w / prediction_w) |
| 728 | + new_y = point[1] * (orig_h / prediction_h) |
| 729 | + else: |
| 730 | + scale = max(orig_w / prediction_w, orig_h / prediction_h) |
| 731 | + pad_x = (orig_w - prediction_w * scale) / 2 |
| 732 | + pad_y = (orig_h - prediction_h * scale) / 2 |
| 733 | + new_x = point[0] * scale + pad_x |
| 734 | + new_y = point[1] * scale + pad_y |
| 735 | + new_points.append(np.array([new_x, new_y])) |
| 736 | + mask_preds.append(new_points) |
| 737 | + batch_mask_preds.append(mask_preds) |
| 738 | + |
| 739 | + responses = [ |
| 740 | + InstanceSegmentationInferenceResponse( |
| 741 | + predictions=[ |
| 742 | + InstanceSegmentationPrediction( |
| 743 | + # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python) |
| 744 | + **{ |
| 745 | + "x": (pred[0] + pred[2]) / 2, |
| 746 | + "y": (pred[1] + pred[3]) / 2, |
| 747 | + "width": pred[2] - pred[0], |
| 748 | + "height": pred[3] - pred[1], |
| 749 | + "confidence": pred[4], |
| 750 | + "class": self.class_names[int(pred[6])], |
| 751 | + "class_id": int(pred[6]), |
| 752 | + "points": [Point(x=point[0], y=point[1]) for point in mask], |
| 753 | + } |
| 754 | + ) |
| 755 | + for pred, mask in zip(batch_predictions, batch_mask_preds[ind]) |
| 756 | + if not class_filter |
| 757 | + or self.class_names[int(pred[6])] in class_filter |
| 758 | + ], |
| 759 | + image=InferenceResponseImage( |
| 760 | + width=img_dims[ind][1], height=img_dims[ind][0] |
| 761 | + ), |
| 762 | + ) |
| 763 | + for ind, batch_predictions in enumerate(predictions) |
| 764 | + ] |
| 765 | + return responses |
0 commit comments