From 899541ffa9b13c65f974d438ba3eb6ba7c8cbbc0 Mon Sep 17 00:00:00 2001 From: LTLA Date: Thu, 8 Jan 2026 17:55:52 +1100 Subject: [PATCH 1/6] Turn match() into a generic for specialization by other BiocPy classes. Introduce a create_match_index generic that builds the matching index from a given 'targets' so that it can be easily reused across different 'x'. An example of a compatible index is provided by the new MatchIndex class. Also added dedicated optimization for Factor matching, which exploits the fact that unique levels are already available. --- src/biocutils/__init__.py | 2 +- src/biocutils/match.py | 240 +++++++++++++++++++++++++++++++------- tests/test_match.py | 88 +++++++++++--- 3 files changed, 273 insertions(+), 57 deletions(-) diff --git a/src/biocutils/__init__.py b/src/biocutils/__init__.py index 81ce3ae..5b98ccb 100644 --- a/src/biocutils/__init__.py +++ b/src/biocutils/__init__.py @@ -28,7 +28,7 @@ from .is_list_of_type import is_list_of_type from .is_missing_scalar import is_missing_scalar from .map_to_index import map_to_index -from .match import match +from .match import match, create_match_index, MatchIndex from .normalize_subscript import normalize_subscript, SubscriptTypes from .print_truncated import print_truncated, print_truncated_dict, print_truncated_list from .print_wrapped_table import create_floating_names, print_type, print_wrapped_table, truncate_strings diff --git a/src/biocutils/match.py b/src/biocutils/match.py index 96e74de..3b7c955 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -1,68 +1,222 @@ -from typing import Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union, Literal +from functools import singledispatch import numpy -from .map_to_index import DUPLICATE_METHOD, map_to_index +from .is_missing_scalar import is_missing_scalar +class MatchIndex: + """ + An index for matching one or more ``x`` against different ``targets``. + This is typically constructed by :py:func:`~create_match_index`. + """ + + def __init__( + self, + targets: Any, + duplicate_method: Literal["first", "last", "any"] = "first", + dtype: Optional[numpy.dtype] = None, + fail_missing: Optional[bool] = None + ): + """ + Args: + targets: + Targets to be matched against, see :py:func:`~match` for details. + + duplicate_method: + How to handle duplicate entries in ``targets``, see :py:func:`~match` for details. + + dtype: + NumPy type of the output array, see :py:func:`~match` for details. + + fail_missing: + Whether to raise an error if a value cannot be found in ``targets``, see :py:func:`~match` for details. + """ + + from .Factor import Factor + + if isinstance(targets, dict): + # Back-compatible behavior. + self._map = targets + + elif isinstance(targets, Factor): + # Optimized method when both x and targets are factors. + target_index = [None] * len(targets.get_levels()) + first_tie = (duplicate_method == "first" or duplicate_method == "any") + for i, code in enumerate(targets.get_codes()): + if code < 0: + continue + if not first_tie or target_index[code] is None: + target_index[code] = i + + mapping = {} + for i, lev in enumerate(targets.get_levels()): + candidate = target_index[i] + if candidate is not None: + mapping[lev] = candidate + self._map = mapping + + else: + first_tie = duplicate_method == "first" or duplicate_method == "any" + mapping = {} + for i, val in enumerate(targets): + if not is_missing_scalar(val): + if not first_tie or val not in mapping: + mapping[val] = i + self._map = mapping + + if dtype is None: + dtype = numpy.min_scalar_type(-len(targets)) # get a signed type + self._dtype = dtype + + if fail_missing is None: + fail_missing = numpy.issubdtype(dtype, numpy.unsignedinteger) + self._fail_missing = fail_missing + + def match(self, x: Any) -> numpy.ndarray: + """ + Args: + x: + Values to match against ``targets``. + + Returns: + NumPy array of length equal to ``x``, containing the integer position of each entry of ``x`` inside ``targets``; + see :py:func:`~match` for more details. + """ + + from .Factor import Factor + indices = numpy.zeros(len(x), dtype=self._dtype) + + if not isinstance(x, Factor): + # Separate loops to reduce branching in the tight inner loop. + if not self._fail_missing: + for i, y in enumerate(x): + if y in self._map: + indices[i] = self._map[y] + else: + indices[i] = -1 + else: + for i, y in enumerate(x): + if y not in self._map: + raise ValueError("cannot find '" + str(y) + "' in 'targets'") + indices[i] = self._map[y] + + else: + x_index = [-1] * len(x.get_levels()) + for i, lev in enumerate(x.get_levels()): + if lev in self._map: + candidate = self._map[lev] + if candidate is not None: + x_index[i] = candidate + + # Separate loops to reduce branching in the tight inner loop. + if self._fail_missing: + for i, code in enumerate(x.get_codes()): + candidate = -1 + if code >= 0: + candidate = x_index[code] + if candidate < 0: + raise ValueError("cannot find '" + x[i] + "' in 'targets'") + indices[i] = candidate + else: + for i, code in enumerate(x.get_codes()): + if code >= 0: + indices[i] = x_index[code] + else: + indices[i] = -1 + + return indices + + +@singledispatch +def create_match_index( + targets: Any, + duplicate_method: Literal["first", "last", "any"] = "first", + dtype: Optional[numpy.dtype] = None, + fail_missing: Optional[bool] = None +) -> MatchIndex: + """ + Create a index for matching an arbitrary sequence against ``targets``. + Calling ``create_match_index(targets, ...).match(x)`` is equivalent to ``match(x, targets, ...)``. + + Args: + targets: + Targets to be matched against, see :py:func:`~match` for details. + + duplicate_method: + How to handle duplicate entries in ``targets``, see :py:func:`~match` for details. + + dtype: + NumPy type of the output array, see :py:func:`~match` for details. + + fail_missing: + Whether to raise an error if a value cannot be found in ``targets``, see :py:func:`~match` for details. + + Returns: + A ``MatchIndex``. + Other implementations of ``create_match_index()`` may return any object that has a ``match()`` method. + + Examples: + >>> import biocutils + >>> mobj = biocutils.create_match_index(["A", "B", "C", "D"]) + >>> mobj.match(["A", "B", "B", "C", "C", "D", "E"]) + >>> + >>> ft = biocutils.Factor.from_sequence(["a", "B", "c", "D", "e", "B", "D"]) + >>> fobj = biocutils.create_match_index(ft) + >>> fx = biocutils.Factor.from_sequence(["A", "B", "B", "C", "C", "D", "E"]) + >>> fobj.match(fx) + """ + + return MatchIndex(targets, duplicate_method=duplicate_method, dtype=dtype, fail_missing=fail_missing) + + +@singledispatch def match( - x: Sequence, - targets: Union[dict, Sequence], - duplicate_method: DUPLICATE_METHOD = "first", + x: Any, + targets: Any, + duplicate_method: Literal["first", "last", "any"] = "first", dtype: Optional[numpy.dtype] = None, fail_missing: Optional[bool] = None, ) -> numpy.ndarray: - """Find a matching value of each element of ``x`` in ``target``. + """ + Find a matching value of each element of ``x`` in ``targets``. + Calling ``match(x, targets, ...)`` should be equivalent to ``create_match_index(targets, ...).match(x)``. Args: x: - Sequence of values to match. + Values to match against ``targets``. targets: - Sequence of targets to be matched against. Alternatively, a - dictionary generated by passing a sequence of targets to - :py:meth:`~biocutils.map_to_index.map_to_index`. + Targets to be matched against. + It is not strictly necessary that ``x`` is of the same type as ``targets``, + but entries of ``x`` should be capable of being equal to entries of ``x``. duplicate_method: - How to handle duplicate entries in ``targets``. Matches can - be reported to the first or last occurrence of duplicates. + How to handle duplicate entries in ``targets``. + Either the first, last or any occurrence of each target is reported. dtype: - NumPy type of the output array. This should be an integer type; if - missing values are expected, the type should be a signed integer. - If None, a suitable signed type is automatically determined. + NumPy type of the output array. + This should be an integer type; if missing values are expected, the type should be a signed integer. + If ``None``, a suitable signed type is automatically determined. fail_missing: Whether to raise an error if ``x`` cannot be found in ``targets``. - If ``None``, this defaults to ``True`` if ``dtype`` is an unsigned - type, otherwise it defaults to ``False``. + If ``None``, this defaults to ``True`` if ``dtype`` is an unsigned type, otherwise it defaults to ``False``. Returns: - Array of length equal to ``x``, containing the integer position of each - entry of ``x`` inside ``target``; or -1, if the entry of ``x`` is - None or cannot be found in ``target``. - """ - if not isinstance(targets, dict): - targets = map_to_index(targets, duplicate_method=duplicate_method) - - if dtype is None: - dtype = numpy.min_scalar_type(-len(targets)) # get a signed type - indices = numpy.zeros(len(x), dtype=dtype) + NumPy array of length equal to ``x``, containing the integer position of each entry of ``x`` inside ``targets``; + or -1, if the entry of ``x`` is ``None`` or cannot be found in ``targets``. - if fail_missing is None: - fail_missing = numpy.issubdtype(dtype, numpy.unsignedinteger) + Examples: + >>> import biocutils + >>> biocutils.match(["A", "B", "B", "C", "D", "D", "E"], ["A", "B", "C", "D"]) + >>> + >>> fx = biocutils.Factor.from_sequence(["A", "B", "B", "C", "C", "D", "E"]) + >>> ft = biocutils.Factor.from_sequence(["a", "B", "c", "D", "e", "B", "D"]) + >>> biocutils.match(fx, ft, duplicate_method="last") + """ - # Separate loops to reduce branching in the tight inner loop. - if not fail_missing: - for i, y in enumerate(x): - if y in targets: - indices[i] = targets[y] - else: - indices[i] = -1 - else: - for i, y in enumerate(x): - if y not in targets: - raise ValueError("cannot find '" + str(y) + "' in 'targets'") - indices[i] = targets[y] - - return indices + obj = create_match_index(targets, duplicate_method=duplicate_method, dtype=dtype, fail_missing=fail_missing) + return obj.match(x) diff --git a/tests/test_match.py b/tests/test_match.py index c8fd6ba..289d808 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -1,4 +1,4 @@ -from biocutils import match, map_to_index +from biocutils import match, Factor import numpy import pytest @@ -11,16 +11,19 @@ def test_match_simple(): assert list(mm) == [3, 1, 2, 0, 3, 3, 1, 0, 2] assert mm.dtype == numpy.dtype("int8") - mm2 = match(x, map_to_index(levels)) - assert (mm == mm2).all() + mm = match(x, levels, fail_missing=True, dtype=numpy.uint32) + assert list(mm) == [3, 1, 2, 0, 3, 3, 1, 0, 2] + assert mm.dtype == numpy.dtype("uint32") def test_match_duplicates(): x = [5, 1, 2, 3, 5, 6, 7, 7, 2, 1] - mm = match(x, [1, 2, 3, 3, 5, 6, 1, 7, 6]) + levels = [1, 2, 3, 3, 5, 6, 1, 7, 6] + + mm = match(x, levels) assert list(mm) == [4, 0, 1, 2, 4, 5, 7, 7, 1, 0] - mm = match(x, [1, 2, 3, 3, 5, 6, 1, 7, 6], duplicate_method="last") + mm = match(x, levels, duplicate_method="last") assert list(mm) == [4, 6, 1, 3, 4, 8, 7, 7, 1, 6] @@ -33,24 +36,83 @@ def test_match_none(): def test_match_dtype(): - mm = match(["A", "F", "B", "D", "F", "A", "C", "F", "B"], ["D", "C", "B", "A"], dtype=numpy.dtype("int32")) + levels = ["D", "C", "B", "A"] + + mm = match(["A", "F", "B", "D", "F", "A", "C", "F", "B"], levels, dtype=numpy.dtype("int32")) assert list(mm) == [3, -1, 2, 0, -1, 3, 1, -1, 2] assert mm.dtype == numpy.dtype("int32") - mm = match(["A", "B", "D", "A", "C", "B"], ["D", "C", "B", "A"], dtype=numpy.dtype("uint32")) + mm = match(["A", "B", "D", "A", "C", "B"], levels, dtype=numpy.dtype("uint32")) assert list(mm) == [3, 2, 0, 3, 1, 2] assert mm.dtype == numpy.dtype("uint32") def test_match_fail_missing(): - x = match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"]) - assert list(x) == [3, -1, 2, 0, -1] + x = ["A", "E", "B", "D", "E"] + levels = ["D", "C", "B", "A"] + mm = match(x, levels) + assert list(mm) == [3, -1, 2, 0, -1] + + with pytest.raises(ValueError, match="cannot find"): + match(x, levels, fail_missing=True) + + with pytest.raises(ValueError, match="cannot find"): + match(x, levels, dtype=numpy.uint32) + + mm = match(["A", "C", "B", "D", "C"], levels, fail_missing=True) + assert list(mm) == [3, 1, 2, 0, 1] + + +def test_match_Factor(): + x = Factor.from_sequence(["A", "C", "B", "D", "A", "A", "C", "D", "B"]) + levels = Factor.from_sequence(["D", "C", "B", "A"]) + + mm = match(x, levels) + assert list(mm) == [3, 1, 2, 0, 3, 3, 1, 0, 2] + assert mm.dtype == numpy.dtype("int8") + + mm = match(x, levels, fail_missing=True, dtype=numpy.uint32) + assert list(mm) == [3, 1, 2, 0, 3, 3, 1, 0, 2] + assert mm.dtype == numpy.dtype("uint32") + + # Also works when only one of these is a factor. + mm = match(list(x), levels) + assert list(mm) == [3, 1, 2, 0, 3, 3, 1, 0, 2] + mm = match(x, list(levels)) + assert list(mm) == [3, 1, 2, 0, 3, 3, 1, 0, 2] + + +def test_match_Factor_duplicates(): + x = Factor.from_sequence([5, 1, 2, 3, 5, 6, 7, 7, 2, 1]) + levels = Factor.from_sequence([1, 2, 3, 3, 5, 6, 1, 7, 6]) + + mm = match(x, levels) + assert list(mm) == [4, 0, 1, 2, 4, 5, 7, 7, 1, 0] + + mm = match(x, levels, duplicate_method="last") + assert list(mm) == [4, 6, 1, 3, 4, 8, 7, 7, 1, 6] + + +def test_match_Factor_none(): + mm = match(Factor.from_sequence(["A", None, "B", "D", None, "A", "C", None, "B"]), Factor.from_sequence(["D", "C", "B", "A"])) + assert list(mm) == [3, -1, 2, 0, -1, 3, 1, -1, 2] + + mm = match(Factor.from_sequence(["A", "B", "D", "A", "C", "B"]), Factor.from_sequence(["D", None, "C", "B", None, "A"])) + assert list(mm) == [5, 3, 0, 5, 2, 3] + + +def test_match_Factor_fail_missing(): + x = Factor.from_sequence(["A", "E", "B", "D", "E"]) + levels = Factor.from_sequence(["D", "C", "B", "A"]) + + mm = match(x, levels) + assert list(mm) == [3, -1, 2, 0, -1] with pytest.raises(ValueError, match="cannot find"): - match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"], fail_missing=True) + match(x, levels, fail_missing=True) with pytest.raises(ValueError, match="cannot find"): - match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"], dtype=numpy.uint32) + match(x, levels, dtype=numpy.uint32) - x = match(["A", "C", "B", "D", "C"], ["D", "C", "B", "A"], fail_missing=True) - assert list(x) == [3, 1, 2, 0, 1] + mm = match(Factor.from_sequence(["A", "C", "B", "D", "C"]), levels, fail_missing=True) + assert list(mm) == [3, 1, 2, 0, 1] From 65408800eacdaf022cc1b25c90d94b53781b8053 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 07:00:30 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/biocutils/match.py | 118 ++++++++++++++++++++++++++++++++++------- tests/test_match.py | 2 +- 2 files changed, 100 insertions(+), 20 deletions(-) diff --git a/src/biocutils/match.py b/src/biocutils/match.py index 3b7c955..98e6bb2 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence, Union, Literal +from typing import Any, Optional, Literal from functools import singledispatch import numpy @@ -8,7 +8,7 @@ class MatchIndex: """ - An index for matching one or more ``x`` against different ``targets``. + An index for matching one or more ``x`` against different ``targets``. This is typically constructed by :py:func:`~create_match_index`. """ @@ -17,7 +17,7 @@ def __init__( targets: Any, duplicate_method: Literal["first", "last", "any"] = "first", dtype: Optional[numpy.dtype] = None, - fail_missing: Optional[bool] = None + fail_missing: Optional[bool] = None, ): """ Args: @@ -43,7 +43,7 @@ def __init__( elif isinstance(targets, Factor): # Optimized method when both x and targets are factors. target_index = [None] * len(targets.get_levels()) - first_tie = (duplicate_method == "first" or duplicate_method == "any") + first_tie = duplicate_method == "first" or duplicate_method == "any" for i, code in enumerate(targets.get_codes()): if code < 0: continue @@ -52,7 +52,7 @@ def __init__( mapping = {} for i, lev in enumerate(targets.get_levels()): - candidate = target_index[i] + candidate = target_index[i] if candidate is not None: mapping[lev] = candidate self._map = mapping @@ -86,6 +86,7 @@ def match(self, x: Any) -> numpy.ndarray: """ from .Factor import Factor + indices = numpy.zeros(len(x), dtype=self._dtype) if not isinstance(x, Factor): @@ -134,7 +135,7 @@ def create_match_index( targets: Any, duplicate_method: Literal["first", "last", "any"] = "first", dtype: Optional[numpy.dtype] = None, - fail_missing: Optional[bool] = None + fail_missing: Optional[bool] = None, ) -> MatchIndex: """ Create a index for matching an arbitrary sequence against ``targets``. @@ -154,17 +155,56 @@ def create_match_index( Whether to raise an error if a value cannot be found in ``targets``, see :py:func:`~match` for details. Returns: - A ``MatchIndex``. + A ``MatchIndex``. Other implementations of ``create_match_index()`` may return any object that has a ``match()`` method. Examples: >>> import biocutils - >>> mobj = biocutils.create_match_index(["A", "B", "C", "D"]) - >>> mobj.match(["A", "B", "B", "C", "C", "D", "E"]) - >>> - >>> ft = biocutils.Factor.from_sequence(["a", "B", "c", "D", "e", "B", "D"]) - >>> fobj = biocutils.create_match_index(ft) - >>> fx = biocutils.Factor.from_sequence(["A", "B", "B", "C", "C", "D", "E"]) + >>> mobj = biocutils.create_match_index( + ... [ + ... "A", + ... "B", + ... "C", + ... "D", + ... ] + ... ) + >>> mobj.match( + ... [ + ... "A", + ... "B", + ... "B", + ... "C", + ... "C", + ... "D", + ... "E", + ... ] + ... ) + >>> + >>> ft = biocutils.Factor.from_sequence( + ... [ + ... "a", + ... "B", + ... "c", + ... "D", + ... "e", + ... "B", + ... "D", + ... ] + ... ) + >>> fobj = biocutils.create_match_index( + ... ft + ... ) + >>> fx = biocutils.Factor.from_sequence( + ... [ + ... "A", + ... "B", + ... "B", + ... "C", + ... "C", + ... "D", + ... "E", + ... ] + ... ) >>> fobj.match(fx) """ @@ -181,7 +221,7 @@ def match( ) -> numpy.ndarray: """ Find a matching value of each element of ``x`` in ``targets``. - Calling ``match(x, targets, ...)`` should be equivalent to ``create_match_index(targets, ...).match(x)``. + Calling ``match(x, targets, ...)`` should be equivalent to ``create_match_index(targets, ...).match(x)``. Args: x: @@ -211,11 +251,51 @@ def match( Examples: >>> import biocutils - >>> biocutils.match(["A", "B", "B", "C", "D", "D", "E"], ["A", "B", "C", "D"]) - >>> - >>> fx = biocutils.Factor.from_sequence(["A", "B", "B", "C", "C", "D", "E"]) - >>> ft = biocutils.Factor.from_sequence(["a", "B", "c", "D", "e", "B", "D"]) - >>> biocutils.match(fx, ft, duplicate_method="last") + >>> biocutils.match( + ... [ + ... "A", + ... "B", + ... "B", + ... "C", + ... "D", + ... "D", + ... "E", + ... ], + ... [ + ... "A", + ... "B", + ... "C", + ... "D", + ... ], + ... ) + >>> + >>> fx = biocutils.Factor.from_sequence( + ... [ + ... "A", + ... "B", + ... "B", + ... "C", + ... "C", + ... "D", + ... "E", + ... ] + ... ) + >>> ft = biocutils.Factor.from_sequence( + ... [ + ... "a", + ... "B", + ... "c", + ... "D", + ... "e", + ... "B", + ... "D", + ... ] + ... ) + >>> biocutils.match( + ... fx, + ... ft, + ... duplicate_method="last", + ... ) """ obj = create_match_index(targets, duplicate_method=duplicate_method, dtype=dtype, fail_missing=fail_missing) diff --git a/tests/test_match.py b/tests/test_match.py index 289d808..0cbffe7 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -105,7 +105,7 @@ def test_match_Factor_fail_missing(): x = Factor.from_sequence(["A", "E", "B", "D", "E"]) levels = Factor.from_sequence(["D", "C", "B", "A"]) - mm = match(x, levels) + mm = match(x, levels) assert list(mm) == [3, -1, 2, 0, -1] with pytest.raises(ValueError, match="cannot find"): From 2a1f47efa680434e87d08ea3c8a37d76fa4dfbe8 Mon Sep 17 00:00:00 2001 From: LTLA Date: Thu, 8 Jan 2026 18:03:03 +1100 Subject: [PATCH 3/6] Added deprecation warning about map_to_index. --- src/biocutils/match.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/biocutils/match.py b/src/biocutils/match.py index 98e6bb2..b5b5a69 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -38,6 +38,8 @@ def __init__( if isinstance(targets, dict): # Back-compatible behavior. + import warnings + warnings.warn(DeprecationWarning("'map_to_index()' is deprecated, use 'create_match_index()' instead")) self._map = targets elif isinstance(targets, Factor): From 2cbc8f40982cc35737dce742473ee08430af57ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 07:03:43 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/biocutils/match.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/biocutils/match.py b/src/biocutils/match.py index b5b5a69..d346e3e 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -39,6 +39,7 @@ def __init__( if isinstance(targets, dict): # Back-compatible behavior. import warnings + warnings.warn(DeprecationWarning("'map_to_index()' is deprecated, use 'create_match_index()' instead")) self._map = targets From 20ea2064edf46df1b21a430e5e4e4c0b6a4f0b59 Mon Sep 17 00:00:00 2001 From: LTLA Date: Thu, 8 Jan 2026 22:34:40 +1100 Subject: [PATCH 5/6] Added support for user-specified incomparables=. This avoids hard-coding None and numpy.ma.masked as unmatchable elements. The default is an empty set, which means that None values can now be matched; this is more consistent with the default behavior of R's match(). --- src/biocutils/match.py | 60 ++++++++++++++++++++++++++---------------- tests/test_match.py | 28 +++++++++++++++----- 2 files changed, 59 insertions(+), 29 deletions(-) diff --git a/src/biocutils/match.py b/src/biocutils/match.py index d346e3e..5636165 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -1,10 +1,8 @@ -from typing import Any, Optional, Literal +from typing import Any, Optional, Literal, Union, Sequence from functools import singledispatch import numpy -from .is_missing_scalar import is_missing_scalar - class MatchIndex: """ @@ -16,6 +14,7 @@ def __init__( self, targets: Any, duplicate_method: Literal["first", "last", "any"] = "first", + incomparables: Union[set, Sequence] = set(), dtype: Optional[numpy.dtype] = None, fail_missing: Optional[bool] = None, ): @@ -27,6 +26,9 @@ def __init__( duplicate_method: How to handle duplicate entries in ``targets``, see :py:func:`~match` for details. + incomparables: + Values that cannot be compared, see :py:func:`~match` for details. + dtype: NumPy type of the output array, see :py:func:`~match` for details. @@ -45,26 +47,32 @@ def __init__( elif isinstance(targets, Factor): # Optimized method when both x and targets are factors. - target_index = [None] * len(targets.get_levels()) + target_index = [None] * (len(targets.get_levels()) + 1) # add 1 so that code = -1 still behaves correctly. first_tie = duplicate_method == "first" or duplicate_method == "any" for i, code in enumerate(targets.get_codes()): - if code < 0: - continue if not first_tie or target_index[code] is None: target_index[code] = i mapping = {} for i, lev in enumerate(targets.get_levels()): - candidate = target_index[i] + if lev not in incomparables: + candidate = target_index[i] + if candidate is not None: + mapping[lev] = candidate + + if None not in incomparables: + # None matching to another None is still possible. + candidate = target_index[-1] if candidate is not None: - mapping[lev] = candidate + mapping[None] = target_index[-1] + self._map = mapping else: first_tie = duplicate_method == "first" or duplicate_method == "any" mapping = {} for i, val in enumerate(targets): - if not is_missing_scalar(val): + if val not in incomparables: if not first_tie or val not in mapping: mapping[val] = i self._map = mapping @@ -107,28 +115,24 @@ def match(self, x: Any) -> numpy.ndarray: indices[i] = self._map[y] else: - x_index = [-1] * len(x.get_levels()) + x_index = [-1] * (len(x.get_levels()) + 1) # adding 1 so that code = -1 still works. for i, lev in enumerate(x.get_levels()): if lev in self._map: - candidate = self._map[lev] - if candidate is not None: - x_index[i] = candidate + x_index[i] = self._map[lev] + + if None in self._map: + x_index[-1] = self._map[None] # Separate loops to reduce branching in the tight inner loop. if self._fail_missing: for i, code in enumerate(x.get_codes()): - candidate = -1 - if code >= 0: - candidate = x_index[code] + candidate = x_index[code] if candidate < 0: - raise ValueError("cannot find '" + x[i] + "' in 'targets'") + raise ValueError("cannot find '" + str(x[i]) + "' in 'targets'") indices[i] = candidate else: for i, code in enumerate(x.get_codes()): - if code >= 0: - indices[i] = x_index[code] - else: - indices[i] = -1 + indices[i] = x_index[code] return indices @@ -137,6 +141,7 @@ def match(self, x: Any) -> numpy.ndarray: def create_match_index( targets: Any, duplicate_method: Literal["first", "last", "any"] = "first", + incomparables: Union[set, Sequence] = set(), dtype: Optional[numpy.dtype] = None, fail_missing: Optional[bool] = None, ) -> MatchIndex: @@ -151,6 +156,9 @@ def create_match_index( duplicate_method: How to handle duplicate entries in ``targets``, see :py:func:`~match` for details. + incomparables: + Values that cannot be compared, see :py:func:`~match` for details. + dtype: NumPy type of the output array, see :py:func:`~match` for details. @@ -211,7 +219,7 @@ def create_match_index( >>> fobj.match(fx) """ - return MatchIndex(targets, duplicate_method=duplicate_method, dtype=dtype, fail_missing=fail_missing) + return MatchIndex(targets, duplicate_method=duplicate_method, incomparables=incomparables, dtype=dtype, fail_missing=fail_missing) @singledispatch @@ -219,6 +227,7 @@ def match( x: Any, targets: Any, duplicate_method: Literal["first", "last", "any"] = "first", + incomparables: Union[set, Sequence] = set(), dtype: Optional[numpy.dtype] = None, fail_missing: Optional[bool] = None, ) -> numpy.ndarray: @@ -239,6 +248,11 @@ def match( How to handle duplicate entries in ``targets``. Either the first, last or any occurrence of each target is reported. + incomparables: + Values of ``x`` or ``targets`` that cannot be compared. + No match will be reported for any value of ``x`` that is in ``incomparables``. + Any object that has an ``__in__`` method can be used here. + dtype: NumPy type of the output array. This should be an integer type; if missing values are expected, the type should be a signed integer. @@ -301,5 +315,5 @@ def match( ... ) """ - obj = create_match_index(targets, duplicate_method=duplicate_method, dtype=dtype, fail_missing=fail_missing) + obj = create_match_index(targets, duplicate_method=duplicate_method, incomparables=incomparables, dtype=dtype, fail_missing=fail_missing) return obj.match(x) diff --git a/tests/test_match.py b/tests/test_match.py index 0cbffe7..2adafe3 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -28,11 +28,19 @@ def test_match_duplicates(): def test_match_none(): - mm = match(["A", None, "B", "D", None, "A", "C", None, "B"], ["D", "C", "B", "A"]) + x = ["A", None, "B", "D", None, "A", "C", None, "B"] + mm = match(x, ["D", "C", "B", "A"]) assert list(mm) == [3, -1, 2, 0, -1, 3, 1, -1, 2] - mm = match(["A", "B", "D", "A", "C", "B"], ["D", None, "C", "B", None, "A"]) - assert list(mm) == [5, 3, 0, 5, 2, 3] + lev = ["D", None, "C", "B", "A"] + mm = match(x, lev) + assert list(mm) == [4, 1, 3, 0, 1, 4, 2, 1, 3] + + mm = match(x, lev, incomparables=set([None])) + assert list(mm) == [4, -1, 3, 0, -1, 4, 2, -1, 3] + + with pytest.raises(match="cannot find"): + match(x, lev, incomparables=set([None]), fail_missing=True) def test_match_dtype(): @@ -94,11 +102,19 @@ def test_match_Factor_duplicates(): def test_match_Factor_none(): - mm = match(Factor.from_sequence(["A", None, "B", "D", None, "A", "C", None, "B"]), Factor.from_sequence(["D", "C", "B", "A"])) + x = Factor.from_sequence(["A", None, "B", "D", None, "A", "C", None, "B"]) + mm = match(x, Factor.from_sequence(["D", "C", "B", "A"])) assert list(mm) == [3, -1, 2, 0, -1, 3, 1, -1, 2] - mm = match(Factor.from_sequence(["A", "B", "D", "A", "C", "B"]), Factor.from_sequence(["D", None, "C", "B", None, "A"])) - assert list(mm) == [5, 3, 0, 5, 2, 3] + lev = Factor.from_sequence(["D", None, "C", "B", "A"]) + mm = match(x, lev) + assert list(mm) == [4, 1, 3, 0, 1, 4, 2, 1, 3] + + mm = match(x, lev, incomparables=set([None])) + assert list(mm) == [4, -1, 3, 0, -1, 4, 2, -1, 3] + + with pytest.raises(match="cannot find"): + match(x, lev, incomparables=set([None]), fail_missing=True) def test_match_Factor_fail_missing(): From 73526e846e8d7fb4d0d4132c089ea169c39aad68 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 11:37:17 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/biocutils/match.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/biocutils/match.py b/src/biocutils/match.py index 5636165..1570970 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -47,7 +47,7 @@ def __init__( elif isinstance(targets, Factor): # Optimized method when both x and targets are factors. - target_index = [None] * (len(targets.get_levels()) + 1) # add 1 so that code = -1 still behaves correctly. + target_index = [None] * (len(targets.get_levels()) + 1) # add 1 so that code = -1 still behaves correctly. first_tie = duplicate_method == "first" or duplicate_method == "any" for i, code in enumerate(targets.get_codes()): if not first_tie or target_index[code] is None: @@ -60,7 +60,7 @@ def __init__( if candidate is not None: mapping[lev] = candidate - if None not in incomparables: + if None not in incomparables: # None matching to another None is still possible. candidate = target_index[-1] if candidate is not None: @@ -115,7 +115,7 @@ def match(self, x: Any) -> numpy.ndarray: indices[i] = self._map[y] else: - x_index = [-1] * (len(x.get_levels()) + 1) # adding 1 so that code = -1 still works. + x_index = [-1] * (len(x.get_levels()) + 1) # adding 1 so that code = -1 still works. for i, lev in enumerate(x.get_levels()): if lev in self._map: x_index[i] = self._map[lev] @@ -219,7 +219,9 @@ def create_match_index( >>> fobj.match(fx) """ - return MatchIndex(targets, duplicate_method=duplicate_method, incomparables=incomparables, dtype=dtype, fail_missing=fail_missing) + return MatchIndex( + targets, duplicate_method=duplicate_method, incomparables=incomparables, dtype=dtype, fail_missing=fail_missing + ) @singledispatch @@ -315,5 +317,7 @@ def match( ... ) """ - obj = create_match_index(targets, duplicate_method=duplicate_method, incomparables=incomparables, dtype=dtype, fail_missing=fail_missing) + obj = create_match_index( + targets, duplicate_method=duplicate_method, incomparables=incomparables, dtype=dtype, fail_missing=fail_missing + ) return obj.match(x)