From 2bdcbc2a7ef77634c123d8935f41391f5f6a559e Mon Sep 17 00:00:00 2001 From: Andreas Schallwig <8491849+andypotato@users.noreply.github.com> Date: Mon, 4 Aug 2025 12:11:29 +0800 Subject: [PATCH] Added support for category_allowlist and category_denylist --- .../object_detection/object_detection.py | 30 ++++++++++++++++++- .../object_detection/detector.py | 8 ++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/node_wrappers/mediapipe_vision/object_detection/object_detection.py b/node_wrappers/mediapipe_vision/object_detection/object_detection.py index c88a76c..5f905ae 100644 --- a/node_wrappers/mediapipe_vision/object_detection/object_detection.py +++ b/node_wrappers/mediapipe_vision/object_detection/object_detection.py @@ -53,6 +53,22 @@ def INPUT_TYPES(cls): "INT", {"default": 5, "min": 1, "max": 50, "step": 1, "tooltip": "Maximum number of objects to detect"}, ), + "category_allowlist": ( + "STRING", + { + "default": "", + "multiline": False, + "tooltip": "Comma-separated list of categories to detect (e.g., 'person,cat'). Empty means all categories.", + }, + ), + "category_denylist": ( + "STRING", + { + "default": "", + "multiline": False, + "tooltip": "Comma-separated list of categories to exclude. Empty means not to exclude any categories.", + }, + ), } ) @@ -64,6 +80,8 @@ def detect( model_info: dict, min_confidence: float, max_results: int, + category_allowlist: str, + category_denylist: str, running_mode: str, delegate: str, ): @@ -75,9 +93,19 @@ def detect( # Initialize or update detector detector = self.initialize_or_update_detector(model_path) + # Parse the allow / deny list string into a list, or None if empty + allowed_categories = [cat.strip() for cat in category_allowlist.split(',') if cat.strip()] or None + denied_categories = [cat.strip() for cat in category_denylist.split(',') if cat.strip()] or None + # Perform detection with all parameters batch_results = detector.detect( - image, score_threshold=min_confidence, max_results=max_results, running_mode=running_mode, delegate=delegate + image, + score_threshold=min_confidence, + max_results=max_results, + category_allowlist=allowed_categories, + category_denylist=denied_categories, + running_mode=running_mode, + delegate=delegate ) return (batch_results,) diff --git a/src/mediapipe_vision/object_detection/detector.py b/src/mediapipe_vision/object_detection/detector.py index d5f9cfc..7ee58ce 100644 --- a/src/mediapipe_vision/object_detection/detector.py +++ b/src/mediapipe_vision/object_detection/detector.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Optional import mediapipe as mp import numpy as np @@ -40,6 +40,8 @@ def _create_detector_options(self, base_options: python.BaseOptions, running_mode=mode_enum, score_threshold=kwargs.get('score_threshold', 0.5), max_results=kwargs.get('max_results', 5), + category_allowlist=kwargs.get('category_allowlist', None), + category_denylist=kwargs.get('category_denylist', None), ) def _create_detector_instance(self, options: vision.ObjectDetectorOptions) -> vision.ObjectDetector: @@ -89,6 +91,8 @@ def detect( image: torch.Tensor, score_threshold: float = 0.5, max_results: int = 5, + category_allowlist: Optional[List[str]] = None, + category_denylist: Optional[List[str]] = None, running_mode: str = "video", delegate: str = "cpu", ) -> List[List[ObjectDetectionResult]]: @@ -99,4 +103,6 @@ def detect( delegate=delegate, score_threshold=score_threshold, max_results=max_results, + category_allowlist=category_allowlist, + category_denylist=category_denylist, )