diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 025eb9ed..a948077f 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -5,7 +5,7 @@ from collections import OrderedDict from copy import deepcopy from pathlib import Path -from typing import Any +from typing import Any, Literal import matplotlib.pyplot as plt import numpy as np @@ -170,6 +170,7 @@ def render_shapes( method: str | None = None, table_name: str | None = None, table_layer: str | None = None, + shape: Literal["circle", "hex", "square"] | None = None, **kwargs: Any, ) -> sd.SpatialData: """ @@ -242,6 +243,9 @@ def render_shapes( table_layer: str | None Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in :attr:`sdata.table.X` is used for coloring. + shape: Literal["circle", "hex", "square"] | None + If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is + specified, the shapes are converted to a circle/hexagon/square before rendering. **kwargs : Any Additional arguments for customization. This can include: @@ -286,6 +290,7 @@ def render_shapes( scale=scale, table_name=table_name, table_layer=table_layer, + shape=shape, method=method, ds_reduction=kwargs.get("datashader_reduction"), ) @@ -318,6 +323,7 @@ def render_shapes( transfunc=kwargs.get("transfunc"), table_name=param_values["table_name"], table_layer=param_values["table_layer"], + shape=param_values["shape"], zorder=n_steps, method=param_values["method"], ds_reduction=param_values["ds_reduction"], diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index f74259a4..52c8b773 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -37,6 +37,7 @@ from spatialdata_plot.pl.utils import ( _ax_show_and_transform, _convert_alpha_to_datashader_range, + _convert_shapes, _create_image_from_datashader_result, _datashader_aggregate_with_function, _datashader_map_aggregate_to_color, @@ -163,6 +164,15 @@ def _render_shapes( trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system) shapes = gpd.GeoDataFrame(shapes, geometry="geometry") + # convert shapes if necessary + if render_params.shape is not None: + current_type = shapes["geometry"].type + if not (render_params.shape == "circle" and (current_type == "Point").all()): + logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.") + max_extent = np.max( + [shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]] + ) + shapes = _convert_shapes(shapes, render_params.shape, max_extent) # Determine which method to use for rendering method = render_params.method @@ -186,17 +196,17 @@ def _render_shapes( # Handle circles encoded as points with radius if is_point.any(): scale = shapes[is_point]["radius"] * render_params.scale - sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) + shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) # apply transformations to the individual points tm = trans.get_matrix() - transformed_element = sdata_filt.shapes[element].transform( + transformed_geometry = shapes["geometry"].transform( lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2] ) transformed_element = ShapesModel.parse( gpd.GeoDataFrame( - data=sdata_filt.shapes[element].drop("geometry", axis=1), - geometry=transformed_element, + data=shapes.drop("geometry", axis=1), + geometry=transformed_geometry, ) ) diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 5e3af820..15812c0c 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -211,6 +211,7 @@ class ShapesRenderParams: zorder: int = 0 table_name: str | None = None table_layer: str | None = None + shape: Literal["circle", "hex", "square"] | None = None ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 9df1f3d0..ff1f8daa 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import os import warnings from collections import OrderedDict @@ -51,6 +52,7 @@ from scanpy.plotting._tools.scatterplots import _add_categorical_legend from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation from scanpy.plotting.palettes import default_20, default_28, default_102 +from scipy.spatial import ConvexHull from skimage.color import label2rgb from skimage.morphology import erosion, square from skimage.segmentation import find_boundaries @@ -1818,6 +1820,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if size < 0: raise ValueError("Parameter 'size' must be a positive number.") + if element_type == "shapes" and (shape := param_dict.get("shape")) is not None: + if not isinstance(shape, str): + raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.") + if shape not in ["circle", "hex", "square"]: + raise ValueError( + f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']." + ) + table_name = param_dict.get("table_name") table_layer = param_dict.get("table_layer") if table_name and not isinstance(param_dict["table_name"], str): @@ -2030,6 +2040,7 @@ def _validate_shape_render_params( scale: float | int, table_name: str | None, table_layer: str | None, + shape: Literal["circle", "hex", "square"] | None, method: str | None, ds_reduction: str | None, ) -> dict[str, dict[str, Any]]: @@ -2049,6 +2060,7 @@ def _validate_shape_render_params( "scale": scale, "table_name": table_name, "table_layer": table_layer, + "shape": shape, "method": method, "ds_reduction": ds_reduction, } @@ -2069,6 +2081,7 @@ def _validate_shape_render_params( element_params[el]["norm"] = param_dict["norm"] element_params[el]["scale"] = param_dict["scale"] element_params[el]["table_layer"] = param_dict["table_layer"] + element_params[el]["shape"] = param_dict["shape"] element_params[el]["color"] = param_dict["color"] @@ -2487,6 +2500,39 @@ def _prepare_transformation( return trans, trans_data +def _get_datashader_trans_matrix_of_single_element( + trans: Identity | Scale | Affine | MapAxis | Translation, +) -> ArrayLike: + flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) + tm: ArrayLike = trans.to_affine_matrix(("x", "y"), ("x", "y")) + + if isinstance(trans, Identity): + return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + if isinstance(trans, (Scale | Affine)): + # idea: "flip the y-axis", apply transformation, flip back + flip_and_transform: ArrayLike = flip_matrix @ tm @ flip_matrix + return flip_and_transform + if isinstance(trans, MapAxis): + # no flipping needed + return tm + # for a Translation, we need the transposed transformation matrix + tm_T = tm.T + assert isinstance(tm_T, np.ndarray) + return tm_T + + +def _get_transformation_matrix_for_datashader( + trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence, +) -> ArrayLike: + """Get the affine matrix needed to transform shapes for rendering with datashader.""" + if isinstance(trans, SDSequence): + tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + for x in trans.transformations: + tm = tm @ _get_datashader_trans_matrix_of_single_element(x) + return tm + return _get_datashader_trans_matrix_of_single_element(trans) + + def _datashader_map_aggregate_to_color( agg: DataArray, cmap: str | list[str] | ListedColormap, @@ -2588,6 +2634,124 @@ def _hex_no_alpha(hex: str) -> str: raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'") +def _convert_shapes( + shapes: GeoDataFrame, target_shape: str, max_extent: float, warn_above_extent_fraction: float = 0.5 +) -> GeoDataFrame: + """Convert the shapes stored in a GeoDataFrame (geometry column) to the target_shape.""" + # NOTE: possible follow-up: when converting equally sized shapes to hex, automatically scale resulting hexagons + # so that they are perfectly adjacent to each other + + if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0: + warn_above_extent_fraction = 0.5 # set to default if the value is outside [0, 1] + warn_shape_size = False + + # define individual conversion methods + def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + vertices = [ + (center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle))) + for angle in range(0, 360, 60) + ] + return shapely.Polygon(vertices), None + + def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + vertices = [ + (center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle))) + for angle in range(45, 360, 90) + ] + return shapely.Polygon(vertices), None + + def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]: + return center, radius + + def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + center, radius = _polygon_to_circle(polygon) + return _circle_to_hexagon(center, radius) + + def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + center, radius = _polygon_to_circle(polygon) + return _circle_to_square(center, radius) + + def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]: + coords = np.array(polygon.exterior.coords) + circle_points = coords[ConvexHull(coords).vertices] + center = np.mean(circle_points, axis=0) + radius = max(float(np.linalg.norm(p - center)) for p in circle_points) + assert isinstance(radius, float) # shut up mypy + if 2 * radius > max_extent * warn_above_extent_fraction: + nonlocal warn_shape_size + warn_shape_size = True + return shapely.Point(center), radius + + def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + center, radius = _multipolygon_to_circle(multipolygon) + return _circle_to_hexagon(center, radius) + + def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + center, radius = _multipolygon_to_circle(multipolygon) + return _circle_to_square(center, radius) + + def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]: + coords = [] + for polygon in multipolygon.geoms: + coords.extend(polygon.exterior.coords) + points = np.array(coords) + circle_points = points[ConvexHull(points).vertices] + center = np.mean(circle_points, axis=0) + radius = max(float(np.linalg.norm(p - center)) for p in circle_points) + assert isinstance(radius, float) # shut up mypy + if 2 * radius > max_extent * warn_above_extent_fraction: + nonlocal warn_shape_size + warn_shape_size = True + return shapely.Point(center), radius + + # define dict with all conversion methods + if target_shape == "circle": + conversion_methods = { + "Point": _circle_to_circle, + "Polygon": _polygon_to_circle, + "Multipolygon": _multipolygon_to_circle, + } + pass + elif target_shape == "hex": + conversion_methods = { + "Point": _circle_to_hexagon, + "Polygon": _polygon_to_hexagon, + "Multipolygon": _multipolygon_to_hexagon, + } + else: + conversion_methods = { + "Point": _circle_to_square, + "Polygon": _polygon_to_square, + "Multipolygon": _multipolygon_to_square, + } + + # convert every shape + for i in range(shapes.shape[0]): + if shapes["geometry"][i].type == "Point": + converted, radius = conversion_methods["Point"](shapes["geometry"][i], shapes["radius"][i]) # type: ignore + elif shapes["geometry"][i].type == "Polygon": + converted, radius = conversion_methods["Polygon"](shapes["geometry"][i]) # type: ignore + elif shapes["geometry"][i].type == "MultiPolygon": + converted, radius = conversion_methods["Multipolygon"](shapes["geometry"][i]) # type: ignore + else: + error_type = shapes["geometry"][i].type + raise ValueError(f"Converting shape {error_type} to {target_shape} is not supported.") + shapes["geometry"][i] = converted + if radius is not None: + if "radius" not in shapes.columns: + shapes["radius"] = np.nan + shapes["radius"][i] = radius + + if warn_shape_size: + logger.info( + f"When converting the shapes, the size of at least one target shape extends " + f"{warn_above_extent_fraction * 100}% of the original total bound of the shapes. The conversion" + " might not give satisfying results in this scenario." + ) + + return shapes + + def _convert_alpha_to_datashader_range(alpha: float) -> float: """Convert alpha from the range [0, 1] to the range [0, 255] used in datashader.""" # prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes diff --git a/tests/_images/Shapes_can_render_circles_to_hex.png b/tests/_images/Shapes_can_render_circles_to_hex.png new file mode 100644 index 00000000..026fdd1e Binary files /dev/null and b/tests/_images/Shapes_can_render_circles_to_hex.png differ diff --git a/tests/_images/Shapes_can_render_circles_to_square.png b/tests/_images/Shapes_can_render_circles_to_square.png new file mode 100644 index 00000000..13003c8a Binary files /dev/null and b/tests/_images/Shapes_can_render_circles_to_square.png differ diff --git a/tests/_images/Shapes_can_render_multipolygons_to_circle.png b/tests/_images/Shapes_can_render_multipolygons_to_circle.png new file mode 100644 index 00000000..fe5a17e7 Binary files /dev/null and b/tests/_images/Shapes_can_render_multipolygons_to_circle.png differ diff --git a/tests/_images/Shapes_can_render_multipolygons_to_hex.png b/tests/_images/Shapes_can_render_multipolygons_to_hex.png new file mode 100644 index 00000000..e5ac72dc Binary files /dev/null and b/tests/_images/Shapes_can_render_multipolygons_to_hex.png differ diff --git a/tests/_images/Shapes_can_render_multipolygons_to_square.png b/tests/_images/Shapes_can_render_multipolygons_to_square.png new file mode 100644 index 00000000..0646e548 Binary files /dev/null and b/tests/_images/Shapes_can_render_multipolygons_to_square.png differ diff --git a/tests/_images/Shapes_can_render_polygons_to_circle.png b/tests/_images/Shapes_can_render_polygons_to_circle.png new file mode 100644 index 00000000..fc3e3906 Binary files /dev/null and b/tests/_images/Shapes_can_render_polygons_to_circle.png differ diff --git a/tests/_images/Shapes_can_render_polygons_to_hex.png b/tests/_images/Shapes_can_render_polygons_to_hex.png new file mode 100644 index 00000000..45d3be26 Binary files /dev/null and b/tests/_images/Shapes_can_render_polygons_to_hex.png differ diff --git a/tests/_images/Shapes_can_render_polygons_to_square.png b/tests/_images/Shapes_can_render_polygons_to_square.png new file mode 100644 index 00000000..02bf5e03 Binary files /dev/null and b/tests/_images/Shapes_can_render_polygons_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_render_circles_to_hex.png b/tests/_images/Shapes_datashader_can_render_circles_to_hex.png new file mode 100644 index 00000000..d6aebef8 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_circles_to_hex.png differ diff --git a/tests/_images/Shapes_datashader_can_render_circles_to_square.png b/tests/_images/Shapes_datashader_can_render_circles_to_square.png new file mode 100644 index 00000000..c5776d4a Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_circles_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_render_multipolygons_to_circle.png b/tests/_images/Shapes_datashader_can_render_multipolygons_to_circle.png new file mode 100644 index 00000000..ee543dc1 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_multipolygons_to_circle.png differ diff --git a/tests/_images/Shapes_datashader_can_render_multipolygons_to_hex.png b/tests/_images/Shapes_datashader_can_render_multipolygons_to_hex.png new file mode 100644 index 00000000..7028fc4c Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_multipolygons_to_hex.png differ diff --git a/tests/_images/Shapes_datashader_can_render_multipolygons_to_square.png b/tests/_images/Shapes_datashader_can_render_multipolygons_to_square.png new file mode 100644 index 00000000..90701900 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_multipolygons_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_render_polygons_to_circle.png b/tests/_images/Shapes_datashader_can_render_polygons_to_circle.png new file mode 100644 index 00000000..01f93369 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_polygons_to_circle.png differ diff --git a/tests/_images/Shapes_datashader_can_render_polygons_to_hex.png b/tests/_images/Shapes_datashader_can_render_polygons_to_hex.png new file mode 100644 index 00000000..8f460b6c Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_polygons_to_hex.png differ diff --git a/tests/_images/Shapes_datashader_can_render_polygons_to_square.png b/tests/_images/Shapes_datashader_can_render_polygons_to_square.png new file mode 100644 index 00000000..2ae09482 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_polygons_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_circles.png b/tests/_images/Shapes_datashader_can_transform_circles.png index 49659e0d..9efe6bd5 100644 Binary files a/tests/_images/Shapes_datashader_can_transform_circles.png and b/tests/_images/Shapes_datashader_can_transform_circles.png differ diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 27924a71..fb676386 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -569,6 +569,53 @@ def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialDat sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show() + def test_plot_can_render_circles_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex").pl.show() + + def test_plot_can_render_circles_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square").pl.show() + + def test_plot_can_render_polygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex").pl.show() + + def test_plot_can_render_polygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square").pl.show() + + def test_plot_can_render_polygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle").pl.show() + + def test_plot_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex").pl.show() + + def test_plot_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square").pl.show() + + def test_plot_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle").pl.show() + + def test_plot_datashader_can_render_circles_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex", method="datashader").pl.show() + + def test_plot_datashader_can_render_circles_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square", method="datashader").pl.show() + + def test_plot_datashader_can_render_polygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex", method="datashader").pl.show() + + def test_plot_datashader_can_render_polygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square", method="datashader").pl.show() + + def test_plot_datashader_can_render_polygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle", method="datashader").pl.show() + + def test_plot_datashader_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex", method="datashader").pl.show() + + def test_plot_datashader_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square", method="datashader").pl.show() + + def test_plot_datashader_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle", method="datashader").pl.show() def test_plot_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes("blobs_circles", outline_width=(10.0, 5.0)).pl.show()