diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 546b71b1..025eb9ed 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -61,7 +61,7 @@ # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | str +ColorLike = tuple[float, ...] | list[float] | str @register_spatial_data_accessor("pl") @@ -156,14 +156,14 @@ def _copy( def render_shapes( self, element: str | None = None, - color: str | None = None, - fill_alpha: float | int = 1.0, + color: ColorLike | None = None, + fill_alpha: float | int | None = None, groups: list[str] | str | None = None, palette: list[str] | str | None = None, na_color: ColorLike | None = "default", - outline_width: float | int = 1.5, - outline_color: str | list[float] = "#000000", - outline_alpha: float | int = 0.0, + outline_width: float | int | tuple[float | int, float | int] | None = None, + outline_color: ColorLike | tuple[ColorLike] | None = None, + outline_alpha: float | int | tuple[float | int, float | int] | None = None, cmap: Colormap | str | None = None, norm: Normalize | None = None, scale: float | int = 1.0, @@ -186,15 +186,18 @@ def render_shapes( element : str | None, optional The name of the shapes element to render. If `None`, all shapes elements in the `SpatialData` object will be used. - color : str | None - Can either be string representing a color-like or key in :attr:`sdata.table.obs`. The latter can be used to - color by categorical or continuous variables. If `element` is `None`, if possible the color will be - broadcasted to all elements. For this, the table in which the color key is found must annotate the - respective element (region must be set to the specific element). If the color column is found in multiple - locations, please provide the table_name to be used for the elements. - fill_alpha : float | int, default 1.0 - Alpha value for the fill of shapes. If the alpha channel is present in a cmap passed by the user, this value - will multiply the value present in the cmap. + color : ColorLike | None, optional + Can either be color-like (name of a color as string, e.g. "red", hex representation, e.g. "#000000" or + "#000000ff", or an RGB(A) array as a tuple or list containing 3-4 floats within [0, 1]. If an alpha value is + indicated, the value of `fill_alpha` takes precedence if given) or a string representing a key in + :attr:`sdata.table.obs`. The latter can be used to color by categorical or continuous variables. If + `element` is `None`, if possible the color will be broadcasted to all elements. For this, the table in which + the color key is found must annotate the respective element (region must be set to the specific element). If + the color column is found in multiple locations, please provide the table_name to be used for the elements. + fill_alpha : float | int | None, optional + Alpha value for the fill of shapes. By default, it is set to 1.0 or, if a color is given that implies an + alpha, that value is used for `fill_alpha`. If an alpha channel is present in a cmap passed by the user, + `fill_alpha` will overwrite the value present in the cmap. groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of them. Other values are set to NA. If elment is None, broadcasting behaviour is attempted (use the same @@ -204,18 +207,25 @@ def render_shapes( match the number of groups. If element is None, broadcasting behaviour is attempted (use the same values for all elements). If groups is provided but not palette, palette is set to default "lightgray". na_color : ColorLike | None, default "default" (gets set to "lightgray") - Color to be used for NAs values, if present. Can either be a named color ("red"), a hex representation + Color to be used for NA values, if present. Can either be a named color ("red"), a hex representation ("#000000ff") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values won't be shown. - outline_width : float | int, default 1.5 - Width of the border. - outline_color : str | list[float], default "#000000" + outline_width : float | int | tuple[float | int, float | int], optional + Width of the border. If 2 values are given (tuple), 2 borders are shown with these widths (outer & inner). + If `outline_color` and/or `outline_alpha` are used to indicate that one/two outlines should be drawn, the + default outline widths 1.5 and 0.5 are used for outer/only and inner outline respectively. + outline_color : ColorLike | tuple[ColorLike], optional Color of the border. Can either be a named color ("red"), a hex representation ("#000000") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If the hex representation includes alpha, e.g. - "#000000ff", the last two positions are ignored, since the alpha of the outlines is solely controlled by - `outline_alpha`. - outline_alpha : float | int, default 0.0 - Alpha value for the outline of shapes. Invisible by default. + "#000000ff", and `outline_alpha` is not given, this value controls the opacity of the outline. If 2 values + are given (tuple), 2 borders are shown with these colors (outer & inner). If `outline_width` and/or + `outline_alpha` are used to indicate that one/two outlines should be drawn, the default outline colors + "#000000" and "#ffffff are used for outer/only and inner outline respectively. + outline_alpha : float | int | tuple[float | int, float | int] | None, optional + Alpha value for the outline of shapes. Invisible by default, meaning outline_alpha=0.0 if both outline_color + and outline_width are not specified. Else, outlines are rendered with the alpha implied by outline_color, or + with outline_alpha=1.0 if outline_color does not imply an alpha. For two outlines, alpha values can be + passed in a tuple of length 2. cmap : Colormap | str | None, optional Colormap for discrete or continuous annotations using 'color', see :class:`matplotlib.colors.Colormap`. norm : bool | Normalize, default False @@ -283,8 +293,12 @@ def render_shapes( sdata = self._copy() sdata = _verify_plotting_tree(sdata) n_steps = len(sdata.plotting_tree.keys()) - outline_params = _set_outline(outline_alpha > 0, outline_width, outline_color) for element, param_values in params_dict.items(): + final_outline_alpha, outline_params = _set_outline( + params_dict[element]["outline_alpha"], + params_dict[element]["outline_width"], + params_dict[element]["outline_color"], + ) cmap_params = _prepare_cmap_norm( cmap=cmap, norm=norm, @@ -299,7 +313,7 @@ def render_shapes( outline_params=outline_params, cmap_params=cmap_params, palette=param_values["palette"], - outline_alpha=param_values["outline_alpha"], + outline_alpha=final_outline_alpha, fill_alpha=param_values["fill_alpha"], transfunc=kwargs.get("transfunc"), table_name=param_values["table_name"], @@ -316,8 +330,8 @@ def render_shapes( def render_points( self, element: str | None = None, - color: str | None = None, - alpha: float | int = 1.0, + color: ColorLike | None = None, + alpha: float | int | None = None, groups: list[str] | str | None = None, palette: list[str] | str | None = None, na_color: ColorLike | None = "default", @@ -343,14 +357,17 @@ def render_points( element : str | None, optional The name of the points element to render. If `None`, all points elements in the `SpatialData` object will be used. - color : str | None - Can either be string representing a color-like or key in :attr:`sdata.table.obs`. The latter can be used to - color by categorical or continuous variables. If `element` is `None`, if possible the color will be - broadcasted to all elements. For this, the table in which the color key is found must annotate the - respective element (region must be set to the specific element). If the color column is found in multiple - locations, please provide the table_name to be used for the elements. - alpha : float | int, default 1.0 - Alpha value for the points. + color : str | None, optional + Can either be color-like (name of a color as string, e.g. "red", hex representation, e.g. "#000000" or + "#000000ff", or an RGB(A) array as a tuple or list containing 3-4 floats within [0, 1]. If an alpha value is + indicated, the value of `fill_alpha` takes precedence if given) or a string representing a key in + :attr:`sdata.table.obs`. The latter can be used to color by categorical or continuous variables. If + `element` is `None`, if possible the color will be broadcasted to all elements. For this, the table in which + the color key is found must annotate the respective element (region must be set to the specific element). If + the color column is found in multiple locations, please provide the table_name to be used for the elements. + alpha : float | int | None, optional + Alpha value for the points. By default, it is set to 1.0 or, if a color is given that implies an alpha, that + value is used instead. groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of them. Other values are set to NA. If `element` is `None`, broadcasting behaviour is attempted (use the same @@ -360,7 +377,7 @@ def render_points( match the number of groups. If `element` is `None`, broadcasting behaviour is attempted (use the same values for all elements). If groups is provided but not palette, palette is set to default "lightgray". na_color : ColorLike | None, default "default" (gets set to "lightgray") - Color to be used for NAs values, if present. Can either be a named color ("red"), a hex representation + Color to be used for NA values, if present. Can either be a named color ("red"), a hex representation ("#000000ff") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values won't be shown. cmap : Colormap | str | None, optional @@ -601,7 +618,7 @@ def render_labels( element : str | None The name of the labels element to render. If `None`, all label elements in the `SpatialData` object will be used and all parameters will be broadcasted if possible. - color : list[str] | str | None + color : str | None Can either be string representing a color-like or key in :attr:`sdata.table.obs` or in the index of :attr:`sdata.table.var`. The latter can be used to color by categorical or continuous variables. If the color column is found in multiple locations, please provide the table_name to be used for the element if you @@ -626,7 +643,7 @@ def render_labels( won't be shown. outline_alpha : float | int, default 0.0 Alpha value for the outline of the labels. Invisible by default. - fill_alpha : float | int, default 0.3 + fill_alpha : float | int, default 0.4 Alpha value for the fill of the labels. scale : str | None Influences the resolution of the rendering. Possibilities for setting this parameter: diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 23d06ce5..f74259a4 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -25,6 +25,7 @@ from spatialdata_plot._logging import logger from spatialdata_plot.pl.render_params import ( + Color, FigParams, ImageRenderParams, LabelsRenderParams, @@ -35,6 +36,7 @@ ) from spatialdata_plot.pl.utils import ( _ax_show_and_transform, + _convert_alpha_to_datashader_range, _create_image_from_datashader_result, _datashader_aggregate_with_function, _datashader_map_aggregate_to_color, @@ -53,7 +55,6 @@ _prepare_transformation, _rasterize_if_necessary, _set_color_source_vec, - to_hex, ) _Normalize = Normalize | abc.Sequence[Normalize] @@ -114,7 +115,7 @@ def _render_shapes( value_to_plot=col_for_color, groups=groups, palette=render_params.palette, - na_color=render_params.color or render_params.cmap_params.na_color, + na_color=render_params.color if render_params.color is not None else render_params.cmap_params.na_color, cmap_params=render_params.cmap_params, table_name=table_name, table_layer=table_layer, @@ -129,7 +130,7 @@ def _render_shapes( norm = copy(render_params.cmap_params.norm) if len(color_vector) == 0: - color_vector = [render_params.cmap_params.na_color] + color_vector = [render_params.cmap_params.na_color.get_hex_with_alpha()] # filter by `groups` if isinstance(groups, list) and color_source_vector is not None: @@ -147,7 +148,10 @@ def _render_shapes( else: palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) - if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color): + if ( + len(set(color_vector)) != 1 + or list(set(color_vector))[0] != render_params.cmap_params.na_color.get_hex_with_alpha() + ): # necessary in case different shapes elements are annotated with one table if color_source_vector is not None and col_for_color is not None: color_source_vector = color_source_vector.remove_unused_categories() @@ -232,12 +236,20 @@ def _render_shapes( aggregate_with_reduction = (agg.min(), agg.max()) else: agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count()) + # render outlines if needed - if (render_outlines := render_params.outline_alpha) > 0: + assert len(render_params.outline_alpha) == 2 # shut up mypy + if render_params.outline_alpha[0] > 0: agg_outlines = cvs.line( transformed_element, geometry="geometry", - line_width=render_params.outline_params.linewidth, + line_width=render_params.outline_params.outer_outline_linewidth, + ) + if render_params.outline_alpha[1] > 0: + agg_inner_outlines = cvs.line( + transformed_element, + geometry="geometry", + line_width=render_params.outline_params.inner_outline_linewidth, ) ds_span = None @@ -273,8 +285,8 @@ def _render_shapes( agg, cmap=ds_cmap, color_key=color_key, - min_alpha=np.min([254, render_params.fill_alpha * 255]), - ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes + min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), + ) elif aggregate_with_reduction is not None: # to shut up mypy ds_cmap = render_params.cmap_params.cmap # in case all elements have the same value X: we render them using cmap(0.0), @@ -290,27 +302,51 @@ def _render_shapes( ds_result = _datashader_map_aggregate_to_color( agg, cmap=ds_cmap, - min_alpha=np.min([254, render_params.fill_alpha * 255]), + min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), span=ds_span, clip=norm.clip, - ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes - - # shade outlines if needed - outline_color = render_params.outline_params.outline_color - if isinstance(outline_color, str) and outline_color.startswith("#") and len(outline_color) == 9: - logger.info( - "alpha component of given RGBA value for outline color is discarded, because outline_alpha" - " takes precedent." ) - outline_color = outline_color[:-2] - if render_outlines: + # shade outlines if needed + if render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color): + outline_color = render_params.outline_params.outer_outline_color.get_hex() ds_outlines = ds.tf.shade( agg_outlines, cmap=outline_color, - min_alpha=np.min([254, render_params.outline_alpha * 255]), + min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[0]), + how="linear", + ) + # inner outlines + if render_params.outline_alpha[1] > 0 and isinstance(render_params.outline_params.inner_outline_color, Color): + outline_color = render_params.outline_params.inner_outline_color.get_hex() + ds_inner_outlines = ds.tf.shade( + agg_inner_outlines, + cmap=outline_color, + min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[1]), how="linear", - ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes + ) + + # render outline image(s) + if render_params.outline_alpha[0] > 0: + rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax) + _ax_show_and_transform( + rgba_image, + trans_data, + ax, + zorder=render_params.zorder, + alpha=render_params.outline_alpha[0], + extent=x_ext + y_ext, + ) + if render_params.outline_alpha[1] > 0: + rgba_image, trans_data = _create_image_from_datashader_result(ds_inner_outlines, factor, ax) + _ax_show_and_transform( + rgba_image, + trans_data, + ax, + zorder=render_params.zorder, + alpha=render_params.outline_alpha[1], + extent=x_ext + y_ext, + ) rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) _cax = _ax_show_and_transform( @@ -321,17 +357,6 @@ def _render_shapes( alpha=render_params.fill_alpha, extent=x_ext + y_ext, ) - # render outline image if needed - if render_outlines: - rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax) - _ax_show_and_transform( - rgba_image, - trans_data, - ax, - zorder=render_params.zorder, - alpha=render_params.outline_alpha, - extent=x_ext + y_ext, - ) cax = None if aggregate_with_reduction is not None: @@ -350,6 +375,48 @@ def _render_shapes( ) elif method == "matplotlib": + # render outlines separately to ensure they are always underneath the shape + if render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color): + _cax = _get_collection_shape( + shapes=shapes, + s=render_params.scale, + c=np.array(["white"]), # hack, will be invisible bc fill_alpha=0 + render_params=render_params, + rasterized=sc_settings._vector_friendly, + cmap=None, + norm=None, + fill_alpha=0.0, + outline_alpha=render_params.outline_alpha[0], + outline_color=render_params.outline_params.outer_outline_color.get_hex(), + linewidth=render_params.outline_params.outer_outline_linewidth, + zorder=render_params.zorder, + # **kwargs, + ) + cax = ax.add_collection(_cax) + # Transform the paths in PatchCollection + for path in _cax.get_paths(): + path.vertices = trans.transform(path.vertices) + if render_params.outline_alpha[1] > 0 and isinstance(render_params.outline_params.inner_outline_color, Color): + _cax = _get_collection_shape( + shapes=shapes, + s=render_params.scale, + c=np.array(["white"]), # hack, will be invisible bc fill_alpha=0 + render_params=render_params, + rasterized=sc_settings._vector_friendly, + cmap=None, + norm=None, + fill_alpha=0.0, + outline_alpha=render_params.outline_alpha[1], + outline_color=render_params.outline_params.inner_outline_color.get_hex(), + linewidth=render_params.outline_params.inner_outline_linewidth, + zorder=render_params.zorder, + # **kwargs, + ) + cax = ax.add_collection(_cax) + # Transform the paths in PatchCollection + for path in _cax.get_paths(): + path.vertices = trans.transform(path.vertices) + _cax = _get_collection_shape( shapes=shapes, s=render_params.scale, @@ -359,7 +426,7 @@ def _render_shapes( cmap=render_params.cmap_params.cmap, norm=norm, fill_alpha=render_params.fill_alpha, - outline_alpha=render_params.outline_alpha, + outline_alpha=0.0, zorder=render_params.zorder, # **kwargs, ) @@ -377,7 +444,10 @@ def _render_shapes( vmax=render_params.cmap_params.norm.vmax or max(color_vector), ) - if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color): + if ( + len(set(color_vector)) != 1 + or list(set(color_vector))[0] != render_params.cmap_params.na_color.get_hex_with_alpha() + ): # necessary in case different shapes elements are annotated with one table if color_source_vector is not None and render_params.col_for_color is not None: color_source_vector = color_source_vector.remove_unused_categories() @@ -420,7 +490,7 @@ def _render_points( col_for_color = render_params.col_for_color table_name = render_params.table_name table_layer = render_params.table_layer - color = render_params.color + color = render_params.color.get_hex() if render_params.color else None groups = render_params.groups palette = render_params.palette @@ -531,7 +601,10 @@ def _render_points( ) # when user specified a single color, we emulate the form of `na_color` and use it - default_color = color if col_for_color is None and color is not None else render_params.cmap_params.na_color + default_color = ( + render_params.color if col_for_color is None and color is not None else render_params.cmap_params.na_color + ) + assert isinstance(default_color, Color) # shut up mypy color_source_vector, color_vector, _ = _set_color_source_vec( sdata=sdata_filt, @@ -652,8 +725,8 @@ def _render_points( ds.tf.spread(agg, px=px), cmap=color_vector[0], color_key=color_key, - min_alpha=np.min([254, render_params.alpha * 255]), - ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes + min_alpha=_convert_alpha_to_datashader_range(render_params.alpha), + ) else: spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction) agg = ds.tf.spread(agg, px=px, how=spread_how) @@ -675,8 +748,8 @@ def _render_points( cmap=ds_cmap, span=ds_span, clip=norm.clip, - min_alpha=np.min([254, render_params.alpha * 255]), - ) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes + min_alpha=_convert_alpha_to_datashader_range(render_params.alpha), + ) rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) _ax_show_and_transform( @@ -726,7 +799,10 @@ def _render_points( ax.set_xbound(extent["x"]) ax.set_ybound(extent["y"]) - if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color): + if ( + len(set(color_vector)) != 1 + or list(set(color_vector))[0] != render_params.cmap_params.na_color.get_hex_with_alpha() + ): if color_source_vector is None: palette = ListedColormap(dict.fromkeys(color_vector)) else: @@ -1094,7 +1170,6 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) seg_erosionpx=seg_erosionpx, seg_boundaries=seg_boundaries, na_color=render_params.cmap_params.na_color, - na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user, ) _cax = ax.imshow( diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index b44175c3..5e3af820 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -4,8 +4,9 @@ from dataclasses import dataclass from typing import Literal +import numpy as np from matplotlib.axes import Axes -from matplotlib.colors import Colormap, ListedColormap, Normalize +from matplotlib.colors import Colormap, ListedColormap, Normalize, rgb2hex, to_hex from matplotlib.figure import Figure _FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"] @@ -17,14 +18,133 @@ ColorLike = tuple[float, ...] | str +# NOTE: defined here instead of utils to avoid circular import +@dataclass(kw_only=True) +class Color: + """Validate, parse and store a single color. + + Accepts a color and an alpha value. + If no color or "default" is given, the default color "lightgray" is used. + If no alpha is given, the default of completely opaque is used ("ff"). + At all times, if color indicates an alpha value, for instance as part of a hex string, the alpha parameter takes + precedence if given. + """ + + color: str + alpha: str + default_color_set: bool = False + user_defined_alpha: bool = False + + def __init__( + self, color: None | str | list[float] | tuple[float, ...] = "default", alpha: float | int | None = None + ) -> None: + # 1) Validate alpha value + if alpha is None: + self.alpha = "ff" # default: completely opaque + elif isinstance(alpha, float | int): + if alpha <= 1.0 and alpha >= 0.0: + # Convert float alpha to hex representation + self.alpha = hex(int(np.round(alpha * 255)))[2:].lower() + if len(self.alpha) == 1: + self.alpha = "0" + self.alpha + self.user_defined_alpha = True + else: + raise ValueError(f"Invalid alpha value `{alpha}`, must lie within [0.0, 1.0].") + else: + raise ValueError(f"Invalid alpha value `{alpha}`, must be None or a float | int within [0.0, 1.0].") + + # 2) Validate color value + if color is None: + self.color = to_hex("lightgray", keep_alpha=False) + # setting color to None should lead to full transparency (except alpha is set manually) + if alpha is None: + self.alpha = "00" + elif color == "default": + self.default_color_set = True + self.color = to_hex("lightgray", keep_alpha=False) + elif isinstance(color, str): + # already hex + if color.startswith("#"): + if len(color) not in [7, 9]: + raise ValueError("Invalid hex color length: only formats '#RRGGBB' and '#RRGGBBAA' are supported.") + self.color = color.lower() + if not all(c in "0123456789abcdef" for c in self.color[1:]): + raise ValueError("Invalid hex color: contains non-hex characters") + if len(self.color) == 9: + if alpha is None: + self.alpha = self.color[7:] + self.user_defined_alpha = True + self.color = self.color[:7] + else: + try: + float(color) + except ValueError: + # we're not dealing with what matplotlib considers greyscale + pass + else: + raise TypeError( + f"Invalid type `{type(color)}` for a color, expecting str | None | tuple[float, ...] | " + "list[float]. Note that unlike in matplotlib, giving a string of a number within [0, 1] as a " + "greyscale value is not supported here!" + ) + # matplotlib raises ValueError in case of invalid color name + self.color = to_hex(color, keep_alpha=False) + elif isinstance(color, list | tuple): + if len(color) < 3: + raise ValueError(f"Color `{color}` can't be interpreted as RGB(A) array, needs 3 or 4 values!") + if len(color) > 4: + raise ValueError(f"Color `{color}` can't be interpreted as RGB(A) array, needs 3 or 4 values!") + # get first 3-4 values + r, g, b = color[0], color[1], color[2] + a = 1.0 + if len(color) == 4: + a = color[3] + self.user_defined_alpha = True + if ( + not isinstance(r, int | float) + or not isinstance(g, int | float) + or not isinstance(b, int | float) + or not isinstance(a, int | float) + ): + raise ValueError(f"Invalid color `{color}`, all values in RGB(A) array must be int or float.") + if any(np.array([r, g, b, a]) > 1) or any(np.array([r, g, b, a]) < 0): + raise ValueError(f"Invalid color `{color}`, all values in RGB(A) array must be within [0.0, 1.0].") + self.color = rgb2hex((r, g, b, a), keep_alpha=False) + if alpha is None: + self.alpha = rgb2hex((r, g, b, a), keep_alpha=True)[7:] + else: + raise TypeError( + f"Invalid type `{type(color)}` for color, expecting str | None | tuple[float, ...] | list[float]." + ) + + def get_hex_with_alpha(self) -> str: + """Get color value as '#RRGGBBAA'.""" + return self.color + self.alpha + + def get_hex(self) -> str: + """Get color value as '#RRGGBB'.""" + return self.color + + def get_alpha_as_float(self) -> float: + """Return alpha as value within [0.0, 1.0].""" + return int(self.alpha, 16) / 255 + + def color_modified_by_user(self) -> bool: + """Get whether a color was passed when the object was created.""" + return not self.default_color_set + + def alpha_is_user_defined(self) -> bool: + """Get whether an alpha was set during object creation.""" + return self.user_defined_alpha + + @dataclass class CmapParams: """Cmap params.""" cmap: Colormap norm: Normalize - na_color: ColorLike - na_color_modified_by_user: bool = False + na_color: Color cmap_is_default: bool = True @@ -45,9 +165,10 @@ class FigParams: class OutlineParams: """Cmap params.""" - outline: bool - outline_color: str | list[float] - linewidth: float + outer_outline_color: Color | None = None + outer_outline_linewidth: float = 1.5 + inner_outline_color: Color | None = None + inner_outline_linewidth: float = 0.5 @dataclass @@ -72,17 +193,17 @@ class ScalebarParams: @dataclass class ShapesRenderParams: - """Labels render parameters..""" + """Shapes render parameters..""" cmap_params: CmapParams outline_params: OutlineParams element: str - color: str | None = None + color: Color | None = None col_for_color: str | None = None groups: str | list[str] | None = None contour_px: int | None = None palette: ListedColormap | list[str] | None = None - outline_alpha: float = 1.0 + outline_alpha: tuple[float, float] = (1.0, 1.0) fill_alpha: float = 0.3 scale: float = 1.0 transfunc: Callable[[float], float] | None = None @@ -99,7 +220,7 @@ class PointsRenderParams: cmap_params: CmapParams element: str - color: str | None = None + color: Color | None = None col_for_color: str | None = None groups: str | list[str] | None = None palette: ListedColormap | list[str] | None = None @@ -115,7 +236,7 @@ class PointsRenderParams: @dataclass class ImageRenderParams: - """Labels render parameters..""" + """Image render parameters..""" cmap_params: list[CmapParams] | CmapParams element: str diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 1368110d..9df1f3d0 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -72,6 +72,7 @@ from spatialdata_plot._logging import logger from spatialdata_plot.pl.render_params import ( CmapParams, + Color, FigParams, ImageRenderParams, LabelsRenderParams, @@ -88,7 +89,7 @@ # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | str +ColorLike = tuple[float, ...] | list[float] | str def _verify_plotting_tree(sdata: SpatialData) -> SpatialData: @@ -142,15 +143,13 @@ def _get_coordinate_system_mapping(sdata: SpatialData) -> dict[str, list[str]]: def _is_color_like(color: Any) -> bool: - """Check if a value is a valid color, returns False for pseudo-bools. + """Check if a value is a valid color. For discussion, see: https://github.com/scverse/spatialdata-plot/issues/327. matplotlib accepts strings in [0, 1] as grey-scale values - therefore, "0" and "1" are considered valid colors. However, we won't do that so we're filtering these out. """ - if isinstance(color, bool): - return False if isinstance(color, str): try: num_value = float(color) @@ -159,6 +158,9 @@ def _is_color_like(color: Any) -> bool: except ValueError: # we're not dealing with what matplotlib considers greyscale pass + if color.startswith("#") and len(color) not in [7, 9]: + # we only accept hex colors in the form #RRGGBB or #RRGGBBAA, not short forms as matplotlib does + return False return bool(colors.is_color_like(color)) @@ -265,37 +267,6 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame: return cs_contents -def _sanitise_na_color(na_color: ColorLike | None) -> tuple[str, bool]: - """Return the color's hex value and a boolean indicating if the user changed the default color. - - Returns the hex representation of the color and a boolean indicating whether the - color was changed by the user or not. Our default is "lightgray", but when we - render labels, we give them random colors instead. However, the user could've - manually specified "lightgray" as the color, so we need to check for that. - - Parameters - ---------- - na_color (ColorLike | None): The color input specified by the user. - - Returns - ------- - tuple[str, bool]: A tuple containing the hex color code and a boolean - indicating if the color was user-specified. - """ - if na_color == "default": - # user kept the default - return to_hex("lightgray"), False - if na_color is None: - # user wants to hide NAs - return "#FFFFFF00", True # zero alpha so it's hidden - if colors.is_color_like(na_color): - # user specified a color (including "lightgray") - return to_hex(na_color), True - - # Handle unexpected values (optional) - raise ValueError(f"Invalid na_color value: {na_color}") - - def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, float]: # Extract the vertices from the PathPatch path = pathpatch.get_path() @@ -327,6 +298,8 @@ def _get_collection_shape( render_params: ShapesRenderParams, fill_alpha: None | float = None, outline_alpha: None | float = None, + outline_color: None | str | list[float] = "white", + linewidth: float = 0.0, **kwargs: Any, ) -> PatchCollection: """ @@ -339,6 +312,8 @@ def _get_collection_shape( - norm: Normalization for the color map. - fill_alpha (float, optional): Opacity for the fill color. - outline_alpha (float, optional): Opacity for the outline. + - outline_color (optional): Color for the outline. + - linewidth (float, optional): Width for the outline. - **kwargs: Additional keyword arguments. Returns @@ -377,11 +352,13 @@ def _get_collection_shape( c = cmap(norm(c)) fill_c = ColorConverter().to_rgba_array(c) - fill_c[..., -1] *= render_params.fill_alpha + # fill_c[..., -1] *= fill_alpha # NOTE: this contradicts matplotlib behavior, therefore discarded + if fill_alpha is not None: + fill_c[..., -1] = fill_alpha - if render_params.outline_params.outline: - outline_c = ColorConverter().to_rgba_array(render_params.outline_params.outline_color) - outline_c[..., -1] = render_params.outline_alpha + if outline_alpha and outline_alpha > 0.0: + outline_c = ColorConverter().to_rgba_array(outline_color) + outline_c[..., -1] = outline_alpha outline_c = outline_c.tolist() else: outline_c = [None] @@ -411,7 +388,8 @@ def _assign_fill_and_outline_to_row( def _process_polygon(row: pd.Series, s: float) -> dict[str, Any]: coords = np.array(row["geometry"].exterior.coords) centroid = np.mean(coords, axis=0) - scaled_coords = (centroid + (coords - centroid) * s).tolist() + scaled_vectors = (coords - centroid) * s + scaled_coords = (centroid + scaled_vectors).tolist() return { **row.to_dict(), "geometry": mpatches.Polygon(scaled_coords, closed=True), @@ -456,7 +434,7 @@ def _create_patches(shapes_df: GeoDataFrame, fill_c: list[Any], outline_c: list[ return PatchCollection( patches["geometry"].values.tolist(), snap=False, - lw=render_params.outline_params.linewidth, + lw=linewidth, facecolor=patches["fill_c"], edgecolor=None if all(outline is None for outline in outline_c) else outline_c, **kwargs, @@ -512,7 +490,7 @@ def _get_scalebar( def _prepare_cmap_norm( cmap: Colormap | str | None = None, norm: Normalize | None = None, - na_color: ColorLike | None = None, + na_color: Color = Color(), ) -> CmapParams: # TODO: check refactoring norm out here as it gets overwritten later cmap_is_default = cmap is None @@ -528,41 +506,109 @@ def _prepare_cmap_norm( if norm is None: norm = Normalize(vmin=None, vmax=None, clip=False) - na_color, na_color_modified_by_user = _sanitise_na_color(na_color) - cmap.set_bad(na_color) + cmap.set_bad(na_color.get_hex_with_alpha()) return CmapParams( cmap=cmap, norm=norm, na_color=na_color, cmap_is_default=cmap_is_default, - na_color_modified_by_user=na_color_modified_by_user, ) def _set_outline( - outline: bool = False, - outline_width: float = 1.5, - outline_color: str | list[float] = "#0000000ff", # black, white + outline_alpha: float | int | tuple[float | int, float | int] | None, + outline_width: int | float | tuple[float | int, float | int] | None, + outline_color: Color | tuple[Color, Color | None] | None, **kwargs: Any, -) -> OutlineParams: - if not isinstance(outline_width, int | float): - raise TypeError(f"Invalid type of `outline_width`: {type(outline_width)}, expected `int` or `float`.") - if outline_width == 0.0: - outline = False - if outline_width < 0.0: - logger.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1) * outline_width}") - outline_width *= -1 - - # the default black and white colors can be changed using the contour_config parameter - if len(outline_color) in {3, 4} and all(isinstance(c, float) for c in outline_color): - outline_color = matplotlib.colors.to_hex(outline_color) - - if outline: +) -> tuple[tuple[float, float], OutlineParams]: + """Create OutlineParams object for shapes, including possibility of double outline. + + Rules for outline rendering: + 1) outline_alpha always takes precedence if given by the user. + In absence of outline_alpha: + 2) If outline_color is specified and implying an alpha (e.g. RGBA array or #RRGGBBAA): that alpha is used + 3) If outline_color (w/o implying an alpha) and/or outline_width is specified: alpha of outlines set to 1.0 + """ + # A) User doesn't want to see outlines + if ( + (outline_alpha and outline_alpha == 0.0) + or (isinstance(outline_alpha, tuple) and np.all(np.array(outline_alpha) == 0.0)) + or not (outline_alpha or outline_width or outline_color) + ): + return (0.0, 0.0), OutlineParams(None, 1.5, None, 0.5) + + # B) User wants to see at least 1 outline + if isinstance(outline_width, tuple): + if len(outline_width) != 2: + raise ValueError( + f"Tuple of length {len(outline_width)} was passed for outline_width. When specifying multiple outlines," + " please pass a tuple of exactly length 2." + ) + if not outline_color: + outline_color = (Color("#000000"), Color("#ffffff")) + elif not isinstance(outline_color, tuple): + raise ValueError( + "No tuple was passed for outline_color, while two outlines were specified by using the outline_width " + "argument. Please specify the outline colors in a tuple of length two." + ) + + if isinstance(outline_color, tuple): + if len(outline_color) != 2: + raise ValueError( + f"Tuple of length {len(outline_color)} was passed for outline_color. When specifying multiple outlines," + " please pass a tuple of exactly length 2." + ) + if not outline_width: + outline_width = (1.5, 0.5) + elif not isinstance(outline_width, tuple): + raise ValueError( + "No tuple was passed for outline_width, while two outlines were specified by using the outline_color " + "argument. Please specify the outline widths in a tuple of length two." + ) + + if isinstance(outline_width, float | int): + outline_width = (outline_width, 0.0) + elif not outline_width: + outline_width = (1.5, 0.0) + if isinstance(outline_color, Color): + outline_color = (outline_color, None) + elif not outline_color: + outline_color = (Color("#000000ff"), None) + + assert isinstance(outline_color, tuple), "outline_color is not a tuple" # shut up mypy + assert isinstance(outline_width, tuple), "outline_width is not a tuple" + + for ow in outline_width: + if not isinstance(ow, int | float): + raise TypeError(f"Invalid type of `outline_width`: {type(ow)}, expected `int` or `float`.") + + if outline_alpha: + if isinstance(outline_alpha, int | float): + # for a single outline: second width value is 0.0 + outline_alpha = (outline_alpha, 0.0) if outline_width[1] == 0.0 else (outline_alpha, outline_alpha) + else: + # if alpha wasn't explicitly specified by the user + outer_ol_alpha = outline_color[0].get_alpha_as_float() if isinstance(outline_color[0], Color) else 1.0 + inner_ol_alpha = outline_color[1].get_alpha_as_float() if isinstance(outline_color[1], Color) else 1.0 + outline_alpha = (outer_ol_alpha, inner_ol_alpha) + + # handle possible linewidths of 0.0 => outline won't be rendered in the first place + if outline_width[0] == 0.0: + outline_alpha = (0.0, outline_alpha[1]) + if outline_width[1] == 0.0: + outline_alpha = (outline_alpha[0], 0.0) + + if outline_alpha[0] > 0.0 or outline_alpha[1] > 0.0: kwargs.pop("edgecolor", None) # remove edge from kwargs if present kwargs.pop("alpha", None) # remove alpha from kwargs if present - return OutlineParams(outline, outline_color, outline_width) + return outline_alpha, OutlineParams( + outline_color[0], + outline_width[0], + outline_color[1], + outline_width[1], + ) def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = 3) -> plt.Figure | plt.Axes: @@ -723,7 +769,7 @@ def _set_color_source_vec( sdata: sd.SpatialData, element: SpatialElement | None, value_to_plot: str | None, - na_color: ColorLike, + na_color: Color, element_name: list[str] | str | None = None, groups: list[str] | str | None = None, palette: list[str] | str | None = None, @@ -734,7 +780,7 @@ def _set_color_source_vec( render_type: Literal["points"] | None = None, ) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: if value_to_plot is None and element is not None: - color = np.full(len(element), na_color) + color = np.full(len(element), na_color.get_hex_with_alpha()) return color, color, False # Figure out where to get the color from @@ -800,7 +846,7 @@ def _set_color_source_vec( return color_source_vector, color_vector, True logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.") - color = np.full(sdata[table_name].n_obs, to_hex(na_color)) + color = np.full(sdata[table_name].n_obs, na_color.get_hex_with_alpha()) return color, color, False @@ -810,8 +856,7 @@ def _map_color_seg( color_vector: ArrayLike | pd.Series[CategoricalDtype], color_source_vector: pd.Series[CategoricalDtype], cmap_params: CmapParams, - na_color: ColorLike, - na_color_modified_by_user: bool = False, + na_color: Color, seg_erosionpx: int | None = None, seg_boundaries: bool = False, ) -> ArrayLike: @@ -834,8 +879,8 @@ def _map_color_seg( if color_source_vector is not None and ( set(color_vector) == set(color_source_vector) and len(set(color_vector)) == 1 - and set(color_vector) == {na_color} - and not na_color_modified_by_user + and set(color_vector) == {na_color.get_hex_with_alpha()} + and not na_color.color_modified_by_user() ): val_im = map_array(seg.copy(), cell_id, cell_id) RNG = default_rng(42) @@ -876,21 +921,17 @@ def _generate_base_categorial_color_mapping( adata: AnnData, cluster_key: str, color_source_vector: ArrayLike | pd.Series[CategoricalDtype], - na_color: ColorLike, + na_color: Color, cmap_params: CmapParams | None = None, ) -> Mapping[str, str]: if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns: colors = adata.uns[f"{cluster_key}_colors"] categories = color_source_vector.categories.tolist() + ["NaN"] - if "#" not in na_color: - # should be unreachable, but just for safety - raise ValueError("Expected `na_color` to be a hex color, but got a non-hex color.") colors = [to_hex(to_rgba(color)[:3]) for color in colors] - na_color = to_hex(to_rgba(na_color)[:3]) - if na_color and len(categories) > len(colors): - return dict(zip(categories, colors + [na_color], strict=True)) + if len(categories) > len(colors): + return dict(zip(categories, colors + [na_color.get_hex_with_alpha()], strict=True)) return dict(zip(categories, colors, strict=True)) @@ -952,7 +993,7 @@ def _get_default_categorial_color_mapping( def _get_categorical_color_mapping( adata: AnnData, - na_color: ColorLike, + na_color: Color, cluster_key: str | None = None, color_source_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None, cmap_params: CmapParams | None = None, @@ -1027,7 +1068,7 @@ def _decorate_axs( adata: AnnData | None = None, palette: ListedColormap | str | list[str] | None = None, alpha: float = 1.0, - na_color: ColorLike | None = "#d3d3d3", # lightgray + na_color: Color = Color("default"), legend_fontsize: int | float | _FontSize | None = None, legend_fontweight: int | _FontWeight = "bold", legend_loc: str | None = "right margin", @@ -1068,7 +1109,7 @@ def _decorate_axs( legend_fontweight=legend_fontweight, legend_fontsize=legend_fontsize, legend_fontoutline=path_effect, - na_color=[na_color], + na_color=[na_color.get_hex()], na_in_legend=na_in_legend, multi_panel=fig_params.axs is not None, ) @@ -1326,14 +1367,13 @@ def _rasterize_if_necessary( target_y_dims = dpi * height target_x_dims = dpi * width - # TODO: when exactly do we want to rasterize? + # Heuristics for when to rasterize do_rasterization = y_dims > target_y_dims + 100 or x_dims > target_x_dims + 100 if x_dims < 2000 and y_dims < 2000: do_rasterization = False if do_rasterization: logger.info("Rasterizing image for faster rendering.") - # TODO: do we want min here? target_unit_to_pixels = min(target_y_dims / y_dims, target_x_dims / x_dims) image = rasterize( image, @@ -1606,28 +1646,93 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st "points", "labels", }: - if not isinstance(color, str): - raise TypeError("Parameter 'color' must be a string.") + if not isinstance(color, str | tuple | list): + raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") if element_type in {"shapes", "points"}: if _is_color_like(color): logger.info("Value for parameter 'color' appears to be a color, using it as such.") param_dict["col_for_color"] = None - else: + param_dict["color"] = Color(color) + if param_dict["color"].alpha_is_user_defined(): + if element_type == "points" and param_dict.get("alpha") is None: + param_dict["alpha"] = param_dict["color"].get_alpha_as_float() + elif element_type == "shapes" and param_dict.get("fill_alpha") is None: + param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() + else: + logger.info( + f"Alpha implied by color '{color}' is ignored since the parameter 'alpha' or 'fill_alpha' " + "is set and its value takes precedence." + ) + elif isinstance(color, str): param_dict["col_for_color"] = color param_dict["color"] = None + else: + raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") elif "color" in param_dict and element_type != "labels": param_dict["col_for_color"] = None if outline_width := param_dict.get("outline_width"): - if not isinstance(outline_width, float | int): - raise TypeError("Parameter 'outline_width' must be numeric.") - if outline_width < 0: + # outline_width only exists for shapes at the moment + if isinstance(outline_width, tuple): + for ow in outline_width: + if isinstance(ow, float | int): + if ow < 0: + raise ValueError("Parameter 'outline_width' cannot contain negative values.") + else: + raise TypeError("Parameter 'outline_width' must contain only numerics when it is a tuple.") + elif not isinstance(outline_width, float | int): + raise TypeError("Parameter 'outline_width' must be numeric or a tuple of two numerics.") + if isinstance(outline_width, float | int) and outline_width < 0: raise ValueError("Parameter 'outline_width' cannot be negative.") - if (outline_alpha := param_dict.get("outline_alpha")) and ( - not isinstance(outline_alpha, float | int) or not 0 <= outline_alpha <= 1 - ): - raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") + if outline_alpha := param_dict.get("outline_alpha"): + if isinstance(outline_alpha, tuple): + if element_type != "shapes": + raise ValueError("Parameter 'outline_alpha' must be a single numeric.") + if len(outline_alpha) == 1: + if not isinstance(outline_alpha[0], float | int) or not 0 <= outline_alpha[0] <= 1: + raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") + param_dict["outline_alpha"] = outline_alpha[0] + elif len(outline_alpha) < 1: + raise ValueError("Empty tuple is not supported as input for outline_alpha!") + else: + if len(outline_alpha) > 2: + logger.warning( + f"Tuple of length {len(outline_alpha)} was passed for outline_alpha, only first two positions " + "are used since more than 2 outlines are not supported!" + ) + if ( + not isinstance(outline_alpha[0], float | int) + or not isinstance(outline_alpha[1], float | int) + or not 0 <= outline_alpha[0] <= 1 + or not 0 <= outline_alpha[1] <= 1 + ): + raise TypeError("Parameter 'outline_alpha' must contain numeric values between 0 and 1.") + param_dict["outline_alpha"] = (outline_alpha[0], outline_alpha[1]) + elif not isinstance(outline_alpha, float | int) or not 0 <= outline_alpha <= 1: + raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") + + if outline_color := param_dict.get("outline_color"): + if not isinstance(outline_color, str | tuple | list): + raise TypeError("Parameter 'color' must be a string or a tuple/list of floats or colors.") + if isinstance(outline_color, tuple | list): + if len(outline_color) < 1: + raise ValueError("Empty tuple is not supported as input for outline_color!") + if len(outline_color) == 1: + param_dict["outline_color"] = Color(outline_color[0]) + elif len(outline_color) == 2: + # assuming the case of 2 outlines + param_dict["outline_color"] = (Color(outline_color[0]), Color(outline_color[1])) + elif len(outline_color) in [3, 4]: + # assuming RGB(A) array + param_dict["outline_color"] = Color(outline_color) + else: + raise ValueError( + f"Tuple/List of length {len(outline_color)} was passed for outline_color. Valid options would be: " + "tuple of 2 colors (for 2 outlines) or an RGB(A) array, aka a list/tuple of 3-4 floats." + ) + else: + param_dict["outline_color"] = Color(outline_color) if contour_px is not None and contour_px <= 0: raise ValueError("Parameter 'contour_px' must be a positive number.") @@ -1637,12 +1742,18 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st raise TypeError("Parameter 'alpha' must be numeric.") if not 0 <= alpha <= 1: raise ValueError("Parameter 'alpha' must be between 0 and 1.") + elif element_type == "points": + # set default alpha for points if not given by user explicitly or implicitly (as part of color) + param_dict["alpha"] = 1.0 if (fill_alpha := param_dict.get("fill_alpha")) is not None: if not isinstance(fill_alpha, float | int): raise TypeError("Parameter 'fill_alpha' must be numeric.") if fill_alpha < 0: raise ValueError("Parameter 'fill_alpha' cannot be negative.") + elif element_type == "shapes": + # set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color) + param_dict["fill_alpha"] = 1.0 if (cmap := param_dict.get("cmap")) is not None and (palette := param_dict.get("palette")) is not None: raise ValueError("Both `palette` and `cmap` are specified. Please specify only one of them.") @@ -1683,10 +1794,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st else: raise TypeError("Parameter 'cmap' must be a string, a Colormap, or a list of these types.") - if (na_color := param_dict.get("na_color")) != "default" and ( - na_color is not None and not _is_color_like(na_color) - ): - raise ValueError("Parameter 'na_color' must be color-like.") + # validation happens within Color constructor + param_dict["na_color"] = Color(param_dict.get("na_color")) if (norm := param_dict.get("norm")) is not None: if element_type in {"images", "labels"} and not isinstance(norm, Normalize): @@ -1846,8 +1955,8 @@ def _validate_label_render_params( def _validate_points_render_params( sdata: sd.SpatialData, element: str | None, - alpha: float | int, - color: str | None, + alpha: float | int | None, + color: ColorLike | None, groups: list[str] | str | None, palette: list[str] | str | None, na_color: ColorLike | None, @@ -1908,14 +2017,14 @@ def _validate_points_render_params( def _validate_shape_render_params( sdata: sd.SpatialData, element: str | None, - fill_alpha: float | int, + fill_alpha: float | int | None, groups: list[str] | str | None, palette: list[str] | str | None, - color: list[str] | str | None, + color: ColorLike | None, na_color: ColorLike | None, - outline_width: float | int, - outline_color: str | list[float], - outline_alpha: float | int, + outline_width: float | int | tuple[float | int, float | int] | None, + outline_color: ColorLike | tuple[ColorLike] | None, + outline_alpha: float | int | tuple[float | int, float | int] | None, cmap: list[Colormap | str] | Colormap | str | None, norm: Normalize | None, scale: float | int, @@ -2477,3 +2586,9 @@ def _hex_no_alpha(hex: str) -> str: return "#" + hex_digits[:6] raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'") + + +def _convert_alpha_to_datashader_range(alpha: float) -> float: + """Convert alpha from the range [0, 1] to the range [0, 255] used in datashader.""" + # prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes + return min([254, alpha * 255]) diff --git a/tests/_images/Points_alpha_overwrites_opacity_from_color.png b/tests/_images/Points_alpha_overwrites_opacity_from_color.png new file mode 100644 index 00000000..b7790708 Binary files /dev/null and b/tests/_images/Points_alpha_overwrites_opacity_from_color.png differ diff --git a/tests/_images/Points_color_recognises_actual_color_as_color.png b/tests/_images/Points_can_color_by_color_name.png similarity index 100% rename from tests/_images/Points_color_recognises_actual_color_as_color.png rename to tests/_images/Points_can_color_by_color_name.png diff --git a/tests/_images/Points_can_color_by_hex.png b/tests/_images/Points_can_color_by_hex.png new file mode 100644 index 00000000..93212977 Binary files /dev/null and b/tests/_images/Points_can_color_by_hex.png differ diff --git a/tests/_images/Points_can_color_by_hex_with_alpha.png b/tests/_images/Points_can_color_by_hex_with_alpha.png new file mode 100644 index 00000000..2ca1a695 Binary files /dev/null and b/tests/_images/Points_can_color_by_hex_with_alpha.png differ diff --git a/tests/_images/Points_can_color_by_rgb_array.png b/tests/_images/Points_can_color_by_rgb_array.png new file mode 100644 index 00000000..b7790708 Binary files /dev/null and b/tests/_images/Points_can_color_by_rgb_array.png differ diff --git a/tests/_images/Points_can_color_by_rgba_array.png b/tests/_images/Points_can_color_by_rgba_array.png new file mode 100644 index 00000000..74000340 Binary files /dev/null and b/tests/_images/Points_can_color_by_rgba_array.png differ diff --git a/tests/_images/Shapes_alpha_overwrites_opacity_from_color.png b/tests/_images/Shapes_alpha_overwrites_opacity_from_color.png new file mode 100644 index 00000000..64b4feb2 Binary files /dev/null and b/tests/_images/Shapes_alpha_overwrites_opacity_from_color.png differ diff --git a/tests/_images/Shapes_color_recognises_actual_color_as_color.png b/tests/_images/Shapes_can_color_by_color_name.png similarity index 100% rename from tests/_images/Shapes_color_recognises_actual_color_as_color.png rename to tests/_images/Shapes_can_color_by_color_name.png diff --git a/tests/_images/Shapes_can_color_by_hex.png b/tests/_images/Shapes_can_color_by_hex.png new file mode 100644 index 00000000..de1cae1f Binary files /dev/null and b/tests/_images/Shapes_can_color_by_hex.png differ diff --git a/tests/_images/Shapes_can_color_by_hex_with_alpha.png b/tests/_images/Shapes_can_color_by_hex_with_alpha.png new file mode 100644 index 00000000..284766df Binary files /dev/null and b/tests/_images/Shapes_can_color_by_hex_with_alpha.png differ diff --git a/tests/_images/Shapes_can_color_by_rgb_array.png b/tests/_images/Shapes_can_color_by_rgb_array.png new file mode 100644 index 00000000..64b4feb2 Binary files /dev/null and b/tests/_images/Shapes_can_color_by_rgb_array.png differ diff --git a/tests/_images/Shapes_can_color_by_rgba_array.png b/tests/_images/Shapes_can_color_by_rgba_array.png new file mode 100644 index 00000000..1429e83e Binary files /dev/null and b/tests/_images/Shapes_can_color_by_rgba_array.png differ diff --git a/tests/_images/Shapes_can_render_circles_with_colored_outline.png b/tests/_images/Shapes_can_render_circles_with_colored_outline.png index 23c0e75f..43510653 100644 Binary files a/tests/_images/Shapes_can_render_circles_with_colored_outline.png and b/tests/_images/Shapes_can_render_circles_with_colored_outline.png differ diff --git a/tests/_images/Shapes_can_render_circles_with_default_outline_width.png b/tests/_images/Shapes_can_render_circles_with_default_outline_width.png index 2d0b11c6..eb6aafab 100644 Binary files a/tests/_images/Shapes_can_render_circles_with_default_outline_width.png and b/tests/_images/Shapes_can_render_circles_with_default_outline_width.png differ diff --git a/tests/_images/Shapes_can_render_circles_with_outline.png b/tests/_images/Shapes_can_render_circles_with_outline.png index 2d0b11c6..eb6aafab 100644 Binary files a/tests/_images/Shapes_can_render_circles_with_outline.png and b/tests/_images/Shapes_can_render_circles_with_outline.png differ diff --git a/tests/_images/Shapes_can_render_circles_with_specified_outline_width.png b/tests/_images/Shapes_can_render_circles_with_specified_outline_width.png index c3d6e271..7b31504f 100644 Binary files a/tests/_images/Shapes_can_render_circles_with_specified_outline_width.png and b/tests/_images/Shapes_can_render_circles_with_specified_outline_width.png differ diff --git a/tests/_images/Shapes_can_render_double_outline_with_diff_alpha.png b/tests/_images/Shapes_can_render_double_outline_with_diff_alpha.png new file mode 100644 index 00000000..944a049b Binary files /dev/null and b/tests/_images/Shapes_can_render_double_outline_with_diff_alpha.png differ diff --git a/tests/_images/Shapes_can_render_polygons_with_outline.png b/tests/_images/Shapes_can_render_polygons_with_outline.png index 12cfde7b..ef938f9d 100644 Binary files a/tests/_images/Shapes_can_render_polygons_with_outline.png and b/tests/_images/Shapes_can_render_polygons_with_outline.png differ diff --git a/tests/_images/Shapes_can_render_polygons_with_rgb_colored_outline.png b/tests/_images/Shapes_can_render_polygons_with_rgb_colored_outline.png index 7d02401b..a3800d4c 100644 Binary files a/tests/_images/Shapes_can_render_polygons_with_rgb_colored_outline.png and b/tests/_images/Shapes_can_render_polygons_with_rgb_colored_outline.png differ diff --git a/tests/_images/Shapes_can_render_polygons_with_rgba_colored_outline.png b/tests/_images/Shapes_can_render_polygons_with_rgba_colored_outline.png index ffc5e422..6d268c2b 100644 Binary files a/tests/_images/Shapes_can_render_polygons_with_rgba_colored_outline.png and b/tests/_images/Shapes_can_render_polygons_with_rgba_colored_outline.png differ diff --git a/tests/_images/Shapes_can_render_polygons_with_str_colored_outline.png b/tests/_images/Shapes_can_render_polygons_with_str_colored_outline.png index a43b1027..920fdc7b 100644 Binary files a/tests/_images/Shapes_can_render_polygons_with_str_colored_outline.png and b/tests/_images/Shapes_can_render_polygons_with_str_colored_outline.png differ diff --git a/tests/_images/Shapes_can_render_shapes_with_colored_double_outline.png b/tests/_images/Shapes_can_render_shapes_with_colored_double_outline.png new file mode 100644 index 00000000..d1381b97 Binary files /dev/null and b/tests/_images/Shapes_can_render_shapes_with_colored_double_outline.png differ diff --git a/tests/_images/Shapes_can_render_shapes_with_double_outline.png b/tests/_images/Shapes_can_render_shapes_with_double_outline.png new file mode 100644 index 00000000..2a01f8e5 Binary files /dev/null and b/tests/_images/Shapes_can_render_shapes_with_double_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_shapes_with_colored_double_outline.png b/tests/_images/Shapes_datashader_can_render_shapes_with_colored_double_outline.png new file mode 100644 index 00000000..36ce6f7b Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_shapes_with_colored_double_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_shapes_with_double_outline.png b/tests/_images/Shapes_datashader_can_render_shapes_with_double_outline.png new file mode 100644 index 00000000..fc7d0270 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_shapes_with_double_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_colored_outline.png b/tests/_images/Shapes_datashader_can_render_with_colored_outline.png index d4db8bd6..ea8944fc 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_colored_outline.png and b/tests/_images/Shapes_datashader_can_render_with_colored_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png b/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png index d0f2d5e5..5f374369 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png and b/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_outline.png b/tests/_images/Shapes_datashader_can_render_with_outline.png index cf017519..c4d0b5c0 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_outline.png and b/tests/_images/Shapes_datashader_can_render_with_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png b/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png index ca512872..9ce92a8b 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png and b/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png b/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png index 6fb9c127..c54b2b6f 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png and b/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_circles.png b/tests/_images/Shapes_datashader_can_transform_circles.png index 60cde073..49659e0d 100644 Binary files a/tests/_images/Shapes_datashader_can_transform_circles.png and b/tests/_images/Shapes_datashader_can_transform_circles.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_multipolygons.png b/tests/_images/Shapes_datashader_can_transform_multipolygons.png index 09e56c63..03fde2cf 100644 Binary files a/tests/_images/Shapes_datashader_can_transform_multipolygons.png and b/tests/_images/Shapes_datashader_can_transform_multipolygons.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_polygons.png b/tests/_images/Shapes_datashader_can_transform_polygons.png index fb2552ff..f58a9bd4 100644 Binary files a/tests/_images/Shapes_datashader_can_transform_polygons.png and b/tests/_images/Shapes_datashader_can_transform_polygons.png differ diff --git a/tests/_images/Shapes_outline_alpha_takes_precedence.png b/tests/_images/Shapes_outline_alpha_takes_precedence.png new file mode 100644 index 00000000..02daf58f Binary files /dev/null and b/tests/_images/Shapes_outline_alpha_takes_precedence.png differ diff --git a/tests/_images/Utils_set_outline_accepts_str_or_float_or_list_thereof.png b/tests/_images/Utils_set_outline_accepts_str_or_float_or_list_thereof.png index 3bc09956..673a4ce1 100644 Binary files a/tests/_images/Utils_set_outline_accepts_str_or_float_or_list_thereof.png and b/tests/_images/Utils_set_outline_accepts_str_or_float_or_list_thereof.png differ diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 9a93c7bd..2a60f4c1 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -84,9 +84,24 @@ def test_plot_can_stack_render_points(self, sdata_blobs: SpatialData): .pl.show() ) - def test_plot_color_recognises_actual_color_as_color(self, sdata_blobs: SpatialData): + def test_plot_can_color_by_color_name(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_points(element="blobs_points", color="red").pl.show() + def test_plot_can_color_by_rgb_array(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points(element="blobs_points", color=[0.5, 0.5, 1.0]).pl.show() + + def test_plot_can_color_by_rgba_array(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points(element="blobs_points", color=[0.5, 0.5, 1.0, 0.5]).pl.show() + + def test_plot_can_color_by_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points(element="blobs_points", color="#88a136").pl.show() + + def test_plot_can_color_by_hex_with_alpha(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points(element="blobs_points", color="#88a13688").pl.show() + + def test_plot_alpha_overwrites_opacity_from_color(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points(element="blobs_points", color=[0.5, 0.5, 1.0, 0.5], alpha=1.0).pl.show() + def test_plot_points_coercable_categorical_color(self, sdata_blobs: SpatialData): n_obs = len(sdata_blobs["blobs_points"]) adata = AnnData( diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 5b890b1d..27924a71 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -40,7 +40,7 @@ def test_plot_can_render_circles_with_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show() def test_plot_can_render_circles_with_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_color="red").pl.show() + sdata_blobs.pl.render_shapes(element="blobs_circles", outline_color="red").pl.show() def test_plot_can_render_polygons(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show() @@ -49,17 +49,13 @@ def test_plot_can_render_polygons_with_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1).pl.show() def test_plot_can_render_polygons_with_str_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1, outline_color="red").pl.show() + sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color="red").pl.show() def test_plot_can_render_polygons_with_rgb_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes( - element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0, 1.0) - ).pl.show() + sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 0.0, 1.0, 1.0)).pl.show() def test_plot_can_render_polygons_with_rgba_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes( - element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0) - ).pl.show() + sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0)).pl.show() def test_plot_can_render_empty_geometry(self, sdata_blobs: SpatialData): sdata_blobs.shapes["blobs_circles"].at[0, "geometry"] = gpd.points_from_xy([None], [None])[0] @@ -69,7 +65,7 @@ def test_plot_can_render_circles_with_default_outline_width(self, sdata_blobs: S sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show() def test_plot_can_render_circles_with_specified_outline_width(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_width=3.0).pl.show() + sdata_blobs.pl.render_shapes(element="blobs_circles", outline_width=3.0).pl.show() def test_plot_can_render_multipolygons(self): def _make_multi(): @@ -270,8 +266,23 @@ def test_plot_can_stack_render_shapes(self, sdata_blobs: SpatialData): .pl.show() ) - def test_plot_color_recognises_actual_color_as_color(self, sdata_blobs: SpatialData): - (sdata_blobs.pl.render_shapes(element="blobs_circles", color="red").pl.show()) + def test_plot_can_color_by_color_name(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", color="red").pl.show() + + def test_plot_can_color_by_rgb_array(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", color=[0.5, 0.5, 1.0]).pl.show() + + def test_plot_can_color_by_rgba_array(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", color=[0.5, 0.5, 1.0, 0.5]).pl.show() + + def test_plot_can_color_by_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", color="#88a136").pl.show() + + def test_plot_can_color_by_hex_with_alpha(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", color="#88a13688").pl.show() + + def test_plot_alpha_overwrites_opacity_from_color(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", color=[0.5, 0.5, 1.0, 0.5], fill_alpha=1.0).pl.show() def test_plot_shapes_coercable_categorical_color(self, sdata_blobs: SpatialData): n_obs = len(sdata_blobs["blobs_polygons"]) @@ -391,23 +402,19 @@ def test_plot_datashader_can_render_with_diff_alpha_outline(self, sdata_blobs: S sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_alpha=0.5).pl.show() def test_plot_datashader_can_render_with_diff_width_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes( - method="datashader", element="blobs_polygons", outline_alpha=1.0, outline_width=5.0 - ).pl.show() + sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_width=5.0).pl.show() def test_plot_datashader_can_render_with_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes( - method="datashader", element="blobs_polygons", outline_alpha=1, outline_color="red" - ).pl.show() + sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_color="red").pl.show() def test_plot_datashader_can_render_with_rgb_colored_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes( - method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0) + method="datashader", element="blobs_polygons", outline_color=(0.0, 0.0, 1.0) ).pl.show() def test_plot_datashader_can_render_with_rgba_colored_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes( - method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0) + method="datashader", element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0) ).pl.show() def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData): @@ -562,6 +569,35 @@ def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialDat sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show() + def test_plot_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes("blobs_circles", outline_width=(10.0, 5.0)).pl.show() + + def test_plot_can_render_shapes_with_colored_double_outline(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes( + "blobs_polygons", outline_width=(10.0, 5.0), outline_color=("purple", "orange") + ).pl.show() + + def test_plot_can_render_double_outline_with_diff_alpha(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes( + element="blobs_circles", outline_color=("red", "blue"), outline_alpha=(0.7, 0.3), outline_width=(20, 10) + ).pl.show() + + def test_plot_outline_alpha_takes_precedence(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes( + element="blobs_circles", outline_color=("#ff660033", "#33aa0066"), outline_width=(20, 10), outline_alpha=1.0 + ).pl.show() + + def test_plot_datashader_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes("blobs_circles", outline_width=(10.0, 5.0), method="datashader").pl.show() + + def test_plot_datashader_can_render_shapes_with_colored_double_outline(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes( + "blobs_polygons", + outline_width=(10.0, 5.0), + outline_color=("purple", "orange"), + method="datashader", + ).pl.show() + def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): # Work on an independent copy since we mutate tables diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index cd324d32..0eef85e3 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -6,7 +6,7 @@ from spatialdata import SpatialData import spatialdata_plot -from spatialdata_plot.pl.utils import _get_subplots, _sanitise_na_color +from spatialdata_plot.pl.utils import _get_subplots from tests.conftest import DPI, PlotTester, PlotTesterMeta sc.pl.set_rcParams_defaults() @@ -89,45 +89,6 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]): assert spatialdata_plot.pl.utils._is_color_like(color) == result -@pytest.mark.parametrize( - "input_output", - [ - (None, ("#FFFFFF00", True)), - ("default", ("#d3d3d3ff", False)), - ("red", ("#ff0000ff", True)), - ((1, 0, 0), ("#ff0000ff", True)), - ((1, 0, 0, 0.5), ("#ff000080", True)), - ], -) -def test_utils_sanitise_na_color(input_output): - from spatialdata_plot.pl.utils import _sanitise_na_color - - func_input, expected_output = input_output - - assert _sanitise_na_color(func_input) == expected_output - - -@pytest.mark.parametrize( - "input_output", - [ - (None, ("#FFFFFF00", True)), - ("default", ("#d3d3d3ff", False)), - ("red", ("#ff0000ff", True)), - ((1, 0, 0), ("#ff0000ff", True)), - ((1, 0, 0, 0.5), ("#ff000080", True)), - ], -) -def test_utils_sanitise_na_color_accepts_valid_inputs(input_output): - func_input, expected_output = input_output - - assert _sanitise_na_color(func_input) == expected_output - - -def test_utils_sanitise_na_color_fails_when_input_isnt_a_color(): - with pytest.raises(ValueError): - _sanitise_na_color((1, 0)) - - @pytest.mark.parametrize( "input_output", [