diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py
index 11352047..0fbffefe 100644
--- a/docs/extensions/typed_returns.py
+++ b/docs/extensions/typed_returns.py
@@ -12,7 +12,7 @@
def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
for line in lines:
if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line):
- yield f'-{m["param"]} (:class:`~{m["type"]}`)'
+ yield f"-{m['param']} (:class:`~{m['type']}`)"
else:
yield line
diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py
index e9dd630d..2dd9019f 100644
--- a/src/spatialdata_plot/pl/basic.py
+++ b/src/spatialdata_plot/pl/basic.py
@@ -290,7 +290,7 @@ def render_shapes(
norm=norm,
na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
)
- sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams(
+ sdata.plotting_tree[f"{n_steps + 1}_render_shapes"] = ShapesRenderParams(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
@@ -433,7 +433,7 @@ def render_points(
norm=norm,
na_color=param_values["na_color"], # type: ignore[arg-type]
)
- sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams(
+ sdata.plotting_tree[f"{n_steps + 1}_render_points"] = PointsRenderParams(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
@@ -538,7 +538,6 @@ def render_images(
n_steps = len(sdata.plotting_tree.keys())
for element, param_values in params_dict.items():
-
cmap_params: list[CmapParams] | CmapParams
if isinstance(cmap, list):
cmap_params = [
@@ -557,7 +556,7 @@ def render_images(
na_color=param_values["na_color"],
**kwargs,
)
- sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
+ sdata.plotting_tree[f"{n_steps + 1}_render_images"] = ImageRenderParams(
element=element,
channel=param_values["channel"],
cmap_params=cmap_params,
@@ -683,7 +682,7 @@ def render_labels(
norm=norm,
na_color=param_values["na_color"], # type: ignore[arg-type]
)
- sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
+ sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
element=element,
color=param_values["color"],
groups=param_values["groups"],
diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py
index 5bafe7a8..d3330b11 100644
--- a/src/spatialdata_plot/pl/render.py
+++ b/src/spatialdata_plot/pl/render.py
@@ -37,6 +37,7 @@
_ax_show_and_transform,
_create_image_from_datashader_result,
_datashader_aggregate_with_function,
+ _datashader_map_aggregate_to_color,
_datshader_get_how_kw_for_spread,
_decorate_axs,
_get_collection_shape,
@@ -229,18 +230,20 @@ def _render_shapes(
line_width=render_params.outline_params.linewidth,
)
+ ds_span = None
if norm.vmin is not None or norm.vmax is not None:
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
- norm.clip = True # NOTE: mpl currently behaves like clip is always True
+ ds_span = [norm.vmin, norm.vmax]
if norm.vmin == norm.vmax:
- # data is mapped to 0
- agg = agg - agg
- else:
- agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
+ # edge case, value vmin is rendered as the middle of the cmap
+ ds_span = [0, 1]
if norm.clip:
- agg = np.maximum(agg, 0)
- agg = np.minimum(agg, 1)
+ agg = (agg - agg) + 0.5
+ else:
+ agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
+ agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
+ agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
color_key = (
[x[:-2] for x in color_vector.categories.values]
@@ -256,13 +259,12 @@ def _render_shapes(
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
ds_cmap = ds_cmap[:-2]
- ds_result = ds.tf.shade(
+ ds_result = _datashader_map_aggregate_to_color(
agg,
cmap=ds_cmap,
color_key=color_key,
min_alpha=np.min([254, render_params.fill_alpha * 255]),
- how="linear",
- )
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
elif aggregate_with_reduction is not None: # to shut up mypy
ds_cmap = render_params.cmap_params.cmap
# in case all elements have the same value X: we render them using cmap(0.0),
@@ -272,12 +274,13 @@ def _render_shapes(
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
- ds_result = ds.tf.shade(
+ ds_result = _datashader_map_aggregate_to_color(
agg,
cmap=ds_cmap,
- how="linear",
min_alpha=np.min([254, render_params.fill_alpha * 255]),
- )
+ span=ds_span,
+ clip=norm.clip,
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
# shade outlines if needed
outline_color = render_params.outline_params.outline_color
@@ -294,7 +297,7 @@ def _render_shapes(
cmap=outline_color,
min_alpha=np.min([254, render_params.outline_alpha * 255]),
how="linear",
- )
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_cax = _ax_show_and_transform(
@@ -322,8 +325,10 @@ def _render_shapes(
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
- vmin = norm.vmin
- vmax = norm.vmin + 1
+ # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
+ # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
+ vmin = norm.vmin - 0.5
+ vmax = norm.vmin + 0.5
cax = ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
cmap=render_params.cmap_params.cmap,
@@ -586,18 +591,21 @@ def _render_points(
else:
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())
+ ds_span = None
if norm.vmin is not None or norm.vmax is not None:
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax
- norm.clip = True # NOTE: mpl currently behaves like clip is always True
+ ds_span = [norm.vmin, norm.vmax]
if norm.vmin == norm.vmax:
- # data is mapped to 0
- agg = agg - agg
- else:
- agg = (agg - norm.vmin) / (norm.vmax - norm.vmin)
+ ds_span = [0, 1]
if norm.clip:
- agg = np.maximum(agg, 0)
- agg = np.minimum(agg, 1)
+ # all data is mapped to 0.5
+ agg = (agg - agg) + 0.5
+ else:
+ # values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2
+ agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1)
+ agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
+ agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
color_key = (
list(color_vector.categories.values)
@@ -615,13 +623,12 @@ def _render_points(
color_vector = np.asarray([x[:-2] for x in color_vector])
if color_by_categorical or col_for_color is None:
- ds_result = ds.tf.shade(
+ ds_result = _datashader_map_aggregate_to_color(
ds.tf.spread(agg, px=px),
cmap=color_vector[0],
color_key=color_key,
min_alpha=np.min([254, render_params.alpha * 255]),
- how="linear",
- )
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
else:
spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction)
agg = ds.tf.spread(agg, px=px, how=spread_how)
@@ -631,15 +638,17 @@ def _render_points(
# in case all elements have the same value X: we render them using cmap(0.0),
# using an artificial "span" of [X, X + 1] for the color bar
# else: all elements would get alpha=0 and the color bar would have a weird range
- if aggregate_with_reduction[0] == aggregate_with_reduction[1]:
+ if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]):
ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False)
aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1)
- ds_result = ds.tf.shade(
+ ds_result = _datashader_map_aggregate_to_color(
agg,
cmap=ds_cmap,
- how="linear",
- )
+ span=ds_span,
+ clip=norm.clip,
+ min_alpha=np.min([254, render_params.alpha * 255]),
+ ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_ax_show_and_transform(
@@ -656,8 +665,10 @@ def _render_points(
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
- vmin = norm.vmin
- vmax = norm.vmin + 1
+ # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
+ # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
+ vmin = norm.vmin - 0.5
+ vmax = norm.vmin + 0.5
cax = ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax),
cmap=render_params.cmap_params.cmap,
@@ -723,7 +734,6 @@ def _render_images(
legend_params: LegendParams,
rasterize: bool,
) -> None:
-
sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_tables=False,
@@ -781,9 +791,6 @@ def _render_images(
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()
- if render_params.cmap_params.norm: # type: ignore[attr-defined]
- layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
-
cmap = (
_get_linear_colormap(palette, "k")[0]
if isinstance(palette, list) and all(isinstance(p, str) for p in palette)
@@ -794,7 +801,10 @@ def _render_images(
cmap._init()
cmap._lut[:, -1] = render_params.alpha
- _ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder)
+ # norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
+ _ax_show_and_transform(
+ layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm
+ )
if legend_params.colorbar:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py
index 8b99d173..c11c6590 100644
--- a/src/spatialdata_plot/pl/utils.py
+++ b/src/spatialdata_plot/pl/utils.py
@@ -305,7 +305,6 @@ def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, fl
def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None:
-
centroid = _get_centroid_of_pathpatch(pathpatch)
vertices = pathpatch.get_path().vertices
scaled_vertices = np.array([centroid + (vertex - centroid) * scale_factor for vertex in vertices])
@@ -677,7 +676,7 @@ def _get_colors_for_categorical_obs(
palette = default_102
else:
palette = ["grey" for _ in range(len_cat)]
- logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
+ logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
else:
# raise error when user didn't provide the right number of colors in palette
if isinstance(palette, list) and len(palette) != len(categories):
@@ -1654,7 +1653,7 @@ def _ensure_table_and_layer_exist_in_sdata(
if table_layer in sdata.tables[tname].layers:
if found_table:
raise ValueError(
- "Trying to guess 'table_name' based on 'table_layer', " "but found multiple matches."
+ "Trying to guess 'table_name' based on 'table_layer', but found multiple matches."
)
found_table = True
@@ -1727,7 +1726,6 @@ def _validate_label_render_params(
element_params: dict[str, dict[str, Any]] = {}
for el in param_dict["element"]:
-
# ensure that the element exists in the SpatialData object
_ = param_dict["sdata"][el]
@@ -1788,7 +1786,6 @@ def _validate_points_render_params(
element_params: dict[str, dict[str, Any]] = {}
for el in param_dict["element"]:
-
# ensure that the element exists in the SpatialData object
_ = param_dict["sdata"][el]
@@ -1859,7 +1856,6 @@ def _validate_shape_render_params(
element_params: dict[str, dict[str, Any]] = {}
for el in param_dict["element"]:
-
# ensure that the element exists in the SpatialData object
_ = param_dict["sdata"][el]
@@ -1896,7 +1892,6 @@ def _validate_shape_render_params(
def _validate_col_for_column_table(
sdata: SpatialData, element_name: str, col_for_color: str | None, table_name: str | None, labels: bool = False
) -> tuple[str | None, str | None]:
-
if not labels and col_for_color in sdata[element_name].columns:
table_name = None
elif table_name is not None:
@@ -2023,6 +2018,7 @@ def _ax_show_and_transform(
cmap: ListedColormap | LinearSegmentedColormap | None = None,
zorder: int = 0,
extent: list[float] | None = None,
+ norm: Normalize | None = None,
) -> matplotlib.image.AxesImage:
# default extent in mpl:
image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5]
@@ -2045,6 +2041,7 @@ def _ax_show_and_transform(
alpha=alpha,
zorder=zorder,
extent=tuple(image_extent),
+ norm=norm,
)
im.set_transform(trans_data)
else:
@@ -2053,6 +2050,7 @@ def _ax_show_and_transform(
cmap=cmap,
zorder=zorder,
extent=tuple(image_extent),
+ norm=norm,
)
im.set_transform(trans_data)
return im
@@ -2117,10 +2115,10 @@ def _get_extent_and_range_for_datashader_canvas(
def _create_image_from_datashader_result(
- ds_result: ds.transfer_functions.Image, factor: float, ax: Axes
+ ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]], factor: float, ax: Axes
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]:
# create SpatialImage from datashader output to get it back to original size
- rgba_image_data = ds_result.to_numpy().base
+ rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base
rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))
rgba_image = Image2DModel.parse(
rgba_image_data,
@@ -2266,3 +2264,51 @@ def _get_transformation_matrix_for_datashader(
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
return tm
return _get_datashader_trans_matrix_of_single_element(trans)
+
+
+def _datashader_map_aggregate_to_color(
+ agg: DataArray,
+ cmap: str | list[str] | ListedColormap,
+ color_key: None | list[str] = None,
+ min_alpha: float = 40,
+ span: None | list[float] = None,
+ clip: bool = True,
+) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
+ """ds.tf.shade() part, ensuring correct clipping behavior.
+
+ If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results.
+ This ensures the correct clipping behavior, because else datashader would always automatically clip.
+ """
+ if not clip and isinstance(cmap, Colormap) and span is not None:
+ # in case we use datashader together with a Normalize object where clip=False
+ # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
+ agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
+ img_in = ds.tf.shade(
+ agg_in,
+ cmap=cmap,
+ span=(span[0], span[1]),
+ how="linear",
+ color_key=color_key,
+ min_alpha=min_alpha,
+ )
+
+ agg_under = agg.where(agg < span[0])
+ img_under = ds.tf.shade(
+ agg_under, cmap=[to_hex(cmap.get_under())[:7]], min_alpha=min_alpha, color_key=color_key
+ )
+
+ agg_over = agg.where(agg > span[1])
+ img_over = ds.tf.shade(agg_over, cmap=[to_hex(cmap.get_over())[:7]], min_alpha=min_alpha, color_key=color_key)
+
+ # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0
+ stack = img_under.to_numpy().base
+ if stack is None:
+ stack = img_in.to_numpy().base
+ else:
+ stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0]
+ img_over = img_over.to_numpy().base
+ if img_over is not None:
+ stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
+ return stack
+
+ return ds.tf.shade(agg, cmap=cmap, color_key=color_key, min_alpha=min_alpha, span=span, how="linear")
diff --git a/tests/_images/Images_can_pass_normalize_clip_False.png b/tests/_images/Images_can_pass_normalize_clip_False.png
index b97a3587..98b05197 100644
Binary files a/tests/_images/Images_can_pass_normalize_clip_False.png and b/tests/_images/Images_can_pass_normalize_clip_False.png differ
diff --git a/tests/_images/Images_can_pass_normalize_clip_True.png b/tests/_images/Images_can_pass_normalize_clip_True.png
index faa35b3b..aa161fdf 100644
Binary files a/tests/_images/Images_can_pass_normalize_clip_True.png and b/tests/_images/Images_can_pass_normalize_clip_True.png differ
diff --git a/tests/_images/Labels_can_color_with_norm_and_clipping.png b/tests/_images/Labels_can_color_with_norm_and_clipping.png
new file mode 100644
index 00000000..ebd0302f
Binary files /dev/null and b/tests/_images/Labels_can_color_with_norm_and_clipping.png differ
diff --git a/tests/_images/Labels_can_color_with_norm_no_clipping.png b/tests/_images/Labels_can_color_with_norm_no_clipping.png
new file mode 100644
index 00000000..f3c6a545
Binary files /dev/null and b/tests/_images/Labels_can_color_with_norm_no_clipping.png differ
diff --git a/tests/_images/Points_can_use_norm_with_clip.png b/tests/_images/Points_can_use_norm_with_clip.png
new file mode 100644
index 00000000..0abb42c6
Binary files /dev/null and b/tests/_images/Points_can_use_norm_with_clip.png differ
diff --git a/tests/_images/Points_can_use_norm_without_clip.png b/tests/_images/Points_can_use_norm_without_clip.png
new file mode 100644
index 00000000..f20eac25
Binary files /dev/null and b/tests/_images/Points_can_use_norm_without_clip.png differ
diff --git a/tests/_images/Points_datashader_can_use_norm_with_clip.png b/tests/_images/Points_datashader_can_use_norm_with_clip.png
new file mode 100644
index 00000000..7bb613b0
Binary files /dev/null and b/tests/_images/Points_datashader_can_use_norm_with_clip.png differ
diff --git a/tests/_images/Points_datashader_can_use_norm_without_clip.png b/tests/_images/Points_datashader_can_use_norm_without_clip.png
new file mode 100644
index 00000000..450d2c54
Binary files /dev/null and b/tests/_images/Points_datashader_can_use_norm_without_clip.png differ
diff --git a/tests/_images/Points_datashader_norm_vmin_eq_vmax_with_clip.png b/tests/_images/Points_datashader_norm_vmin_eq_vmax_with_clip.png
new file mode 100644
index 00000000..2d7ccf8a
Binary files /dev/null and b/tests/_images/Points_datashader_norm_vmin_eq_vmax_with_clip.png differ
diff --git a/tests/_images/Points_datashader_norm_vmin_eq_vmax_without_clip.png b/tests/_images/Points_datashader_norm_vmin_eq_vmax_without_clip.png
new file mode 100644
index 00000000..5aaf98b2
Binary files /dev/null and b/tests/_images/Points_datashader_norm_vmin_eq_vmax_without_clip.png differ
diff --git a/tests/_images/Shapes_can_color_with_norm_no_clipping.png b/tests/_images/Shapes_can_color_with_norm_no_clipping.png
new file mode 100644
index 00000000..01294587
Binary files /dev/null and b/tests/_images/Shapes_can_color_with_norm_no_clipping.png differ
diff --git a/tests/_images/Shapes_datashader_can_color_with_norm_and_clipping.png b/tests/_images/Shapes_datashader_can_color_with_norm_and_clipping.png
new file mode 100644
index 00000000..2e754d19
Binary files /dev/null and b/tests/_images/Shapes_datashader_can_color_with_norm_and_clipping.png differ
diff --git a/tests/_images/Shapes_datashader_can_color_with_norm_no_clipping.png b/tests/_images/Shapes_datashader_can_color_with_norm_no_clipping.png
new file mode 100644
index 00000000..64b053b6
Binary files /dev/null and b/tests/_images/Shapes_datashader_can_color_with_norm_no_clipping.png differ
diff --git a/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_with_clip.png b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_with_clip.png
new file mode 100644
index 00000000..54842f7b
Binary files /dev/null and b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_with_clip.png differ
diff --git a/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_without_clip.png b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_without_clip.png
new file mode 100644
index 00000000..230e3aeb
Binary files /dev/null and b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_without_clip.png differ
diff --git a/tests/conftest.py b/tests/conftest.py
index 27adff68..73ea410b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,6 +3,7 @@
from functools import wraps
from pathlib import Path
+import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
@@ -149,6 +150,23 @@ def test_sdata_multiple_images_diverging_dims():
return sdata
+@pytest.fixture
+def sdata_blobs_shapes_annotated() -> SpatialData:
+ """Get blobs sdata with continuous annotation of polygons."""
+ blob = blobs()
+ blob["table"].obs["region"] = "blobs_polygons"
+ blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
+ blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
+ return blob
+
+
+def _viridis_with_under_over() -> matplotlib.colors.ListedColormap:
+ cmap = matplotlib.colormaps["viridis"]
+ cmap.set_under("black")
+ cmap.set_over("grey")
+ return cmap
+
+
# Code below taken from spatialdata main repo
diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py
index c4e43977..5484ac4e 100644
--- a/tests/pl/test_render_images.py
+++ b/tests/pl/test_render_images.py
@@ -7,7 +7,7 @@
from spatialdata import SpatialData
import spatialdata_plot # noqa: F401
-from tests.conftest import DPI, PlotTester, PlotTesterMeta
+from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
RNG = np.random.default_rng(seed=42)
sc.pl.set_rcParams_defaults()
@@ -67,12 +67,16 @@ def test_plot_can_render_two_channels_str_from_multiscale_image(self, sdata_blob
sdata_blobs_str.pl.render_images(element="blobs_multiscale_image", channel=["c1", "c2"]).pl.show()
def test_plot_can_pass_normalize_clip_True(self, sdata_blobs: SpatialData):
- norm = Normalize(vmin=0, vmax=0.4, clip=True)
- sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show()
+ norm = Normalize(vmin=0.1, vmax=0.5, clip=True)
+ sdata_blobs.pl.render_images(
+ element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over()
+ ).pl.show()
def test_plot_can_pass_normalize_clip_False(self, sdata_blobs: SpatialData):
- norm = Normalize(vmin=0, vmax=0.4, clip=False)
- sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show()
+ norm = Normalize(vmin=0.1, vmax=0.5, clip=False)
+ sdata_blobs.pl.render_images(
+ element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over()
+ ).pl.show()
def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channel=1, palette="red").pl.show()
diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py
index d7697bd7..0196ae86 100644
--- a/tests/pl/test_render_labels.py
+++ b/tests/pl/test_render_labels.py
@@ -6,12 +6,13 @@
import pytest
import scanpy as sc
from anndata import AnnData
+from matplotlib.colors import Normalize
from spatial_image import to_spatial_image
from spatialdata import SpatialData, deepcopy, get_element_instances
from spatialdata.models import TableModel
import spatialdata_plot # noqa: F401
-from tests.conftest import DPI, PlotTester, PlotTesterMeta
+from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
RNG = np.random.default_rng(seed=42)
sc.pl.set_rcParams_defaults()
@@ -234,6 +235,19 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str
sdata_blobs["other_table"] = table
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
+ def test_plot_can_color_with_norm_and_clipping(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_labels(
+ "blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=True), cmap=_viridis_with_under_over()
+ ).pl.show()
+
+ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_labels(
+ "blobs_labels",
+ color="channel_0_sum",
+ norm=Normalize(400, 1000, clip=False),
+ cmap=_viridis_with_under_over(),
+ ).pl.show()
+
def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):
sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape)
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()
diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py
index 3ddef2bb..ad340f70 100644
--- a/tests/pl/test_render_points.py
+++ b/tests/pl/test_render_points.py
@@ -7,13 +7,14 @@
import pandas as pd
import scanpy as sc
from anndata import AnnData
+from matplotlib.colors import Normalize
from spatialdata import SpatialData, deepcopy
from spatialdata.models import PointsModel, TableModel
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation
from spatialdata.transformations._utils import _set_transformations
import spatialdata_plot # noqa: F401
-from tests.conftest import DPI, PlotTester, PlotTesterMeta
+from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
RNG = np.random.default_rng(seed=42)
sc.pl.set_rcParams_defaults()
@@ -225,6 +226,56 @@ def test_plot_datashader_can_transform_points(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points("blobs_points", method="datashader", color="black", size=5).pl.show()
+ def test_plot_can_use_norm_with_clip(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_points(
+ color="instance_id", size=40, norm=Normalize(3, 7, clip=True), cmap=_viridis_with_under_over()
+ ).pl.show()
+
+ def test_plot_can_use_norm_without_clip(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_points(
+ color="instance_id", size=40, norm=Normalize(3, 7, clip=False), cmap=_viridis_with_under_over()
+ ).pl.show()
+
+ def test_plot_datashader_can_use_norm_with_clip(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_points(
+ color="instance_id",
+ size=40,
+ norm=Normalize(3, 7, clip=True),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
+ def test_plot_datashader_can_use_norm_without_clip(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_points(
+ color="instance_id",
+ size=40,
+ norm=Normalize(3, 7, clip=False),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
+ def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_points(
+ color="instance_id",
+ size=40,
+ norm=Normalize(5, 5, clip=True),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
+ def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs: SpatialData):
+ sdata_blobs.pl.render_points(
+ color="instance_id",
+ size=40,
+ norm=Normalize(5, 5, clip=False),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
def test_plot_can_annotate_points_with_table_obs(self, sdata_blobs: SpatialData):
nrows, ncols = 200, 3
feature_matrix = RNG.random((nrows, ncols))
diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py
index 7abb0783..affeebd0 100644
--- a/tests/pl/test_render_shapes.py
+++ b/tests/pl/test_render_shapes.py
@@ -16,7 +16,7 @@
from spatialdata.transformations._utils import _set_transformations
import spatialdata_plot # noqa: F401
-from tests.conftest import DPI, PlotTester, PlotTesterMeta
+from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
RNG = np.random.default_rng(seed=42)
sc.pl.set_rcParams_defaults()
@@ -456,6 +456,51 @@ def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes("blobs_circles", color="instance_id").pl.show()
+ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
+ sdata_blobs_shapes_annotated.pl.render_shapes(
+ element="blobs_polygons", color="value", norm=Normalize(2, 4, clip=False), cmap=_viridis_with_under_over()
+ ).pl.show()
+
+ def test_plot_datashader_can_color_with_norm_and_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
+ sdata_blobs_shapes_annotated.pl.render_shapes(
+ element="blobs_polygons",
+ color="value",
+ norm=Normalize(2, 4, clip=True),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
+ def test_plot_datashader_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
+ sdata_blobs_shapes_annotated.pl.render_shapes(
+ element="blobs_polygons",
+ color="value",
+ norm=Normalize(2, 4, clip=False),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
+ def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs_shapes_annotated: SpatialData):
+ sdata_blobs_shapes_annotated.pl.render_shapes(
+ element="blobs_polygons",
+ color="value",
+ norm=Normalize(3, 3, clip=False),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
+ def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs_shapes_annotated: SpatialData):
+ sdata_blobs_shapes_annotated.pl.render_shapes(
+ element="blobs_polygons",
+ color="value",
+ norm=Normalize(3, 3, clip=True),
+ cmap=_viridis_with_under_over(),
+ method="datashader",
+ datashader_reduction="max",
+ ).pl.show()
+
def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialData):
nrows, ncols = 5, 3
feature_matrix = RNG.random((nrows, ncols))