diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 63e1851..93ef3f6 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -6,6 +6,6 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - uses: psf/black@20.8b1 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + - uses: psf/black@@stable diff --git a/.github/workflows/cli-coverage.yml b/.github/workflows/cli-coverage.yml index 06a23e4..5d35574 100644 --- a/.github/workflows/cli-coverage.yml +++ b/.github/workflows/cli-coverage.yml @@ -8,7 +8,7 @@ jobs: cli-coverage-report: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.12"] os: [ ubuntu-latest ] # can't use macOS when using service containers or container jobs runs-on: ${{ matrix.os }} services: @@ -28,7 +28,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.12' - name: Install uv run: pip install uv diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index e1da342..2516e66 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,7 @@ jobs: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/run-pytest.yml b/.github/workflows/run-pytest.yml index 876596f..5b594a0 100644 --- a/.github/workflows/run-pytest.yml +++ b/.github/workflows/run-pytest.yml @@ -12,7 +12,7 @@ jobs: pytest: strategy: matrix: - python-version: ["3.9", "3.12"] + python-version: ["3.10", "3.13"] os: [ubuntu-latest] # can't use macOS when using service containers or container jobs runs-on: ${{ matrix.os }} services: diff --git a/bbconf/_version.py b/bbconf/_version.py index ea370a8..9e78220 100644 --- a/bbconf/_version.py +++ b/bbconf/_version.py @@ -1 +1 @@ -__version__ = "0.12.0" +__version__ = "0.14.0" diff --git a/bbconf/bbagent.py b/bbconf/bbagent.py index 59a58b6..ca8de8c 100644 --- a/bbconf/bbagent.py +++ b/bbconf/bbagent.py @@ -1,36 +1,36 @@ import logging +import statistics from functools import cached_property from pathlib import Path -from typing import List, Union, Dict -import numpy as np -import statistics +from typing import Dict, List, Union +import numpy as np from sqlalchemy.orm import Session -from sqlalchemy.sql import distinct, func, select, and_, or_ +from sqlalchemy.sql import and_, distinct, func, or_, select from bbconf.config_parser.bedbaseconfig import BedBaseConfig from bbconf.db_utils import ( Bed, BedMetadata, BedSets, + BedStats, + Files, + GeoGsmStatus, License, - UsageBedSetMeta, UsageBedMeta, + UsageBedSetMeta, UsageFiles, UsageSearch, - GeoGsmStatus, - BedStats, - Files, ) from bbconf.models.base_models import ( - StatsReturn, - UsageModel, - FileStats, - UsageStats, AllFilesInfo, - FileInfo, BinValues, + FileInfo, + FileStats, GEOStatistics, + StatsReturn, + UsageModel, + UsageStats, ) from bbconf.modules.bedfiles import BedAgentBedFile from bbconf.modules.bedsets import BedAgentBedSet diff --git a/bbconf/config_parser/bedbaseconfig.py b/bbconf/config_parser/bedbaseconfig.py index a6cedc3..3b7feb4 100644 --- a/bbconf/config_parser/bedbaseconfig.py +++ b/bbconf/config_parser/bedbaseconfig.py @@ -1,3 +1,4 @@ +import io import logging import os import warnings @@ -5,25 +6,30 @@ from typing import List, Literal, Union import boto3 +import joblib import qdrant_client -from qdrant_client import QdrantClient, models +import requests import s3fs import yacman import zarr from botocore.exceptions import BotoCoreError, EndpointConnectionError +from fastembed import TextEmbedding from geniml.region2vec.main import Region2VecExModel from geniml.search import BED2BEDSearchInterface from geniml.search.backends import BiVectorBackend, QdrantBackend from geniml.search.interfaces import BiVectorSearchInterface from geniml.search.query2vec import BED2Vec from pephubclient import PEPHubClient +from qdrant_client import QdrantClient, models +from sentence_transformers import SparseEncoder +from umap import UMAP from zarr import Group as Z_GROUP +from zarr.storage import FsspecStore from bbconf.config_parser.const import ( S3_BEDSET_PATH_FOLDER, S3_FILE_PATH_FOLDER, S3_PLOTS_PATH_FOLDER, - TEXT_EMBEDDING_DIMENSION, ) from bbconf.config_parser.models import ConfigFile from bbconf.const import PKG_NAME, ZARR_TOKENIZED_FOLDER @@ -59,22 +65,55 @@ def __init__(self, config: Union[Path, str], init_ml: bool = True): self._config = self._read_config_file(self.cfg_path) self._db_engine = self._init_db_engine() - self._qdrant_engine = self._init_qdrant_backend() - self._qdrant_text_engine = self._init_qdrant_text_backend() - self._qdrant_advanced_engine = self._init_qdrant_advanced_backend() + try: + self.qdrant_client: QdrantClient = self._init_qdrant_client() + except Exception as err: + _LOGGER.error( + f"Unable to create Qdrant client. Skipping ML model initialization. Error: {err}" + ) + init_ml = False if init_ml: - self._b2bsi = self._init_b2bsi_object() - self._r2v = self._init_r2v_object() - self._bivec = self._init_bivec_object() + + self.dense_encoder: TextEmbedding = self._init_dense_encoder() + self.sparce_encoder: Union[SparseEncoder, None] = self._init_sparce_model() + self._umap_encoder: Union[UMAP, None] = self._init_umap_model() + self.r2v_encoder: Union[Region2VecExModel, None] = self._init_r2v_encoder() + + self._init_qdrant_hybrid( + qdrant_cl=self.qdrant_client, + dense_encoder=self.dense_encoder, + ) + + self.qdrant_file_backend: Union[QdrantBackend, None] = ( + self._init_qdrant_file_backend(qdrant_cl=self.qdrant_client) + ) # used for bivec search + self._qdrant_text_backend: Union[QdrantBackend, None] = ( + self._init_qdrant_text_backend( + qdrant_cl=self.qdrant_client, + dense_encoder=self.dense_encoder, + ) + ) # used for bivec search + + self.b2b_search_interface = self._init_b2b_search_interface( + qdrant_file_backend=self.qdrant_file_backend, + region_encoder=self.r2v_encoder, + ) + + self.bivec_search_interface = self._init_bivec_interface( + qdrant_file_backend=self.qdrant_file_backend, + qdrant_text_backend=self._qdrant_text_backend, + text_encoder=self.dense_encoder, + ) else: _LOGGER.info( "Skipping initialization of ML models, init_ml parameter set to False." ) - - self._b2bsi = None - self._r2v = None - self._bivec = None + self.r2v_encoder = None + self.b2b_search_interface = None + self.bivec_search_interface = None + self._umap_encoder: Union[UMAP, None] = None + self.sparce_encoder = None self._phc = self._init_pephubclient() self._boto3_client = self._init_boto3_client() @@ -121,43 +160,6 @@ def db_engine(self) -> BaseEngine: """ return self._db_engine - @property - def b2bsi(self) -> Union[BED2BEDSearchInterface, None]: - """ - Get bed2bednn object - - :return: bed2bednn object - """ - return self._b2bsi - - @property - def r2v(self) -> Region2VecExModel: - """ - Get region2vec object - - :return: region2vec object - """ - return self._r2v - - @property - def bivec(self) -> BiVectorSearchInterface: - """ - Get bivec search interface object - - :return: bivec search interface object - """ - - return self._bivec - - @property - def qdrant_engine(self) -> QdrantBackend: - """ - Get qdrant engine - - :return: qdrant engine - """ - return self._qdrant_engine - @property def phc(self) -> PEPHubClient: """ @@ -189,20 +191,30 @@ def zarr_root(self) -> Union[Z_GROUP, None]: endpoint_url=self._config.s3.endpoint_url, key=self._config.s3.aws_access_key_id, secret=self._config.s3.aws_secret_access_key, + asynchronous=True, ) except BotoCoreError as e: _LOGGER.error(f"Error in creating s3fs object: {e}") warnings.warn(f"Error in creating s3fs object: {e}", UserWarning) return None - s3_path = f"s3://{self._config.s3.bucket}/{ZARR_TOKENIZED_FOLDER}" + s3_path = f"{self._config.s3.bucket}/{ZARR_TOKENIZED_FOLDER}" - zarr_store = s3fs.S3Map( - root=s3_path, s3=s3fc_obj, check=False, create=self._config.s3.modify_access + store = FsspecStore( + fs=s3fc_obj, + path=s3_path, ) - cache = zarr.LRUStoreCache(zarr_store, max_size=2**28) - return zarr.group(store=cache, overwrite=False) + try: + root = zarr.open_group( + store=store, + mode="a" if self._config.s3.modify_access else "r", + ) + return root + except Exception as e: + _LOGGER.error(f"Error opening zarr group: {e}") + warnings.warn(f"Error opening zarr group: {e}", UserWarning) + return None def _init_db_engine(self) -> BaseEngine: """ @@ -219,81 +231,136 @@ def _init_db_engine(self) -> BaseEngine: drivername=f"{self._config.database.dialect}+{self._config.database.driver}", ) - def _init_qdrant_backend(self) -> QdrantBackend: + def _init_qdrant_client(self) -> QdrantClient: """ Create qdrant client object using credentials provided in config file - - :return: QdrantClient """ - _LOGGER.info("Initializing qdrant engine...") + _LOGGER.info("Initializing qdrant client...") + try: - return QdrantBackend( - collection=self._config.qdrant.file_collection, - qdrant_host=self._config.qdrant.host, - qdrant_port=self._config.qdrant.port, - qdrant_api_key=self._config.qdrant.api_key, + qdrant_cl = QdrantClient( + url=self.config.qdrant.host, + port=self.config.qdrant.port, + api_key=self.config.qdrant.api_key, ) + except qdrant_client.http.exceptions.ResponseHandlingException as err: + raise BedBaseConfError( + f"Error in Connection to qdrant! skipping... Error: {err}" + ) + + return qdrant_cl + + def _init_qdrant_file_backend( + self, qdrant_cl: QdrantClient + ) -> Union[QdrantBackend, None]: + """ + Create qdrant client object using credentials provided in config file + + :param: qdrant_cl: QdrantClient object + :return: QdrantClient + """ + + _LOGGER.info("Initializing qdrant bivec file backend...") + + if not isinstance(qdrant_cl, QdrantClient): _LOGGER.error( - f"Error in Connection to qdrant! skipping... Error: {err}. Qdrant host: {self._config.qdrant.host}" + f"Unable to create Qdrant bivec file collection, qdrant client is None." ) - warnings.warn( - f"error in Connection to qdrant! skipping... Error: {err}", UserWarning + return None + + try: + return QdrantBackend( + qdrant_client=qdrant_cl, + collection=self.config.qdrant.file_collection, ) + except Exception as e: + _LOGGER.error(f"Unable to create Qdrant collection: {e}") + return None - def _init_qdrant_text_backend(self) -> Union[QdrantBackend, None]: + def _init_qdrant_text_backend( + self, qdrant_cl: QdrantClient, dense_encoder: TextEmbedding + ) -> Union[QdrantBackend, None]: """ Create qdrant client text embedding object using credentials provided in config file + :param: qdrant_cl: QdrantClient object + :param: dense_encoder: TextEmbedding model for encoding text queries :return: QdrantClient """ - _LOGGER.info("Initializing qdrant text engine...") + _LOGGER.info("Initializing qdrant bivec text backend...") + + if not isinstance(qdrant_cl, QdrantClient): + _LOGGER.error( + f"Unable to create Qdrant bivec text collection, qdrant client is None." + ) + return None + if not isinstance(dense_encoder, TextEmbedding): + _LOGGER.error( + f"Unable to create Qdrant bivec text collection,, dense encoder is None." + ) + return None + + dimensions = int(dense_encoder.get_embedding_size(self._config.path.text2vec)) try: return QdrantBackend( - dim=TEXT_EMBEDDING_DIMENSION, + qdrant_client=qdrant_cl, + dim=dimensions, collection=self.config.qdrant.text_collection, - qdrant_host=self.config.qdrant.host, - qdrant_api_key=self.config.qdrant.api_key, ) except Exception as e: - _LOGGER.error( - f"Error in Connection to qdrant text! skipping {e}. Qdrant host: {self._config.qdrant.host}" - ) - warnings.warn( - "Error in Connection to qdrant text! skipping...", UserWarning - ) + _LOGGER.error(f"Unable to create Qdrant collection: {e}") return None - def _init_qdrant_advanced_backend(self) -> Union[QdrantClient, None]: + def _init_qdrant_hybrid( + self, qdrant_cl: QdrantClient, dense_encoder: TextEmbedding + ) -> None: """ - Create qdrant client text embedding object using credentials provided in config file + Create qdrant client with sparse and text embedding object using credentials provided in config file + :param: qdrant_cl: QdrantClient object + :param: dense_encoder: TextEmbedding model for encoding text queries :return: QdrantClient """ - COLLECTION_NAME = self.config.qdrant.search_collection - DIMENSIONS = 384 - - _LOGGER.info("Initializing qdrant text advanced engine...") + _LOGGER.info("Initializing qdrant sparse collection...") - try: - qdrant_cl = QdrantClient( - url=self.config.qdrant.host, - port=self.config.qdrant.port, - api_key=self.config.qdrant.api_key, + if not isinstance(qdrant_cl, QdrantClient): + _LOGGER.error( + f"Unable to create Qdrant hybrid collection, qdrant client is None." + ) + return None + if not isinstance(dense_encoder, TextEmbedding): + _LOGGER.error( + f"Unable to create Qdrant hybrid collection, dense encoder is None." ) + return None + + dimensions = int(dense_encoder.get_embedding_size(self._config.path.text2vec)) + collection_name = self.config.qdrant.search_collection - if not qdrant_cl.collection_exists(COLLECTION_NAME): + try: + if not qdrant_cl.collection_exists(collection_name): _LOGGER.info( "Collection 'bedbase_query_search' does not exist, creating it." ) + qdrant_cl.create_collection( - collection_name=COLLECTION_NAME, - vectors_config=models.VectorParams( - size=DIMENSIONS, distance=models.Distance.COSINE - ), + collection_name=collection_name, + vectors_config={ + "dense": models.VectorParams( + size=dimensions, distance=models.Distance.COSINE + ), + }, + sparse_vectors_config={ + "sparse": models.SparseVectorParams( + index=models.SparseIndexParams( + on_disk=False, + ) + ) + }, quantization_config=models.ScalarQuantization( scalar=models.ScalarQuantizationConfig( type=models.ScalarType.INT8, @@ -304,56 +371,67 @@ def _init_qdrant_advanced_backend(self) -> Union[QdrantClient, None]: ) qdrant_cl.create_payload_index( - collection_name=COLLECTION_NAME, + collection_name=collection_name, field_name="assay", - field_schema="keyword", + field_type=models.PayloadSchemaType.KEYWORD, ) qdrant_cl.create_payload_index( - collection_name=COLLECTION_NAME, + collection_name=collection_name, field_name="genome_alias", - field_schema="keyword", + field_type=models.PayloadSchemaType.KEYWORD, ) - return qdrant_cl - - except qdrant_client.http.exceptions.ResponseHandlingException as err: + except Exception as err: _LOGGER.error( - f"Error in Connection to qdrant! skipping... Error: {err}. Qdrant host: {self._config.qdrant.host}" + f"Error in creating Qdrant hybrid collection! skipping... Error: {err}. Qdrant host: {self._config.qdrant.host}" ) warnings.warn( - f"error in Connection to qdrant! skipping... Error: {err}", UserWarning + f"error in creating Qdrant hybrid collection! skipping... Error: {err}", + UserWarning, ) return None - def _init_bivec_object(self) -> Union[BiVectorSearchInterface, None]: + def _init_bivec_interface( + self, + qdrant_file_backend: QdrantBackend, + qdrant_text_backend: QdrantBackend, + text_encoder: TextEmbedding, + ) -> Union[BiVectorSearchInterface, None]: """ Create BiVectorSearchInterface object using credentials provided in config file + :param: qdrant_file_backend: QdrantBackend for file vectors + :param: qdrant_text_backend: QdrantBackend for text vectors + :param: text_encoder: TextEmbedding model for encoding text queries :return: BiVectorSearchInterface """ _LOGGER.info("Initializing BiVectorBackend...") search_backend = BiVectorBackend( - metadata_backend=self._qdrant_text_engine, bed_backend=self._qdrant_engine + metadata_backend=qdrant_text_backend, bed_backend=qdrant_file_backend ) _LOGGER.info("Initializing BiVectorSearchInterface...") search_interface = BiVectorSearchInterface( backend=search_backend, - query2vec=self.config.path.text2vec, + query2vec=text_encoder, ) return search_interface - def _init_b2bsi_object(self) -> Union[BED2BEDSearchInterface, None]: + def _init_b2b_search_interface( + self, + qdrant_file_backend: QdrantBackend, + region_encoder: Union[Region2VecExModel, str], + ) -> Union[BED2BEDSearchInterface, None]: """ Create Bed 2 BED search interface and return this object :return: Bed2BEDSearchInterface object """ try: - _LOGGER.info("Initializing search interfaces...") + _LOGGER.info("Initializing search bed 2 bed search interfaces...") return BED2BEDSearchInterface( - backend=self.qdrant_engine, - query2vec=BED2Vec(model=self._config.path.region2vec), + backend=qdrant_file_backend, + query2vec=BED2Vec(model=region_encoder), ) except Exception as e: _LOGGER.error("Error in creating BED2BEDSearchInterface object: " + str(e)) @@ -363,26 +441,101 @@ def _init_b2bsi_object(self) -> Union[BED2BEDSearchInterface, None]: ) return None - @staticmethod - def _init_pephubclient() -> Union[PEPHubClient, None]: + def _init_r2v_encoder(self) -> Union[Region2VecExModel, None]: """ - Create Pephub client object using credentials provided in config file + Create Region2VecExModel object using credentials provided in config file + """ + try: + _LOGGER.info( + f"Initializing region2vec encoder... Model used: {self.config.path.region2vec}" + ) + return Region2VecExModel(self.config.path.region2vec) + except Exception as e: + _LOGGER.error(f"Error in creating Region2VecExModel object: {e}") + warnings.warn( + f"Error in creating Region2VecExModel object: {e}", UserWarning + ) + return None - :return: PephubClient + def _init_dense_encoder(self) -> Union[None, TextEmbedding]: + """ + Initialize dense model from the specified path or huggingface model hub """ - # try: - # _LOGGER.info("Initializing PEPHub client...") - # return PEPHubClient() - # except Exception as e: - # _LOGGER.error(f"Error in creating PephubClient object: {e}") - # warnings.warn(f"Error in creating PephubClient object: {e}", UserWarning) - # return None - return None + _LOGGER.info( + f"Initializing dense encoder... Model used: {self.config.path.text2vec}" + ) + dense_encoder = TextEmbedding(self.config.path.text2vec) + return dense_encoder + + def _init_sparce_model(self) -> Union[None, SparseEncoder]: + """ + Initialize SparseEncoder model from the specified path or huggingface model hub + """ + try: + _LOGGER.info( + f"Initializing sparse encoder... Model used: {self.config.path.sparse_model}" + ) + sparse_encoder = SparseEncoder(self.config.path.sparse_model) + except Exception as e: + _LOGGER.error(f"Error in creating SparseEncoder object: {e}") + warnings.warn(f"Error in creating SparseEncoder object: {e}", UserWarning) + return None + return sparse_encoder + + def _init_umap_model(self) -> Union[UMAP, None]: + """ + Load UMAP model from the specified path, or url + """ + + if not self.config.path.umap_model: + _LOGGER.warning( + "UMAP model path is not specified in the configuration, and won't be used." + ) + return None + + model_path = self.config.path.umap_model + umap_model = None + if model_path.startswith(("http://", "https://")): + + try: + response = requests.get(model_path) + response.raise_for_status() + buffer = io.BytesIO(response.content) + umap_model = joblib.load(buffer) + print(f"UMAP model loaded from URL: {model_path}") + except requests.RequestException as e: + _LOGGER.error(f"Error downloading UMAP model from URL: {e}") + return None + except TypeError as e: + _LOGGER.error( + f"Error loading UMAP model from URL. Unable open pickle file. Error: {e}" + ) + return None + else: + try: + with open(model_path, "rb") as file: + umap_model = joblib.load(file) + print(f"UMAP model loaded from local path: {model_path}") + except FileNotFoundError as e: + _LOGGER.error(f"Error loading UMAP model from local path: {e}") + return None + except TypeError as e: + _LOGGER.error( + f"Error loading UMAP model from URL. Unable open pickle file. Error: {e}" + ) + return None + + if not isinstance(umap_model, UMAP): + _LOGGER.error(f"Loaded object is not a UMAP instance: {type(umap_model)}") + return None + # np.random.seed(42) + umap_model.random_state = 42 + return umap_model def _init_boto3_client( self, - ) -> boto3.client: + ) -> Union[boto3.client, None]: """ Create Pephub client object using credentials provided in config file @@ -400,20 +553,6 @@ def _init_boto3_client( warnings.warn(f"Error in creating boto3 client object: {e}", UserWarning) return None - def _init_r2v_object(self) -> Union[Region2VecExModel, None]: - """ - Create Region2VecExModel object using credentials provided in config file - """ - try: - _LOGGER.info("Initializing R2V object...") - return Region2VecExModel(self.config.path.region2vec) - except Exception as e: - _LOGGER.error(f"Error in creating Region2VecExModel object: {e}") - warnings.warn( - f"Error in creating Region2VecExModel object: {e}", UserWarning - ) - return None - def upload_s3(self, file_path: str, s3_path: Union[Path, str]) -> None: """ Upload file to s3. @@ -530,6 +669,23 @@ def delete_files_s3(self, files: List[FileModel]) -> None: self.delete_s3(file.path_thumbnail) return None + @staticmethod + def _init_pephubclient() -> Union[PEPHubClient, None]: + """ + Create Pephub client object using credentials provided in config file + + :return: PephubClient + """ + + # try: + # _LOGGER.info("Initializing PEPHub client...") + # return PEPHubClient() + # except Exception as e: + # _LOGGER.error(f"Error in creating PephubClient object: {e}") + # warnings.warn(f"Error in creating PephubClient object: {e}", UserWarning) + # return None + return None + def get_prefixed_uri(self, postfix: str, access_id: str) -> str: """ Return uri with correct prefix (schema) diff --git a/bbconf/config_parser/const.py b/bbconf/config_parser/const.py index b9541f5..61aad4e 100644 --- a/bbconf/config_parser/const.py +++ b/bbconf/config_parser/const.py @@ -5,16 +5,17 @@ DEFAULT_QDRANT_HOST = "localhost" DEFAULT_QDRANT_PORT = 6333 -DEFAULT_QDRANT_COLLECTION_NAME = "bedbase" -DEFAULT_QDRANT_TEXT_COLLECTION_NAME = "bed_text" -DEFAULT_QDRANT_SEARCH_COLLECTION_NAME = "bedbase_query_search" +DEFAULT_QDRANT_FILE_COLLECTION_NAME = "bedbase" +DEFAULT_QDRANT_BIVEC_COLLECTION_NAME = "bed_text" +DEFAULT_QDRANT_HYBRID_COLLECTION_NAME = "bedbase_query_search" DEFAULT_QDRANT_API_KEY = None DEFAULT_SERVER_PORT = 80 DEFAULT_SERVER_HOST = "0.0.0.0" DEFAULT_TEXT2VEC_MODEL = "sentence-transformers/all-MiniLM-L6-v2" -DEFAULT_REGION2_VEC_MODEL = "databio/r2v-ChIP-atlas-hg38" +DEFAULT_SPARSE_MODEL = "prithivida/Splade_PP_en_v2" +DEFAULT_REGION2_VEC_MODEL = "databio/r2v_encoder-ChIP-atlas-hg38" DEFAULT_PEPHUB_NAMESPACE = "databio" DEFAULT_PEPHUB_NAME = "bedbase_all" diff --git a/bbconf/config_parser/models.py b/bbconf/config_parser/models.py index 13726a6..b0786dd 100644 --- a/bbconf/config_parser/models.py +++ b/bbconf/config_parser/models.py @@ -13,15 +13,16 @@ DEFAULT_PEPHUB_NAME, DEFAULT_PEPHUB_NAMESPACE, DEFAULT_PEPHUB_TAG, - DEFAULT_QDRANT_COLLECTION_NAME, + DEFAULT_QDRANT_BIVEC_COLLECTION_NAME, + DEFAULT_QDRANT_FILE_COLLECTION_NAME, + DEFAULT_QDRANT_HYBRID_COLLECTION_NAME, DEFAULT_QDRANT_PORT, - DEFAULT_QDRANT_TEXT_COLLECTION_NAME, DEFAULT_REGION2_VEC_MODEL, DEFAULT_S3_BUCKET, DEFAULT_SERVER_HOST, DEFAULT_SERVER_PORT, + DEFAULT_SPARSE_MODEL, DEFAULT_TEXT2VEC_MODEL, - DEFAULT_QDRANT_SEARCH_COLLECTION_NAME, ) _LOGGER = logging.getLogger(__name__) @@ -53,9 +54,9 @@ class ConfigQdrant(BaseModel): host: str port: int = DEFAULT_QDRANT_PORT api_key: Optional[str] = None - file_collection: str = DEFAULT_QDRANT_COLLECTION_NAME - text_collection: Optional[str] = DEFAULT_QDRANT_TEXT_COLLECTION_NAME - search_collection: Optional[str] = DEFAULT_QDRANT_SEARCH_COLLECTION_NAME + file_collection: str = DEFAULT_QDRANT_FILE_COLLECTION_NAME + text_collection: Optional[str] = DEFAULT_QDRANT_BIVEC_COLLECTION_NAME + search_collection: Optional[str] = DEFAULT_QDRANT_HYBRID_COLLECTION_NAME class ConfigServer(BaseModel): @@ -67,6 +68,8 @@ class ConfigPath(BaseModel): region2vec: str = DEFAULT_REGION2_VEC_MODEL # vec2vec: str = DEFAULT_VEC2VEC_MODEL text2vec: str = DEFAULT_TEXT2VEC_MODEL + sparse_model: str = DEFAULT_SPARSE_MODEL + umap_model: Union[str, None] = None # Path or link to pre-trained UMAP model class AccessMethodsStruct(BaseModel): diff --git a/bbconf/db_utils.py b/bbconf/db_utils.py index d49902a..b2520e9 100644 --- a/bbconf/db_utils.py +++ b/bbconf/db_utils.py @@ -9,12 +9,12 @@ ForeignKey, Result, Select, + String, UniqueConstraint, event, select, - String, ) -from sqlalchemy.dialects.postgresql import JSON, ARRAY +from sqlalchemy.dialects.postgresql import ARRAY, JSON from sqlalchemy.engine import URL, Engine, create_engine from sqlalchemy.event import listens_for from sqlalchemy.exc import IntegrityError, ProgrammingError diff --git a/bbconf/models/base_models.py b/bbconf/models/base_models.py index a771185..20c1c26 100644 --- a/bbconf/models/base_models.py +++ b/bbconf/models/base_models.py @@ -1,5 +1,5 @@ -from typing import List, Optional, Union, Dict import datetime +from typing import Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field diff --git a/bbconf/modules/bedfiles.py b/bbconf/modules/bedfiles.py index 48ad904..7505f01 100644 --- a/bbconf/modules/bedfiles.py +++ b/bbconf/modules/bedfiles.py @@ -5,21 +5,19 @@ import numpy as np from geniml.bbclient import BBClient -from geniml.io import RegionSet from geniml.search.backends import QdrantBackend from gtars.models import RegionSet as GRegionSet from pephubclient.exceptions import ResponseError from pydantic import BaseModel +from qdrant_client import models from qdrant_client.http.models import PointStruct -from qdrant_client.models import Distance, PointIdsList, VectorParams -from sqlalchemy import and_, delete, func, or_, select, cast +from qdrant_client.models import PointIdsList, QueryResponse +from sqlalchemy import and_, cast, delete, func, or_, select +from sqlalchemy.dialects import postgresql from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, aliased from sqlalchemy.orm.attributes import flag_modified -from sqlalchemy.dialects import postgresql from tqdm import tqdm -from fastembed import TextEmbedding -from qdrant_client import models from bbconf.config_parser.bedbaseconfig import BedBaseConfig from bbconf.const import DEFAULT_LICENSE, PKG_NAME, ZARR_TOKENIZED_FOLDER @@ -84,14 +82,10 @@ def __init__(self, config: BedBaseConfig, bbagent_obj=None): """ self._sa_engine = config.db_engine.engine self._db_engine = config.db_engine - self._qdrant_engine = config.qdrant_engine self._boto3_client = config.boto3_client - self._config = config + self.config = config self.bb_agent = bbagent_obj - self._embedding_model = TextEmbedding(config.config.path.text2vec) - # self._embedding_model = TextEmbedding("BAAI/bge-large-en-v1.5") - def get(self, identifier: str, full: bool = False) -> BedMetadataAll: """ Get file metadata by identifier. @@ -120,7 +114,7 @@ def get(self, identifier: str, full: bool = False) -> BedMetadataAll: FileModel( **result.__dict__, object_id=f"bed.{identifier}.{result.name}", - access_methods=self._config.construct_access_method_list( + access_methods=self.config.construct_access_method_list( result.path ), ), @@ -134,7 +128,7 @@ def get(self, identifier: str, full: bool = False) -> BedMetadataAll: FileModel( **result.__dict__, object_id=f"bed.{identifier}.{result.name}", - access_methods=self._config.construct_access_method_list( + access_methods=self.config.construct_access_method_list( result.path ), ), @@ -170,10 +164,10 @@ def get(self, identifier: str, full: bool = False) -> BedMetadataAll: try: if full: bed_metadata = BedPEPHubRestrict( - **self._config.phc.sample.get( - namespace=self._config.config.phc.namespace, - name=self._config.config.phc.name, - tag=self._config.config.phc.tag, + **self.config.phc.sample.get( + namespace=self.config.config.phc.namespace, + name=self.config.config.phc.name, + tag=self.config.config.phc.tag, sample_name=identifier, ) ) @@ -249,7 +243,7 @@ def get_plots(self, identifier: str) -> BedPlots: FileModel( **result.__dict__, object_id=f"bed.{identifier}.{result.name}", - access_methods=self._config.construct_access_method_list( + access_methods=self.config.construct_access_method_list( result.path ), ), @@ -271,8 +265,8 @@ def get_neighbours( if not self.exists(identifier): raise BEDFileNotFoundError(f"Bed file with id: {identifier} not found.") s = identifier - results = self._qdrant_engine.qd_client.query_points( - collection_name=self._config.config.qdrant.file_collection, + results = self.config.qdrant_file_backend.qd_client.query_points( + collection_name=self.config.config.qdrant.file_collection, query="-".join([s[:8], s[8:12], s[12:16], s[16:20], s[20:]]), limit=limit, offset=offset, @@ -318,7 +312,7 @@ def get_files(self, identifier: str) -> BedFiles: FileModel( **result.__dict__, object_id=f"bed.{identifier}.{result.name}", - access_methods=self._config.construct_access_method_list( + access_methods=self.config.construct_access_method_list( result.path ), ), @@ -333,10 +327,10 @@ def get_raw_metadata(self, identifier: str) -> BedPEPHub: :return: project metadata """ try: - bed_metadata = self._config.phc.sample.get( - namespace=self._config.config.phc.namespace, - name=self._config.config.phc.name, - tag=self._config.config.phc.tag, + bed_metadata = self.config.phc.sample.get( + namespace=self.config.config.phc.namespace, + name=self.config.config.phc.name, + tag=self.config.config.phc.tag, sample_name=identifier, ) except Exception as e: @@ -389,8 +383,8 @@ def get_embedding(self, identifier: str) -> BedEmbeddingResult: """ if not self.exists(identifier): raise BEDFileNotFoundError(f"Bed file with id: {identifier} not found.") - result = self._qdrant_engine.qd_client.retrieve( - collection_name=self._config.config.qdrant.file_collection, + result = self.config.qdrant_file_backend.qd_client.retrieve( + collection_name=self.config.config.qdrant.file_collection, ids=[identifier], with_vectors=True, with_payload=True, @@ -628,12 +622,12 @@ def add( # Upload files to s3 if upload_s3: if files: - files = self._config.upload_files_s3( + files = self.config.upload_files_s3( identifier, files=files, base_path=local_path, type="files" ) if plots: - plots = self._config.upload_files_s3( + plots = self.config.upload_files_s3( identifier, files=plots, base_path=local_path, type="plots" ) with Session(self._sa_engine) as session: @@ -934,7 +928,7 @@ def _update_plots( _LOGGER.info("Updating bed file plots..") if plots: - plots = self._config.upload_files_s3( + plots = self.config.upload_files_s3( bed_object.id, files=plots, base_path=local_path, type="plots" ) plots_dict = plots.model_dump( @@ -982,7 +976,7 @@ def _update_files( _LOGGER.info("Updating bed files..") if files: - files = self._config.upload_files_s3( + files = self.config.upload_files_s3( bed_object.id, files=files, base_path=local_path, type="files" ) @@ -1104,16 +1098,16 @@ def delete(self, identifier: str) -> None: self.delete_pephub_sample(identifier) if delete_qdrant: self.delete_qdrant_point(identifier) - self._config.delete_files_s3(files) + self.config.delete_files_s3(files) def upload_pephub(self, identifier: str, metadata: dict, overwrite: bool = False): if not metadata: _LOGGER.warning("No metadata provided. Skipping pephub upload..") return False - self._config.phc.sample.create( - namespace=self._config.config.phc.namespace, - name=self._config.config.phc.name, - tag=self._config.config.phc.tag, + self.config.phc.sample.create( + namespace=self.config.config.phc.namespace, + name=self.config.config.phc.name, + tag=self.config.config.phc.tag, sample_name=identifier, sample_dict=metadata, overwrite=overwrite, @@ -1126,10 +1120,10 @@ def update_pephub( if not metadata: _LOGGER.warning("No metadata provided. Skipping pephub upload..") return None - self._config.phc.sample.update( - namespace=self._config.config.phc.namespace, - name=self._config.config.phc.name, - tag=self._config.config.phc.tag, + self.config.phc.sample.update( + namespace=self.config.config.phc.namespace, + name=self.config.config.phc.name, + tag=self.config.config.phc.tag, sample_name=identifier, sample_dict=metadata, ) @@ -1143,10 +1137,10 @@ def delete_pephub_sample(self, identifier: str): :param identifier: bed file identifier """ try: - self._config.phc.sample.remove( - namespace=self._config.config.phc.namespace, - name=self._config.config.phc.name, - tag=self._config.config.phc.tag, + self.config.phc.sample.remove( + namespace=self.config.config.phc.namespace, + name=self.config.config.phc.name, + tag=self.config.config.phc.tag, sample_name=identifier, ) except ResponseError as e: @@ -1155,7 +1149,7 @@ def delete_pephub_sample(self, identifier: str): def upload_file_qdrant( self, bed_id: str, - bed_file: Union[str, RegionSet], + bed_file: Union[str, GRegionSet], payload: dict = None, ) -> None: """ @@ -1171,19 +1165,19 @@ def upload_file_qdrant( _LOGGER.debug(f"Adding bed file to qdrant. bed_id: {bed_id}") - if not isinstance(self._qdrant_engine, QdrantBackend): + if not isinstance(self.config.qdrant_file_backend, QdrantBackend): raise QdrantInstanceNotInitializedError("Could not upload file.") bed_embedding = self._embed_file(bed_file) - self._qdrant_engine.load( + self.config.qdrant_file_backend.load( ids=[bed_id], vectors=bed_embedding, payloads=[{**payload}], ) return None - def _embed_file(self, bed_file: Union[str, RegionSet]) -> np.ndarray: + def _embed_file(self, bed_file: Union[str, GRegionSet]) -> np.ndarray: """ Create embedding for bed file @@ -1192,9 +1186,9 @@ def _embed_file(self, bed_file: Union[str, RegionSet]) -> np.ndarray: :return np array of embeddings """ - if self._qdrant_engine is None: + if self.config.qdrant_file_backend is None: raise QdrantInstanceNotInitializedError - if not self._config.r2v: + if not self.config.r2v_encoder: raise BedBaseConfError( "Could not add add region to qdrant. Invalid type, or path. " ) @@ -1204,17 +1198,31 @@ def _embed_file(self, bed_file: Union[str, RegionSet]) -> np.ndarray: try: bed_region_set = GRegionSet(bed_file) except RuntimeError as _: - bed_region_set = RegionSet(bed_file) - elif isinstance(bed_file, RegionSet) or isinstance(bed_file, GRegionSet): + bed_region_set = GRegionSet(bed_file) + elif isinstance(bed_file, GRegionSet) or isinstance(bed_file, GRegionSet): bed_region_set = bed_file else: raise BedBaseConfError( "Could not add add region to qdrant. Invalid type, or path. " ) - bed_embedding = np.mean(self._config.r2v.encode(bed_region_set), axis=0) + bed_embedding = np.mean(self.config.r2v_encoder.encode(bed_region_set), axis=0) vec_dim = bed_embedding.shape[0] return bed_embedding.reshape(1, vec_dim) + def _get_umap_file(self, bed_file: Union[str, GRegionSet]) -> np.ndarray: + """ + Create UMAP for bed file + + :param bed_file: bed file path or region set + """ + + if self.config._umap_encoder is None: + raise BedBaseConfError("UMAP model is not initialized.") + + bed_embedding = self._embed_file(bed_file) + bed_umap = self.config._umap_encoder.transform(bed_embedding) + return bed_umap + def text_to_bed_search( self, query: str, @@ -1235,7 +1243,9 @@ def text_to_bed_search( """ _LOGGER.info(f"Looking for: {query}") - results = self._config.bivec.query_search(query, limit=limit, offset=offset) + results = self.config.bivec_search_interface.query_search( + query, limit=limit, offset=offset + ) results_list = [] for result in results: result_id = result["id"].replace("-", "") @@ -1256,8 +1266,8 @@ def text_to_bed_search( ) if with_metadata: - count = self._config._qdrant_advanced_engine.get_collection( - collection_name=self._config.config.qdrant.file_collection + count = self.config.qdrant_client.get_collection( + collection_name=self.config.config.qdrant.file_collection ).points_count else: count = 0 @@ -1270,7 +1280,7 @@ def text_to_bed_search( def bed_to_bed_search( self, - region_set: RegionSet, + region_set: GRegionSet, limit: int = 10, offset: int = 0, ) -> BedListSearchResult: @@ -1283,7 +1293,7 @@ def bed_to_bed_search( :return: BedListSetResults """ - results = self._config.b2bsi.query_search( + results = self.config.b2b_search_interface.query_search( region_set, limit=limit, offset=offset ) results_list = [] @@ -1486,8 +1496,8 @@ def reindex_qdrant(self, batch: int = 100, purge: bool = False) -> None: pbar.set_description( "Uploading points to qdrant using batch..." ) - operation_info = self._config.qdrant_engine.qd_client.upsert( - collection_name=self._config.config.qdrant.file_collection, + operation_info = self.config.qdrant_file_backend.qd_client.upsert( + collection_name=self.config.config.qdrant.file_collection, points=points_list, ) pbar.write("Uploaded batch to qdrant.") @@ -1500,8 +1510,8 @@ def reindex_qdrant(self, batch: int = 100, purge: bool = False) -> None: pbar.update(1) _LOGGER.info("Uploading points to qdrant using batches...") - operation_info = self._config.qdrant_engine.qd_client.upsert( - collection_name=self._config.config.qdrant.file_collection, + operation_info = self.config.qdrant_file_backend.qd_client.upsert( + collection_name=self.config.config.qdrant.file_collection, points=points_list, ) assert operation_info.status == "completed" @@ -1515,22 +1525,13 @@ def delete_qdrant_point(self, identifier: str) -> None: :return: None """ - result = self._config.qdrant_engine.qd_client.delete( - collection_name=self._config.config.qdrant.file_collection, + result = self.config.qdrant_file_backend.qd_client.delete( + collection_name=self.config.config.qdrant.file_collection, points_selector=PointIdsList( points=[identifier], ), ) - return result - - def create_qdrant_collection(self) -> bool: - """ - Create qdrant collection for bed files. - """ - return self._config.qdrant_engine.qd_client.create_collection( - collection_name=self._config.config.qdrant.file_collection, - vectors_config=VectorParams(size=100, distance=Distance.DOT), - ) + return None def exists(self, identifier: str) -> bool: """ @@ -1625,14 +1626,15 @@ def add_tokenized( ) if self.exist_tokenized(bed_id, universe_id): + _LOGGER.info("Tokenized file already exists in the database.") if not overwrite: - if not overwrite: - raise TokenizeFileExistsError( - "Tokenized file already exists in the database. " - "Set overwrite to True to overwrite it." - ) - else: - self.delete_tokenized(bed_id, universe_id) + raise TokenizeFileExistsError( + "Tokenized file already exists in the database. " + "Set overwrite to True to overwrite it." + ) + else: + _LOGGER.info("Overwriting existing tokenized file in the database.") + self.delete_tokenized(bed_id, universe_id) path = self._add_zarr_s3( bed_id=bed_id, @@ -1640,7 +1642,7 @@ def add_tokenized( tokenized_vector=token_vector, overwrite=overwrite, ) - path = os.path.join(f"s3://{self._config.config.s3.bucket}", path) + path = os.path.join(f"s3://{self.config.config.s3.bucket}", path) new_token = TokenizedBed(bed_id=bed_id, universe_id=universe_id, path=path) session.add(new_token) @@ -1663,16 +1665,32 @@ def _add_zarr_s3( :return: zarr path """ - univers_group = self._config.zarr_root.require_group(universe_id) + root = self.config.zarr_root - if not univers_group.get(bed_id): + # Handle group creation (require_group is deprecated in zarr 3.x) + try: + univers_group = root[universe_id] + except KeyError: + univers_group = root.create_group(universe_id) + + # Check existence + bed_exists = bed_id in univers_group + + if not bed_exists: _LOGGER.info("Saving tokenized vector to s3") - path = univers_group.create_dataset(bed_id, data=tokenized_vector).path + array = univers_group.create_array( + name=bed_id, + data=np.array(tokenized_vector, dtype="int64"), + ) + path = array.name elif overwrite: _LOGGER.info("Overwriting tokenized vector in s3") - path = univers_group.create_dataset( - bed_id, data=tokenized_vector, overwrite=True - ).path + del univers_group[bed_id] + array = univers_group.create_array( + name=bed_id, + data=np.array(tokenized_vector, dtype="int64"), + ) + path = array.name else: raise TokenizeFileExistsError( "Tokenized file already exists in the database. " @@ -1693,12 +1711,19 @@ def get_tokenized(self, bed_id: str, universe_id: str) -> TokenizedBedResponse: if not self.exist_tokenized(bed_id, universe_id): raise TokenizeFileNotExistError("Tokenized file not found in the database.") - univers_group = self._config.zarr_root.require_group(universe_id) + + root = self.config.zarr_root + + try: + univers_group = root[universe_id] + data = list(univers_group[bed_id][:]) # Explicit slice for full read + except KeyError: + raise TokenizeFileNotExistError("Tokenized file not found in the database.") return TokenizedBedResponse( universe_id=universe_id, bed_id=bed_id, - tokenized_bed=list(univers_group[bed_id]), + tokenized_bed=data, ) def delete_tokenized(self, bed_id: str, universe_id: str) -> None: @@ -1712,9 +1737,14 @@ def delete_tokenized(self, bed_id: str, universe_id: str) -> None: """ if not self.exist_tokenized(bed_id, universe_id): raise TokenizeFileNotExistError("Tokenized file not found in the database.") - univers_group = self._config.zarr_root.require_group(universe_id) - del univers_group[bed_id] + root = self.config.zarr_root + + try: + univers_group = root[universe_id] + del univers_group[bed_id] # Delete syntax unchanged + except KeyError: + raise TokenizeFileNotExistError("Tokenized file not found in the database.") with Session(self._sa_engine) as session: statement = delete(TokenizedBed).where( @@ -1786,7 +1816,7 @@ def get_tokenized_link( file_path = self._get_tokenized_path(bed_id, universe_id) return TokenizedPathResponse( - endpoint_url=self._config.config.s3.endpoint_url, + endpoint_url=self.config.config.s3.endpoint_url, file_path=file_path, bed_id=bed_id, universe_id=universe_id, @@ -1814,7 +1844,9 @@ def get_missing_plots( t2_alias = aliased(Files) # Define the subquery - subquery = select(t2_alias).where(t2_alias.name == plot_name).subquery() + subquery = ( + select(t2_alias).where(and_(t2_alias.name == plot_name)).subquery() + ) query = ( select(Bed.id) @@ -1869,7 +1901,9 @@ def get_missing_files(self, limit: int = 1000, offset: int = 0) -> List[str]: t2_alias = aliased(Files) # Define the subquery - subquery = select(t2_alias).where(t2_alias.name == "bigbed_file").subquery() + subquery = ( + select(t2_alias).where(and_(t2_alias.name == "bigbed_file")).subquery() + ) query = ( select(Bed.id) @@ -1991,7 +2025,103 @@ def _update_sources( session.commit() - def reindex_semantic_search(self, batch: int = 1000, purge: bool = False) -> None: + # def reindex_semantic_search(self, batch: int = 1000, purge: bool = False) -> None: + # """ + # Reindex all bed files for semantic database + # + # :param batch: number of files to upload in one batch + # :param purge: resets indexed in database for all files to False + # + # :return: None + # """ + # + # # Add column that will indicate if this file is indexed or not + # statement = ( + # select(Bed) + # .join(BedMetadata, Bed.id == BedMetadata.id) + # .where(Bed.indexed == False) + # .limit(150000) + # ) + # + # with Session(self._sa_engine) as session: + # + # if purge: + # _LOGGER.info("Purging indexed files in the database ...") + # session.query(Bed).update({Bed.indexed: False}) + # session.commit() + # _LOGGER.info("Purged indexed files in the database successfully!") + # + # _LOGGER.info("Fetching data from the database ...") + # results = session.scalars(statement) + # + # _LOGGER.info("Fetch data successfully!") + # + # points = [] + # results = [result for result in results] + # + # with tqdm(total=len(results), position=0, leave=True) as pbar: + # processed_number = 0 + # for result in results: + # text = ( + # f"biosample is {result.annotations.cell_line} / {result.annotations.cell_type} / " + # f"{result.annotations.tissue} with target {result.annotations.target} " + # f"assay {result.annotations.assay}." + # f"File name {result.name} with summary {result.description}" + # ) + # + # embeddings_list = list(self.config.dense_encoder.embed(text)) + # # result_list.append( + # data = VectorMetadata( + # id=result.id, + # name=result.name, + # description=result.description, + # genome_alias=result.genome_alias, + # genome_digest=result.genome_digest, + # cell_line=result.annotations.cell_line, + # cell_type=result.annotations.cell_type, + # tissue=result.annotations.tissue, + # target=result.annotations.target, + # treatment=result.annotations.treatment, + # assay=result.annotations.assay, + # species_name=result.annotations.species_name, + # ) + # + # points.append( + # PointStruct( + # id=result.id, + # vector=list(embeddings_list[0]), + # payload=data.model_dump(), + # ) + # ) + # processed_number += 1 + # result.indexed = True + # + # if processed_number % batch == 0: + # pbar.set_description( + # "Uploading points to qdrant using batch..." + # ) + # operation_info = self.config._qdrant_advanced_engine.upsert( + # collection_name=self.config.config.qdrant.search_collection, + # points=points, + # ) + # session.commit() + # pbar.write("Uploaded batch to qdrant.") + # points = [] + # assert operation_info.status == "completed" + # + # pbar.write(f"File: {result.id} successfully indexed.") + # pbar.update(1) + # + # operation_info = self.config._qdrant_advanced_engine.upsert( + # collection_name=self.config.config.qdrant.search_collection, + # points=points, + # ) + # assert operation_info.status == "completed" + # session.commit() + # + # return None + + def reindex_hybrid_search(self, batch: int = 1000, purge: bool = False) -> None: """ Reindex all bed files for semantic database @@ -2006,7 +2136,7 @@ def reindex_semantic_search(self, batch: int = 1000, purge: bool = False) -> Non select(Bed) .join(BedMetadata, Bed.id == BedMetadata.id) .where(Bed.indexed == False) - .limit(150000) + .limit(150) ) with Session(self._sa_engine) as session: @@ -2035,7 +2165,26 @@ def reindex_semantic_search(self, batch: int = 1000, purge: bool = False) -> Non f"File name {result.name} with summary {result.description}" ) - embeddings_list = list(self._embedding_model.embed(text)) + embeddings_list = list(self.config.dense_encoder.embed(text)) + + if self.config.sparce_encoder: + sparse_result = self.config.sparce_encoder.encode( + text + ).coalesce() + + sparse_embeddings = models.SparseVector( + indices=sparse_result.indices().tolist()[0], + values=sparse_result.values().tolist(), + ) + + point_vectors = { + "dense": list(embeddings_list[0]), + "sparse": sparse_embeddings, + } + else: + point_vectors = { + "dense": list(embeddings_list[0]), + } # result_list.append( data = VectorMetadata( id=result.id, @@ -2055,7 +2204,7 @@ def reindex_semantic_search(self, batch: int = 1000, purge: bool = False) -> Non points.append( PointStruct( id=result.id, - vector=list(embeddings_list[0]), + vector=point_vectors, payload=data.model_dump(), ) ) @@ -2066,8 +2215,8 @@ def reindex_semantic_search(self, batch: int = 1000, purge: bool = False) -> Non pbar.set_description( "Uploading points to qdrant using batch..." ) - operation_info = self._config._qdrant_advanced_engine.upsert( - collection_name=self._config.config.qdrant.search_collection, + operation_info = self.config.qdrant_client.upsert( + collection_name=self.config.config.qdrant.search_collection, points=points, ) session.commit() @@ -2078,8 +2227,8 @@ def reindex_semantic_search(self, batch: int = 1000, purge: bool = False) -> Non pbar.write(f"File: {result.id} successfully indexed.") pbar.update(1) - operation_info = self._config._qdrant_advanced_engine.upsert( - collection_name=self._config.config.qdrant.search_collection, + operation_info = self.config.qdrant_client.upsert( + collection_name=self.config.config.qdrant.search_collection, points=points, ) assert operation_info.status == "completed" @@ -2127,11 +2276,11 @@ def semantic_search( ) ) - embeddings_list = list(self._embedding_model.embed(query))[0] + embeddings_list = list(self.config.dense_encoder.embed(query))[0] - results = self._config._qdrant_advanced_engine.search( - collection_name=self._config.config.qdrant.search_collection, - query_vector=list(embeddings_list), + results: QueryResponse = self.config.qdrant_client.query_points( + collection_name=self.config.config.qdrant.search_collection, + query=list(embeddings_list), limit=limit, offset=offset, search_params=models.SearchParams( @@ -2146,7 +2295,113 @@ def semantic_search( ) result_list = [] - for result in results: + for result in results.points: + result_id = result.id.replace("-", "") + + if with_metadata: + metadata = self.get(result_id, full=False) + else: + metadata = None + + result_list.append( + QdrantSearchResult( + id=result_id, + payload=result.payload, + score=result.score, + metadata=metadata, + ) + ) + + if with_metadata: + count = self.bb_agent.get_stats().bedfiles_number + else: + count = 0 + + return BedListSearchResult( + count=count, + limit=limit, + offset=offset, + results=result_list, + ) + + def hybrid_search( + self, + query: str = "liver", + genome_alias: str = "", + assay: str = "", + limit: int = 100, + offset: int = 0, + with_metadata: bool = True, + ) -> BedListSearchResult: + """ + Run semantic search for bed files using qdrant. + This is not bivec search, but usual qdrant search with sparse and dense embeddings. + + :param query: text query to search for + :param genome_alias: genome alias to filter results + :param assay: filter by assay type + :param limit: number of results to return + :param offset: offset to start from + :param with_metadata: if True, metadata will be returned in the results. Default is True. + + :return: list of bed file metadata + """ + + must_statement = [] + + if genome_alias: + must_statement.append( + models.FieldCondition( + key="genome_alias", + match=models.MatchValue(value=genome_alias), + ) + ) + if assay: + must_statement.append( + models.FieldCondition( + key="assay", + match=models.MatchValue(value=assay), + ) + ) + + dense_query = list(list(self.config.dense_encoder.embed(query))[0]) + if self.config.sparce_encoder: + sparse_result = self.config.sparce_encoder.encode(query).coalesce() + sparse_embeddings = models.SparseVector( + indices=sparse_result.indices().tolist()[0], + values=sparse_result.values().tolist(), + ) + + hybrid_query = [ + # Dense retrieval: semantic understanding + models.Prefetch(query=dense_query, using="dense", limit=limit), + # Sparse retrieval: exact technical term matching + models.Prefetch(query=sparse_embeddings, using="sparse", limit=limit), + ] + else: + hybrid_query = [ + # Dense retrieval: semantic understanding + models.Prefetch(query=dense_query, using="dense", limit=limit), + ] + + results = self.config.qdrant_client.query_points( + collection_name=self.config.config.qdrant.search_collection, + limit=limit, + offset=offset, + prefetch=hybrid_query, + query=models.FusionQuery(fusion=models.Fusion.RRF), + with_payload=True, + with_vectors=True, + search_params=models.SearchParams( + exact=True, + ), + query_filter=( + models.Filter(must=must_statement) if must_statement else None + ), + ) + + result_list = [] + for result in results.points: result_id = result.id.replace("-", "") if with_metadata: diff --git a/docs/changelog.md b/docs/changelog.md index 44bbe79..b028266 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html) and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) format. +### [0.13.0] - 2025-11-24 +### Added: +- Conversion of bedfile to umap from predefined model + ### [0.12.0] - 2025-09-11 ### Added: - New qdrant semantic search diff --git a/manual_testing.py b/manual_testing.py index 2fc2205..ee2ecdb 100644 --- a/manual_testing.py +++ b/manual_testing.py @@ -1,14 +1,15 @@ import os +import time +import matplotlib.pyplot as plt +import numpy as np import s3fs import zarr +from zarr.storage import FsspecStore # from dotenv import load_dotenv from geniml.io import RegionSet from gtars.utils import read_tokens_from_gtok -import matplotlib.pyplot as plt -import numpy as np -import time # from gtars.tokenizers import RegionSet @@ -38,23 +39,33 @@ def zarr_local(): tokenized_name = "0dcdf8986a72a3d85805bbc9493a13026l" overwrite = True - root = zarr.group( - store="/home/bnt4me/virginia/repos/bbconf/zarr_test", overwrite=False + # Use zarr.open_group instead of zarr.group + root = zarr.open_group( + store="/home/bnt4me/virginia/repos/bbconf/zarr_test", mode="a" ) - univers_group = root.require_group("7126993b14054a32de2da4a0b9173be5") - if not univers_group.get(tokenized_name): + # Handle group creation (require_group is deprecated) + universe_id = "7126993b14054a32de2da4a0b9173be5" + try: + univers_group = root[universe_id] + except KeyError: + univers_group = root.create_group(universe_id) + + # Check existence and handle overwrite + if tokenized_name not in univers_group: print("not overwriting") - ua = univers_group.create_dataset(tokenized_name, data=tok_regions) + ua = univers_group.create_array( + name=tokenized_name, data=np.array(tok_regions, dtype="int64") + ) elif overwrite: print("overwriting") - ua = univers_group.create_dataset( - tokenized_name, data=tok_regions, overwrite=True + del univers_group[tokenized_name] + ua = univers_group.create_array( + name=tokenized_name, data=np.array(tok_regions, dtype="int64") ) else: raise ValueError("fff") ua = univers_group - univers_group._delitem_nosync() def zarr_s3(): @@ -87,16 +98,34 @@ def zarr_s3(): endpoint_url=os.getenv("AWS_ENDPOINT_URL"), key=os.getenv("AWS_ACCESS_KEY_ID"), secret=os.getenv("AWS_SECRET_ACCESS_KEY"), + asynchronous=False, + skip_instance_cache=True, ) print(os.getenv("AWS_SECRET_ACCESS_KEY")) - s3_path = "s3://bedbase/new/" + s3_path = "bedbase/new/" # Remove s3:// prefix for FsspecStore - zarr_store = s3fs.S3Map(root=s3_path, s3=s3fc_obj, check=False, create=True) - cache = zarr.LRUStoreCache(zarr_store, max_size=2**28) + # Use FsspecStore instead of S3Map + LRUStoreCache + store = FsspecStore(fs=s3fc_obj, path=s3_path) - root = zarr.group(store=cache, overwrite=False) - univers_group = root.require_group("7126993b14054a32de2da4a0b9173be5") - univers_group.create_dataset(tokenized_name, data=tok_regions, overwrite=True) + # Use zarr.open_group instead of zarr.group + root = zarr.open_group(store=store, mode="a") + + # Handle group creation (require_group is deprecated) + universe_id = "7126993b14054a32de2da4a0b9173be5" + try: + univers_group = root[universe_id] + except KeyError: + univers_group = root.create_group(universe_id) + + # Handle overwrite + if tokenized_name in univers_group: + if overwrite: + del univers_group[tokenized_name] + + # Use create_array instead of create_dataset + univers_group.create_array( + name=tokenized_name, data=np.array(tok_regions, dtype="int64") + ) f = univers_group[tokenized_name] @@ -109,16 +138,24 @@ def get_from_s3(): # endpoint_url="https://s3.us-west-002.backblazeb2.com/", # key=os.getenv("AWS_ACCESS_KEY_ID"), # secret=os.getenv("AWS_SECRET_ACCESS_KEY"), + asynchronous=False, + skip_instance_cache=True, ) import s3fs - s3fc_obj = s3fs.S3FileSystem(endpoint_url="https://s3.us-west-002.backblazeb2.com/") - s3_path = "s3://bedbase/tokenized.zarr/" - zarr_store = s3fs.S3Map(root=s3_path, s3=s3fc_obj, check=False, create=True) - cache = zarr.LRUStoreCache(zarr_store, max_size=2**28) + s3fc_obj = s3fs.S3FileSystem( + endpoint_url="https://s3.us-west-002.backblazeb2.com/", + asynchronous=False, + skip_instance_cache=True, + ) + s3_path = "bedbase/tokenized.zarr/" # Remove s3:// prefix for FsspecStore + + # Use FsspecStore instead of S3Map + LRUStoreCache + store = FsspecStore(fs=s3fc_obj, path=s3_path) - root = zarr.group(store=cache, overwrite=False) + # Use zarr.open_group instead of zarr.group + root = zarr.open_group(store=store, mode="r") # print(str(root.tree)) @@ -329,8 +366,9 @@ def new_search(): agent = BedBaseAgent(config="/home/bnt4me/virginia/repos/bedhost/config.yaml") time1 = time.time() - results = agent.bed.reindex_semantic_search() + # results = agent.bed.reindex_semantic_search() # results = agent.bed.comp_search() + results = agent.bed.hybrid_search("leukemia") time2 = time.time() print(f"Time taken: {time2 - time1} seconds") @@ -370,10 +408,10 @@ def reindex_files(): # neighbour_beds() # sql_search() # config_t() - compreh_stats() + # compreh_stats() # get_unprocessed_files() # get_genomes() - # new_search() + new_search() # external_search() # get_assay_list() diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index ce2fca3..ee55ed0 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -1,7 +1,7 @@ yacman >= 0.9.1 sqlalchemy >= 2.0.0 gtars >= 0.4.0 -geniml[ml] >= 0.7.1 +geniml[ml] >= 0.8.3 psycopg >= 3.1.15 coloredlogs pydantic >= 2.9.0 @@ -9,8 +9,9 @@ botocore >= 1.34.0, < 1.36.0 boto3 >= 1.34.54, < 1.36.0 pephubclient >= 0.4.5 sqlalchemy_schemadisplay -zarr < 3.0.0 +zarr >= 3.0.0 pyyaml >= 6.0.1 # for s3fs because of the errors s3fs >= 2024.3.1 pandas >= 2.0.0 pybiocfilecache == 0.6.1 +umap-learn >= 0.5.8 diff --git a/tests/config_test.yaml b/tests/config_test.yaml index 54c3ba1..ef8069e 100644 --- a/tests/config_test.yaml +++ b/tests/config_test.yaml @@ -1,5 +1,5 @@ path: - region2vec: databio/r2v-encode-hg38 + region2vec: databio/r2v_encoder-encode-hg38 # vec2vec: databio/v2v-geo-hg38 database: host: localhost diff --git a/tests/test_common.py b/tests/test_common.py index d7921e9..9e9b029 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,12 +1,13 @@ +import datetime + import pytest from bbconf.const import DEFAULT_LICENSE -from bbconf.models.base_models import UsageModel from bbconf.exceptions import BedBaseConfError -import datetime +from bbconf.models.base_models import UsageModel from .conftest import SERVICE_UNAVAILABLE -from .utils import ContextManagerDBTesting, BED_TEST_ID, BEDSET_TEST_ID +from .utils import BED_TEST_ID, BEDSET_TEST_ID, ContextManagerDBTesting @pytest.mark.skipif(SERVICE_UNAVAILABLE, reason="Database is not available")