diff --git a/.gitignore b/.gitignore index 580789d6..02a43f70 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,6 @@ tests/figures/ # other _version.py /temp/ + +# pixi +pixi.lock diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 029168be..d34d9fe4 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1,7 +1,6 @@ from __future__ import annotations import sys -import warnings from collections import OrderedDict from copy import deepcopy from pathlib import Path @@ -23,6 +22,7 @@ from xarray import DataArray, DataTree from spatialdata_plot._accessor import register_spatial_data_accessor +from spatialdata_plot._logging import logger from spatialdata_plot.pl.render import ( _render_images, _render_labels, @@ -272,11 +272,7 @@ def render_shapes( """ # TODO add Normalize object in tutorial notebook and point to that notebook here if "vmin" in kwargs or "vmax" in kwargs: - warnings.warn( - "`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.", - DeprecationWarning, - stacklevel=2, - ) + logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.") params_dict = _validate_shape_render_params( self._sdata, element=element, @@ -423,11 +419,7 @@ def render_points( """ # TODO add Normalize object in tutorial notebook and point to that notebook here if "vmin" in kwargs or "vmax" in kwargs: - warnings.warn( - "`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.", - DeprecationWarning, - stacklevel=2, - ) + logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.") params_dict = _validate_points_render_params( self._sdata, element=element, @@ -544,11 +536,7 @@ def render_images( """ # TODO add Normalize object in tutorial notebook and point to that notebook here if "vmin" in kwargs or "vmax" in kwargs: - warnings.warn( - "`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.", - DeprecationWarning, - stacklevel=2, - ) + logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.") params_dict = _validate_image_render_params( self._sdata, element=element, @@ -679,11 +667,7 @@ def render_labels( """ # TODO add Normalize object in tutorial notebook and point to that notebook here if "vmin" in kwargs or "vmax" in kwargs: - warnings.warn( - "`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.", - DeprecationWarning, - stacklevel=2, - ) + logger.warning("`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.") params_dict = _validate_label_render_params( self._sdata, element=element, @@ -918,9 +902,7 @@ def show( # go through tree for i, cs in enumerate(coordinate_systems): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - sdata = self._copy() + sdata = self._copy() _, has_images, has_labels, has_points, has_shapes = ( cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist() ) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 76e57d82..396efba9 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -147,9 +147,9 @@ def _render_shapes( color_vector = np.asarray(color_vector, dtype=float) if np.isnan(color_vector).any(): nan_count = int(np.isnan(color_vector).sum()) - logger.warning( - f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'." - ) + msg = f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'." + warnings.warn(msg, UserWarning, stacklevel=2) + logger.warning(msg) # Using dict.fromkeys here since set returns in arbitrary order # remove the color of NaN values, else it might be assigned to a category @@ -656,12 +656,14 @@ def _render_points( cols = sc.get.obs_df(adata, [col_for_color]) # maybe set color based on type if isinstance(cols[col_for_color].dtype, pd.CategoricalDtype): - _maybe_set_colors( - source=adata, - target=adata, - key=col_for_color, - palette=palette, - ) + uns_color_key = f"{col_for_color}_colors" + if uns_color_key in adata.uns: + _maybe_set_colors( + source=adata, + target=adata, + key=col_for_color, + palette=palette, + ) # when user specified a single color, we emulate the form of `na_color` and use it default_color = ( @@ -778,7 +780,7 @@ def _render_points( agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) - color_key = ( + color_key: list[str] | None = ( list(color_vector.categories.values) if (type(color_vector) is pd.core.arrays.categorical.Categorical) and (len(color_vector.categories.values) > 1) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index cccb587e..95df77c0 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -425,8 +425,10 @@ def _as_rgba_array(x: Any) -> np.ndarray: fill_c[is_num] = cmap(used_norm(num[is_num])) # non-numeric entries as explicit colors - if (~is_num).any(): - fill_c[~is_num] = ColorConverter().to_rgba_array(c_series[~is_num].tolist()) + # treat missing values as na_color, and only convert valid color-like entries + non_numeric_mask = (~is_num) & c_series.notna() + if non_numeric_mask.any(): + fill_c[non_numeric_mask] = ColorConverter().to_rgba_array(c_series[non_numeric_mask].tolist()) # Case C: single color or list of color-like specs (strings or tuples) else: @@ -834,35 +836,101 @@ def _set_color_source_vec( table_layer=table_layer, )[value_to_plot] - # numerical case, return early - # TODO temporary split until refactor is complete + # numerical vs. categorical case if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype): - if ( - not isinstance(element, GeoDataFrame) - and isinstance(palette, list) - and palette[0] is not None - or isinstance(element, GeoDataFrame) - and isinstance(palette, list) - ): - logger.warning( - "Ignoring categorical palette which is given for a continuous variable. " - "Consider using `cmap` to pass a ColorMap." - ) - return None, color_source_vector, False + is_numeric_like = pd.api.types.is_numeric_dtype(color_source_vector.dtype) + is_object_series = isinstance(color_source_vector, pd.Series) and color_source_vector.dtype == "O" + + # If it's an object-typed series but not coercible to float, treat as categorical + if is_object_series and not _is_coercable_to_float(color_source_vector): + color_source_vector = pd.Categorical(color_source_vector) + else: + is_numeric_like = True + + # Continuous case: return early + if is_numeric_like: + if ( + not isinstance(element, GeoDataFrame) + and isinstance(palette, list) + and palette[0] is not None + or isinstance(element, GeoDataFrame) + and isinstance(palette, list) + ): + logger.warning( + "Ignoring categorical palette which is given for a continuous variable. " + "Consider using `cmap` to pass a ColorMap." + ) + return None, color_source_vector, False color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series` - color_mapping = _get_categorical_color_mapping( - adata=sdata.get(table_name, None), - cluster_key=value_to_plot, - color_source_vector=color_source_vector, - cmap_params=cmap_params, - alpha=alpha, - groups=groups, - palette=palette, - na_color=na_color, - render_type=render_type, - ) + # Use the provided table_name parameter, fall back to only one present + table_to_use: str | None + if table_name is not None and table_name in sdata.tables: + table_to_use = table_name + elif table_name is not None and table_name not in sdata.tables: + logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.") + table_to_use = None + else: + table_keys = list(sdata.tables.keys()) + if table_keys: + table_to_use = table_keys[0] + logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.") + else: + table_to_use = None + + adata_for_mapping = sdata[table_to_use] if table_to_use is not None else None + + # Check if custom colors exist in the table's .uns slot + if value_to_plot is not None and _has_colors_in_uns(sdata, table_name, value_to_plot): + # Extract colors directly from the table's .uns slot + # Convert Color to ColorLike (str) for the function + na_color_like: ColorLike = na_color.get_hex() if isinstance(na_color, Color) else na_color + color_mapping = _extract_colors_from_table_uns( + sdata=sdata, + table_name=table_name, + col_to_colorby=value_to_plot, + color_source_vector=color_source_vector, + na_color=na_color_like, + ) + if color_mapping is not None: + if isinstance(palette, str): + palette = [palette] + color_mapping = _modify_categorical_color_mapping( + mapping=color_mapping, + groups=groups, + palette=palette, + ) + else: + logger.warning(f"Failed to extract colors for '{value_to_plot}', falling back to default mapping.") + # Fall back to the existing method if extraction fails + color_mapping = _get_categorical_color_mapping( + adata=sdata[table_to_use], + cluster_key=value_to_plot, + color_source_vector=color_source_vector, + cmap_params=cmap_params, + alpha=alpha, + groups=groups, + palette=palette, + na_color=na_color, + render_type=render_type, + ) + else: + color_mapping = None + + if color_mapping is None: + # Use the existing color mapping method + color_mapping = _get_categorical_color_mapping( + adata=adata_for_mapping, + cluster_key=value_to_plot, + color_source_vector=color_source_vector, + cmap_params=cmap_params, + alpha=alpha, + groups=groups, + palette=palette, + na_color=na_color, + render_type=render_type, + ) color_source_vector = color_source_vector.set_categories(color_mapping.keys()) if color_mapping is None: @@ -874,7 +942,9 @@ def _set_color_source_vec( return color_source_vector, color_vector, True logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.") - color = np.full(sdata[table_name].n_obs, na_color.get_hex_with_alpha()) + # Fallback: color everything with na_color; use element length when table is unknown + n_obs = len(element) if element is not None else (sdata[table_name].n_obs if table_name in sdata.tables else 0) + color = np.full(n_obs, na_color.get_hex_with_alpha()) return color, color, False @@ -916,8 +986,9 @@ def _map_color_seg( else: # Case D: User didn't specify a column to color by, but modified the na_color val_im = map_array(seg.copy(), cell_id, cell_id) - if "#" in str(color_vector[0]): - # we have hex colors + first_value = color_vector.iloc[0] if isinstance(color_vector, pd.Series) else color_vector[0] + if _is_color_like(first_value): + # we have color-like values (e.g., hex or named colors) assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like." cols = colors.to_rgba_array(color_vector) else: @@ -966,6 +1037,172 @@ def _generate_base_categorial_color_mapping( return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params) +def _has_colors_in_uns( + sdata: sd.SpatialData, + table_name: str | None, + col_to_colorby: str, +) -> bool: + """ + Check if _colors exists in the specified table's .uns slot. + + Parameters + ---------- + sdata + SpatialData object containing tables + table_name + Name of the table to check. If None, uses the first available table. + col_to_colorby + Name of the categorical column (e.g., "celltype") + + Returns + ------- + True if _colors exists in the table's .uns, False otherwise + """ + color_key = f"{col_to_colorby}_colors" + + # Determine which table to use + if table_name is not None: + if table_name not in sdata.tables: + return False + table_to_use = table_name + else: + if len(sdata.tables.keys()) == 0: + return False + # When no table is specified, check all tables for the color key + return any(color_key in adata.uns for adata in sdata.tables.values()) + + adata = sdata.tables[table_to_use] + return color_key in adata.uns + + +def _extract_colors_from_table_uns( + sdata: sd.SpatialData, + table_name: str | None, + col_to_colorby: str, + color_source_vector: ArrayLike | pd.Series[CategoricalDtype], + na_color: ColorLike, +) -> Mapping[str, str] | None: + """ + Extract categorical colors from the _colors pattern in adata.uns. + + This function looks for colors stored in the format _colors in the + specified table's .uns slot and creates a mapping from categories to colors. + + Parameters + ---------- + sdata + SpatialData object containing tables + table_name + Name of the table to look in. If None, uses the first available table. + col_to_colorby + Name of the categorical column (e.g., "celltype") + color_source_vector + Categorical vector containing the categories to map + na_color + Color to use for NaN/missing values + + Returns + ------- + Mapping from category names to hex colors, or None if colors not found + """ + color_key = f"{col_to_colorby}_colors" + + # Determine which table to use + if table_name is not None: + if table_name not in sdata.tables: + logger.warning(f"Table '{table_name}' not found in sdata. Available tables: {list(sdata.tables.keys())}") + return None + table_to_use = table_name + else: + if len(sdata.tables) == 0: + logger.warning("No tables found in sdata.") + return None + # No explicit table provided: search all tables for the color key + candidate_tables: list[str] = [ + name + for name, ad in sdata.tables.items() + if color_key in ad.uns # type: ignore[union-attr] + ] + if not candidate_tables: + logger.debug(f"Color key '{color_key}' not found in any table uns.") + return None + table_to_use = candidate_tables[0] + if len(candidate_tables) > 1: + logger.warning( + f"Color key '{color_key}' found in multiple tables {candidate_tables}; using table '{table_to_use}'." + ) + logger.info(f"No table name provided, using '{table_to_use}' for color extraction.") + + adata = sdata.tables[table_to_use] + + # Check if the color pattern exists + if color_key not in adata.uns: + logger.debug(f"Color key '{color_key}' not found in table '{table_to_use}' uns.") + return None + + # Extract colors and categories + stored_colors = adata.uns[color_key] + categories = color_source_vector.categories.tolist() + + # Validate na_color format and convert to hex string + if isinstance(na_color, Color): + na_color_hex = na_color.get_hex() + else: + na_color_str = str(na_color) + if "#" not in na_color_str: + logger.warning("Expected `na_color` to be a hex color, converting...") + na_color_hex = to_hex(to_rgba(na_color)[:3]) + else: + na_color_hex = na_color_str + + # Strip alpha channel from na_color if present + if len(na_color_hex) == 9: # #rrggbbaa format + na_color_hex = na_color_hex[:7] # Keep only #rrggbb + + def _to_hex_no_alpha(color_value: Any) -> str | None: + try: + rgba = to_rgba(color_value)[:3] + hex_color: str = to_hex(rgba) + if len(hex_color) == 9: + hex_color = hex_color[:7] + return hex_color + except (TypeError, ValueError) as e: + logger.warning(f"Error converting color '{color_value}' to hex format: {e}") + return None + + color_mapping: dict[str, str] = {} + + if isinstance(stored_colors, Mapping): + for category in categories: + raw_color = stored_colors.get(category) + if raw_color is None: + logger.warning(f"No color specified for '{category}' in '{color_key}', using na_color.") + color_mapping[category] = na_color_hex + continue + hex_color = _to_hex_no_alpha(raw_color) + color_mapping[category] = hex_color if hex_color is not None else na_color_hex + logger.info(f"Successfully extracted {len(color_mapping)} colors from '{color_key}' in table '{table_to_use}'.") + else: + try: + hex_colors = [_to_hex_no_alpha(color) for color in stored_colors] + except TypeError: + logger.warning(f"Unsupported color storage for '{color_key}'. Expected sequence or mapping.") + return None + + for i, category in enumerate(categories): + if i < len(hex_colors) and hex_colors[i] is not None: + hex_color = hex_colors[i] + assert hex_color is not None # type narrowing for mypy + color_mapping[category] = hex_color + else: + logger.warning(f"Not enough colors provided for category '{category}', using na_color.") + color_mapping[category] = na_color_hex + logger.info(f"Successfully extracted {len(hex_colors)} colors from '{color_key}' in table '{table_to_use}'.") + + color_mapping["NaN"] = na_color_hex + return color_mapping + + def _modify_categorical_color_mapping( mapping: Mapping[str, str], groups: list[str] | str | None = None, diff --git a/tests/_images/Labels_respects_custom_colors_from_uns.png b/tests/_images/Labels_respects_custom_colors_from_uns.png new file mode 100644 index 00000000..d540e08f Binary files /dev/null and b/tests/_images/Labels_respects_custom_colors_from_uns.png differ diff --git a/tests/_images/Labels_respects_custom_colors_from_uns_with_groups_and_palette.png b/tests/_images/Labels_respects_custom_colors_from_uns_with_groups_and_palette.png new file mode 100644 index 00000000..88a05954 Binary files /dev/null and b/tests/_images/Labels_respects_custom_colors_from_uns_with_groups_and_palette.png differ diff --git a/tests/_images/Points_respects_custom_colors_from_uns_for_points.png b/tests/_images/Points_respects_custom_colors_from_uns_for_points.png new file mode 100644 index 00000000..ddc3b5fa Binary files /dev/null and b/tests/_images/Points_respects_custom_colors_from_uns_for_points.png differ diff --git a/tests/_images/Shapes_respects_custom_colors_from_uns.png b/tests/_images/Shapes_respects_custom_colors_from_uns.png new file mode 100644 index 00000000..742a1397 Binary files /dev/null and b/tests/_images/Shapes_respects_custom_colors_from_uns.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 8d96ec96..baa2ce29 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -263,7 +263,7 @@ def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialDat sdata_blobs["table"].layers["normalized"] = get_standard_RNG().random(sdata_blobs["table"].X.shape) sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show() - def _prepare_small_labels(self, sdata_blobs: SpatialData) -> SpatialData: + def _prepare_labels_with_small_objects(self, sdata_blobs: SpatialData) -> SpatialData: # add a categorical column adata = sdata_blobs["table"] sdata_blobs["table"].obs["category"] = ["a"] * 10 + ["b"] * 10 + ["c"] * 6 @@ -289,15 +289,76 @@ def _prepare_small_labels(self, sdata_blobs: SpatialData) -> SpatialData: def test_plot_can_handle_dropping_small_labels_after_rasterize_continuous(self, sdata_blobs: SpatialData): # reported here https://github.com/scverse/spatialdata-plot/issues/443 - sdata_blobs = self._prepare_small_labels(sdata_blobs) + sdata_blobs = self._prepare_labels_with_small_objects(sdata_blobs) sdata_blobs.pl.render_labels("blobs_labels_large", color="channel_0_sum", table_name="table").pl.show() def test_plot_can_handle_dropping_small_labels_after_rasterize_categorical(self, sdata_blobs: SpatialData): - sdata_blobs = self._prepare_small_labels(sdata_blobs) + sdata_blobs = self._prepare_labels_with_small_objects(sdata_blobs) sdata_blobs.pl.render_labels("blobs_labels_large", color="category", table_name="table").pl.show() + def test_plot_respects_custom_colors_from_uns(self, sdata_blobs: SpatialData): + labels_name = "blobs_labels" + instances = get_element_instances(sdata_blobs[labels_name]) + n_obs = len(instances) + adata = AnnData( + get_standard_RNG().normal(size=(n_obs, 10)), + obs=pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["a", "b", "c"]), + ) + adata.obs["instance_id"] = instances.values + adata.obs["category"] = get_standard_RNG().choice(["a", "b", "c"], size=adata.n_obs) + adata.obs["category"][:3] = ["a", "b", "c"] + adata.obs["region"] = labels_name + table = TableModel.parse( + adata=adata, + region_key="region", + instance_key="instance_id", + region=labels_name, + ) + sdata_blobs["other_table"] = table + sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category") + sdata_blobs["other_table"].uns["category_colors"] = ["red", "green", "blue"] # purple, green ,yellow + + sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show() + + def test_plot_respects_custom_colors_from_uns_with_groups_and_palette( + self, + sdata_blobs: SpatialData, + ): + labels_name = "blobs_labels" + instances = get_element_instances(sdata_blobs[labels_name]) + n_obs = len(instances) + adata = AnnData( + get_standard_RNG().normal(size=(n_obs, 10)), + obs=pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["a", "b", "c"]), + ) + adata.obs["instance_id"] = instances.values + adata.obs["category"] = get_standard_RNG().choice(["a", "b", "c"], size=adata.n_obs) + adata.obs["category"][:3] = ["a", "b", "c"] + adata.obs["region"] = labels_name + table = TableModel.parse( + adata=adata, + region_key="region", + instance_key="instance_id", + region=labels_name, + ) + sdata_blobs["other_table"] = table + sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category") + sdata_blobs["other_table"].uns["category_colors"] = { + "a": "red", + "b": "green", + "c": "blue", + } + + # palette overwrites uns colors + sdata_blobs.pl.render_labels( + "blobs_labels", + color="category", + groups=["a", "b"], + palette=["yellow", "cyan"], + ).pl.show() + def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): # Work on an independent copy since we mutate tables diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 5e3fc38a..704a7d6b 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -70,6 +70,18 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData): palette=["lightgreen", "darkblue"], ).pl.show() + def test_plot_respects_custom_colors_from_uns_for_points(self, sdata_blobs: SpatialData): + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" + + # set a custom palette in `.uns` for the categorical column + sdata_blobs["table"].uns["genes_colors"] = ["#800080", "#008000", "#FFFF00"] + + sdata_blobs.pl.render_points( + element="blobs_points", + color="genes", + ).pl.show() + def test_plot_coloring_with_cmap(self, sdata_blobs: SpatialData): sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 8fc39582..f4bd2308 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -617,6 +617,20 @@ def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialDat sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show() + def test_plot_respects_custom_colors_from_uns(self, sdata_blobs: SpatialData): + shapes_name = "blobs_polygons" + # Ensure that the table annotations point to the shapes element + sdata_blobs["table"].obs["region"] = pd.Categorical([shapes_name] * sdata_blobs["table"].n_obs) + sdata_blobs.set_table_annotates_spatialelement("table", region=shapes_name) + + categories = get_standard_RNG().choice(["a", "b", "c"], size=sdata_blobs["table"].n_obs) + categories[:3] = ["a", "b", "c"] + categories = pd.Categorical(categories, categories=["a", "b", "c"]) + sdata_blobs["table"].obs["category"] = categories + sdata_blobs["table"].uns["category_colors"] = ["red", "green", "blue"] # purple, green, yellow + + sdata_blobs.pl.render_shapes(shapes_name, color="category", table_name="table").pl.show() + def test_plot_can_render_circles_to_hex(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex").pl.show() @@ -725,39 +739,43 @@ def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): ).pl.show() ) - def test_plot_can_handle_nan_values_in_color_data(self, sdata_blobs: SpatialData): - """Test that NaN values in color data are handled gracefully.""" - sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) - sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" - # Add color column with NaN values - sdata_blobs.shapes["blobs_circles"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] +def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData): + """Test that NaN values in color data are handled gracefully.""" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" - # Test that rendering works with NaN values and issues warning - with pytest.warns(UserWarning, match="Found 1 NaN values in color data"): - sdata_blobs.pl.render_shapes(element="blobs_circles", color="color_with_nan", na_color="red").pl.show() + # Add color column with NaN values + sdata_blobs.shapes["blobs_circles"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] - def test_plot_colorbar_normalization_with_nan_values(self, sdata_blobs: SpatialData): - """Test that colorbar normalization works correctly with NaN values.""" - sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) - sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + # Test that rendering works with NaN values and issues warning + with pytest.warns(UserWarning, match="Found 1 NaN values in color data"): + sdata_blobs.pl.render_shapes(element="blobs_circles", color="color_with_nan", na_color="red").pl.show() - sdata_blobs.shapes["blobs_polygons"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] - # Test colorbar with NaN values - should use nanmin/nanmax - sdata_blobs.pl.render_shapes(element="blobs_polygons", color="color_with_nan", na_color="gray").pl.show() +def test_plot_colorbar_normalization_with_nan_values(sdata_blobs: SpatialData): + """Test that colorbar normalization works correctly with NaN values.""" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" - def test_plot_can_handle_non_numeric_radius_values(self, sdata_blobs: SpatialData): - """Test that non-numeric radius values are handled gracefully.""" - sdata_blobs.shapes["blobs_circles"]["radius_mixed"] = [1.0, "invalid", 3.0, np.nan, 5.0] + sdata_blobs.shapes["blobs_polygons"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] - sdata_blobs.pl.render_shapes(element="blobs_circles", color="red").pl.show() + # Test colorbar with NaN values - should use nanmin/nanmax + sdata_blobs.pl.render_shapes(element="blobs_polygons", color="color_with_nan", na_color="gray").pl.show() - def test_plot_can_handle_mixed_numeric_and_color_data(self, sdata_blobs: SpatialData): - """Test handling of mixed numeric and color-like data.""" - sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) - sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" - sdata_blobs.shapes["blobs_circles"]["mixed_data"] = [1.0, 2.0, np.nan, "red", 5.0] +def test_plot_can_handle_non_numeric_radius_values(sdata_blobs: SpatialData): + """Test that non-numeric radius values are handled gracefully.""" + sdata_blobs.shapes["blobs_circles"]["radius_mixed"] = [1.0, "invalid", 3.0, np.nan, 5.0] + + sdata_blobs.pl.render_shapes(element="blobs_circles", color="red").pl.show() + + +def test_plot_can_handle_mixed_numeric_and_color_data(sdata_blobs: SpatialData): + """Test handling of mixed numeric and color-like data.""" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" + + sdata_blobs.shapes["blobs_circles"]["mixed_data"] = [1.0, 2.0, np.nan, "red", 5.0] - sdata_blobs.pl.render_shapes(element="blobs_circles", color="mixed_data", na_color="gray").pl.show() + sdata_blobs.pl.render_shapes(element="blobs_circles", color="mixed_data", na_color="gray").pl.show()