diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py index 11352047..0fbffefe 100644 --- a/docs/extensions/typed_returns.py +++ b/docs/extensions/typed_returns.py @@ -12,7 +12,7 @@ def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: for line in lines: if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): - yield f'-{m["param"]} (:class:`~{m["type"]}`)' + yield f"-{m['param']} (:class:`~{m['type']}`)" else: yield line diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index e9dd630d..2dd9019f 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -290,7 +290,7 @@ def render_shapes( norm=norm, na_color=params_dict[element]["na_color"], # type: ignore[arg-type] ) - sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams( + sdata.plotting_tree[f"{n_steps + 1}_render_shapes"] = ShapesRenderParams( element=element, color=param_values["color"], col_for_color=param_values["col_for_color"], @@ -433,7 +433,7 @@ def render_points( norm=norm, na_color=param_values["na_color"], # type: ignore[arg-type] ) - sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams( + sdata.plotting_tree[f"{n_steps + 1}_render_points"] = PointsRenderParams( element=element, color=param_values["color"], col_for_color=param_values["col_for_color"], @@ -538,7 +538,6 @@ def render_images( n_steps = len(sdata.plotting_tree.keys()) for element, param_values in params_dict.items(): - cmap_params: list[CmapParams] | CmapParams if isinstance(cmap, list): cmap_params = [ @@ -557,7 +556,7 @@ def render_images( na_color=param_values["na_color"], **kwargs, ) - sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams( + sdata.plotting_tree[f"{n_steps + 1}_render_images"] = ImageRenderParams( element=element, channel=param_values["channel"], cmap_params=cmap_params, @@ -683,7 +682,7 @@ def render_labels( norm=norm, na_color=param_values["na_color"], # type: ignore[arg-type] ) - sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams( + sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams( element=element, color=param_values["color"], groups=param_values["groups"], diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 5bafe7a8..d3330b11 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -37,6 +37,7 @@ _ax_show_and_transform, _create_image_from_datashader_result, _datashader_aggregate_with_function, + _datashader_map_aggregate_to_color, _datshader_get_how_kw_for_spread, _decorate_axs, _get_collection_shape, @@ -229,18 +230,20 @@ def _render_shapes( line_width=render_params.outline_params.linewidth, ) + ds_span = None if norm.vmin is not None or norm.vmax is not None: norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - norm.clip = True # NOTE: mpl currently behaves like clip is always True + ds_span = [norm.vmin, norm.vmax] if norm.vmin == norm.vmax: - # data is mapped to 0 - agg = agg - agg - else: - agg = (agg - norm.vmin) / (norm.vmax - norm.vmin) + # edge case, value vmin is rendered as the middle of the cmap + ds_span = [0, 1] if norm.clip: - agg = np.maximum(agg, 0) - agg = np.minimum(agg, 1) + agg = (agg - agg) + 0.5 + else: + agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) + agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) + agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) color_key = ( [x[:-2] for x in color_vector.categories.values] @@ -256,13 +259,12 @@ def _render_shapes( if isinstance(ds_cmap, str) and ds_cmap[0] == "#": ds_cmap = ds_cmap[:-2] - ds_result = ds.tf.shade( + ds_result = _datashader_map_aggregate_to_color( agg, cmap=ds_cmap, color_key=color_key, min_alpha=np.min([254, render_params.fill_alpha * 255]), - how="linear", - ) + ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes elif aggregate_with_reduction is not None: # to shut up mypy ds_cmap = render_params.cmap_params.cmap # in case all elements have the same value X: we render them using cmap(0.0), @@ -272,12 +274,13 @@ def _render_shapes( ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1) - ds_result = ds.tf.shade( + ds_result = _datashader_map_aggregate_to_color( agg, cmap=ds_cmap, - how="linear", min_alpha=np.min([254, render_params.fill_alpha * 255]), - ) + span=ds_span, + clip=norm.clip, + ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes # shade outlines if needed outline_color = render_params.outline_params.outline_color @@ -294,7 +297,7 @@ def _render_shapes( cmap=outline_color, min_alpha=np.min([254, render_params.outline_alpha * 255]), how="linear", - ) + ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) _cax = _ax_show_and_transform( @@ -322,8 +325,10 @@ def _render_shapes( vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - vmin = norm.vmin - vmax = norm.vmin + 1 + # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and + # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) + vmin = norm.vmin - 0.5 + vmax = norm.vmin + 0.5 cax = ScalarMappable( norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), cmap=render_params.cmap_params.cmap, @@ -586,18 +591,21 @@ def _render_points( else: agg = cvs.points(transformed_element, "x", "y", agg=ds.count()) + ds_span = None if norm.vmin is not None or norm.vmax is not None: norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - norm.clip = True # NOTE: mpl currently behaves like clip is always True + ds_span = [norm.vmin, norm.vmax] if norm.vmin == norm.vmax: - # data is mapped to 0 - agg = agg - agg - else: - agg = (agg - norm.vmin) / (norm.vmax - norm.vmin) + ds_span = [0, 1] if norm.clip: - agg = np.maximum(agg, 0) - agg = np.minimum(agg, 1) + # all data is mapped to 0.5 + agg = (agg - agg) + 0.5 + else: + # values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2 + agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) + agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) + agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) color_key = ( list(color_vector.categories.values) @@ -615,13 +623,12 @@ def _render_points( color_vector = np.asarray([x[:-2] for x in color_vector]) if color_by_categorical or col_for_color is None: - ds_result = ds.tf.shade( + ds_result = _datashader_map_aggregate_to_color( ds.tf.spread(agg, px=px), cmap=color_vector[0], color_key=color_key, min_alpha=np.min([254, render_params.alpha * 255]), - how="linear", - ) + ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes else: spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction) agg = ds.tf.spread(agg, px=px, how=spread_how) @@ -631,15 +638,17 @@ def _render_points( # in case all elements have the same value X: we render them using cmap(0.0), # using an artificial "span" of [X, X + 1] for the color bar # else: all elements would get alpha=0 and the color bar would have a weird range - if aggregate_with_reduction[0] == aggregate_with_reduction[1]: + if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]): ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) aggregate_with_reduction = (aggregate_with_reduction[0], aggregate_with_reduction[0] + 1) - ds_result = ds.tf.shade( + ds_result = _datashader_map_aggregate_to_color( agg, cmap=ds_cmap, - how="linear", - ) + span=ds_span, + clip=norm.clip, + min_alpha=np.min([254, render_params.alpha * 255]), + ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) _ax_show_and_transform( @@ -656,8 +665,10 @@ def _render_points( vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - vmin = norm.vmin - vmax = norm.vmin + 1 + # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and + # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) + vmin = norm.vmin - 0.5 + vmax = norm.vmin + 0.5 cax = ScalarMappable( norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), cmap=render_params.cmap_params.cmap, @@ -723,7 +734,6 @@ def _render_images( legend_params: LegendParams, rasterize: bool, ) -> None: - sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, filter_tables=False, @@ -781,9 +791,6 @@ def _render_images( if n_channels == 1 and not isinstance(render_params.cmap_params, list): layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze() - if render_params.cmap_params.norm: # type: ignore[attr-defined] - layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined] - cmap = ( _get_linear_colormap(palette, "k")[0] if isinstance(palette, list) and all(isinstance(p, str) for p in palette) @@ -794,7 +801,10 @@ def _render_images( cmap._init() cmap._lut[:, -1] = render_params.alpha - _ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder) + # norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip. + _ax_show_and_transform( + layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, norm=render_params.cmap_params.norm + ) if legend_params.colorbar: sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 8b99d173..c11c6590 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -305,7 +305,6 @@ def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, fl def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None: - centroid = _get_centroid_of_pathpatch(pathpatch) vertices = pathpatch.get_path().vertices scaled_vertices = np.array([centroid + (vertex - centroid) * scale_factor for vertex in vertices]) @@ -677,7 +676,7 @@ def _get_colors_for_categorical_obs( palette = default_102 else: palette = ["grey" for _ in range(len_cat)] - logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.") + logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") else: # raise error when user didn't provide the right number of colors in palette if isinstance(palette, list) and len(palette) != len(categories): @@ -1654,7 +1653,7 @@ def _ensure_table_and_layer_exist_in_sdata( if table_layer in sdata.tables[tname].layers: if found_table: raise ValueError( - "Trying to guess 'table_name' based on 'table_layer', " "but found multiple matches." + "Trying to guess 'table_name' based on 'table_layer', but found multiple matches." ) found_table = True @@ -1727,7 +1726,6 @@ def _validate_label_render_params( element_params: dict[str, dict[str, Any]] = {} for el in param_dict["element"]: - # ensure that the element exists in the SpatialData object _ = param_dict["sdata"][el] @@ -1788,7 +1786,6 @@ def _validate_points_render_params( element_params: dict[str, dict[str, Any]] = {} for el in param_dict["element"]: - # ensure that the element exists in the SpatialData object _ = param_dict["sdata"][el] @@ -1859,7 +1856,6 @@ def _validate_shape_render_params( element_params: dict[str, dict[str, Any]] = {} for el in param_dict["element"]: - # ensure that the element exists in the SpatialData object _ = param_dict["sdata"][el] @@ -1896,7 +1892,6 @@ def _validate_shape_render_params( def _validate_col_for_column_table( sdata: SpatialData, element_name: str, col_for_color: str | None, table_name: str | None, labels: bool = False ) -> tuple[str | None, str | None]: - if not labels and col_for_color in sdata[element_name].columns: table_name = None elif table_name is not None: @@ -2023,6 +2018,7 @@ def _ax_show_and_transform( cmap: ListedColormap | LinearSegmentedColormap | None = None, zorder: int = 0, extent: list[float] | None = None, + norm: Normalize | None = None, ) -> matplotlib.image.AxesImage: # default extent in mpl: image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5] @@ -2045,6 +2041,7 @@ def _ax_show_and_transform( alpha=alpha, zorder=zorder, extent=tuple(image_extent), + norm=norm, ) im.set_transform(trans_data) else: @@ -2053,6 +2050,7 @@ def _ax_show_and_transform( cmap=cmap, zorder=zorder, extent=tuple(image_extent), + norm=norm, ) im.set_transform(trans_data) return im @@ -2117,10 +2115,10 @@ def _get_extent_and_range_for_datashader_canvas( def _create_image_from_datashader_result( - ds_result: ds.transfer_functions.Image, factor: float, ax: Axes + ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]], factor: float, ax: Axes ) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]: # create SpatialImage from datashader output to get it back to original size - rgba_image_data = ds_result.to_numpy().base + rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1)) rgba_image = Image2DModel.parse( rgba_image_data, @@ -2266,3 +2264,51 @@ def _get_transformation_matrix_for_datashader( tm = tm @ _get_datashader_trans_matrix_of_single_element(x) return tm return _get_datashader_trans_matrix_of_single_element(trans) + + +def _datashader_map_aggregate_to_color( + agg: DataArray, + cmap: str | list[str] | ListedColormap, + color_key: None | list[str] = None, + min_alpha: float = 40, + span: None | list[float] = None, + clip: bool = True, +) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]: + """ds.tf.shade() part, ensuring correct clipping behavior. + + If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results. + This ensures the correct clipping behavior, because else datashader would always automatically clip. + """ + if not clip and isinstance(cmap, Colormap) and span is not None: + # in case we use datashader together with a Normalize object where clip=False + # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372 + agg_in = agg.where((agg >= span[0]) & (agg <= span[1])) + img_in = ds.tf.shade( + agg_in, + cmap=cmap, + span=(span[0], span[1]), + how="linear", + color_key=color_key, + min_alpha=min_alpha, + ) + + agg_under = agg.where(agg < span[0]) + img_under = ds.tf.shade( + agg_under, cmap=[to_hex(cmap.get_under())[:7]], min_alpha=min_alpha, color_key=color_key + ) + + agg_over = agg.where(agg > span[1]) + img_over = ds.tf.shade(agg_over, cmap=[to_hex(cmap.get_over())[:7]], min_alpha=min_alpha, color_key=color_key) + + # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0 + stack = img_under.to_numpy().base + if stack is None: + stack = img_in.to_numpy().base + else: + stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0] + img_over = img_over.to_numpy().base + if img_over is not None: + stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] + return stack + + return ds.tf.shade(agg, cmap=cmap, color_key=color_key, min_alpha=min_alpha, span=span, how="linear") diff --git a/tests/_images/Images_can_pass_normalize_clip_False.png b/tests/_images/Images_can_pass_normalize_clip_False.png index b97a3587..98b05197 100644 Binary files a/tests/_images/Images_can_pass_normalize_clip_False.png and b/tests/_images/Images_can_pass_normalize_clip_False.png differ diff --git a/tests/_images/Images_can_pass_normalize_clip_True.png b/tests/_images/Images_can_pass_normalize_clip_True.png index faa35b3b..aa161fdf 100644 Binary files a/tests/_images/Images_can_pass_normalize_clip_True.png and b/tests/_images/Images_can_pass_normalize_clip_True.png differ diff --git a/tests/_images/Labels_can_color_with_norm_and_clipping.png b/tests/_images/Labels_can_color_with_norm_and_clipping.png new file mode 100644 index 00000000..ebd0302f Binary files /dev/null and b/tests/_images/Labels_can_color_with_norm_and_clipping.png differ diff --git a/tests/_images/Labels_can_color_with_norm_no_clipping.png b/tests/_images/Labels_can_color_with_norm_no_clipping.png new file mode 100644 index 00000000..f3c6a545 Binary files /dev/null and b/tests/_images/Labels_can_color_with_norm_no_clipping.png differ diff --git a/tests/_images/Points_can_use_norm_with_clip.png b/tests/_images/Points_can_use_norm_with_clip.png new file mode 100644 index 00000000..0abb42c6 Binary files /dev/null and b/tests/_images/Points_can_use_norm_with_clip.png differ diff --git a/tests/_images/Points_can_use_norm_without_clip.png b/tests/_images/Points_can_use_norm_without_clip.png new file mode 100644 index 00000000..f20eac25 Binary files /dev/null and b/tests/_images/Points_can_use_norm_without_clip.png differ diff --git a/tests/_images/Points_datashader_can_use_norm_with_clip.png b/tests/_images/Points_datashader_can_use_norm_with_clip.png new file mode 100644 index 00000000..7bb613b0 Binary files /dev/null and b/tests/_images/Points_datashader_can_use_norm_with_clip.png differ diff --git a/tests/_images/Points_datashader_can_use_norm_without_clip.png b/tests/_images/Points_datashader_can_use_norm_without_clip.png new file mode 100644 index 00000000..450d2c54 Binary files /dev/null and b/tests/_images/Points_datashader_can_use_norm_without_clip.png differ diff --git a/tests/_images/Points_datashader_norm_vmin_eq_vmax_with_clip.png b/tests/_images/Points_datashader_norm_vmin_eq_vmax_with_clip.png new file mode 100644 index 00000000..2d7ccf8a Binary files /dev/null and b/tests/_images/Points_datashader_norm_vmin_eq_vmax_with_clip.png differ diff --git a/tests/_images/Points_datashader_norm_vmin_eq_vmax_without_clip.png b/tests/_images/Points_datashader_norm_vmin_eq_vmax_without_clip.png new file mode 100644 index 00000000..5aaf98b2 Binary files /dev/null and b/tests/_images/Points_datashader_norm_vmin_eq_vmax_without_clip.png differ diff --git a/tests/_images/Shapes_can_color_with_norm_no_clipping.png b/tests/_images/Shapes_can_color_with_norm_no_clipping.png new file mode 100644 index 00000000..01294587 Binary files /dev/null and b/tests/_images/Shapes_can_color_with_norm_no_clipping.png differ diff --git a/tests/_images/Shapes_datashader_can_color_with_norm_and_clipping.png b/tests/_images/Shapes_datashader_can_color_with_norm_and_clipping.png new file mode 100644 index 00000000..2e754d19 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_color_with_norm_and_clipping.png differ diff --git a/tests/_images/Shapes_datashader_can_color_with_norm_no_clipping.png b/tests/_images/Shapes_datashader_can_color_with_norm_no_clipping.png new file mode 100644 index 00000000..64b053b6 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_color_with_norm_no_clipping.png differ diff --git a/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_with_clip.png b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_with_clip.png new file mode 100644 index 00000000..54842f7b Binary files /dev/null and b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_with_clip.png differ diff --git a/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_without_clip.png b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_without_clip.png new file mode 100644 index 00000000..230e3aeb Binary files /dev/null and b/tests/_images/Shapes_datashader_norm_vmin_eq_vmax_without_clip.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 27adff68..73ea410b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from functools import wraps from pathlib import Path +import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -149,6 +150,23 @@ def test_sdata_multiple_images_diverging_dims(): return sdata +@pytest.fixture +def sdata_blobs_shapes_annotated() -> SpatialData: + """Get blobs sdata with continuous annotation of polygons.""" + blob = blobs() + blob["table"].obs["region"] = "blobs_polygons" + blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5] + return blob + + +def _viridis_with_under_over() -> matplotlib.colors.ListedColormap: + cmap = matplotlib.colormaps["viridis"] + cmap.set_under("black") + cmap.set_over("grey") + return cmap + + # Code below taken from spatialdata main repo diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index c4e43977..5484ac4e 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -7,7 +7,7 @@ from spatialdata import SpatialData import spatialdata_plot # noqa: F401 -from tests.conftest import DPI, PlotTester, PlotTesterMeta +from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over RNG = np.random.default_rng(seed=42) sc.pl.set_rcParams_defaults() @@ -67,12 +67,16 @@ def test_plot_can_render_two_channels_str_from_multiscale_image(self, sdata_blob sdata_blobs_str.pl.render_images(element="blobs_multiscale_image", channel=["c1", "c2"]).pl.show() def test_plot_can_pass_normalize_clip_True(self, sdata_blobs: SpatialData): - norm = Normalize(vmin=0, vmax=0.4, clip=True) - sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show() + norm = Normalize(vmin=0.1, vmax=0.5, clip=True) + sdata_blobs.pl.render_images( + element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over() + ).pl.show() def test_plot_can_pass_normalize_clip_False(self, sdata_blobs: SpatialData): - norm = Normalize(vmin=0, vmax=0.4, clip=False) - sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show() + norm = Normalize(vmin=0.1, vmax=0.5, clip=False) + sdata_blobs.pl.render_images( + element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over() + ).pl.show() def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images(element="blobs_image", channel=1, palette="red").pl.show() diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index d7697bd7..0196ae86 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -6,12 +6,13 @@ import pytest import scanpy as sc from anndata import AnnData +from matplotlib.colors import Normalize from spatial_image import to_spatial_image from spatialdata import SpatialData, deepcopy, get_element_instances from spatialdata.models import TableModel import spatialdata_plot # noqa: F401 -from tests.conftest import DPI, PlotTester, PlotTesterMeta +from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over RNG = np.random.default_rng(seed=42) sc.pl.set_rcParams_defaults() @@ -234,6 +235,19 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str sdata_blobs["other_table"] = table sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category") + def test_plot_can_color_with_norm_and_clipping(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels( + "blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=True), cmap=_viridis_with_under_over() + ).pl.show() + + def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels( + "blobs_labels", + color="channel_0_sum", + norm=Normalize(400, 1000, clip=False), + cmap=_viridis_with_under_over(), + ).pl.show() + def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData): sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape) sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show() diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 3ddef2bb..ad340f70 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -7,13 +7,14 @@ import pandas as pd import scanpy as sc from anndata import AnnData +from matplotlib.colors import Normalize from spatialdata import SpatialData, deepcopy from spatialdata.models import PointsModel, TableModel from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation from spatialdata.transformations._utils import _set_transformations import spatialdata_plot # noqa: F401 -from tests.conftest import DPI, PlotTester, PlotTesterMeta +from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over RNG = np.random.default_rng(seed=42) sc.pl.set_rcParams_defaults() @@ -225,6 +226,56 @@ def test_plot_datashader_can_transform_points(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_points("blobs_points", method="datashader", color="black", size=5).pl.show() + def test_plot_can_use_norm_with_clip(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points( + color="instance_id", size=40, norm=Normalize(3, 7, clip=True), cmap=_viridis_with_under_over() + ).pl.show() + + def test_plot_can_use_norm_without_clip(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points( + color="instance_id", size=40, norm=Normalize(3, 7, clip=False), cmap=_viridis_with_under_over() + ).pl.show() + + def test_plot_datashader_can_use_norm_with_clip(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points( + color="instance_id", + size=40, + norm=Normalize(3, 7, clip=True), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + + def test_plot_datashader_can_use_norm_without_clip(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points( + color="instance_id", + size=40, + norm=Normalize(3, 7, clip=False), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + + def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points( + color="instance_id", + size=40, + norm=Normalize(5, 5, clip=True), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + + def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points( + color="instance_id", + size=40, + norm=Normalize(5, 5, clip=False), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + def test_plot_can_annotate_points_with_table_obs(self, sdata_blobs: SpatialData): nrows, ncols = 200, 3 feature_matrix = RNG.random((nrows, ncols)) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 7abb0783..affeebd0 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -16,7 +16,7 @@ from spatialdata.transformations._utils import _set_transformations import spatialdata_plot # noqa: F401 -from tests.conftest import DPI, PlotTester, PlotTesterMeta +from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over RNG = np.random.default_rng(seed=42) sc.pl.set_rcParams_defaults() @@ -456,6 +456,51 @@ def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes("blobs_circles", color="instance_id").pl.show() + def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated: SpatialData): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", color="value", norm=Normalize(2, 4, clip=False), cmap=_viridis_with_under_over() + ).pl.show() + + def test_plot_datashader_can_color_with_norm_and_clipping(self, sdata_blobs_shapes_annotated: SpatialData): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", + color="value", + norm=Normalize(2, 4, clip=True), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + + def test_plot_datashader_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated: SpatialData): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", + color="value", + norm=Normalize(2, 4, clip=False), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + + def test_plot_datashader_norm_vmin_eq_vmax_without_clip(self, sdata_blobs_shapes_annotated: SpatialData): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", + color="value", + norm=Normalize(3, 3, clip=False), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + + def test_plot_datashader_norm_vmin_eq_vmax_with_clip(self, sdata_blobs_shapes_annotated: SpatialData): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", + color="value", + norm=Normalize(3, 3, clip=True), + cmap=_viridis_with_under_over(), + method="datashader", + datashader_reduction="max", + ).pl.show() + def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialData): nrows, ncols = 5, 3 feature_matrix = RNG.random((nrows, ncols))