Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
from copy import deepcopy
from pathlib import Path
from typing import Any
from typing import Any, Literal

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -170,6 +170,7 @@ def render_shapes(
method: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
shape: Literal["circle", "hex", "square"] | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -242,6 +243,9 @@ def render_shapes(
table_layer: str | None
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
:attr:`sdata.table.X` is used for coloring.
shape: Literal["circle", "hex", "square"] | None
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
specified, the shapes are converted to a circle/hexagon/square before rendering.

**kwargs : Any
Additional arguments for customization. This can include:
Expand Down Expand Up @@ -286,6 +290,7 @@ def render_shapes(
scale=scale,
table_name=table_name,
table_layer=table_layer,
shape=shape,
method=method,
ds_reduction=kwargs.get("datashader_reduction"),
)
Expand Down Expand Up @@ -318,6 +323,7 @@ def render_shapes(
transfunc=kwargs.get("transfunc"),
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
shape=param_values["shape"],
zorder=n_steps,
method=param_values["method"],
ds_reduction=param_values["ds_reduction"],
Expand Down
18 changes: 14 additions & 4 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from spatialdata_plot.pl.utils import (
_ax_show_and_transform,
_convert_alpha_to_datashader_range,
_convert_shapes,
_create_image_from_datashader_result,
_datashader_aggregate_with_function,
_datashader_map_aggregate_to_color,
Expand Down Expand Up @@ -163,6 +164,15 @@ def _render_shapes(
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)

shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
# convert shapes if necessary
if render_params.shape is not None:
current_type = shapes["geometry"].type
if not (render_params.shape == "circle" and (current_type == "Point").all()):
logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.")
max_extent = np.max(
[shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]]
)
shapes = _convert_shapes(shapes, render_params.shape, max_extent)

# Determine which method to use for rendering
method = render_params.method
Expand All @@ -186,17 +196,17 @@ def _render_shapes(
# Handle circles encoded as points with radius
if is_point.any():
scale = shapes[is_point]["radius"] * render_params.scale
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())

# apply transformations to the individual points
tm = trans.get_matrix()
transformed_element = sdata_filt.shapes[element].transform(
transformed_geometry = shapes["geometry"].transform(
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2]
)
transformed_element = ShapesModel.parse(
gpd.GeoDataFrame(
data=sdata_filt.shapes[element].drop("geometry", axis=1),
geometry=transformed_element,
data=shapes.drop("geometry", axis=1),
geometry=transformed_geometry,
)
)

Expand Down
1 change: 1 addition & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class ShapesRenderParams:
zorder: int = 0
table_name: str | None = None
table_layer: str | None = None
shape: Literal["circle", "hex", "square"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None


Expand Down
164 changes: 164 additions & 0 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
import os
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -51,6 +52,7 @@
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
from scanpy.plotting.palettes import default_20, default_28, default_102
from scipy.spatial import ConvexHull
from skimage.color import label2rgb
from skimage.morphology import erosion, square
from skimage.segmentation import find_boundaries
Expand Down Expand Up @@ -1818,6 +1820,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
if size < 0:
raise ValueError("Parameter 'size' must be a positive number.")

if element_type == "shapes" and (shape := param_dict.get("shape")) is not None:
if not isinstance(shape, str):
raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.")
if shape not in ["circle", "hex", "square"]:
raise ValueError(
f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']."
)

table_name = param_dict.get("table_name")
table_layer = param_dict.get("table_layer")
if table_name and not isinstance(param_dict["table_name"], str):
Expand Down Expand Up @@ -2030,6 +2040,7 @@ def _validate_shape_render_params(
scale: float | int,
table_name: str | None,
table_layer: str | None,
shape: Literal["circle", "hex", "square"] | None,
method: str | None,
ds_reduction: str | None,
) -> dict[str, dict[str, Any]]:
Expand All @@ -2049,6 +2060,7 @@ def _validate_shape_render_params(
"scale": scale,
"table_name": table_name,
"table_layer": table_layer,
"shape": shape,
"method": method,
"ds_reduction": ds_reduction,
}
Expand All @@ -2069,6 +2081,7 @@ def _validate_shape_render_params(
element_params[el]["norm"] = param_dict["norm"]
element_params[el]["scale"] = param_dict["scale"]
element_params[el]["table_layer"] = param_dict["table_layer"]
element_params[el]["shape"] = param_dict["shape"]

element_params[el]["color"] = param_dict["color"]

Expand Down Expand Up @@ -2487,6 +2500,39 @@ def _prepare_transformation(
return trans, trans_data


def _get_datashader_trans_matrix_of_single_element(
trans: Identity | Scale | Affine | MapAxis | Translation,
) -> ArrayLike:
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
tm: ArrayLike = trans.to_affine_matrix(("x", "y"), ("x", "y"))

if isinstance(trans, Identity):
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
if isinstance(trans, (Scale | Affine)):
# idea: "flip the y-axis", apply transformation, flip back
flip_and_transform: ArrayLike = flip_matrix @ tm @ flip_matrix
return flip_and_transform
if isinstance(trans, MapAxis):
# no flipping needed
return tm
# for a Translation, we need the transposed transformation matrix
tm_T = tm.T
assert isinstance(tm_T, np.ndarray)
return tm_T


def _get_transformation_matrix_for_datashader(
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
) -> ArrayLike:
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
if isinstance(trans, SDSequence):
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
for x in trans.transformations:
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
return tm
return _get_datashader_trans_matrix_of_single_element(trans)


def _datashader_map_aggregate_to_color(
agg: DataArray,
cmap: str | list[str] | ListedColormap,
Expand Down Expand Up @@ -2588,6 +2634,124 @@ def _hex_no_alpha(hex: str) -> str:
raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'")


def _convert_shapes(
shapes: GeoDataFrame, target_shape: str, max_extent: float, warn_above_extent_fraction: float = 0.5
) -> GeoDataFrame:
"""Convert the shapes stored in a GeoDataFrame (geometry column) to the target_shape."""
# NOTE: possible follow-up: when converting equally sized shapes to hex, automatically scale resulting hexagons
# so that they are perfectly adjacent to each other

if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0:
warn_above_extent_fraction = 0.5 # set to default if the value is outside [0, 1]
warn_shape_size = False

# define individual conversion methods
def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
vertices = [
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
for angle in range(0, 360, 60)
]
return shapely.Polygon(vertices), None

def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
vertices = [
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
for angle in range(45, 360, 90)
]
return shapely.Polygon(vertices), None

def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]:
return center, radius

def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
center, radius = _polygon_to_circle(polygon)
return _circle_to_hexagon(center, radius)

def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
center, radius = _polygon_to_circle(polygon)
return _circle_to_square(center, radius)

def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]:
coords = np.array(polygon.exterior.coords)
circle_points = coords[ConvexHull(coords).vertices]
center = np.mean(circle_points, axis=0)
radius = max(float(np.linalg.norm(p - center)) for p in circle_points)
assert isinstance(radius, float) # shut up mypy
if 2 * radius > max_extent * warn_above_extent_fraction:
nonlocal warn_shape_size
warn_shape_size = True
return shapely.Point(center), radius

def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
center, radius = _multipolygon_to_circle(multipolygon)
return _circle_to_hexagon(center, radius)

def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
center, radius = _multipolygon_to_circle(multipolygon)
return _circle_to_square(center, radius)

def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]:
coords = []
for polygon in multipolygon.geoms:
coords.extend(polygon.exterior.coords)
points = np.array(coords)
circle_points = points[ConvexHull(points).vertices]
center = np.mean(circle_points, axis=0)
radius = max(float(np.linalg.norm(p - center)) for p in circle_points)
assert isinstance(radius, float) # shut up mypy
if 2 * radius > max_extent * warn_above_extent_fraction:
nonlocal warn_shape_size
warn_shape_size = True
return shapely.Point(center), radius

# define dict with all conversion methods
if target_shape == "circle":
conversion_methods = {
"Point": _circle_to_circle,
"Polygon": _polygon_to_circle,
"Multipolygon": _multipolygon_to_circle,
}
pass
elif target_shape == "hex":
conversion_methods = {
"Point": _circle_to_hexagon,
"Polygon": _polygon_to_hexagon,
"Multipolygon": _multipolygon_to_hexagon,
}
else:
conversion_methods = {
"Point": _circle_to_square,
"Polygon": _polygon_to_square,
"Multipolygon": _multipolygon_to_square,
}

# convert every shape
for i in range(shapes.shape[0]):
if shapes["geometry"][i].type == "Point":
converted, radius = conversion_methods["Point"](shapes["geometry"][i], shapes["radius"][i]) # type: ignore
elif shapes["geometry"][i].type == "Polygon":
converted, radius = conversion_methods["Polygon"](shapes["geometry"][i]) # type: ignore
elif shapes["geometry"][i].type == "MultiPolygon":
converted, radius = conversion_methods["Multipolygon"](shapes["geometry"][i]) # type: ignore
else:
error_type = shapes["geometry"][i].type
raise ValueError(f"Converting shape {error_type} to {target_shape} is not supported.")
shapes["geometry"][i] = converted
if radius is not None:
if "radius" not in shapes.columns:
shapes["radius"] = np.nan
shapes["radius"][i] = radius

if warn_shape_size:
logger.info(
f"When converting the shapes, the size of at least one target shape extends "
f"{warn_above_extent_fraction * 100}% of the original total bound of the shapes. The conversion"
" might not give satisfying results in this scenario."
)

return shapes


def _convert_alpha_to_datashader_range(alpha: float) -> float:
"""Convert alpha from the range [0, 1] to the range [0, 255] used in datashader."""
# prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_datashader_can_transform_circles.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 47 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,53 @@ def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialDat

sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show()

def test_plot_can_render_circles_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex").pl.show()

def test_plot_can_render_circles_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square").pl.show()

def test_plot_can_render_polygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex").pl.show()

def test_plot_can_render_polygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square").pl.show()

def test_plot_can_render_polygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle").pl.show()

def test_plot_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex").pl.show()

def test_plot_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square").pl.show()

def test_plot_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle").pl.show()

def test_plot_datashader_can_render_circles_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex", method="datashader").pl.show()

def test_plot_datashader_can_render_circles_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square", method="datashader").pl.show()

def test_plot_datashader_can_render_polygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex", method="datashader").pl.show()

def test_plot_datashader_can_render_polygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square", method="datashader").pl.show()

def test_plot_datashader_can_render_polygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle", method="datashader").pl.show()

def test_plot_datashader_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex", method="datashader").pl.show()

def test_plot_datashader_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square", method="datashader").pl.show()

def test_plot_datashader_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle", method="datashader").pl.show()
def test_plot_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes("blobs_circles", outline_width=(10.0, 5.0)).pl.show()

Expand Down