From 58503ec8b5c7e7111c325d5d771b51e09cc76f08 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Fri, 15 Mar 2024 18:21:07 +0100 Subject: [PATCH 01/17] changed interface of Anchor to support multi queries --- notebooks/query_engine_usage.ipynb | 66 +++++++++---------- spoc/query_engine.py | 37 +++++++---- tests/query_engine/conftest.py | 10 +-- tests/query_engine/test_contact_selection.py | 35 +++++++--- .../query_engine/test_distance_aggregation.py | 10 +-- tests/query_engine/test_pixel_selection.py | 33 +++++++--- 6 files changed, 121 insertions(+), 70 deletions(-) diff --git a/notebooks/query_engine_usage.ipynb b/notebooks/query_engine_usage.ipynb index e64ff4e..d01b9f1 100644 --- a/notebooks/query_engine_usage.ipynb +++ b/notebooks/query_engine_usage.ipynb @@ -112,17 +112,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "First, we want to select all contacts where any of the fragments constituting the contact overlaps the target region. To perform this action, we use the Overlap class and pass the target region as well as an instance of the `Anchor` class. The `Anchor` dataclass allows us to specify how we want to filter contacts for region overlap. It has two attributes `mode` and `anchors`. `Anchors` indicates the positions we want to filter on (default is all positions) and `mode` specifies whether we require all positions to overlap or any position to overlap. So for example, if we want all of our two-way contacts for which any of the positions overlap, we would use `Anchor(mode='ANY', anchors=[1,2])`." + "First, we want to select all contacts where any of the fragments constituting the contact overlaps the target region. To perform this action, we use the Overlap class and pass the target region as well as an instance of the `Anchor` class. The `Anchor` dataclass allows us to specify how we want to filter contacts for region overlap. It has three attributes `fragment_mode`, `region_mode` and `positions`. `Positions` indicates the fragment positions we want to filter on (default is all positions) and `fragment_mode` specifies whether we require all positions to overlap or any position to overlap. `Region_mode` specifies whether we require an overlap with all supplied regions, which is only relevant for lists of regions. It default to \"ALL\". So for example, if we want all of our two-way contacts for which any of the positions overlap, we would use `Anchor(mode='ANY', anchors=[1,2])`." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_region, anchor_mode=Anchor(mode=\"ANY\", anchors=[1,2]))\n", + " Overlap(target_region, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1,2]))\n", "]" ] }, @@ -135,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -151,16 +151,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -278,7 +278,7 @@ "2 400 0 " ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -360,14 +360,14 @@ "0 400 0 " ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query_steps = [\n", - " Overlap(target_region, anchor_mode=Anchor(mode=\"ANY\", anchors=[1]))\n", + " Overlap(target_region, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1]))\n", "]\n", "Query(query_steps=query_steps)\\\n", " .build(contacts)\\\n", @@ -406,7 +406,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -419,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -496,14 +496,14 @@ "1 200 1 " ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ANY\", anchors=[1]))\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1]))\n", "]\n", "Query(query_steps=query_steps)\\\n", " .build(contacts)\\\n", @@ -535,7 +535,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -555,7 +555,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -599,7 +599,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -615,7 +615,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -637,12 +637,12 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ANY\")),\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ANY\")),\n", " DistanceTransformation(\n", " distance_mode=DistanceMode.LEFT,\n", " ),\n", @@ -658,7 +658,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -824,7 +824,7 @@ "[250 rows x 7 columns]" ] }, - "execution_count": 15, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -846,7 +846,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -867,12 +867,12 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ALL\")),\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ALL\")),\n", " DistanceTransformation(),\n", " DistanceAggregation(\n", " value_column='count',\n", @@ -883,7 +883,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1013,7 +1013,7 @@ "[125 rows x 4 columns]" ] }, - "execution_count": 18, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -1033,12 +1033,12 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "query_steps = [\n", - " Overlap(target_regions, anchor_mode=Anchor(mode=\"ALL\")),\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ALL\")),\n", " DistanceTransformation(),\n", " DistanceAggregation(\n", " value_column='count',\n", @@ -1050,7 +1050,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1263,7 +1263,7 @@ "24 100000.0 100000.0 4.3" ] }, - "execution_count": 20, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } diff --git a/spoc/query_engine.py b/spoc/query_engine.py index 6069897..248dc75 100644 --- a/spoc/query_engine.py +++ b/spoc/query_engine.py @@ -56,25 +56,36 @@ def __call__(self, *args: Any, **kwds: Any) -> "QueryPlan": """Apply the query step to the data""" -# TODO: think about allowing anchor composition class Anchor(BaseModel): """Represents an anchor. Attributes: - mode (str): The mode of the anchor. (Can be "ANY" or "ALL") - anchors (Optional[List[int]]): The list of anchor values (optional). + fragment_mode (str): The mode for fragment overlap. (Can be "ANY" or "ALL") + region_mode (str): The mode for region overlap. (Can be "ANY" or "ALL"). Defaults to "ALL". + positions (Optional[List[int]]): The list of anchor values (optional). """ - mode: str - anchors: Optional[List[int]] = None + fragment_mode: str + region_mode: str = "ALL" + positions: Optional[List[int]] = None def __repr__(self) -> str: - return f"Anchor(mode={self.mode}, anchors={self.anchors})" + return f"Anchor(fragment_mode={self.fragment_mode}, positions={self.positions})" def __str__(self) -> str: return self.__repr__() +class MultiOverlap: + """ + This class represents an overlap calculation with multiple + genomic regions used for contact and pixel selection. + It provides methods to validate the filter against a data schema, + convert data to a duckdb relation, construct a filter string, + and apply the filter to the data. + """ + + class Overlap: """ This class represents an overlap calculation used for contact and pixel selection. @@ -125,17 +136,19 @@ def __init__( (self._regions["region_end"] - self._regions["region_start"]).max() // 2 ) if isinstance(anchor_mode, tuple): - self._anchor_mode = Anchor(mode=anchor_mode[0], anchors=anchor_mode[1]) + self._anchor_mode = Anchor( + fragment_mode=anchor_mode[0], positions=anchor_mode[1] + ) else: self._anchor_mode = anchor_mode def validate(self, data_schema: GenomicDataSchema) -> None: """Validate the filter against the data schema""" # check whether an anchor is specified that is not in the data - if self._anchor_mode.anchors is not None: + if self._anchor_mode.positions is not None: if not all( anchor in data_schema.get_position_fields().keys() - for anchor in self._anchor_mode.anchors + for anchor in self._anchor_mode.positions ): raise ValueError( "An anchor is specified that is not in the data schema." @@ -171,11 +184,11 @@ def _contstruct_filter(self, position_fields: Dict[int, List[str]]) -> str: NotImplementedError: If the length of fields is not equal to 3. """ query_strings = [] - join_string = " or " if self._anchor_mode.mode == "ANY" else " and " + join_string = " or " if self._anchor_mode.fragment_mode == "ANY" else " and " # subset on anchor regions - if self._anchor_mode.anchors is not None: + if self._anchor_mode.positions is not None: subset_positions = [ - position_fields[anchor] for anchor in self._anchor_mode.anchors + position_fields[anchor] for anchor in self._anchor_mode.positions ] else: subset_positions = list(position_fields.values()) diff --git a/tests/query_engine/conftest.py b/tests/query_engine/conftest.py index e75fdf7..40b3afd 100644 --- a/tests/query_engine/conftest.py +++ b/tests/query_engine/conftest.py @@ -122,7 +122,7 @@ def pixels_wihtout_regions_fixture(pixel_dataframe): @pytest.fixture(name="contacts_with_single_region") def contacts_with_single_region_fixture(contacts_without_regions, single_region): """Contacts with single region""" - return Overlap(single_region, anchor_mode=Anchor(mode="ANY"))( + return Overlap(single_region, anchor_mode=Anchor(fragment_mode="ANY"))( contacts_without_regions, ) @@ -130,7 +130,7 @@ def contacts_with_single_region_fixture(contacts_without_regions, single_region) @pytest.fixture(name="contacts_with_multiple_regions") def contacts_with_multiple_regions_fixture(contacts_without_regions, multi_region): """Contacts with multiple regions""" - return Overlap(multi_region, anchor_mode=Anchor(mode="ANY"))( + return Overlap(multi_region, anchor_mode=Anchor(fragment_mode="ANY"))( contacts_without_regions, ) @@ -138,7 +138,7 @@ def contacts_with_multiple_regions_fixture(contacts_without_regions, multi_regio @pytest.fixture(name="pixels_with_single_region") def pixels_with_single_region_fixture(pixels_without_regions, single_region): """Pixels with single region""" - return Overlap(single_region, anchor_mode=Anchor(mode="ANY"))( + return Overlap(single_region, anchor_mode=Anchor(fragment_mode="ANY"))( pixels_without_regions, ) @@ -146,4 +146,6 @@ def pixels_with_single_region_fixture(pixels_without_regions, single_region): @pytest.fixture(name="pixels_with_multiple_regions") def pixels_with_multiple_regions_fixture(pixels_without_regions, multi_region): """Pixels with multiple regions""" - return Overlap(multi_region, anchor_mode=Anchor(mode="ANY"))(pixels_without_regions) + return Overlap(multi_region, anchor_mode=Anchor(fragment_mode="ANY"))( + pixels_without_regions + ) diff --git a/tests/query_engine/test_contact_selection.py b/tests/query_engine/test_contact_selection.py index daae42d..87f36d8 100644 --- a/tests/query_engine/test_contact_selection.py +++ b/tests/query_engine/test_contact_selection.py @@ -60,7 +60,9 @@ def test_any_anchor_region_returns_correct_contacts( """Test that any anchor region returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -83,7 +85,9 @@ def test_all_anchor_regions_returns_correct_contacts( """Test that all anchor regions returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -116,7 +120,8 @@ def test_specific_anchor_regions_returns_correct_contacts( contacts = request.getfixturevalue(contact_fixture) query_plan = [ Overlap( - regions=single_region_2, anchor_mode=Anchor(mode="ALL", anchors=anchors) + regions=single_region_2, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), ) ] # execution @@ -153,7 +158,7 @@ def test_specific_anchor_regions_returns_correct_contacts_point_region( query_plan = [ Overlap( regions=single_region_3, - anchor_mode=Anchor(mode="ALL", anchors=anchors), + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), half_window_size=50, ) ] @@ -180,7 +185,9 @@ def test_any_anchor_region_returns_correct_contacts_multi_region( """Test that any anchor region returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -205,7 +212,9 @@ def test_all_anchor_regions_returns_correct_contacts_multi_region( """Test that all anchor regions returns correct contacts""" # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -231,7 +240,9 @@ def test_contacts_duplicated_for_multiple_overlapping_regions( """ # setup contacts = request.getfixturevalue(contact_fixture) - query_plan = [Overlap(regions=multi_region_2, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region_2, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(contacts) @@ -265,7 +276,10 @@ def test_specific_anchor_regions_returns_correct_contacts_multi_region( # setup contacts = request.getfixturevalue(contact_fixture) query_plan = [ - Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL", anchors=anchors)) + Overlap( + regions=multi_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), + ) ] # execution query = Query(query_steps=query_plan) @@ -293,7 +307,10 @@ def test_specific_anchor_region_not_in_contacts_raises_error( # setup contacts = request.getfixturevalue(contact_fixture) query_plan = [ - Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL", anchors=[3])) + Overlap( + regions=single_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=[3]), + ) ] with pytest.raises(ValueError): query = Query(query_steps=query_plan) diff --git a/tests/query_engine/test_distance_aggregation.py b/tests/query_engine/test_distance_aggregation.py index c77ae86..4772f36 100644 --- a/tests/query_engine/test_distance_aggregation.py +++ b/tests/query_engine/test_distance_aggregation.py @@ -232,7 +232,7 @@ def test_aggregations_on_dense_input( region = request.getfixturevalue(region_fixture) mapped_pixels = Query( query_steps=[ - Overlap(region, anchor_mode=Anchor(mode="ANY")), + Overlap(region, anchor_mode=Anchor(fragment_mode="ANY")), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], ).build(pixels) @@ -295,7 +295,7 @@ def test_aggregations_on_dense_input_with_reduced_dimensionality( region = request.getfixturevalue(region_fixture) mapped_pixels = Query( query_steps=[ - Overlap(region, anchor_mode=Anchor(mode="ANY")), + Overlap(region, anchor_mode=Anchor(fragment_mode="ANY")), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], ).build(pixels) @@ -367,7 +367,8 @@ def test_aggregations_on_sparse_input( query_plan = Query( query_steps=[ Overlap( - request.getfixturevalue(region_fixture), anchor_mode=Anchor(mode="ANY") + request.getfixturevalue(region_fixture), + anchor_mode=Anchor(fragment_mode="ANY"), ), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], @@ -456,7 +457,8 @@ def test_aggregations_on_sparse_input_with_reduced_dimensionality( query_plan = Query( query_steps=[ Overlap( - request.getfixturevalue(region_fixture), anchor_mode=Anchor(mode="ANY") + request.getfixturevalue(region_fixture), + anchor_mode=Anchor(fragment_mode="ANY"), ), DistanceTransformation(distance_mode=DistanceMode.LEFT), ], diff --git a/tests/query_engine/test_pixel_selection.py b/tests/query_engine/test_pixel_selection.py index c0c84e4..33be834 100644 --- a/tests/query_engine/test_pixel_selection.py +++ b/tests/query_engine/test_pixel_selection.py @@ -69,7 +69,9 @@ def test_any_anchor_region_returns_correct_pixels( """Test that any anchor region returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -95,7 +97,9 @@ def test_all_anchor_regions_returns_correct_pixels( """Test that all anchor regions returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=single_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -128,7 +132,8 @@ def test_specific_anchor_regions_returns_correct_pixels( pixels = request.getfixturevalue(pixel_fixture) query_plan = [ Overlap( - regions=single_region_2, anchor_mode=Anchor(mode="ALL", anchors=anchors) + regions=single_region_2, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), ) ] # execution @@ -153,7 +158,9 @@ def test_any_anchor_region_returns_correct_pixels_multi_region( """Test that any anchor region returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ANY"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ANY")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -176,7 +183,9 @@ def test_all_anchor_regions_returns_correct_pixels_multi_region( """Test that all anchor regions returns correct pixels""" # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -202,7 +211,9 @@ def test_pixels_duplicated_for_multiple_overlapping_regions( """ # setup pixels = request.getfixturevalue(pixels_fixture) - query_plan = [Overlap(regions=multi_region_2, anchor_mode=Anchor(mode="ALL"))] + query_plan = [ + Overlap(regions=multi_region_2, anchor_mode=Anchor(fragment_mode="ALL")) + ] # execution query = Query(query_steps=query_plan) result = query.build(pixels) @@ -234,7 +245,10 @@ def test_specific_anchor_regions_returns_correct_pixels_multi_region( # setup pixels = request.getfixturevalue(pixels_fixture) query_plan = [ - Overlap(regions=multi_region, anchor_mode=Anchor(mode="ALL", anchors=anchors)) + Overlap( + regions=multi_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=anchors), + ) ] # execution query = Query(query_steps=query_plan) @@ -262,7 +276,10 @@ def test_specific_anchor_region_not_in_pixels_raises_error( # setup pixels = request.getfixturevalue(pixels_fixture) query_plan = [ - Overlap(regions=single_region, anchor_mode=Anchor(mode="ALL", anchors=[3])) + Overlap( + regions=single_region, + anchor_mode=Anchor(fragment_mode="ALL", positions=[3]), + ) ] with pytest.raises(ValueError): query = Query(query_steps=query_plan) From 5c6deaf1a94bc23559eede73f2adcd44e0240d84 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Fri, 15 Mar 2024 18:52:32 +0100 Subject: [PATCH 02/17] constructor can handle multiple regions --- spoc/query_engine.py | 73 +++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/spoc/query_engine.py b/spoc/query_engine.py index 248dc75..5f5d587 100644 --- a/spoc/query_engine.py +++ b/spoc/query_engine.py @@ -70,22 +70,12 @@ class Anchor(BaseModel): positions: Optional[List[int]] = None def __repr__(self) -> str: - return f"Anchor(fragment_mode={self.fragment_mode}, positions={self.positions})" + return f"Anchor(fragment_mode={self.fragment_mode}, positions={self.positions}, region_mode={self.region_mode})" def __str__(self) -> str: return self.__repr__() -class MultiOverlap: - """ - This class represents an overlap calculation with multiple - genomic regions used for contact and pixel selection. - It provides methods to validate the filter against a data schema, - convert data to a duckdb relation, construct a filter string, - and apply the filter to the data. - """ - - class Overlap: """ This class represents an overlap calculation used for contact and pixel selection. @@ -96,7 +86,7 @@ class Overlap: def __init__( self, - regions: pd.DataFrame, + regions: Union[pd.DataFrame, List[pd.DataFrame]], anchor_mode: Union[Anchor, Tuple[str, List[int]]], half_window_size: Optional[int] = None, ) -> None: @@ -104,14 +94,39 @@ def __init__( Initialize the Overlap object. Args: - regions (pd.DataFrame): A DataFrame containing the regions data. + regions (Union[pd.DataFrame, List[pd.DataFrame]]): A DataFrame containing the regions data, + or a list of DataFrames containing the regions data. anchor_mode (Union[Anchor,Tuple[str,List[int]]]): The anchor mode to be used. half_window_size (Optional[int]): The window size the regions should be expanded to. Defaults to None and is inferred from the data. Returns: None """ - # add ids to regions if they don't exist + # preprocess regions + if isinstance(regions, list): + regions, half_window_sizes = zip( + *[self._prepare_regions(region, half_window_size) for region in regions] + ) + if not all( + half_window_size == half_window_sizes[0] + for half_window_size in half_window_sizes + ): + raise ValueError("All regions need to have the same window size.") + else: + self._regions, self._half_window_size = self._prepare_regions( + regions, half_window_size + ) + if isinstance(anchor_mode, tuple): + self._anchor_mode = Anchor( + fragment_mode=anchor_mode[0], positions=anchor_mode[1] + ) + else: + self._anchor_mode = anchor_mode + + def _prepare_regions( + self, regions: pd.DataFrame, half_window_size: Optional[int] + ) -> Tuple[pd.DataFrame, int]: + """Preprocessing of regions including adding an id column.""" if "id" not in regions.columns: regions["id"] = range(len(regions)) if half_window_size is not None: @@ -124,23 +139,19 @@ def __init__( expanded_regions["start"] = expanded_regions["midpoint"] - half_window_size expanded_regions["end"] = expanded_regions["midpoint"] + half_window_size # drop midpoint - expanded_regions = expanded_regions.drop(columns=["midpoint"]) - self._regions = RegionSchema.validate( - expanded_regions.add_prefix("region_") - ) - self._half_window_size = half_window_size - else: - self._regions = RegionSchema.validate(regions.add_prefix("region_")) - # infer window size -> variable regions will have largest possible window size - self._half_window_size = int( - (self._regions["region_end"] - self._regions["region_start"]).max() // 2 - ) - if isinstance(anchor_mode, tuple): - self._anchor_mode = Anchor( - fragment_mode=anchor_mode[0], positions=anchor_mode[1] - ) - else: - self._anchor_mode = anchor_mode + preprocssed_regions = expanded_regions.drop( + columns=["midpoint"] + ).add_prefix("region_") + return preprocssed_regions, half_window_size + preprocssed_regions = RegionSchema.validate(regions.add_prefix("region_")) + # infer window size -> variable regions will have largest possible window size + calculated_half_window_size = int( + ( + preprocssed_regions["region_end"] - preprocssed_regions["region_start"] + ).max() + // 2 + ) + return preprocssed_regions, calculated_half_window_size def validate(self, data_schema: GenomicDataSchema) -> None: """Validate the filter against the data schema""" From 9689a15a8fb8d17c635a20f7b14ef2f2c555d8f5 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Sun, 17 Mar 2024 15:50:51 +0100 Subject: [PATCH 03/17] first implementation of multioverlap --- spoc/models/dataframe_models.py | 10 +- spoc/query_engine.py | 107 ++++++++++++++---- .../test_multi_contact_selection.py | 98 ++++++++++++++++ 3 files changed, 190 insertions(+), 25 deletions(-) create mode 100644 tests/query_engine/test_multi_contact_selection.py diff --git a/spoc/models/dataframe_models.py b/spoc/models/dataframe_models.py index af42ff4..9cbd4a9 100644 --- a/spoc/models/dataframe_models.py +++ b/spoc/models/dataframe_models.py @@ -68,7 +68,7 @@ def get_schema(self) -> pa.DataFrameSchema: def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" @@ -89,7 +89,7 @@ def __init__( position_fields: Dict[int, List[str]], contact_order: int, binsize: Optional[int] = None, - region_number: Optional[int] = None, + region_number: Optional[Union[int, List[int]]] = None, half_window_size: Optional[int] = None, ) -> None: self._columns = columns @@ -127,7 +127,7 @@ def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" return self._binsize - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" return self._region_number @@ -258,7 +258,7 @@ def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" return None - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" return None @@ -360,7 +360,7 @@ def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" return self._binsize - def get_region_number(self) -> Optional[int]: + def get_region_number(self) -> Optional[Union[int, List[int]]]: """Returns the number of regions in the genomic data if present.""" return None diff --git a/spoc/query_engine.py b/spoc/query_engine.py index 5f5d587..48b5802 100644 --- a/spoc/query_engine.py +++ b/spoc/query_engine.py @@ -97,14 +97,15 @@ def __init__( regions (Union[pd.DataFrame, List[pd.DataFrame]]): A DataFrame containing the regions data, or a list of DataFrames containing the regions data. anchor_mode (Union[Anchor,Tuple[str,List[int]]]): The anchor mode to be used. - half_window_size (Optional[int]): The window size the regions should be expanded to. Defaults to None and is inferred from the data. + half_window_size (Optional[int]): The window size the regions should be expanded to. + Defaults to None and is inferred from the data. Returns: None """ # preprocess regions if isinstance(regions, list): - regions, half_window_sizes = zip( + self._regions, half_window_sizes = zip( *[self._prepare_regions(region, half_window_size) for region in regions] ) if not all( @@ -112,16 +113,17 @@ def __init__( for half_window_size in half_window_sizes ): raise ValueError("All regions need to have the same window size.") + self._half_window_size = half_window_sizes[0] else: self._regions, self._half_window_size = self._prepare_regions( regions, half_window_size ) if isinstance(anchor_mode, tuple): - self._anchor_mode = Anchor( + self._anchor = Anchor( fragment_mode=anchor_mode[0], positions=anchor_mode[1] ) else: - self._anchor_mode = anchor_mode + self._anchor = anchor_mode def _prepare_regions( self, regions: pd.DataFrame, half_window_size: Optional[int] @@ -130,6 +132,9 @@ def _prepare_regions( if "id" not in regions.columns: regions["id"] = range(len(regions)) if half_window_size is not None: + # check halfwindowsize is positive + if half_window_size < 0: + raise ValueError("Half window size must be positive.") expanded_regions = regions.copy() # create midpoint expanded_regions["midpoint"] = ( @@ -156,10 +161,10 @@ def _prepare_regions( def validate(self, data_schema: GenomicDataSchema) -> None: """Validate the filter against the data schema""" # check whether an anchor is specified that is not in the data - if self._anchor_mode.positions is not None: + if self._anchor.positions is not None: if not all( anchor in data_schema.get_position_fields().keys() - for anchor in self._anchor_mode.positions + for anchor in self._anchor.positions ): raise ValueError( "An anchor is specified that is not in the data schema." @@ -182,7 +187,55 @@ def _convert_to_duckdb( data = data.compute() return duckdb.from_df(data, connection=DUCKDB_CONNECTION) - def _contstruct_filter(self, position_fields: Dict[int, List[str]]) -> str: + def _construct_query_multi_region( + self, + regions: List[duckdb.DuckDBPyRelation], + genomic_df: duckdb.DuckDBPyRelation, + position_fields: Dict[int, List[str]], + ) -> duckdb.DuckDBPyRelation: + """Constructs the query for multiple regions.""" + snipped_df = genomic_df.set_alias("data") + for index, region in enumerate(regions): + snipped_df = snipped_df.join( + region.set_alias(f"regions_{index}"), + self._contstruct_filter(position_fields, f"regions_{index}"), + how="left", + ) + # filter regions based on region mode + if self._anchor.region_mode == "ALL": + return snipped_df.filter( + " and ".join( + [ + f"regions_{index}.region_chrom is not null" + for index in range(0, len(regions)) + ] + ) + ) + + return snipped_df.filter( + " or ".join( + [ + f"regions_{index}.region_chrom is not null" + for index in range(0, len(regions)) + ] + ) + ) + + def _constrcut_query_single_region( + self, + regions: duckdb.DuckDBPyRelation, + genomic_df: duckdb.DuckDBPyRelation, + position_fields: Dict[int, List[str]], + ) -> duckdb.DuckDBPyRelation: + """Constructs the query for a single region.""" + return genomic_df.set_alias("data").join( + regions.set_alias("regions"), + self._contstruct_filter(position_fields, "regions"), + ) + + def _contstruct_filter( + self, position_fields: Dict[int, List[str]], region_name: str + ) -> str: """Constructs the filter string. Args: @@ -195,21 +248,21 @@ def _contstruct_filter(self, position_fields: Dict[int, List[str]]) -> str: NotImplementedError: If the length of fields is not equal to 3. """ query_strings = [] - join_string = " or " if self._anchor_mode.fragment_mode == "ANY" else " and " + join_string = " or " if self._anchor.fragment_mode == "ANY" else " and " # subset on anchor regions - if self._anchor_mode.positions is not None: + if self._anchor.positions is not None: subset_positions = [ - position_fields[anchor] for anchor in self._anchor_mode.positions + position_fields[anchor] for anchor in self._anchor.positions ] else: subset_positions = list(position_fields.values()) for fields in subset_positions: chrom, start, end = fields - output_string = f"""(data.{chrom} = regions.region_chrom and + output_string = f"""(data.{chrom} = {region_name}.region_chrom and ( - data.{start} between regions.region_start and regions.region_end or - data.{end} between regions.region_start and regions.region_end or - regions.region_start between data.{start} and data.{end} + data.{start} between {region_name}.region_start and {region_name}.region_end or + data.{end} between {region_name}.region_start and {region_name}.region_end or + {region_name}.region_start between data.{start} and data.{end} ) )""" query_strings.append(output_string) @@ -223,12 +276,17 @@ def _get_transformed_schema( ) -> GenomicDataSchema: """Returns the schema of the transformed data.""" # construct schema + # get region number + if isinstance(self._regions, (list, tuple)): + region_number = [region.shape[0] for region in self._regions] + else: + region_number = self._regions.shape[0] return QueryStepDataSchema( columns=data_frame.columns, position_fields=position_fields, contact_order=input_schema.get_contact_order(), binsize=input_schema.get_binsize(), - region_number=len(self._regions), + region_number=region_number, half_window_size=self._half_window_size, ) @@ -253,7 +311,7 @@ def _add_end_position( ) def __repr__(self) -> str: - return f"Overlap(anchor_mode={self._anchor_mode})" + return f"Overlap(anchor_mode={self._anchor})" def __call__(self, genomic_data: GenomicData) -> GenomicData: """Apply the filter to the data""" @@ -264,7 +322,11 @@ def __call__(self, genomic_data: GenomicData) -> GenomicData: genomic_df = genomic_data.data else: genomic_df = self._convert_to_duckdb(genomic_data.data) - regions = self._convert_to_duckdb(self._regions) + # bring regions to duckdb dataframe + if isinstance(self._regions, (list, tuple)): + regions = [self._convert_to_duckdb(region) for region in self._regions] + else: + regions = self._convert_to_duckdb(self._regions) # get position columns and construct filter position_fields = input_schema.get_position_fields() # add end position if not present @@ -277,9 +339,14 @@ def __call__(self, genomic_data: GenomicData) -> GenomicData: for position in position_fields.keys() } # construct query - snipped_df = genomic_df.set_alias("data").join( - regions.set_alias("regions"), self._contstruct_filter(position_fields) - ) + if isinstance(regions, (list, tuple)): + snipped_df = self._construct_query_multi_region( + regions, genomic_df, position_fields + ) + else: + snipped_df = self._constrcut_query_single_region( + regions, genomic_df, position_fields + ) return QueryPlan( snipped_df, self._get_transformed_schema(snipped_df, input_schema, position_fields), diff --git a/tests/query_engine/test_multi_contact_selection.py b/tests/query_engine/test_multi_contact_selection.py new file mode 100644 index 0000000..94b3c72 --- /dev/null +++ b/tests/query_engine/test_multi_contact_selection.py @@ -0,0 +1,98 @@ +"""These set of tests test selection of contacts overlapping +with a list of mulitple regions""" +import dask.dataframe as dd +import duckdb +import pytest + +from spoc.contacts import Contacts +from spoc.io import DUCKDB_CONNECTION +from spoc.query_engine import Anchor +from spoc.query_engine import Overlap +from spoc.query_engine import Query + + +@pytest.fixture(name="example_2d_contacts_pandas") +def example_2d_contacts_pandas_fixture(example_2d_df): + """Example 2d contacts""" + return Contacts(example_2d_df) + + +@pytest.fixture(name="example_2d_contacts_dask") +def example_2d_contacts_dask_fixture(example_2d_df): + """Example 2d contacts""" + return Contacts(dd.from_pandas(example_2d_df, npartitions=2)) + + +@pytest.fixture(name="example_2d_contacts_duckdb") +def example_2d_contacts_duckdb_fixture(example_2d_df): + """Example 2d contacts""" + return Contacts(duckdb.from_df(example_2d_df, connection=DUCKDB_CONNECTION)) + + +def test_different_half_window_size_throws_error(single_region, single_region_2): + """Test that different half window size throws error""" + with pytest.raises(ValueError): + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode="ANY"), + ) + + +def test_negative_half_window_size_throws_error(single_region, single_region_2): + """Test that negative half window size throws error""" + with pytest.raises(ValueError): + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode="ANY"), + half_window_size=-1, + ) + + +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_contacts", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 4), + ("ALL", "ANY", 1), + ], +) +def test_overlap_without_position_subset( + fragment_mode, + region_mode, + number_contacts, + single_region, + single_region_2, + example_2d_contacts_pandas, +): + """Test that overlap without position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode=fragment_mode, region_mode=region_mode), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(example_2d_contacts_pandas) + + # assert + assert result.compute().shape[0] == number_contacts + + +def test_overlap_with_position_subset(): + """Test that overlap with position subset""" + + +def test_overlap_without_adding_columns_does_not_duplicate_contacts(): + """Test that overlap without adding columns""" + + +# validation tests + + +def test_specific_fragment_not_in_contacts(): + """Test that specific fragment not in contacts throws errors""" From 0792aed7577f64eab637de42c789d84b7b2de605 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Sun, 17 Mar 2024 16:26:52 +0100 Subject: [PATCH 04/17] finished implementation of multiregion handling --- spoc/models/dataframe_models.py | 9 ++- spoc/query_engine.py | 5 ++ .../test_multi_contact_selection.py | 69 +++++++++++++++++-- 3 files changed, 72 insertions(+), 11 deletions(-) diff --git a/spoc/models/dataframe_models.py b/spoc/models/dataframe_models.py index 9cbd4a9..ea4d0f6 100644 --- a/spoc/models/dataframe_models.py +++ b/spoc/models/dataframe_models.py @@ -350,11 +350,10 @@ def get_position_fields(self) -> Dict[int, List[str]]: return { i: ["chrom", f"start_{i}"] for i in range(1, self._number_fragments + 1) } - else: - return { - i: [f"chrom_{i}", f"start_{i}"] - for i in range(1, self._number_fragments + 1) - } + return { + i: [f"chrom_{i}", f"start_{i}"] + for i in range(1, self._number_fragments + 1) + } def get_binsize(self) -> Optional[int]: """Returns the binsize of the genomic data""" diff --git a/spoc/query_engine.py b/spoc/query_engine.py index 48b5802..33fd75c 100644 --- a/spoc/query_engine.py +++ b/spoc/query_engine.py @@ -89,6 +89,7 @@ def __init__( regions: Union[pd.DataFrame, List[pd.DataFrame]], anchor_mode: Union[Anchor, Tuple[str, List[int]]], half_window_size: Optional[int] = None, + add_overlap_columns: bool = True, ) -> None: """ Initialize the Overlap object. @@ -103,6 +104,7 @@ def __init__( Returns: None """ + self._add_overlap_columns = add_overlap_columns # preprocess regions if isinstance(regions, list): self._regions, half_window_sizes = zip( @@ -347,6 +349,9 @@ def __call__(self, genomic_data: GenomicData) -> GenomicData: snipped_df = self._constrcut_query_single_region( regions, genomic_df, position_fields ) + # remove overlap columns and drop ducpliates if requested + if not self._add_overlap_columns: + snipped_df = snipped_df.project("data.*").distinct() return QueryPlan( snipped_df, self._get_transformed_schema(snipped_df, input_schema, position_fields), diff --git a/tests/query_engine/test_multi_contact_selection.py b/tests/query_engine/test_multi_contact_selection.py index 94b3c72..3ce7c4f 100644 --- a/tests/query_engine/test_multi_contact_selection.py +++ b/tests/query_engine/test_multi_contact_selection.py @@ -83,16 +83,73 @@ def test_overlap_without_position_subset( assert result.compute().shape[0] == number_contacts -def test_overlap_with_position_subset(): +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_contacts", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 3), + ("ALL", "ANY", 3), + ], +) +def test_overlap_with_position_subset( + fragment_mode, + region_mode, + number_contacts, + single_region, + single_region_2, + example_2d_contacts_pandas, +): """Test that overlap with position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor( + fragment_mode=fragment_mode, region_mode=region_mode, positions=[1] + ), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + # run + result = query.build(example_2d_contacts_pandas) -def test_overlap_without_adding_columns_does_not_duplicate_contacts(): - """Test that overlap without adding columns""" + # assert + assert result.compute().shape[0] == number_contacts -# validation tests +@pytest.mark.parametrize( + "add_overlap_columns,number_contacts", + [ + (True, 5), + (False, 3), + ], +) +def test_duplicates_after_overlap_handled_correctly( + add_overlap_columns, + number_contacts, + multi_region_2, + single_region, + example_2d_contacts_pandas, +): + """Test that duplicates after overlap are handled correctly""" + # setup + query_plan = [ + Overlap( + regions=[multi_region_2, single_region], + anchor_mode=Anchor(fragment_mode="ANY", region_mode="ANY"), + half_window_size=100, + add_overlap_columns=add_overlap_columns, + ) + ] + query = Query(query_steps=query_plan) + # run + result = query.build(example_2d_contacts_pandas) -def test_specific_fragment_not_in_contacts(): - """Test that specific fragment not in contacts throws errors""" + # assert + assert result.compute().shape[0] == number_contacts + if not add_overlap_columns: + assert len(result.compute().filter(regex="region").columns) == 0 From 027e94458bb38e997934aa7fb70f13fc83ef85c4 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Sun, 17 Mar 2024 18:34:40 +0100 Subject: [PATCH 05/17] added testcase for distance transformation and multiple regions --- spoc/query_engine.py | 41 ++++++++++++++----- tests/query_engine/conftest.py | 12 ++++++ .../test_distance_transformation.py | 1 + 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/spoc/query_engine.py b/spoc/query_engine.py index 33fd75c..03f2215 100644 --- a/spoc/query_engine.py +++ b/spoc/query_engine.py @@ -108,7 +108,10 @@ def __init__( # preprocess regions if isinstance(regions, list): self._regions, half_window_sizes = zip( - *[self._prepare_regions(region, half_window_size) for region in regions] + *[ + self._prepare_regions(region, half_window_size, index=index) + for index, region in enumerate(regions) + ] ) if not all( half_window_size == half_window_sizes[0] @@ -128,7 +131,7 @@ def __init__( self._anchor = anchor_mode def _prepare_regions( - self, regions: pd.DataFrame, half_window_size: Optional[int] + self, regions: pd.DataFrame, half_window_size: Optional[int], index: int = 0 ) -> Tuple[pd.DataFrame, int]: """Preprocessing of regions including adding an id column.""" if "id" not in regions.columns: @@ -149,6 +152,9 @@ def _prepare_regions( preprocssed_regions = expanded_regions.drop( columns=["midpoint"] ).add_prefix("region_") + preprocssed_regions = RegionSchema.validate(preprocssed_regions) + if index > 0: + preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}") return preprocssed_regions, half_window_size preprocssed_regions = RegionSchema.validate(regions.add_prefix("region_")) # infer window size -> variable regions will have largest possible window size @@ -158,6 +164,9 @@ def _prepare_regions( ).max() // 2 ) + # add index + if index > 0: + preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}") return preprocssed_regions, calculated_half_window_size def validate(self, data_schema: GenomicDataSchema) -> None: @@ -200,7 +209,9 @@ def _construct_query_multi_region( for index, region in enumerate(regions): snipped_df = snipped_df.join( region.set_alias(f"regions_{index}"), - self._contstruct_filter(position_fields, f"regions_{index}"), + self._contstruct_filter( + position_fields, f"regions_{index}", index=index + ), how="left", ) # filter regions based on region mode @@ -208,7 +219,7 @@ def _construct_query_multi_region( return snipped_df.filter( " and ".join( [ - f"regions_{index}.region_chrom is not null" + f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null" for index in range(0, len(regions)) ] ) @@ -217,7 +228,7 @@ def _construct_query_multi_region( return snipped_df.filter( " or ".join( [ - f"regions_{index}.region_chrom is not null" + f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null" for index in range(0, len(regions)) ] ) @@ -236,7 +247,7 @@ def _constrcut_query_single_region( ) def _contstruct_filter( - self, position_fields: Dict[int, List[str]], region_name: str + self, position_fields: Dict[int, List[str]], region_name: str, index: int = 0 ) -> str: """Constructs the filter string. @@ -251,6 +262,10 @@ def _contstruct_filter( """ query_strings = [] join_string = " or " if self._anchor.fragment_mode == "ANY" else " and " + if index > 0: + column_index = f"_{index}" + else: + column_index = "" # subset on anchor regions if self._anchor.positions is not None: subset_positions = [ @@ -260,11 +275,11 @@ def _contstruct_filter( subset_positions = list(position_fields.values()) for fields in subset_positions: chrom, start, end = fields - output_string = f"""(data.{chrom} = {region_name}.region_chrom and + output_string = f"""(data.{chrom} = {region_name}.region_chrom{column_index} and ( - data.{start} between {region_name}.region_start and {region_name}.region_end or - data.{end} between {region_name}.region_start and {region_name}.region_end or - {region_name}.region_start between data.{start} and data.{end} + data.{start} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or + data.{end} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or + {region_name}.region_start{column_index} between data.{start} and data.{end} ) )""" query_strings.append(output_string) @@ -607,6 +622,12 @@ def validate(self, data_schema: GenomicDataSchema) -> None: raise ValueError( "Binsize specified in data schema, but distance mode is not set to LEFT." ) + # check wheter there has only been a single region overlapped + region_number = data_schema.get_region_number() + if isinstance(region_number, list): + raise ValueError( + "Distance transformation requires only a single set of regions overlapped." + ) def _create_transform_columns( self, genomic_df: duckdb.DuckDBPyRelation, input_schema: GenomicDataSchema diff --git a/tests/query_engine/conftest.py b/tests/query_engine/conftest.py index 40b3afd..ee2bd46 100644 --- a/tests/query_engine/conftest.py +++ b/tests/query_engine/conftest.py @@ -135,6 +135,18 @@ def contacts_with_multiple_regions_fixture(contacts_without_regions, multi_regio ) +@pytest.fixture(name="contacts_with_multiple_regions_overlapped") +def contacts_with_multiple_regions_overlapped_fixture( + contacts_without_regions, single_region, single_region_2 +): + """Pixels with multiple regions overlapped""" + return Overlap( + [single_region, single_region_2], + anchor_mode=Anchor(fragment_mode="ANY"), + half_window_size=100, + )(contacts_without_regions) + + @pytest.fixture(name="pixels_with_single_region") def pixels_with_single_region_fixture(pixels_without_regions, single_region): """Pixels with single region""" diff --git a/tests/query_engine/test_distance_transformation.py b/tests/query_engine/test_distance_transformation.py index 2278006..5532009 100644 --- a/tests/query_engine/test_distance_transformation.py +++ b/tests/query_engine/test_distance_transformation.py @@ -14,6 +14,7 @@ [ "contacts_without_regions", "pixels_without_regions", + "contacts_with_multiple_regions_overlapped", ], ) def test_incompatible_input_rejected(genomic_data_fixture, request): From f8096d497570195f286dd59824034c423bc65ae3 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Sun, 17 Mar 2024 18:58:02 +0100 Subject: [PATCH 06/17] added tests for pixels for multiple overlap --- .../test_multi_pixel_selection.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 tests/query_engine/test_multi_pixel_selection.py diff --git a/tests/query_engine/test_multi_pixel_selection.py b/tests/query_engine/test_multi_pixel_selection.py new file mode 100644 index 0000000..ae2621e --- /dev/null +++ b/tests/query_engine/test_multi_pixel_selection.py @@ -0,0 +1,121 @@ +"""These set of tests test selection of pixels overlapping +with a list of mulitple regions""" +import pytest + +from spoc.pixels import Pixels +from spoc.query_engine import Anchor +from spoc.query_engine import Overlap +from spoc.query_engine import Query + + +@pytest.fixture(name="pixels_pandas") +def pixels_pandas_fixture(pixel_dataframe): + """A pandas dataframe containing pixels""" + return Pixels(pixel_dataframe, number_fragments=2, binsize=10) + + +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_pixels", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 4), + ("ALL", "ANY", 1), + ], +) +def test_overlap_without_position_subset( + fragment_mode, + region_mode, + number_pixels, + single_region, + single_region_2, + pixels_pandas, +): + """Test that overlap without position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor(fragment_mode=fragment_mode, region_mode=region_mode), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(pixels_pandas) + + # assert + assert result.compute().shape[0] == number_pixels + + +@pytest.mark.parametrize( + "fragment_mode,region_mode,number_pixels", + [ + ("ALL", "ALL", 0), + ("ANY", "ALL", 0), + ("ANY", "ANY", 3), + ("ALL", "ANY", 3), + ], +) +def test_overlap_with_position_subset( + fragment_mode, + region_mode, + number_pixels, + single_region, + single_region_2, + pixels_pandas, +): + """Test that overlap with position subset""" + # setup + query_plan = [ + Overlap( + regions=[single_region, single_region_2], + anchor_mode=Anchor( + fragment_mode=fragment_mode, region_mode=region_mode, positions=[1] + ), + half_window_size=100, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(pixels_pandas) + + # assert + assert result.compute().shape[0] == number_pixels + + +@pytest.mark.parametrize( + "add_overlap_columns,number_pixels", + [ + (True, 5), + (False, 3), + ], +) +def test_duplicates_after_overlap_handled_correctly( + add_overlap_columns, + number_pixels, + multi_region_2, + single_region, + pixels_pandas, +): + """Test that duplicates after overlap are handled correctly""" + # setup + query_plan = [ + Overlap( + regions=[multi_region_2, single_region], + anchor_mode=Anchor(fragment_mode="ANY", region_mode="ANY"), + half_window_size=100, + add_overlap_columns=add_overlap_columns, + ) + ] + query = Query(query_steps=query_plan) + + # run + result = query.build(pixels_pandas) + + # assert + assert result.compute().shape[0] == number_pixels + if not add_overlap_columns: + assert len(result.compute().filter(regex="region").columns) == 0 From 49cb6de94de48fd59eeac908d84e3c9eac34db68 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Tue, 19 Mar 2024 22:06:04 +0100 Subject: [PATCH 07/17] added documentation for multi region overlap --- notebooks/query_engine_usage.ipynb | 246 +++++++++++++++++++++++++++-- 1 file changed, 231 insertions(+), 15 deletions(-) diff --git a/notebooks/query_engine_usage.ipynb b/notebooks/query_engine_usage.ipynb index d01b9f1..7cda810 100644 --- a/notebooks/query_engine_usage.ipynb +++ b/notebooks/query_engine_usage.ipynb @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -135,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -151,16 +151,16 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -174,12 +174,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The `.load_result` method of the `QueryResult` object can be executed using `.load_result`, which returns a `pd.DataFrame`. The resulting dataframe has additional columns that represent the regions, with which the input contacts overlapped." + "The `.compute` method of the `QueryPlan` object can be executed using `.compute()`, which returns a `pd.DataFrame`. The resulting dataframe has additional columns that represent the regions, with which the input contacts overlapped." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -278,7 +278,7 @@ "2 400 0 " ] }, - "execution_count": 8, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -360,7 +360,7 @@ "0 400 0 " ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -393,8 +393,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Selecting a subset of contacts at multiple genomic regions\n", - "The Overlap class is also capable of selecting contacts at multiple genomic regions. Here, the behavior of `Overlap` deviates from a simple filter, because if a given contact overlaps with multiple regions, it will be returned multiple times." + "## Selecting a subset of contacts at a set of genomic regions\n", + "The Overlap class is also capable of selecting contacts at a set of genomic regions. Here, the default behavior of `Overlap` deviates from a simple filter, because if a given contact overlaps with multiple regions in the set, it will be returned multiple times." ] }, { @@ -406,7 +406,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -419,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -496,7 +496,7 @@ "1 200 1 " ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -518,6 +518,89 @@ "In this example, the contact overlapping both regions is duplicated." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we, however, only want to filter contacts that overlap a given set of regions without duplicates being returned, we can pass the `add_overlap_columns` argument to the `Overlap` constructor to as `False`. This will only return the respective contacts that have been deduplicated:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrom_1start_1end_1chrom_2start_2end_2
0chr1100200chr110002000
\n", + "
" + ], + "text/plain": [ + " chrom_1 start_1 end_1 chrom_2 start_2 end_2\n", + "0 chr1 100 200 chr1 1000 2000" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query_steps = [\n", + " Overlap(target_regions, anchor_mode=Anchor(fragment_mode=\"ANY\", positions=[1]),\n", + " add_overlap_columns=False)\n", + "]\n", + "Query(query_steps=query_steps)\\\n", + " .build(contacts)\\\n", + " .compute()\\\n", + " .filter(regex=r\"chrom|start|end|id\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, only a a ingle contact is returned, without the corresponding region columns." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -525,6 +608,139 @@ "The same functionality is implemented also for the pixels class." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Selecting a subset of contacts at multiple set of genomic regions\n", + "The `Overlap` class capable of selecting conatcts and pixels base don overlap with multiple sets of genomic regions. Here, the `Anchor` class specifies how the differnet overlap possibilities should be handled: The `fragment_mode` parameter specifies whether we require all (value `ALL`) fragments to overlap, or whether we require any (value `ANY`) fragments to overlap. The `region_mode` parameter then specifies whether we require the fragments to overlap any (`ANY`) of the passed sets of genomic regions or all (`ALL`). The `positions` parameter specifies - as in the case with a single set of genomic regions - which fragments we apply this logic to." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "target_regions = pd.DataFrame({\n", + " \"chrom\": ['chr1'],\n", + " \"start\": [110],\n", + " \"end\": [140],\n", + "})\n", + "\n", + "target_regions_2 = pd.DataFrame({\n", + " \"chrom\": ['chr1'],\n", + " \"start\": [1000],\n", + " \"end\": [1030],\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we require that any of the fragments overlap all of the passed regions" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrom_1start_1end_1chrom_2start_2end_2region_chromregion_startregion_endregion_idregion_chrom_1region_start_1region_end_1region_id_1
0chr1100200chr110002000chr11101400chr1100010300
\n", + "
" + ], + "text/plain": [ + " chrom_1 start_1 end_1 chrom_2 start_2 end_2 region_chrom region_start \\\n", + "0 chr1 100 200 chr1 1000 2000 chr1 110 \n", + "\n", + " region_end region_id region_chrom_1 region_start_1 region_end_1 \\\n", + "0 140 0 chr1 1000 1030 \n", + "\n", + " region_id_1 \n", + "0 0 " + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query_steps = [\n", + " Overlap([target_regions, target_regions_2],\n", + " anchor_mode=Anchor(fragment_mode=\"ANY\",region_mode='ALL')\n", + " )\n", + "]\n", + "Query(query_steps=query_steps)\\\n", + " .build(contacts)\\\n", + " .compute()\\\n", + " .filter(regex=r\"chrom|start|end|id\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This way we can filter for contacts between specific target regions, e.g. loop bases or promoters and enhancers." + ] + }, { "cell_type": "markdown", "metadata": {}, From b0cad64a452d31c2d143b02d7589db4335d39e77 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 09:01:19 +0200 Subject: [PATCH 08/17] made get label values more efficient --- spoc/contacts.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/spoc/contacts.py b/spoc/contacts.py index 0f08941..4655f68 100644 --- a/spoc/contacts.py +++ b/spoc/contacts.py @@ -49,6 +49,7 @@ def __init__( label_sorted: bool = False, binary_labels_equal: bool = False, symmetry_flipped: bool = False, + label_values: Optional[List[str]] = None ) -> None: self.contains_metadata = ( "metadata_1" in contact_frame.columns @@ -61,7 +62,6 @@ def __init__( number_fragments=self.number_fragments, contains_metadata=self.contains_metadata, ) - # TODO: make this work for duckdb pyrelation -> switch to mode if isinstance(contact_frame, pd.DataFrame): self.data_mode = DataMode.PANDAS elif isinstance(contact_frame, dd.DataFrame): @@ -75,6 +75,7 @@ def __init__( self.label_sorted = label_sorted self.binary_labels_equal = binary_labels_equal self.symmetry_flipped = symmetry_flipped + self.label_values = label_values @staticmethod def from_uri(uri, mode=DataMode.PANDAS): @@ -116,15 +117,19 @@ def get_label_values(self) -> List[str]: # TODO: This could be put in global metadata of parquet file if not self.contains_metadata: raise ValueError("Contacts do not contain metadata!") - output = set() - for i in range(self.number_fragments): - if self.data_mode == DataMode.DASK: - output.update(self.data[f"metadata_{i+1}"].unique().compute()) - elif self.data_mode == DataMode.PANDAS: - output.update(self.data[f"metadata_{i+1}"].unique()) - else: - raise ValueError("Label values not supported for duckdb!") - return list(output) + if self.label_values is None: + output = set() + for i in range(self.number_fragments): + if self.data_mode == DataMode.DASK: + output.update(self.data[f"metadata_{i+1}"].unique().compute()) + elif self.data_mode == DataMode.PANDAS: + output.update(self.data[f"metadata_{i+1}"].unique()) + else: + raise ValueError("Label values not supported for duckdb!") + # add metadata values and return + self.label_values = list(output) + return list(output) + return self.label_values def get_chromosome_values(self) -> List[str]: """Returns all chromosome values""" From bf1f6e8a14d725445971d9b9034f75d008f33dda Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 12:17:48 +0200 Subject: [PATCH 09/17] removed unnecessary sorting of index from sort labels --- spoc/contacts.py | 5 ++--- tests/contacts_tests/test_symmetry.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/spoc/contacts.py b/spoc/contacts.py index 4655f68..5de24eb 100644 --- a/spoc/contacts.py +++ b/spoc/contacts.py @@ -350,12 +350,11 @@ def sort_labels(self, contacts: Contacts) -> Contacts: ) # determine which method to use for concatenation if contacts.data_mode == DataMode.DASK: - # this is a bit of a hack to get the index sorted. Dask does not support index sorting result = ( - dd.concat(subsets).reset_index().sort_values("index").set_index("index") + dd.concat(subsets) ) elif contacts.data_mode == DataMode.PANDAS: - result = pd.concat(subsets).sort_index() + result = pd.concat(subsets) else: raise ValueError("Sorting labels for duckdb relations is not implemented.") return Contacts( diff --git a/tests/contacts_tests/test_symmetry.py b/tests/contacts_tests/test_symmetry.py index 0eea19f..df9169c 100644 --- a/tests/contacts_tests/test_symmetry.py +++ b/tests/contacts_tests/test_symmetry.py @@ -60,7 +60,7 @@ def test_labelled_contacts_are_sorted_correctly(unsorted, sorted_contacts, reque ), request.getfixturevalue(sorted_contacts) contacts = Contacts(unsorted) result = ContactManipulator().sort_labels(contacts) - pd.testing.assert_frame_equal(result.data, sorted_contacts) + pd.testing.assert_frame_equal(result.data.sort_index(), sorted_contacts) assert result.label_sorted @@ -82,7 +82,7 @@ def test_labelled_contacts_are_sorted_correctly_dask( contacts = Contacts(unsorted) result = ContactManipulator().sort_labels(contacts) pd.testing.assert_frame_equal( - result.data.compute().reset_index(drop=True), + result.data.compute().sort_index().reset_index(drop=True), sorted_contacts.reset_index(drop=True), ) assert result.label_sorted From 19a2f5d9c74cf6c9e8c88cba5df2bb647c686776 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 12:31:37 +0200 Subject: [PATCH 10/17] removed unnecessary index sort --- spoc/contacts.py | 5 ++--- tests/contacts_tests/test_symmetry.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/spoc/contacts.py b/spoc/contacts.py index 5de24eb..6bb14de 100644 --- a/spoc/contacts.py +++ b/spoc/contacts.py @@ -456,12 +456,11 @@ def equate_binary_labels(self, contacts: Contacts) -> Contacts: subsets.append(subset) # determine which method to use for concatenation if contacts.data_mode == DataMode.DASK: - # this is a bit of a hack to get the index sorted. Dask does not support index sorting result = ( - dd.concat(subsets).reset_index().sort_values("index").set_index("index") + dd.concat(subsets) ) elif contacts.data_mode == DataMode.PANDAS: - result = pd.concat(subsets).sort_index() + result = pd.concat(subsets) else: raise ValueError( "Equate binary labels for duckdb relations is not implemented." diff --git a/tests/contacts_tests/test_symmetry.py b/tests/contacts_tests/test_symmetry.py index df9169c..2f39109 100644 --- a/tests/contacts_tests/test_symmetry.py +++ b/tests/contacts_tests/test_symmetry.py @@ -103,7 +103,7 @@ def test_equate_binary_labels(unequated, equated, request): ) contacts = Contacts(unequated, label_sorted=True) result = ContactManipulator().equate_binary_labels(contacts) - pd.testing.assert_frame_equal(result.data, equated) + pd.testing.assert_frame_equal(result.data.sort_index(), equated) @pytest.mark.parametrize( @@ -123,7 +123,7 @@ def test_equate_binary_labels_dask(unequated, equated, request): contacts = Contacts(unequated, label_sorted=True) result = ContactManipulator().equate_binary_labels(contacts) pd.testing.assert_frame_equal( - result.data.compute().reset_index(drop=True), equated.reset_index(drop=True) + result.data.compute().sort_index().reset_index(drop=True), equated.reset_index(drop=True) ) From e1f7b0ca7a87b628aaa8b5eb60cd64ff907d2d20 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 13:16:36 +0200 Subject: [PATCH 11/17] propagation of label values --- spoc/contacts.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/spoc/contacts.py b/spoc/contacts.py index 6bb14de..5a1398a 100644 --- a/spoc/contacts.py +++ b/spoc/contacts.py @@ -49,7 +49,7 @@ def __init__( label_sorted: bool = False, binary_labels_equal: bool = False, symmetry_flipped: bool = False, - label_values: Optional[List[str]] = None + label_values: Optional[List[str]] = None, ) -> None: self.contains_metadata = ( "metadata_1" in contact_frame.columns @@ -350,15 +350,16 @@ def sort_labels(self, contacts: Contacts) -> Contacts: ) # determine which method to use for concatenation if contacts.data_mode == DataMode.DASK: - result = ( - dd.concat(subsets) - ) + result = dd.concat(subsets) elif contacts.data_mode == DataMode.PANDAS: result = pd.concat(subsets) else: raise ValueError("Sorting labels for duckdb relations is not implemented.") return Contacts( - result, number_fragments=contacts.number_fragments, label_sorted=True + result, + number_fragments=contacts.number_fragments, + label_sorted=True, + label_values=label_values, ) def _sort_chromosomes(self, df: DataFrame, number_fragments: int) -> DataFrame: @@ -456,9 +457,7 @@ def equate_binary_labels(self, contacts: Contacts) -> Contacts: subsets.append(subset) # determine which method to use for concatenation if contacts.data_mode == DataMode.DASK: - result = ( - dd.concat(subsets) - ) + result = dd.concat(subsets) elif contacts.data_mode == DataMode.PANDAS: result = pd.concat(subsets) else: @@ -509,6 +508,7 @@ def subset_on_metadata( label_sorted=contacts.label_sorted, binary_labels_equal=contacts.binary_labels_equal, symmetry_flipped=contacts.symmetry_flipped, + label_values=label_values, ) def flip_symmetric_contacts( @@ -537,6 +537,7 @@ def flip_symmetric_contacts( label_sorted=True, binary_labels_equal=contacts.binary_labels_equal, symmetry_flipped=True, + label_values=label_values, ) result = self._flip_unlabelled_contacts(contacts.data) if sort_chromosomes: From a111187da06d96738d98b7854a11f8f26168a1cc Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 13:48:38 +0200 Subject: [PATCH 12/17] added propagation of label values --- spoc/contacts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spoc/contacts.py b/spoc/contacts.py index 5a1398a..23dca99 100644 --- a/spoc/contacts.py +++ b/spoc/contacts.py @@ -469,6 +469,7 @@ def equate_binary_labels(self, contacts: Contacts) -> Contacts: number_fragments=contacts.number_fragments, label_sorted=True, binary_labels_equal=True, + label_values=label_values, ) def subset_on_metadata( From 7d98991faa942c5e02e27a353c2cb381a9b85f8e Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 16:54:01 +0200 Subject: [PATCH 13/17] Fixed bug where duckdb can't read dask output --- spoc/io.py | 19 +++++++- spoc/models/dataframe_models.py | 4 +- tests/io_tests/test_io_contacts.py | 78 ++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 4 deletions(-) diff --git a/spoc/io.py b/spoc/io.py index 0e9cb77..1205d70 100644 --- a/spoc/io.py +++ b/spoc/io.py @@ -45,6 +45,8 @@ def __init__(self, data_mode: DataMode = DataMode.PANDAS) -> None: self._parquet_reader_func = pd.read_parquet else: raise ValueError(f"Data mode {data_mode} not supported!") + # store data mode + self._data_mode = data_mode @staticmethod def write_label_library(path: str, data: Dict[str, bool]) -> None: @@ -254,9 +256,19 @@ def load_pixels( ) # rewrite path to contain parent folder full_pixel_path = Path(path) / pixel_path - df = self._parquet_reader_func(full_pixel_path) + if self._data_mode == DataMode.DUCKDB: + df = self._parquet_reader_func(self._get_duckdb_path(full_pixel_path)) + else: + df = self._parquet_reader_func(full_pixel_path) return Pixels(df, **matched_parameters.dict()) + def _get_duckdb_path(self, path: Path) -> str: + """Constructs duckdb path string to handle cases + where parquet files are stored in a directory with multiple files.""" + if Path(path).is_dir(): + return f"{path}/*.parquet" + return str(path) + def load_contacts( self, path: str, global_parameters: Optional[ContactsParameters] = None ) -> Contacts: @@ -287,7 +299,10 @@ def load_contacts( ) # rewrite path to contain parent folder full_contacts_path = Path(path) / contacts_path - df = self._parquet_reader_func(full_contacts_path) + if self._data_mode == DataMode.DUCKDB: + df = self._parquet_reader_func(self._get_duckdb_path(full_contacts_path)) + else: + df = self._parquet_reader_func(str(full_contacts_path)) return Contacts(df, **matched_parameters.dict()) @staticmethod diff --git a/spoc/models/dataframe_models.py b/spoc/models/dataframe_models.py index ea4d0f6..a3bb4f9 100644 --- a/spoc/models/dataframe_models.py +++ b/spoc/models/dataframe_models.py @@ -215,8 +215,8 @@ def validate_header(self, data_frame: DataFrame) -> None: Args: data_frame (DataFrame): The DataFrame to validate. """ - for column in data_frame.columns: - if column not in self._schema.columns: + for column in self._schema.columns: + if column not in data_frame.columns: raise pa.errors.SchemaError( self._schema, data_frame, "Header is invalid!" ) diff --git a/tests/io_tests/test_io_contacts.py b/tests/io_tests/test_io_contacts.py index 567f1ee..ad310a0 100644 --- a/tests/io_tests/test_io_contacts.py +++ b/tests/io_tests/test_io_contacts.py @@ -74,6 +74,57 @@ def example_contacts_w_metadata(unlabelled_contacts_2d, labelled_binary_contacts shutil.rmtree("tmp") +@pytest.fixture +def example_contacts_w_metadata_dask_parquet_structure( + unlabelled_contacts_2d, labelled_binary_contacts_2d +): + # setup + _create_tmp_dir() + # create contacts directory + contacts_dir = "tmp/contacts_test.parquet" + os.mkdir(contacts_dir) + expected_parameters = [ + ContactsParameters(number_fragments=2), + ContactsParameters(number_fragments=2, metadata_combi=["A", "B"]), + ContactsParameters( + number_fragments=2, metadata_combi=["A", "B"], label_sorted=True + ), + ContactsParameters( + number_fragments=2, + metadata_combi=["A", "B"], + label_sorted=True, + symmetry_flipped=True, + ), + ] + paths = [ + Path("tmp/contacts_test.parquet/test1.parquet"), + Path("tmp/contacts_test.parquet/test2.parquet"), + Path("tmp/contacts_test.parquet/test3.parquet"), + Path("tmp/contacts_test.parquet/test4.parquet"), + ] + dataframes = [ + dd.from_pandas(unlabelled_contacts_2d, npartitions=2), + dd.from_pandas(labelled_binary_contacts_2d, npartitions=2), + dd.from_pandas(labelled_binary_contacts_2d, npartitions=2), + dd.from_pandas(labelled_binary_contacts_2d, npartitions=2), + ] + # create pixels files + for path, df in zip(paths, dataframes): + df.to_parquet(path) + # create metadata json file + metadata = { + "test1.parquet": expected_parameters[0].dict(), + "test2.parquet": expected_parameters[1].dict(), + "test3.parquet": expected_parameters[2].dict(), + "test4.parquet": expected_parameters[3].dict(), + } + with open(contacts_dir + "/metadata.json", "w") as f: + json.dump(metadata, f) + yield contacts_dir, expected_parameters, paths, dataframes + # teardown + shutil.rmtree("tmp") + + def test_read_contacts_metadata_json(example_contacts_w_metadata): """Test reading pixels metadata json file""" contacts_dir, expected_parameters, _, _ = example_contacts_w_metadata @@ -115,6 +166,33 @@ def test_read_contacts_as_dask_df(example_contacts_w_metadata): assert contacts.data.compute().equals(df) +def test_read_contacts_as_duckdb_connection(example_contacts_w_metadata): + """Test reading contacts as pandas dataframe""" + contacts_dir, expected_parameters, paths, dataframes = example_contacts_w_metadata + # read metadata + for path, expected, df in zip(paths, expected_parameters, dataframes): + contacts = FileManager(DataMode.DUCKDB).load_contacts(contacts_dir, expected) + assert contacts.get_global_parameters() == expected + assert contacts.data.df().equals(df) + + +def test_read_contacts_as_duckdb_connection_from_dask_parquet( + example_contacts_w_metadata_dask_parquet_structure, +): + """Test reading contacts as pandas dataframe""" + ( + contacts_dir, + expected_parameters, + paths, + dataframes, + ) = example_contacts_w_metadata_dask_parquet_structure + # read metadata + for path, expected, df in zip(paths, expected_parameters, dataframes): + contacts = FileManager(DataMode.DUCKDB).load_contacts(contacts_dir, expected) + assert contacts.get_global_parameters() == expected + assert contacts.data.df()[df.columns].equals(df.compute()) + + @pytest.mark.parametrize( "df, params", [ From ee15b47387bda2ef4bd82bd3913db8a5fb9f37c1 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 19:22:12 +0200 Subject: [PATCH 14/17] fixed bug with duckdb writing --- spoc/io.py | 4 ++-- tests/io_tests/test_io_contacts.py | 37 ++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/spoc/io.py b/spoc/io.py index 1205d70..e0b9ebb 100644 --- a/spoc/io.py +++ b/spoc/io.py @@ -346,7 +346,7 @@ def write_pixels(self, path: str, pixels: Pixels) -> None: raise ValueError( "Writing pixels only suppported for pixels hodling dataframes!" ) - pixels.data.to_parquet(write_path, row_group_size=1024 * 1024) + pixels.data.to_parquet(str(write_path)) # write metadata current_metadata[write_path.name] = pixels.get_global_parameters().dict() with open(metadata_path, "w", encoding="UTF-8") as f: @@ -369,7 +369,7 @@ def write_contacts(self, path: str, contacts: Contacts) -> None: raise ValueError( "Writing contacts only suppported for contacts hodling dataframes!" ) - contacts.data.to_parquet(write_path, row_group_size=1024 * 1024) + contacts.data.to_parquet(str(write_path)) # write metadata current_metadata[write_path.name] = contacts.get_global_parameters().dict() with open(metadata_path, "w") as f: diff --git a/tests/io_tests/test_io_contacts.py b/tests/io_tests/test_io_contacts.py index ad310a0..b51b634 100644 --- a/tests/io_tests/test_io_contacts.py +++ b/tests/io_tests/test_io_contacts.py @@ -7,6 +7,7 @@ from pathlib import Path import dask.dataframe as dd +import duckdb import pytest from spoc.contacts import Contacts @@ -264,6 +265,42 @@ def test_write_dask_contacts_to_new_file(df, params, request): assert contacts.data.compute().equals(contacts_read.data) +@pytest.mark.parametrize( + "df, params", + [ + ("unlabelled_contacts_2d", ContactsParameters(number_fragments=2)), + ( + "labelled_binary_contacts_2d", + ContactsParameters( + number_fragments=2, metadata_combi=["A", "B"], symmetry_flipped=True + ), + ), + ( + "labelled_binary_contacts_2d", + ContactsParameters( + number_fragments=2, metadata_combi=["A", "B"], label_sorted=True + ), + ), + ], +) +def test_write_duckdb_contacts_to_new_file(df, params, request): + df = request.getfixturevalue(df) + duckdb_df = duckdb.from_df(df) + contacts = Contacts(duckdb_df, **params.dict()) + with tempfile.TemporaryDirectory() as tmpdirname: + file_name = tmpdirname + "/" + "test.parquet" + FileManager(data_mode=DataMode.DUCKDB).write_contacts(file_name, contacts) + # check metadata + metadata = FileManager().list_contacts(file_name) + assert len(metadata) == 1 + assert metadata[0] == contacts.get_global_parameters() + # read contacts + contacts_read = FileManager().load_contacts(file_name, metadata[0]) + # check whether parameters are equal + assert contacts.get_global_parameters() == contacts_read.get_global_parameters() + assert contacts.data.df().equals(contacts_read.data) + + @pytest.mark.parametrize( "df1,df2,params", [ From aceeda0da27fbf5678204aef3d0279e74ff06dfd Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 19:53:33 +0200 Subject: [PATCH 15/17] added duckdb parameters --- spoc/io.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/spoc/io.py b/spoc/io.py index e0b9ebb..596d159 100644 --- a/spoc/io.py +++ b/spoc/io.py @@ -34,7 +34,7 @@ class FileManager: data_mode (DataMode, optional): Data mode. Defaults to DataMode.PANDAS. """ - def __init__(self, data_mode: DataMode = DataMode.PANDAS) -> None: + def __init__(self, data_mode: DataMode = DataMode.PANDAS, **kwargs) -> None: if data_mode == DataMode.DUCKDB: self._parquet_reader_func = partial( duckdb.read_parquet, connection=DUCKDB_CONNECTION @@ -47,6 +47,15 @@ def __init__(self, data_mode: DataMode = DataMode.PANDAS) -> None: raise ValueError(f"Data mode {data_mode} not supported!") # store data mode self._data_mode = data_mode + # set duckdb parameters if they are there + if "duckdb_max_memory" in kwargs and data_mode == DataMode.DUCKDB: + DUCKDB_CONNECTION.execute( + f"PRAGMA memory_limit = {kwargs['duckdb_max_memory']}" + ) + if "duckdb_max_threads" in kwargs and data_mode == DataMode.DUCKDB: + DUCKDB_CONNECTION.execute( + f"PRAGMA threads = {kwargs['duckdb_max_threads']}" + ) @staticmethod def write_label_library(path: str, data: Dict[str, bool]) -> None: From eeccd894f4e6ffaba163122fd16bd7f9cdc8b702 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Mon, 1 Apr 2024 19:58:40 +0200 Subject: [PATCH 16/17] fixed typo --- spoc/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spoc/io.py b/spoc/io.py index 596d159..8a333c2 100644 --- a/spoc/io.py +++ b/spoc/io.py @@ -50,7 +50,7 @@ def __init__(self, data_mode: DataMode = DataMode.PANDAS, **kwargs) -> None: # set duckdb parameters if they are there if "duckdb_max_memory" in kwargs and data_mode == DataMode.DUCKDB: DUCKDB_CONNECTION.execute( - f"PRAGMA memory_limit = {kwargs['duckdb_max_memory']}" + f"PRAGMA memory_limit = '{kwargs['duckdb_max_memory']}'" ) if "duckdb_max_threads" in kwargs and data_mode == DataMode.DUCKDB: DUCKDB_CONNECTION.execute( From a80f7c1c1a4a3c39e1eb857ec5f972d51a4c0e95 Mon Sep 17 00:00:00 2001 From: Michael Mitter Date: Tue, 2 Apr 2024 20:26:49 +0200 Subject: [PATCH 17/17] passing kwargs to write parquet --- spoc/io.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spoc/io.py b/spoc/io.py index 8a333c2..095aeeb 100644 --- a/spoc/io.py +++ b/spoc/io.py @@ -329,7 +329,7 @@ def _get_object_hash_path(path: str, data_object: Union[Pixels, Contacts]) -> st ) return md5(hash_string.encode(encoding="utf-8")).hexdigest() + ".parquet" - def write_pixels(self, path: str, pixels: Pixels) -> None: + def write_pixels(self, path: str, pixels: Pixels, **kwargs) -> None: """Write pixels Args: @@ -355,13 +355,13 @@ def write_pixels(self, path: str, pixels: Pixels) -> None: raise ValueError( "Writing pixels only suppported for pixels hodling dataframes!" ) - pixels.data.to_parquet(str(write_path)) + pixels.data.to_parquet(str(write_path), **kwargs) # write metadata current_metadata[write_path.name] = pixels.get_global_parameters().dict() with open(metadata_path, "w", encoding="UTF-8") as f: json.dump(current_metadata, f) - def write_contacts(self, path: str, contacts: Contacts) -> None: + def write_contacts(self, path: str, contacts: Contacts, **kwargs) -> None: """Write contacts""" # check whether path exists metadata_path = Path(path) / "metadata.json" @@ -378,7 +378,7 @@ def write_contacts(self, path: str, contacts: Contacts) -> None: raise ValueError( "Writing contacts only suppported for contacts hodling dataframes!" ) - contacts.data.to_parquet(str(write_path)) + contacts.data.to_parquet(str(write_path), **kwargs) # write metadata current_metadata[write_path.name] = contacts.get_global_parameters().dict() with open(metadata_path, "w") as f: