diff --git a/notebooks/query_engine_usage.ipynb b/notebooks/query_engine_usage.ipynb index e64ff4e..7cda810 100644 --- a/notebooks/query_engine_usage.ipynb +++ b/notebooks/query_engine_usage.ipynb @@ -112,7 +112,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "First, we want to select all contacts where any of the fragments constituting the contact overlaps the target region. To perform this action, we use the Overlap class and pass the target region as well as an instance of the `Anchor` class. The `Anchor` dataclass allows us to specify how we want to filter contacts for region overlap. It has two attributes `mode` and `anchors`. `Anchors` indicates the positions we want to filter on (default is all positions) and `mode` specifies whether we require all positions to overlap or any position to overlap. So for example, if we want all of our two-way contacts for which any of the positions overlap, we would use `Anchor(mode='ANY', anchors=[1,2])`." + "First, we want to select all contacts where any of the fragments constituting the contact overlaps the target region. To perform this action, we use the Overlap class and pass the target region as well as an instance of the `Anchor` class. The `Anchor` dataclass allows us to specify how we want to filter contacts for region overlap. It has three attributes `fragment_mode`, `region_mode` and `positions`. `Positions` indicates the fragment positions we want to filter on (default is all positions) and `fragment_mode` specifies whether we require all positions to overlap or any position to overlap. `Region_mode` specifies whether we require an overlap with all supplied regions, which is only relevant for lists of regions. It default to \"ALL\". So for example, if we want all of our two-way contacts for which any of the positions overlap, we would use `Anchor(mode='ANY', anchors=[1,2])`." ] }, { @@ -122,7 +122,7 @@ "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_region, anchor_mode=Anchor(mode=\"ANY\", anchors=[1,2]))\n", + " Overlap(target_region, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1,2]))\n", "]" ] }, @@ -157,7 +157,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -174,7 +174,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The `.load_result` method of the `QueryResult` object can be executed using `.load_result`, which returns a `pd.DataFrame`. The resulting dataframe has additional columns that represent the regions, with which the input contacts overlapped." + "The `.compute` method of the `QueryPlan` object can be executed using `.compute()`, which returns a `pd.DataFrame`. The resulting dataframe has additional columns that represent the regions, with which the input contacts overlapped." ] }, { @@ -367,7 +367,7 @@ ], "source": [ "query_steps = [\n", - " Overlap(target_region, anchor_mode=Anchor(mode=\"ANY\", anchors=[1]))\n", + " Overlap(target_region, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1]))\n", "]\n", "Query(query_steps=query_steps)\\\n", " .build(contacts)\\\n", @@ -393,8 +393,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Selecting a subset of contacts at multiple genomic regions\n", - "The Overlap class is also capable of selecting contacts at multiple genomic regions. Here, the behavior of `Overlap` deviates from a simple filter, because if a given contact overlaps with multiple regions, it will be returned multiple times." + "## Selecting a subset of contacts at a set of genomic regions\n", + "The Overlap class is also capable of selecting contacts at a set of genomic regions. Here, the default behavior of `Overlap` deviates from a simple filter, because if a given contact overlaps with multiple regions in the set, it will be returned multiple times." ] }, { @@ -503,7 +503,7 @@ ], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ANY\", anchors=[1]))\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1]))\n", "]\n", "Query(query_steps=query_steps)\\\n", " .build(contacts)\\\n", @@ -518,6 +518,89 @@ "In this example, the contact overlapping both regions is duplicated." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we, however, only want to filter contacts that overlap a given set of regions without duplicates being returned, we can pass the `add_overlap_columns` argument to the `Overlap` constructor to as `False`. This will only return the respective contacts that have been deduplicated:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrom_1start_1end_1chrom_2start_2end_2
0chr1100200chr110002000
\n", + "
" + ], + "text/plain": [ + " chrom_1 start_1 end_1 chrom_2 start_2 end_2\n", + "0 chr1 100 200 chr1 1000 2000" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query_steps = [\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1]),\n", + " add_overlap_columns=False)\n", + "]\n", + "Query(query_steps=query_steps)\\\n", + " .build(contacts)\\\n", + " .compute()\\\n", + " .filter(regex=r\"chrom|start|end|id\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, only a a ingle contact is returned, without the corresponding region columns." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -525,6 +608,139 @@ "The same functionality is implemented also for the pixels class." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Selecting a subset of contacts at multiple set of genomic regions\n", + "The `Overlap` class capable of selecting conatcts and pixels base don overlap with multiple sets of genomic regions. Here, the `Anchor` class specifies how the differnet overlap possibilities should be handled: The `fragment_mode` parameter specifies whether we require all (value `ALL`) fragments to overlap, or whether we require any (value `ANY`) fragments to overlap. The `region_mode` parameter then specifies whether we require the fragments to overlap any (`ANY`) of the passed sets of genomic regions or all (`ALL`). The `positions` parameter specifies - as in the case with a single set of genomic regions - which fragments we apply this logic to." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "target_regions = pd.DataFrame({\n", + " \"chrom\": ['chr1'],\n", + " \"start\": [110],\n", + " \"end\": [140],\n", + "})\n", + "\n", + "target_regions_2 = pd.DataFrame({\n", + " \"chrom\": ['chr1'],\n", + " \"start\": [1000],\n", + " \"end\": [1030],\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we require that any of the fragments overlap all of the passed regions" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrom_1start_1end_1chrom_2start_2end_2region_chromregion_startregion_endregion_idregion_chrom_1region_start_1region_end_1region_id_1
0chr1100200chr110002000chr11101400chr1100010300
\n", + "
" + ], + "text/plain": [ + " chrom_1 start_1 end_1 chrom_2 start_2 end_2 region_chrom region_start \\\n", + "0 chr1 100 200 chr1 1000 2000 chr1 110 \n", + "\n", + " region_end region_id region_chrom_1 region_start_1 region_end_1 \\\n", + "0 140 0 chr1 1000 1030 \n", + "\n", + " region_id_1 \n", + "0 0 " + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query_steps = [\n", + " Overlap([target_regions, target_regions_2],\n", + " anchor_mode=Anchor(fragment_mode=\"ANY\",region_mode='ALL')\n", + " )\n", + "]\n", + "Query(query_steps=query_steps)\\\n", + " .build(contacts)\\\n", + " .compute()\\\n", + " .filter(regex=r\"chrom|start|end|id\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This way we can filter for contacts between specific target regions, e.g. loop bases or promoters and enhancers." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -535,7 +751,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -555,7 +771,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -599,7 +815,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -615,7 +831,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -637,12 +853,12 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ANY\")),\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ANY\")),\n", " DistanceTransformation(\n", " distance_mode=DistanceMode.LEFT,\n", " ),\n", @@ -658,7 +874,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -824,7 +1040,7 @@ "[250 rows x 7 columns]" ] }, - "execution_count": 15, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -846,7 +1062,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -867,12 +1083,12 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ALL\")),\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ALL\")),\n", " DistanceTransformation(),\n", " DistanceAggregation(\n", " value_column='count',\n", @@ -883,7 +1099,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1013,7 +1229,7 @@ "[125 rows x 4 columns]" ] }, - "execution_count": 18, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -1033,12 +1249,12 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ALL\")),\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ALL\")),\n", " DistanceTransformation(),\n", " DistanceAggregation(\n", " value_column='count',\n", @@ -1050,7 +1266,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1263,7 +1479,7 @@ "24 100000.0 100000.0 4.3" ] }, - "execution_count": 20, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } diff --git a/spoc/contacts.py b/spoc/contacts.py index 0f08941..23dca99 100644 --- a/spoc/contacts.py +++ b/spoc/contacts.py @@ -49,6 +49,7 @@ def __init__( label_sorted: bool = False, binary_labels_equal: bool = False, symmetry_flipped: bool = False, + label_values: Optional[List[str]] = None, ) -> None: self.contains_metadata = ( "metadata_1" in contact_frame.columns @@ -61,7 +62,6 @@ def __init__( number_fragments=self.number_fragments, contains_metadata=self.contains_metadata, ) - # TODO: make this work for duckdb pyrelation -> switch to mode if isinstance(contact_frame, pd.DataFrame): self.data_mode = DataMode.PANDAS elif isinstance(contact_frame, dd.DataFrame): @@ -75,6 +75,7 @@ def __init__( self.label_sorted = label_sorted self.binary_labels_equal = binary_labels_equal self.symmetry_flipped = symmetry_flipped + self.label_values = label_values @staticmethod def from_uri(uri, mode=DataMode.PANDAS): @@ -116,15 +117,19 @@ def get_label_values(self) -> List[str]: # TODO: This could be put in global metadata of parquet file if not self.contains_metadata: raise ValueError("Contacts do not contain metadata!") - output = set() - for i in range(self.number_fragments): - if self.data_mode == DataMode.DASK: - output.update(self.data[f"metadata_{i+1}"].unique().compute()) - elif self.data_mode == DataMode.PANDAS: - output.update(self.data[f"metadata_{i+1}"].unique()) - else: - raise ValueError("Label values not supported for duckdb!") - return list(output) + if self.label_values is None: + output = set() + for i in range(self.number_fragments): + if self.data_mode == DataMode.DASK: + output.update(self.data[f"metadata_{i+1}"].unique().compute()) + elif self.data_mode == DataMode.PANDAS: + output.update(self.data[f"metadata_{i+1}"].unique()) + else: + raise ValueError("Label values not supported for duckdb!") + # add metadata values and return + self.label_values = list(output) + return list(output) + return self.label_values def get_chromosome_values(self) -> List[str]: """Returns all chromosome values""" @@ -345,16 +350,16 @@ def sort_labels(self, contacts: Contacts) -> Contacts: ) # determine which method to use for concatenation if contacts.data_mode == DataMode.DASK: - # this is a bit of a hack to get the index sorted. Dask does not support index sorting - result = ( - dd.concat(subsets).reset_index().sort_values("index").set_index("index") - ) + result = dd.concat(subsets) elif contacts.data_mode == DataMode.PANDAS: - result = pd.concat(subsets).sort_index() + result = pd.concat(subsets) else: raise ValueError("Sorting labels for duckdb relations is not implemented.") return Contacts( - result, number_fragments=contacts.number_fragments, label_sorted=True + result, + number_fragments=contacts.number_fragments, + label_sorted=True, + label_values=label_values, ) def _sort_chromosomes(self, df: DataFrame, number_fragments: int) -> DataFrame: @@ -452,12 +457,9 @@ def equate_binary_labels(self, contacts: Contacts) -> Contacts: subsets.append(subset) # determine which method to use for concatenation if contacts.data_mode == DataMode.DASK: - # this is a bit of a hack to get the index sorted. Dask does not support index sorting - result = ( - dd.concat(subsets).reset_index().sort_values("index").set_index("index") - ) + result = dd.concat(subsets) elif contacts.data_mode == DataMode.PANDAS: - result = pd.concat(subsets).sort_index() + result = pd.concat(subsets) else: raise ValueError( "Equate binary labels for duckdb relations is not implemented." @@ -467,6 +469,7 @@ def equate_binary_labels(self, contacts: Contacts) -> Contacts: number_fragments=contacts.number_fragments, label_sorted=True, binary_labels_equal=True, + label_values=label_values, ) def subset_on_metadata( @@ -506,6 +509,7 @@ def subset_on_metadata( label_sorted=contacts.label_sorted, binary_labels_equal=contacts.binary_labels_equal, symmetry_flipped=contacts.symmetry_flipped, + label_values=label_values, ) def flip_symmetric_contacts( @@ -534,6 +538,7 @@ def flip_symmetric_contacts( label_sorted=True, binary_labels_equal=contacts.binary_labels_equal, symmetry_flipped=True, + label_values=label_values, ) result = self._flip_unlabelled_contacts(contacts.data) if sort_chromosomes: diff --git a/spoc/io.py b/spoc/io.py index 0e9cb77..095aeeb 100644 --- a/spoc/io.py +++ b/spoc/io.py @@ -34,7 +34,7 @@ class FileManager: data_mode (DataMode, optional): Data mode. Defaults to DataMode.PANDAS. """ - def __init__(self, data_mode: DataMode = DataMode.PANDAS) -> None: + def __init__(self, data_mode: DataMode = DataMode.PANDAS, **kwargs) -> None: if data_mode == DataMode.DUCKDB: self._parquet_reader_func = partial( duckdb.read_parquet, connection=DUCKDB_CONNECTION @@ -45,6 +45,17 @@ def __init__(self, data_mode: DataMode = DataMode.PANDAS) -> None: self._parquet_reader_func = pd.read_parquet else: raise ValueError(f"Data mode {data_mode} not supported!") + # store data mode + self._data_mode = data_mode + # set duckdb parameters if they are there + if "duckdb_max_memory" in kwargs and data_mode == DataMode.DUCKDB: + DUCKDB_CONNECTION.execute( + f"PRAGMA memory_limit = '{kwargs['duckdb_max_memory']}'" + ) + if "duckdb_max_threads" in kwargs and data_mode == DataMode.DUCKDB: + DUCKDB_CONNECTION.execute( + f"PRAGMA threads = {kwargs['duckdb_max_threads']}" + ) @staticmethod def write_label_library(path: str, data: Dict[str, bool]) -> None: @@ -254,9 +265,19 @@ def load_pixels( ) # rewrite path to contain parent folder full_pixel_path = Path(path) / pixel_path - df = self._parquet_reader_func(full_pixel_path) + if self._data_mode == DataMode.DUCKDB: + df = self._parquet_reader_func(self._get_duckdb_path(full_pixel_path)) + else: + df = self._parquet_reader_func(full_pixel_path) return Pixels(df, **matched_parameters.dict()) + def _get_duckdb_path(self, path: Path) -> str: + """Constructs duckdb path string to handle cases + where parquet files are stored in a directory with multiple files.""" + if Path(path).is_dir(): + return f"{path}/*.parquet" + return str(path) + def load_contacts( self, path: str, global_parameters: Optional[ContactsParameters] = None ) -> Contacts: @@ -287,7 +308,10 @@ def load_contacts( ) # rewrite path to contain parent folder full_contacts_path = Path(path) / contacts_path - df = self._parquet_reader_func(full_contacts_path) + if self._data_mode == DataMode.DUCKDB: + df = self._parquet_reader_func(self._get_duckdb_path(full_contacts_path)) + else: + df = self._parquet_reader_func(str(full_contacts_path)) return Contacts(df, **matched_parameters.dict()) @staticmethod @@ -305,7 +329,7 @@ def _get_object_hash_path(path: str, data_object: Union[Pixels, Contacts]) -> st ) return md5(hash_string.encode(encoding="utf-8")).hexdigest() + ".parquet" - def write_pixels(self, path: str, pixels: Pixels) -> None: + def write_pixels(self, path: str, pixels: Pixels, **kwargs) -> None: """Write pixels Args: @@ -331,13 +355,13 @@ def write_pixels(self, path: str, pixels: Pixels) -> None: raise ValueError( "Writing pixels only suppported for pixels hodling dataframes!" ) - pixels.data.to_parquet(write_path, row_group_size=1024 * 1024) + pixels.data.to_parquet(str(write_path), **kwargs) # write metadata current_metadata[write_path.name] = pixels.get_global_parameters().dict() with open(metadata_path, "w", encoding="UTF-8") as f: json.dump(current_metadata, f) - def write_contacts(self, path: str, contacts: Contacts) -> None: + def write_contacts(self, path: str, contacts: Contacts, **kwargs) -> None: """Write contacts""" # check whether path exists metadata_path = Path(path) / "metadata.json" @@ -354,7 +378,7 @@ def write_contacts(self, path: str, contacts: Contacts) -> None: raise ValueError( "Writing contacts only suppported for contacts hodling dataframes!" ) - contacts.data.to_parquet(write_path, row_group_size=1024 * 1024) + contacts.data.to_parquet(str(write_path), **kwargs) # write metadata current_metadata[write_path.name] = contacts.get_global_parameters().dict() with open(metadata_path, "w") as f: diff --git a/spoc/models/dataframe_models.py b/spoc/models/dataframe_models.py index af42ff4..a3bb4f9 100644 --- a/spoc/models/dataframe_models.py +++ b/spoc/models/dataframe_models.py @@ -68,7 +68,7 @@ def get_schema(self) -> pa.DataFrameSchema: def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" @@ -89,7 +89,7 @@ def __init__( position_fields: Dict[int, List[str]], contact_order: int, binsize: Optional[int] = None, - region_number: Optional[int] = None, + region_number: Optional[Union[int, List[int]]] = None, half_window_size: Optional[int] = None, ) -> None: self._columns = columns @@ -127,7 +127,7 @@ def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" return self._binsize - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" return self._region_number @@ -215,8 +215,8 @@ def validate_header(self, data_frame: DataFrame) -> None: Args: data_frame (DataFrame): The DataFrame to validate. """ - for column in data_frame.columns: - if column not in self._schema.columns: + for column in self._schema.columns: + if column not in data_frame.columns: raise pa.errors.SchemaError( self._schema, data_frame, "Header is invalid!" ) @@ -258,7 +258,7 @@ def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" return None - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" return None @@ -350,17 +350,16 @@ def get_position_fields(self) -> Dict[int, List[str]]: return { i: ["chrom", f"start_{i}"] for i in range(1, self._number_fragments + 1) } - else: - return { - i: [f"chrom_{i}", f"start_{i}"] - for i in range(1, self._number_fragments + 1) - } + return { + i: [f"chrom_{i}", f"start_{i}"] + for i in range(1, self._number_fragments + 1) + } def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" return self._binsize - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" return None diff --git a/spoc/query_engine.py b/spoc/query_engine.py index 6069897..03f2215 100644 --- a/spoc/query_engine.py +++ b/spoc/query_engine.py @@ -56,20 +56,21 @@ def __call__(self, *args: Any, **kwds: Any) -> "QueryPlan": """Apply the query step to the data""" -# TODO: think about allowing anchor composition class Anchor(BaseModel): """Represents an anchor. Attributes: - mode (str): The mode of the anchor. (Can be "ANY" or "ALL") - anchors (Optional[List[int]]): The list of anchor values (optional). + fragment_mode (str): The mode for fragment overlap. (Can be "ANY" or "ALL") + region_mode (str): The mode for region overlap. (Can be "ANY" or "ALL"). Defaults to "ALL". + positions (Optional[List[int]]): The list of anchor values (optional). """ - mode: str - anchors: Optional[List[int]] = None + fragment_mode: str + region_mode: str = "ALL" + positions: Optional[List[int]] = None def __repr__(self) -> str: - return f"Anchor(mode={self.mode}, anchors={self.anchors})" + return f"Anchor(fragment_mode={self.fragment_mode}, positions={self.positions}, region_mode={self.region_mode})" def __str__(self) -> str: return self.__repr__() @@ -85,25 +86,60 @@ class Overlap: def __init__( self, - regions: pd.DataFrame, + regions: Union[pd.DataFrame, List[pd.DataFrame]], anchor_mode: Union[Anchor, Tuple[str, List[int]]], half_window_size: Optional[int] = None, + add_overlap_columns: bool = True, ) -> None: """ Initialize the Overlap object. Args: - regions (pd.DataFrame): A DataFrame containing the regions data. + regions (Union[pd.DataFrame, List[pd.DataFrame]]): A DataFrame containing the regions data, + or a list of DataFrames containing the regions data. anchor_mode (Union[Anchor,Tuple[str,List[int]]]): The anchor mode to be used. - half_window_size (Optional[int]): The window size the regions should be expanded to. Defaults to None and is inferred from the data. + half_window_size (Optional[int]): The window size the regions should be expanded to. + Defaults to None and is inferred from the data. Returns: None """ - # add ids to regions if they don't exist + self._add_overlap_columns = add_overlap_columns + # preprocess regions + if isinstance(regions, list): + self._regions, half_window_sizes = zip( + *[ + self._prepare_regions(region, half_window_size, index=index) + for index, region in enumerate(regions) + ] + ) + if not all( + half_window_size == half_window_sizes[0] + for half_window_size in half_window_sizes + ): + raise ValueError("All regions need to have the same window size.") + self._half_window_size = half_window_sizes[0] + else: + self._regions, self._half_window_size = self._prepare_regions( + regions, half_window_size + ) + if isinstance(anchor_mode, tuple): + self._anchor = Anchor( + fragment_mode=anchor_mode[0], positions=anchor_mode[1] + ) + else: + self._anchor = anchor_mode + + def _prepare_regions( + self, regions: pd.DataFrame, half_window_size: Optional[int], index: int = 0 + ) -> Tuple[pd.DataFrame, int]: + """Preprocessing of regions including adding an id column.""" if "id" not in regions.columns: regions["id"] = range(len(regions)) if half_window_size is not None: + # check halfwindowsize is positive + if half_window_size < 0: + raise ValueError("Half window size must be positive.") expanded_regions = regions.copy() # create midpoint expanded_regions["midpoint"] = ( @@ -113,29 +149,33 @@ def __init__( expanded_regions["start"] = expanded_regions["midpoint"] - half_window_size expanded_regions["end"] = expanded_regions["midpoint"] + half_window_size # drop midpoint - expanded_regions = expanded_regions.drop(columns=["midpoint"]) - self._regions = RegionSchema.validate( - expanded_regions.add_prefix("region_") - ) - self._half_window_size = half_window_size - else: - self._regions = RegionSchema.validate(regions.add_prefix("region_")) - # infer window size -> variable regions will have largest possible window size - self._half_window_size = int( - (self._regions["region_end"] - self._regions["region_start"]).max() // 2 - ) - if isinstance(anchor_mode, tuple): - self._anchor_mode = Anchor(mode=anchor_mode[0], anchors=anchor_mode[1]) - else: - self._anchor_mode = anchor_mode + preprocssed_regions = expanded_regions.drop( + columns=["midpoint"] + ).add_prefix("region_") + preprocssed_regions = RegionSchema.validate(preprocssed_regions) + if index > 0: + preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}") + return preprocssed_regions, half_window_size + preprocssed_regions = RegionSchema.validate(regions.add_prefix("region_")) + # infer window size -> variable regions will have largest possible window size + calculated_half_window_size = int( + ( + preprocssed_regions["region_end"] - preprocssed_regions["region_start"] + ).max() + // 2 + ) + # add index + if index > 0: + preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}") + return preprocssed_regions, calculated_half_window_size def validate(self, data_schema: GenomicDataSchema) -> None: """Validate the filter against the data schema""" # check whether an anchor is specified that is not in the data - if self._anchor_mode.anchors is not None: + if self._anchor.positions is not None: if not all( anchor in data_schema.get_position_fields().keys() - for anchor in self._anchor_mode.anchors + for anchor in self._anchor.positions ): raise ValueError( "An anchor is specified that is not in the data schema." @@ -158,7 +198,57 @@ def _convert_to_duckdb( data = data.compute() return duckdb.from_df(data, connection=DUCKDB_CONNECTION) - def _contstruct_filter(self, position_fields: Dict[int, List[str]]) -> str: + def _construct_query_multi_region( + self, + regions: List[duckdb.DuckDBPyRelation], + genomic_df: duckdb.DuckDBPyRelation, + position_fields: Dict[int, List[str]], + ) -> duckdb.DuckDBPyRelation: + """Constructs the query for multiple regions.""" + snipped_df = genomic_df.set_alias("data") + for index, region in enumerate(regions): + snipped_df = snipped_df.join( + region.set_alias(f"regions_{index}"), + self._contstruct_filter( + position_fields, f"regions_{index}", index=index + ), + how="left", + ) + # filter regions based on region mode + if self._anchor.region_mode == "ALL": + return snipped_df.filter( + " and ".join( + [ + f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null" + for index in range(0, len(regions)) + ] + ) + ) + + return snipped_df.filter( + " or ".join( + [ + f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null" + for index in range(0, len(regions)) + ] + ) + ) + + def _constrcut_query_single_region( + self, + regions: duckdb.DuckDBPyRelation, + genomic_df: duckdb.DuckDBPyRelation, + position_fields: Dict[int, List[str]], + ) -> duckdb.DuckDBPyRelation: + """Constructs the query for a single region.""" + return genomic_df.set_alias("data").join( + regions.set_alias("regions"), + self._contstruct_filter(position_fields, "regions"), + ) + + def _contstruct_filter( + self, position_fields: Dict[int, List[str]], region_name: str, index: int = 0 + ) -> str: """Constructs the filter string. Args: @@ -171,21 +261,25 @@ def _contstruct_filter(self, position_fields: Dict[int, List[str]]) -> str: NotImplementedError: If the length of fields is not equal to 3. """ query_strings = [] - join_string = " or " if self._anchor_mode.mode == "ANY" else " and " + join_string = " or " if self._anchor.fragment_mode == "ANY" else " and " + if index > 0: + column_index = f"_{index}" + else: + column_index = "" # subset on anchor regions - if self._anchor_mode.anchors is not None: + if self._anchor.positions is not None: subset_positions = [ - position_fields[anchor] for anchor in self._anchor_mode.anchors + position_fields[anchor] for anchor in self._anchor.positions ] else: subset_positions = list(position_fields.values()) for fields in subset_positions: chrom, start, end = fields - output_string = f"""(data.{chrom} = regions.region_chrom and + output_string = f"""(data.{chrom} = {region_name}.region_chrom{column_index} and ( - data.{start} between regions.region_start and regions.region_end or - data.{end} between regions.region_start and regions.region_end or - regions.region_start between data.{start} and data.{end} + data.{start} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or + data.{end} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or + {region_name}.region_start{column_index} between data.{start} and data.{end} ) )""" query_strings.append(output_string) @@ -199,12 +293,17 @@ def _get_transformed_schema( ) -> GenomicDataSchema: """Returns the schema of the transformed data.""" # construct schema + # get region number + if isinstance(self._regions, (list, tuple)): + region_number = [region.shape[0] for region in self._regions] + else: + region_number = self._regions.shape[0] return QueryStepDataSchema( columns=data_frame.columns, position_fields=position_fields, contact_order=input_schema.get_contact_order(), binsize=input_schema.get_binsize(), - region_number=len(self._regions), + region_number=region_number, half_window_size=self._half_window_size, ) @@ -229,7 +328,7 @@ def _add_end_position( ) def __repr__(self) -> str: - return f"Overlap(anchor_mode={self._anchor_mode})" + return f"Overlap(anchor_mode={self._anchor})" def __call__(self, genomic_data: GenomicData) -> GenomicData: """Apply the filter to the data""" @@ -240,7 +339,11 @@ def __call__(self, genomic_data: GenomicData) -> GenomicData: genomic_df = genomic_data.data else: genomic_df = self._convert_to_duckdb(genomic_data.data) - regions = self._convert_to_duckdb(self._regions) + # bring regions to duckdb dataframe + if isinstance(self._regions, (list, tuple)): + regions = [self._convert_to_duckdb(region) for region in self._regions] + else: + regions = self._convert_to_duckdb(self._regions) # get position columns and construct filter position_fields = input_schema.get_position_fields() # add end position if not present @@ -253,9 +356,17 @@ def __call__(self, genomic_data: GenomicData) -> GenomicData: for position in position_fields.keys() } # construct query - snipped_df = genomic_df.set_alias("data").join( - regions.set_alias("regions"), self._contstruct_filter(position_fields) - ) + if isinstance(regions, (list, tuple)): + snipped_df = self._construct_query_multi_region( + regions, genomic_df, position_fields + ) + else: + snipped_df = self._constrcut_query_single_region( + regions, genomic_df, position_fields + ) + # remove overlap columns and drop ducpliates if requested + if not self._add_overlap_columns: + snipped_df = snipped_df.project("data.*").distinct() return QueryPlan( snipped_df, self._get_transformed_schema(snipped_df, input_schema, position_fields), @@ -511,6 +622,12 @@ def validate(self, data_schema: GenomicDataSchema) -> None: raise ValueError( "Binsize specified in data schema, but distance mode is not set to LEFT." ) + # check wheter there has only been a single region overlapped + region_number = data_schema.get_region_number() + if isinstance(region_number, list): + raise ValueError( + "Distance transformation requires only a single set of regions overlapped." + ) def _create_transform_columns( self, genomic_df: duckdb.DuckDBPyRelation, input_schema: GenomicDataSchema diff --git a/tests/contacts_tests/test_symmetry.py b/tests/contacts_tests/test_symmetry.py index 0eea19f..2f39109 100644 --- a/tests/contacts_tests/test_symmetry.py +++ b/tests/contacts_tests/test_symmetry.py @@ -60,7 +60,7 @@ def test_labelled_contacts_are_sorted_correctly(unsorted, sorted_contacts, reque ), request.getfixturevalue(sorted_contacts) contacts = Contacts(unsorted) result = ContactManipulator().sort_labels(contacts) - pd.testing.assert_frame_equal(result.data, sorted_contacts) + pd.testing.assert_frame_equal(result.data.sort_index(), sorted_contacts) assert result.label_sorted @@ -82,7 +82,7 @@ def test_labelled_contacts_are_sorted_correctly_dask( contacts = Contacts(unsorted) result = ContactManipulator().sort_labels(contacts) pd.testing.assert_frame_equal( - result.data.compute().reset_index(drop=True), + result.data.compute().sort_index().reset_index(drop=True), sorted_contacts.reset_index(drop=True), ) assert result.label_sorted @@ -103,7 +103,7 @@ def test_equate_binary_labels(unequated, equated, request): ) contacts = Contacts(unequated, label_sorted=True) result = ContactManipulator().equate_binary_labels(contacts) - pd.testing.assert_frame_equal(result.data, equated) + pd.testing.assert_frame_equal(result.data.sort_index(), equated) @pytest.mark.parametrize( @@ -123,7 +123,7 @@ def test_equate_binary_labels_dask(unequated, equated, request): contacts = Contacts(unequated, label_sorted=True) result = ContactManipulator().equate_binary_labels(contacts) pd.testing.assert_frame_equal( - result.data.compute().reset_index(drop=True), equated.reset_index(drop=True) + result.data.compute().sort_index().reset_index(drop=True), equated.reset_index(drop=True) ) diff --git a/tests/io_tests/test_io_contacts.py b/tests/io_tests/test_io_contacts.py index 567f1ee..b51b634 100644 --- a/tests/io_tests/test_io_contacts.py +++ b/tests/io_tests/test_io_contacts.py @@ -7,6 +7,7 @@ from pathlib import Path import dask.dataframe as dd +import duckdb import pytest from spoc.contacts import Contacts @@ -74,6 +75,57 @@ def example_contacts_w_metadata(unlabelled_contacts_2d, labelled_binary_contacts shutil.rmtree("tmp") +@pytest.fixture +def example_contacts_w_metadata_dask_parquet_structure( + unlabelled_contacts_2d, labelled_binary_contacts_2d +): + # setup + _create_tmp_dir() + # create contacts directory + contacts_dir = "tmp/contacts_test.parquet" + os.mkdir(contacts_dir) + expected_parameters = [ + ContactsParameters(number_fragments=2), + ContactsParameters(number_fragments=2, metadata_combi=["A", "B"]), + ContactsParameters( + number_fragments=2, metadata_combi=["A", "B"], label_sorted=True + ), + ContactsParameters( + number_fragments=2, + metadata_combi=["A", "B"], + label_sorted=True, + symmetry_flipped=True, + ), + ] + paths = [ + Path("tmp/contacts_test.parquet/test1.parquet"), + Path("tmp/contacts_test.parquet/test2.parquet"), + Path("tmp/contacts_test.parquet/test3.parquet"), + Path("tmp/contacts_test.parquet/test4.parquet"), + ] + dataframes = [ + dd.from_pandas(unlabelled_contacts_2d, npartitions=2), + dd.from_pandas(labelled_binary_contacts_2d, npartitions=2), + dd.from_pandas(labelled_binary_contacts_2d, npartitions=2), + dd.from_pandas(labelled_binary_contacts_2d, npartitions=2), + ] + # create pixels files + for path, df in zip(paths, dataframes): + df.to_parquet(path) + # create metadata json file + metadata = { + "test1.parquet": expected_parameters[0].dict(), + "test2.parquet": expected_parameters[1].dict(), + "test3.parquet": expected_parameters[2].dict(), + "test4.parquet": expected_parameters[3].dict(), + } + with open(contacts_dir + "/metadata.json", "w") as f: + json.dump(metadata, f) + yield contacts_dir, expected_parameters, paths, dataframes + # teardown + shutil.rmtree("tmp") + + def test_read_contacts_metadata_json(example_contacts_w_metadata): """Test reading pixels metadata json file""" contacts_dir, expected_parameters, _, _ = example_contacts_w_metadata @@ -115,6 +167,33 @@ def test_read_contacts_as_dask_df(example_contacts_w_metadata): assert contacts.data.compute().equals(df) +def test_read_contacts_as_duckdb_connection(example_contacts_w_metadata): + """Test reading contacts as pandas dataframe""" + contacts_dir, expected_parameters, paths, dataframes = example_contacts_w_metadata + # read metadata + for path, expected, df in zip(paths, expected_parameters, dataframes): + contacts = FileManager(DataMode.DUCKDB).load_contacts(contacts_dir, expected) + assert contacts.get_global_parameters() == expected + assert contacts.data.df().equals(df) + + +def test_read_contacts_as_duckdb_connection_from_dask_parquet( + example_contacts_w_metadata_dask_parquet_structure, +): + """Test reading contacts as pandas dataframe""" + ( + contacts_dir, + expected_parameters, + paths, + dataframes, + ) = example_contacts_w_metadata_dask_parquet_structure + # read metadata + for path, expected, df in zip(paths, expected_parameters, dataframes): + contacts = FileManager(DataMode.DUCKDB).load_contacts(contacts_dir, expected) + assert contacts.get_global_parameters() == expected + assert contacts.data.df()[df.columns].equals(df.compute()) + + @pytest.mark.parametrize( "df, params", [ @@ -186,6 +265,42 @@ def test_write_dask_contacts_to_new_file(df, params, request): assert contacts.data.compute().equals(contacts_read.data) +@pytest.mark.parametrize( + "df, params", + [ + ("unlabelled_contacts_2d", ContactsParameters(number_fragments=2)), + ( + "labelled_binary_contacts_2d", + ContactsParameters( + number_fragments=2, metadata_combi=["A", "B"], symmetry_flipped=True + ), + ), + ( + "labelled_binary_contacts_2d", + ContactsParameters( + number_fragments=2, metadata_combi=["A", "B"], label_sorted=True + ), + ), + ], +) +def test_write_duckdb_contacts_to_new_file(df, params, request): + df = request.getfixturevalue(df) + duckdb_df = duckdb.from_df(df) + contacts = Contacts(duckdb_df, **params.dict()) + with tempfile.TemporaryDirectory() as tmpdirname: + file_name = tmpdirname + "/" + "test.parquet" + FileManager(data_mode=DataMode.DUCKDB).write_contacts(file_name, contacts) + # check metadata + metadata = FileManager().list_contacts(file_name) + assert len(metadata) == 1 + assert metadata[0] == contacts.get_global_parameters() + # read contacts + contacts_read = FileManager().load_contacts(file_name, metadata[0]) + # check whether parameters are equal + assert contacts.get_global_parameters() == contacts_read.get_global_parameters() + assert contacts.data.df().equals(contacts_read.data) + + @pytest.mark.parametrize( "df1,df2,params", [ diff --git a/tests/query_engine/conftest.py b/tests/query_engine/conftest.py index e75fdf7..ee2bd46 100644 --- a/tests/query_engine/conftest.py +++ b/tests/query_engine/conftest.py @@ -122,7 +122,7 @@ def pixels_wihtout_regions_fixture(pixel_dataframe): @pytest.fixture(name="contacts_with_single_region") def contacts_with_single_region_fixture(contacts_without_regions, single_region): """Contacts with single region""" - return Overlap(single_region, anchor_mode=Anchor(mode="ANY"))( + return Overlap(single_region, anchor_mode=Anchor(fragment_mode="ANY"))( contacts_without_regions, ) @@ -130,15 +130,27 @@ def contacts_with_single_region_fixture(contacts_without_regions, single_region) @pytest.fixture(name="contacts_with_multiple_regions") def contacts_with_multiple_regions_fixture(contacts_without_regions, multi_region): """Contacts with multiple regions""" - return Overlap(multi_region, anchor_mode=Anchor(mode="ANY"))( + return Overlap(multi_region, anchor_mode=Anchor(fragment_mode="ANY"))( contacts_without_regions, ) +@pytest.fixture(name="contacts_with_multiple_regions_overlapped") +def contacts_with_multiple_regions_overlapped_fixture( + contacts_without_regions, single_region, single_region_2 +): + """Pixels with multiple regions overlapped""" + return Overlap( + [single_region, single_region_2], + anchor_mode=Anchor(fragment_mode="ANY"), + half_window_size=100, + )(contacts_without_regions) + + @pytest.fixture(name="pixels_with_single_region") def pixels_with_single_region_fixture(pixels_without_regions, single_region): """Pixels with single region""" - return Overlap(single_region, anchor_mode=Anchor(mode="ANY"))( + return Overlap(single_region, anchor_mode=Anchor(fragment_mode="ANY"))( pixels_without_regions, ) @@ -146,4 +158,6 @@ def pixels_with_single_region_fixture(pixels_without_regions, single_region): @pytest.fixture(name="pixels_with_multiple_regions") def pixels_with_multiple_regions_fixture(pixels_without_regions, multi_region): """Pixels with multiple regions""" - return Overlap(multi_region, anchor_mode=Anchor(mode="ANY"))(pixels_without_regions) + return Overlap(multi_region, anchor_mode=Anchor(fragment_mode="ANY"))( + pixels_without_regions + ) diff --git a/tests/query_engine/test_contact_selection.py b/tests/query_engine/test_contact_selection.py index daae42d..87f36d8 100644 --- a/tests/query_engine/test_contact_selection.py +++ b/tests/query_engine/test_contact_selection.py @@ -60,7 +60,9 @@ def test_any_anchor_region_returns_correct_contacts( """Test that any anchor region returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -83,7 +85,9 @@ def test_all_anchor_regions_returns_correct_contacts( """Test that all anchor regions returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -116,7 +120,8 @@ def test_specific_anchor_regions_returns_correct_contacts( contacts = request.getfixturevalue(contact_fixture) query_plan = [ Overlap( - regions=single_region_2, anchor_mode=Anchor(mode="ALL", anchors=anchors) + regions=single_region_2, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), ) ] # execution @@ -153,7 +158,7 @@ def test_specific_anchor_regions_returns_correct_contacts_point_region( query_plan = [ Overlap( regions=single_region_3, - anchor_mode=Anchor(mode="ALL", anchors=anchors), + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), half_window_size=50, ) ] @@ -180,7 +185,9 @@ def test_any_anchor_region_returns_correct_contacts_multi_region( """Test that any anchor region returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -205,7 +212,9 @@ def test_all_anchor_regions_returns_correct_contacts_multi_region( """Test that all anchor regions returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -231,7 +240,9 @@ def test_contacts_duplicated_for_multiple_overlapping_regions( """ # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=multi_region_2, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region_2, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -265,7 +276,10 @@ def test_specific_anchor_regions_returns_correct_contacts_multi_region( # setup contacts = request.getfixturevalue(contact_fixture) query_plan = [ - Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL", anchors=anchors)) + Overlap( + regions=multi_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), + ) ] # execution query = Query(query_steps=query_plan) @@ -293,7 +307,10 @@ def test_specific_anchor_region_not_in_contacts_raises_error( # setup contacts = request.getfixturevalue(contact_fixture) query_plan = [ - Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL", anchors=[3])) + Overlap( + regions=single_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=[3]), + ) ] with pytest.raises(ValueError): query = Query(query_steps=query_plan) diff --git a/tests/query_engine/test_distance_aggregation.py b/tests/query_engine/test_distance_aggregation.py index c77ae86..4772f36 100644 --- a/tests/query_engine/test_distance_aggregation.py +++ b/tests/query_engine/test_distance_aggregation.py @@ -232,7 +232,7 @@ def test_aggregations_on_dense_input( region = request.getfixturevalue(region_fixture) mapped_pixels = Query( query_steps=[ - Overlap(region, anchor_mode=Anchor(mode="ANY")), + Overlap(region, anchor_mode=Anchor(fragment_mode="ANY")), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], ).build(pixels) @@ -295,7 +295,7 @@ def test_aggregations_on_dense_input_with_reduced_dimensionality( region = request.getfixturevalue(region_fixture) mapped_pixels = Query( query_steps=[ - Overlap(region, anchor_mode=Anchor(mode="ANY")), + Overlap(region, anchor_mode=Anchor(fragment_mode="ANY")), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], ).build(pixels) @@ -367,7 +367,8 @@ def test_aggregations_on_sparse_input( query_plan = Query( query_steps=[ Overlap( - request.getfixturevalue(region_fixture), anchor_mode=Anchor(mode="ANY") + request.getfixturevalue(region_fixture), + anchor_mode=Anchor(fragment_mode="ANY"), ), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], @@ -456,7 +457,8 @@ def test_aggregations_on_sparse_input_with_reduced_dimensionality( query_plan = Query( query_steps=[ Overlap( - request.getfixturevalue(region_fixture), anchor_mode=Anchor(mode="ANY") + request.getfixturevalue(region_fixture), + anchor_mode=Anchor(fragment_mode="ANY"), ), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], diff --git a/tests/query_engine/test_distance_transformation.py b/tests/query_engine/test_distance_transformation.py index 2278006..5532009 100644 --- a/tests/query_engine/test_distance_transformation.py +++ b/tests/query_engine/test_distance_transformation.py @@ -14,6 +14,7 @@ [ "contacts_without_regions", "pixels_without_regions", + "contacts_with_multiple_regions_overlapped", ], ) def test_incompatible_input_rejected(genomic_data_fixture, request): diff --git a/tests/query_engine/test_multi_contact_selection.py b/tests/query_engine/test_multi_contact_selection.py new file mode 100644 index 0000000..3ce7c4f --- /dev/null +++ b/tests/query_engine/test_multi_contact_selection.py @@ -0,0 +1,155 @@ +"""These set of tests test selection of contacts overlapping +with a list of mulitple regions""" +import dask.dataframe as dd +import duckdb +import pytest + +from spoc.contacts import Contacts +from spoc.io import DUCKDB_CONNECTION +from spoc.query_engine import Anchor +from spoc.query_engine import Overlap +from spoc.query_engine import Query + + +@pytest.fixture(name="example_2d_contacts_pandas") +def example_2d_contacts_pandas_fixture(example_2d_df): + """Example 2d contacts""" + return Contacts(example_2d_df) + + +@pytest.fixture(name="example_2d_contacts_dask") +def example_2d_contacts_dask_fixture(example_2d_df): + """Example 2d contacts""" + return Contacts(dd.from_pandas(example_2d_df, npartitions=2)) + + +@pytest.fixture(name="example_2d_contacts_duckdb") +def example_2d_contacts_duckdb_fixture(example_2d_df): + """Example 2d contacts""" + return Contacts(duckdb.from_df(example_2d_df, connection=DUCKDB_CONNECTION)) + + +def test_different_half_window_size_throws_error(single_region, single_region_2): + """Test that different half window size throws error""" + with pytest.raises(ValueError): + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode="ANY"), + ) + + +def test_negative_half_window_size_throws_error(single_region, single_region_2): + """Test that negative half window size throws error""" + with pytest.raises(ValueError): + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode="ANY"), + half_window_size=-1, + ) + + +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_contacts", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 4), + ("ALL", "ANY", 1), + ], +) +def test_overlap_without_position_subset( + fragment_mode, + region_mode, + number_contacts, + single_region, + single_region_2, + example_2d_contacts_pandas, +): + """Test that overlap without position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode=fragment_mode, region_mode=region_mode), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(example_2d_contacts_pandas) + + # assert + assert result.compute().shape[0] == number_contacts + + +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_contacts", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 3), + ("ALL", "ANY", 3), + ], +) +def test_overlap_with_position_subset( + fragment_mode, + region_mode, + number_contacts, + single_region, + single_region_2, + example_2d_contacts_pandas, +): + """Test that overlap with position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor( + fragment_mode=fragment_mode, region_mode=region_mode, positions=[1] + ), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(example_2d_contacts_pandas) + + # assert + assert result.compute().shape[0] == number_contacts + + +@pytest.mark.parametrize( + "add_overlap_columns,number_contacts", + [ + (True, 5), + (False, 3), + ], +) +def test_duplicates_after_overlap_handled_correctly( + add_overlap_columns, + number_contacts, + multi_region_2, + single_region, + example_2d_contacts_pandas, +): + """Test that duplicates after overlap are handled correctly""" + # setup + query_plan = [ + Overlap( + regions=[multi_region_2, single_region], + anchor_mode=Anchor(fragment_mode="ANY", region_mode="ANY"), + half_window_size=100, + add_overlap_columns=add_overlap_columns, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(example_2d_contacts_pandas) + + # assert + assert result.compute().shape[0] == number_contacts + if not add_overlap_columns: + assert len(result.compute().filter(regex="region").columns) == 0 diff --git a/tests/query_engine/test_multi_pixel_selection.py b/tests/query_engine/test_multi_pixel_selection.py new file mode 100644 index 0000000..ae2621e --- /dev/null +++ b/tests/query_engine/test_multi_pixel_selection.py @@ -0,0 +1,121 @@ +"""These set of tests test selection of pixels overlapping +with a list of mulitple regions""" +import pytest + +from spoc.pixels import Pixels +from spoc.query_engine import Anchor +from spoc.query_engine import Overlap +from spoc.query_engine import Query + + +@pytest.fixture(name="pixels_pandas") +def pixels_pandas_fixture(pixel_dataframe): + """A pandas dataframe containing pixels""" + return Pixels(pixel_dataframe, number_fragments=2, binsize=10) + + +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_pixels", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 4), + ("ALL", "ANY", 1), + ], +) +def test_overlap_without_position_subset( + fragment_mode, + region_mode, + number_pixels, + single_region, + single_region_2, + pixels_pandas, +): + """Test that overlap without position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode=fragment_mode, region_mode=region_mode), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(pixels_pandas) + + # assert + assert result.compute().shape[0] == number_pixels + + +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_pixels", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 3), + ("ALL", "ANY", 3), + ], +) +def test_overlap_with_position_subset( + fragment_mode, + region_mode, + number_pixels, + single_region, + single_region_2, + pixels_pandas, +): + """Test that overlap with position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor( + fragment_mode=fragment_mode, region_mode=region_mode, positions=[1] + ), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(pixels_pandas) + + # assert + assert result.compute().shape[0] == number_pixels + + +@pytest.mark.parametrize( + "add_overlap_columns,number_pixels", + [ + (True, 5), + (False, 3), + ], +) +def test_duplicates_after_overlap_handled_correctly( + add_overlap_columns, + number_pixels, + multi_region_2, + single_region, + pixels_pandas, +): + """Test that duplicates after overlap are handled correctly""" + # setup + query_plan = [ + Overlap( + regions=[multi_region_2, single_region], + anchor_mode=Anchor(fragment_mode="ANY", region_mode="ANY"), + half_window_size=100, + add_overlap_columns=add_overlap_columns, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(pixels_pandas) + + # assert + assert result.compute().shape[0] == number_pixels + if not add_overlap_columns: + assert len(result.compute().filter(regex="region").columns) == 0 diff --git a/tests/query_engine/test_pixel_selection.py b/tests/query_engine/test_pixel_selection.py index c0c84e4..33be834 100644 --- a/tests/query_engine/test_pixel_selection.py +++ b/tests/query_engine/test_pixel_selection.py @@ -69,7 +69,9 @@ def test_any_anchor_region_returns_correct_pixels( """Test that any anchor region returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -95,7 +97,9 @@ def test_all_anchor_regions_returns_correct_pixels( """Test that all anchor regions returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -128,7 +132,8 @@ def test_specific_anchor_regions_returns_correct_pixels( pixels = request.getfixturevalue(pixel_fixture) query_plan = [ Overlap( - regions=single_region_2, anchor_mode=Anchor(mode="ALL", anchors=anchors) + regions=single_region_2, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), ) ] # execution @@ -153,7 +158,9 @@ def test_any_anchor_region_returns_correct_pixels_multi_region( """Test that any anchor region returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -176,7 +183,9 @@ def test_all_anchor_regions_returns_correct_pixels_multi_region( """Test that all anchor regions returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -202,7 +211,9 @@ def test_pixels_duplicated_for_multiple_overlapping_regions( """ # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=multi_region_2, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region_2, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -234,7 +245,10 @@ def test_specific_anchor_regions_returns_correct_pixels_multi_region( # setup pixels = request.getfixturevalue(pixels_fixture) query_plan = [ - Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL", anchors=anchors)) + Overlap( + regions=multi_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), + ) ] # execution query = Query(query_steps=query_plan) @@ -262,7 +276,10 @@ def test_specific_anchor_region_not_in_pixels_raises_error( # setup pixels = request.getfixturevalue(pixels_fixture) query_plan = [ - Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL", anchors=[3])) + Overlap( + regions=single_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=[3]), + ) ] with pytest.raises(ValueError): query = Query(query_steps=query_plan)