diff --git a/tests/_images/Labels_label_categorical_color.png b/tests/_images/Labels_label_categorical_color.png new file mode 100644 index 00000000..1d43aa1b Binary files /dev/null and b/tests/_images/Labels_label_categorical_color.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 343eceb1..1e609f83 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -5,8 +5,9 @@ import pandas as pd import pytest import scanpy as sc +from anndata import AnnData from spatial_image import to_spatial_image -from spatialdata import SpatialData, deepcopy +from spatialdata import SpatialData, deepcopy, get_element_instances from spatialdata.models import TableModel import spatialdata_plot # noqa: F401 @@ -208,3 +209,27 @@ def test_plot_subset_categorical_label_maintains_order_when_palette_overwrite(se sdata_blobs.pl.render_labels( "blobs_labels", color="which_max", groups=["channel_0_sum"], palette="red" ).pl.show(ax=axs[1]) + + def test_plot_label_categorical_color(self, sdata_blobs: SpatialData): + self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels") + sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show() + + def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str): + instances = get_element_instances(sdata_blobs[labels_name]) + n_obs = len(instances) + adata = AnnData( + RNG.normal(size=(n_obs, 10)), + obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"]), + ) + adata.obs["instance_id"] = instances.values + adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs) + adata.obs["category"][:3] = ["a", "b", "c"] + adata.obs["region"] = labels_name + table = TableModel.parse( + adata=adata, + region_key="region", + instance_key="instance_id", + region=labels_name, + ) + sdata_blobs["other_table"] = table + sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")