From 32e590847ef03099fe7f2d1a6941e021836d0181 Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 18 Dec 2025 14:25:13 +0000 Subject: [PATCH 1/2] CU-869bhknfm: Refactor setting of filters for embedding linker --- .../components/linking/embedding_linker.py | 214 ++++++++++++++---- 1 file changed, 172 insertions(+), 42 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index c13a48fb6..9ceab3953 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -11,6 +11,7 @@ from collections import defaultdict import logging import math +import numpy as np from medcat.utils.import_utils import ensure_optional_extras_installed import medcat @@ -85,8 +86,9 @@ def __init__(self, cdb: CDB, config: Config) -> None: ] for name in self._name_keys ] + self._initialize_filter_structures() - def create_embeddings(self, + def create_embeddings(self, embedding_model_name: Optional[str] = None, max_length: Optional[int] = None, ): @@ -281,6 +283,170 @@ def _get_context_vectors( texts.append(text) return self._embed(texts, self.device) + def _initialize_cui_name_mapping(self) -> None: + """Call this once during initialization to pre-compute CUI->name.""" + self._cui_to_name_mask = {} + + for cui, cui_idx in self._cui_to_idx.items(): + mask = torch.tensor( + [cui_idx in name_cui_idxs + for name_cui_idxs in self._name_to_cui_idxs], + dtype=torch.bool, + device=self.device + ) + self._cui_to_name_mask[cui] = mask + + # Cache _has_cuis_all as well + self._has_cuis_all_cached = torch.tensor( + [bool(self.cdb.name2info[name]["per_cui_status"]) + for name in self._name_keys], + device=self.device, + dtype=torch.bool, + ) + + def _initialize_filter_structures(self) -> None: + """Call once during initialization to create efficient lookup structures.""" + # Build an inverted index: cui_idx -> list of name indices that contain it + # This is the KEY optimization - we flip the lookup direction + if not hasattr(self, '_cui_idx_to_name_idxs'): + cui2name_indices: defaultdict[ + int, list[int]] = defaultdict(list) + + for name_idx, cui_idxs in enumerate(self._name_to_cui_idxs): + for cui_idx in cui_idxs: + cui2name_indices[cui_idx].append(name_idx) + + # Convert lists to numpy arrays for faster indexing + self._cui_idx_to_name_idxs = { + cui_idx: np.array(name_idxs, dtype=np.int32) + for cui_idx, name_idxs in cui2name_indices.items() + } + + # Cache _has_cuis_all + if not hasattr(self, '_has_cuis_all_cached'): + self._has_cuis_all_cached = torch.tensor( + [bool(self.cdb.name2info[name]["per_cui_status"]) + for name in self._name_keys], + device=self.device, + dtype=torch.bool, + ) + + def _get_include_filters_1cui( + self, cui: str, n: int) -> torch.Tensor: + """Optimized single CUI include filter using inverted index.""" + if cui not in self._cui_to_idx: + return torch.zeros(n, dtype=torch.bool, device=self.device) + + cui_idx = self._cui_to_idx[cui] + + # Use inverted index: get all name indices that contain this CUI + if cui_idx in self._cui_idx_to_name_idxs: + name_indices = self._cui_idx_to_name_idxs[cui_idx] + + # Create mask by setting specific indices to True + allowed_mask = torch.zeros(n, dtype=torch.bool, device=self.device) + allowed_mask[torch.from_numpy(name_indices).to(self.device)] = True + return allowed_mask + else: + return torch.zeros(n, dtype=torch.bool, device=self.device) + + def _get_include_filters_multi_cui( + self, include_set: Set[str], n: int) -> torch.Tensor: + """Optimized multi-CUI include filter using inverted index.""" + include_cui_idxs = [ + self._cui_to_idx[cui] for cui in include_set + if cui in self._cui_to_idx + ] + + if not include_cui_idxs: + return torch.zeros(n, dtype=torch.bool, device=self.device) + + # Collect all name indices from inverted index + all_name_indices_list: list[np.ndarray] = [] + for cui_idx in include_cui_idxs: + if cui_idx in self._cui_idx_to_name_idxs: + all_name_indices_list.append( + self._cui_idx_to_name_idxs[cui_idx]) + + if not all_name_indices_list: + return torch.zeros(n, dtype=torch.bool, device=self.device) + + # Concatenate and get unique indices + all_name_indices = np.unique( + np.concatenate(all_name_indices_list)) + + # Create mask + allowed_mask = torch.zeros(n, dtype=torch.bool, device=self.device) + allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = True + return allowed_mask + + def _get_include_filters( + self, include_set: Set[str], n: int) -> torch.Tensor: + """Route to appropriate include filter method.""" + if len(include_set) == 1: + cui = next(iter(include_set)) + return self._get_include_filters_1cui(cui, n) + else: + return self._get_include_filters_multi_cui( + include_set, n) + + def _get_exclude_filters_1cui( + self, allowed_mask: torch.Tensor, cui: str) -> torch.Tensor: + """Optimized single CUI exclude filter using inverted index.""" + if cui not in self._cui_to_idx: + return allowed_mask + + cui_idx = self._cui_to_idx[cui] + + if cui_idx in self._cui_idx_to_name_idxs: + name_indices = self._cui_idx_to_name_idxs[cui_idx] + # Set specific indices to False + allowed_mask[ + torch.from_numpy(name_indices).to(self.device)] = False + + return allowed_mask + + def _get_exclude_filters_multi_cui( + self, allowed_mask: torch.Tensor, exclude_set: Set[str], + ) -> torch.Tensor: + """Optimized multi-CUI exclude filter using inverted index.""" + exclude_cui_idxs = [ + self._cui_to_idx[cui] for cui in exclude_set + if cui in self._cui_to_idx + ] + + if not exclude_cui_idxs: + return allowed_mask + + # Collect all name indices to exclude + all_name_indices: list[np.ndarray] = [] + for cui_idx in exclude_cui_idxs: + if cui_idx in self._cui_idx_to_name_idxs: + all_name_indices.append(self._cui_idx_to_name_idxs[cui_idx]) + + if all_name_indices: + all_name_indices = np.unique(np.concatenate(all_name_indices)) + allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = False + + return allowed_mask + + def _get_exclude_filters( + self, exclude_set: Set[str], n: int) -> torch.Tensor: + """Route to appropriate exclude filter method.""" + # Start with all allowed + allowed_mask = torch.ones(n, dtype=torch.bool, device=self.device) + + if not exclude_set: + return allowed_mask + + if len(exclude_set) == 1: + cui = next(iter(exclude_set)) + return self._get_exclude_filters_1cui( + allowed_mask, cui) + else: + return self._get_exclude_filters_multi_cui( + allowed_mask, exclude_set) + def _set_filters(self) -> None: include_set = self.cnf_l.filters.cuis exclude_set = self.cnf_l.filters.cuis_exclude @@ -295,54 +461,18 @@ def _set_filters(self) -> None: return n = len(self._name_keys) - allowed_mask = torch.empty(n, dtype=torch.bool, device=self.device) if include_set: - # if in include set, ignore exclude set. - allowed_mask[:] = False - include_cui_idxs = { - self._cui_to_idx[cui] for cui in include_set if cui in self._cui_to_idx - } - include_idxs = [ - name_idx - for name_idx, name_cui_idxs in enumerate(self._name_to_cui_idxs) - if any(cui in include_cui_idxs for cui in name_cui_idxs) - ] - allowed_mask[ - torch.tensor(include_idxs, dtype=torch.long, device=self.device) - ] = True + allowed_mask = self._get_include_filters( + include_set, n) else: - # only look at exclude if there's no include set - allowed_mask[:] = True - if exclude_set: - exclude_cui_idxs = { - self._cui_to_idx[cui] - for cui in exclude_set - if cui in self._cui_to_idx - } - exclude_idxs = [ - i - for i, name_cui_idxs in enumerate(self._name_to_cui_idxs) - if any(ci in exclude_cui_idxs for ci in name_cui_idxs) - ] - allowed_mask[ - torch.tensor(exclude_idxs, dtype=torch.long, device=self.device) - ] = False + allowed_mask = self._get_exclude_filters( + exclude_set, n) - # checking if a name has at least 1 cui related to it. - _has_cuis_all = torch.tensor( - [ - bool(self.cdb.name2info[name]["per_cui_status"]) - for name in self._name_keys - ], - device=self.device, - dtype=torch.bool, - ) - self._valid_names = _has_cuis_all & allowed_mask + self._valid_names = self._has_cuis_all_cached & allowed_mask self._last_include_set = set(include_set) if include_set is not None else None self._last_exclude_set = set(exclude_set) if exclude_set is not None else None - def _disambiguate_by_cui( self, cui_candidates: list[str], scores: Tensor ) -> tuple[str, float]: From 92fa0b26b3b5ebe4eefd7f678e68bb3c7ea4aa07 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Thu, 18 Dec 2025 16:57:24 +0200 Subject: [PATCH 2/2] CU-869bfagqw: Fix a small typing issue --- medcat-v2/medcat/components/linking/embedding_linker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index 9ceab3953..3d127a9e6 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -419,13 +419,13 @@ def _get_exclude_filters_multi_cui( return allowed_mask # Collect all name indices to exclude - all_name_indices: list[np.ndarray] = [] + _all_name_indices: list[np.ndarray] = [] for cui_idx in exclude_cui_idxs: if cui_idx in self._cui_idx_to_name_idxs: - all_name_indices.append(self._cui_idx_to_name_idxs[cui_idx]) + _all_name_indices.append(self._cui_idx_to_name_idxs[cui_idx]) - if all_name_indices: - all_name_indices = np.unique(np.concatenate(all_name_indices)) + if _all_name_indices: + all_name_indices = np.unique(np.concatenate(_all_name_indices)) allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = False return allowed_mask