diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 8a1dc737..cc96e26b 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1023,6 +1023,20 @@ def _render_labels( table_layer=table_layer, ) + # rasterize could have removed labels from label + # only problematic if color is specified + if rasterize and color is not None: + labels_in_rasterized_image = np.unique(label.values) + mask = np.isin(instance_id, labels_in_rasterized_image) + instance_id = instance_id[mask] + color_vector = color_vector[mask] + if isinstance(color_vector.dtype, pd.CategoricalDtype): + color_vector = color_vector.remove_unused_categories() + assert color_source_vector is not None + color_source_vector = color_source_vector[mask] + else: + assert color_source_vector is None + def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage: labels = _map_color_seg( seg=label.values, diff --git a/tests/_images/Labels_can_handle_dropping_small_labels_after_rasterize_categorical.png b/tests/_images/Labels_can_handle_dropping_small_labels_after_rasterize_categorical.png new file mode 100644 index 00000000..36d7ae5a Binary files /dev/null and b/tests/_images/Labels_can_handle_dropping_small_labels_after_rasterize_categorical.png differ diff --git a/tests/_images/Labels_can_handle_dropping_small_labels_after_rasterize_continuous.png b/tests/_images/Labels_can_handle_dropping_small_labels_after_rasterize_continuous.png new file mode 100644 index 00000000..ac0bacda Binary files /dev/null and b/tests/_images/Labels_can_handle_dropping_small_labels_after_rasterize_continuous.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index a3f592b0..3f50effb 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -9,7 +9,7 @@ from matplotlib.colors import Normalize from spatial_image import to_spatial_image from spatialdata import SpatialData, deepcopy, get_element_instances -from spatialdata.models import TableModel +from spatialdata.models import Labels2DModel, TableModel import spatialdata_plot # noqa: F401 from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over @@ -76,7 +76,13 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData): fill_alpha=1, outline_alpha=0, ) - .pl.render_labels(element="blobs_labels", na_color="blue", fill_alpha=0, outline_alpha=1, contour_px=15) + .pl.render_labels( + element="blobs_labels", + na_color="blue", + fill_alpha=0, + outline_alpha=1, + contour_px=15, + ) .pl.show() ) @@ -146,7 +152,11 @@ def test_plot_two_calls_with_coloring_result_in_two_colorbars(self, sdata_blobs: def test_plot_can_control_label_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels( - "blobs_labels", color="channel_0_sum", outline_alpha=0.4, fill_alpha=0.0, contour_px=15 + "blobs_labels", + color="channel_0_sum", + outline_alpha=0.4, + fill_alpha=0.0, + contour_px=15, ).pl.show() def test_plot_can_control_label_infill(self, sdata_blobs: SpatialData): @@ -162,7 +172,11 @@ def test_plot_label_colorbar_uses_alpha_of_less_transparent_infill( sdata_blobs: SpatialData, ): sdata_blobs.pl.render_labels( - "blobs_labels", color="channel_0_sum", fill_alpha=0.1, outline_alpha=0.7, contour_px=15 + "blobs_labels", + color="channel_0_sum", + fill_alpha=0.1, + outline_alpha=0.7, + contour_px=15, ).pl.show() def test_plot_label_colorbar_uses_alpha_of_less_transparent_outline( @@ -233,7 +247,10 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str def test_plot_can_color_with_norm_and_clipping(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels( - "blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=True), cmap=_viridis_with_under_over() + "blobs_labels", + color="channel_0_sum", + norm=Normalize(400, 1000, clip=True), + cmap=_viridis_with_under_over(), ).pl.show() def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData): @@ -247,3 +264,38 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData): def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData): sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape) sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show() + + def _prepare_small_labels(self, sdata_blobs: SpatialData) -> SpatialData: + # add a categorical column + adata = sdata_blobs["table"] + sdata_blobs["table"].obs["category"] = ["a"] * 10 + ["b"] * 10 + ["c"] * 6 + + sdata_blobs["table"].obs["category"] = sdata_blobs["table"].obs["category"].astype("category") + + labels = sdata_blobs["blobs_labels"].data.compute() + + # make label 1 small + mask = labels == 1 + labels[mask] = 0 + labels[200, 200] = 1 + + sdata_blobs["blobs_labels"] = Labels2DModel.parse(labels) + + # tile the labels object + arr = da.tile(sdata_blobs["blobs_labels"], (4, 4)) + sdata_blobs["blobs_labels_large"] = Labels2DModel.parse(arr) + + adata.obs["region"] = "blobs_labels_large" + sdata_blobs.set_table_annotates_spatialelement("table", region="blobs_labels_large") + return sdata_blobs + + def test_plot_can_handle_dropping_small_labels_after_rasterize_continuous(self, sdata_blobs: SpatialData): + # reported here https://github.com/scverse/spatialdata-plot/issues/443 + sdata_blobs = self._prepare_small_labels(sdata_blobs) + + sdata_blobs.pl.render_labels("blobs_labels_large", color="channel_0_sum", table_name="table").pl.show() + + def test_plot_can_handle_dropping_small_labels_after_rasterize_categorical(self, sdata_blobs: SpatialData): + sdata_blobs = self._prepare_small_labels(sdata_blobs) + + sdata_blobs.pl.render_labels("blobs_labels_large", color="category", table_name="table").pl.show()