Skip to content

Commit 3dc4f10

Browse files
committed
split test so that it actually render points with mpl
1 parent 8aadf07 commit 3dc4f10

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,20 @@ def _render_shapes(
369369
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
370370
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
371371

372-
color_key = (
373-
[_hex_no_alpha(x) for x in color_vector.categories.values]
374-
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
375-
and (len(color_vector.categories.values) > 1)
376-
else None
377-
)
372+
color_key: dict[str, str] | None = None
373+
if color_by_categorical and col_for_color is not None:
374+
cat_series = pd.Categorical(transformed_element[col_for_color])
375+
colors_arr = np.asarray(color_vector, dtype=object)
376+
color_key = {}
377+
for cat in cat_series.categories:
378+
if cat == "nan":
379+
key_color = render_params.cmap_params.na_color.get_hex()
380+
else:
381+
idx = np.flatnonzero(cat_series == cat)
382+
key_color = colors_arr[idx[0]] if idx.size else render_params.cmap_params.na_color.get_hex()
383+
if isinstance(key_color, str) and key_color.startswith("#"):
384+
key_color = _hex_no_alpha(key_color)
385+
color_key[str(cat)] = key_color
378386

379387
if color_by_categorical or col_for_color is None:
380388
ds_cmap = None
@@ -897,16 +905,20 @@ def _render_points(
897905
agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2)
898906
agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5)
899907

900-
color_key: list[str] | None = (
901-
list(color_vector.categories.values)
902-
if (type(color_vector) is pd.core.arrays.categorical.Categorical)
903-
and (len(color_vector.categories.values) > 1)
904-
else None
905-
)
906-
907-
# remove alpha from color if it's hex
908-
if color_key is not None and color_key[0][0] == "#":
909-
color_key = [_hex_no_alpha(x) for x in color_key]
908+
color_key: dict[str, str] | None = None
909+
if color_by_categorical and col_for_color is not None:
910+
cat_series = pd.Categorical(transformed_element[col_for_color])
911+
colors_arr = np.asarray(color_vector, dtype=object)
912+
color_key = {}
913+
for cat in cat_series.categories:
914+
if cat == "nan":
915+
key_color = render_params.cmap_params.na_color.get_hex()
916+
else:
917+
idx = np.flatnonzero(cat_series == cat)
918+
key_color = colors_arr[idx[0]] if idx.size else render_params.cmap_params.na_color.get_hex()
919+
if isinstance(key_color, str) and key_color.startswith("#"):
920+
key_color = _hex_no_alpha(key_color)
921+
color_key[str(cat)] = key_color
910922
if isinstance(color_vector[0], str) and (color_vector is not None and color_vector[0][0] == "#"):
911923
color_vector = np.asarray([_hex_no_alpha(x) for x in color_vector])
912924

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3046,7 +3046,7 @@ def _prepare_transformation(
30463046
def _datashader_map_aggregate_to_color(
30473047
agg: DataArray,
30483048
cmap: str | list[str] | ListedColormap,
3049-
color_key: None | list[str] = None,
3049+
color_key: None | list[str] | Mapping[str, str] = None,
30503050
min_alpha: float = 40,
30513051
span: None | list[float] = None,
30523052
clip: bool = True,

tests/pl/test_render_points.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,12 @@ def test_plot_can_annotate_points_with_table_layer(self, sdata_blobs: SpatialDat
522522

523523
sdata_blobs.pl.render_points("blobs_points", color="feature0", size=10, table_layer="normalized").pl.show()
524524

525-
def test_plot_can_annotate_points_with_nan_in_table_obs_categorical(
525+
def test_plot_can_annotate_points_with_nan_in_table_obs_categorical_matplotlib(
526526
self, sdata_blobs_points_with_nans_in_table: SpatialData
527527
):
528-
sdata_blobs_points_with_nans_in_table.pl.render_points("blobs_points", color="category", size=30).pl.show()
528+
sdata_blobs_points_with_nans_in_table.pl.render_points(
529+
"blobs_points", color="category", size=40, method="matplotlib"
530+
).pl.show()
529531

530532
def test_plot_can_annotate_points_with_nan_in_table_obs_categorical_datashader(
531533
self, sdata_blobs_points_with_nans_in_table: SpatialData

0 commit comments

Comments
 (0)