Skip to content
Merged
14 changes: 14 additions & 0 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,20 @@ def _render_labels(
table_layer=table_layer,
)

# rasterize could have removed labels from label
# only problematic if color is specified
if rasterize and color is not None:
labels_in_rasterized_image = np.unique(label.values)
mask = np.isin(instance_id, labels_in_rasterized_image)
instance_id = instance_id[mask]
color_vector = color_vector[mask]
if isinstance(color_vector.dtype, pd.CategoricalDtype):
color_vector = color_vector.remove_unused_categories()
assert color_source_vector is not None
color_source_vector = color_source_vector[mask]
else:
assert color_source_vector is None

def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:
labels = _map_color_seg(
seg=label.values,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
62 changes: 57 additions & 5 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from matplotlib.colors import Normalize
from spatial_image import to_spatial_image
from spatialdata import SpatialData, deepcopy, get_element_instances
from spatialdata.models import TableModel
from spatialdata.models import Labels2DModel, TableModel

import spatialdata_plot # noqa: F401
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
Expand Down Expand Up @@ -76,7 +76,13 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData):
fill_alpha=1,
outline_alpha=0,
)
.pl.render_labels(element="blobs_labels", na_color="blue", fill_alpha=0, outline_alpha=1, contour_px=15)
.pl.render_labels(
element="blobs_labels",
na_color="blue",
fill_alpha=0,
outline_alpha=1,
contour_px=15,
)
.pl.show()
)

Expand Down Expand Up @@ -146,7 +152,11 @@ def test_plot_two_calls_with_coloring_result_in_two_colorbars(self, sdata_blobs:

def test_plot_can_control_label_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels(
"blobs_labels", color="channel_0_sum", outline_alpha=0.4, fill_alpha=0.0, contour_px=15
"blobs_labels",
color="channel_0_sum",
outline_alpha=0.4,
fill_alpha=0.0,
contour_px=15,
).pl.show()

def test_plot_can_control_label_infill(self, sdata_blobs: SpatialData):
Expand All @@ -162,7 +172,11 @@ def test_plot_label_colorbar_uses_alpha_of_less_transparent_infill(
sdata_blobs: SpatialData,
):
sdata_blobs.pl.render_labels(
"blobs_labels", color="channel_0_sum", fill_alpha=0.1, outline_alpha=0.7, contour_px=15
"blobs_labels",
color="channel_0_sum",
fill_alpha=0.1,
outline_alpha=0.7,
contour_px=15,
).pl.show()

def test_plot_label_colorbar_uses_alpha_of_less_transparent_outline(
Expand Down Expand Up @@ -233,7 +247,10 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str

def test_plot_can_color_with_norm_and_clipping(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels(
"blobs_labels", color="channel_0_sum", norm=Normalize(400, 1000, clip=True), cmap=_viridis_with_under_over()
"blobs_labels",
color="channel_0_sum",
norm=Normalize(400, 1000, clip=True),
cmap=_viridis_with_under_over(),
).pl.show()

def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
Expand All @@ -247,3 +264,38 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):
sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape)
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()

def _prepare_small_labels(self, sdata_blobs: SpatialData) -> SpatialData:
# add a categorical column
adata = sdata_blobs["table"]
sdata_blobs["table"].obs["category"] = ["a"] * 10 + ["b"] * 10 + ["c"] * 6

sdata_blobs["table"].obs["category"] = sdata_blobs["table"].obs["category"].astype("category")

labels = sdata_blobs["blobs_labels"].data.compute()

# make label 1 small
mask = labels == 1
labels[mask] = 0
labels[200, 200] = 1

sdata_blobs["blobs_labels"] = Labels2DModel.parse(labels)

# tile the labels object
arr = da.tile(sdata_blobs["blobs_labels"], (4, 4))
sdata_blobs["blobs_labels_large"] = Labels2DModel.parse(arr)

adata.obs["region"] = "blobs_labels_large"
sdata_blobs.set_table_annotates_spatialelement("table", region="blobs_labels_large")
return sdata_blobs

def test_plot_can_handle_dropping_small_labels_after_rasterize_continuous(self, sdata_blobs: SpatialData):
# reported here https://github.com/scverse/spatialdata-plot/issues/443
sdata_blobs = self._prepare_small_labels(sdata_blobs)

sdata_blobs.pl.render_labels("blobs_labels_large", color="channel_0_sum", table_name="table").pl.show()

def test_plot_can_handle_dropping_small_labels_after_rasterize_categorical(self, sdata_blobs: SpatialData):
sdata_blobs = self._prepare_small_labels(sdata_blobs)

sdata_blobs.pl.render_labels("blobs_labels_large", color="category", table_name="table").pl.show()
Loading