@@ -369,12 +369,20 @@ def _render_shapes(
369369 agg = agg .where ((agg <= norm .vmin ) | (np .isnan (agg )), other = 2 )
370370 agg = agg .where ((agg != norm .vmin ) | (np .isnan (agg )), other = 0.5 )
371371
372- color_key = (
373- [_hex_no_alpha (x ) for x in color_vector .categories .values ]
374- if (type (color_vector ) is pd .core .arrays .categorical .Categorical )
375- and (len (color_vector .categories .values ) > 1 )
376- else None
377- )
372+ color_key : dict [str , str ] | None = None
373+ if color_by_categorical and col_for_color is not None :
374+ cat_series = pd .Categorical (transformed_element [col_for_color ])
375+ colors_arr = np .asarray (color_vector , dtype = object )
376+ color_key = {}
377+ for cat in cat_series .categories :
378+ if cat == "nan" :
379+ key_color = render_params .cmap_params .na_color .get_hex ()
380+ else :
381+ idx = np .flatnonzero (cat_series == cat )
382+ key_color = colors_arr [idx [0 ]] if idx .size else render_params .cmap_params .na_color .get_hex ()
383+ if isinstance (key_color , str ) and key_color .startswith ("#" ):
384+ key_color = _hex_no_alpha (key_color )
385+ color_key [str (cat )] = key_color
378386
379387 if color_by_categorical or col_for_color is None :
380388 ds_cmap = None
@@ -897,16 +905,20 @@ def _render_points(
897905 agg = agg .where ((agg <= norm .vmin ) | (np .isnan (agg )), other = 2 )
898906 agg = agg .where ((agg != norm .vmin ) | (np .isnan (agg )), other = 0.5 )
899907
900- color_key : list [str ] | None = (
901- list (color_vector .categories .values )
902- if (type (color_vector ) is pd .core .arrays .categorical .Categorical )
903- and (len (color_vector .categories .values ) > 1 )
904- else None
905- )
906-
907- # remove alpha from color if it's hex
908- if color_key is not None and color_key [0 ][0 ] == "#" :
909- color_key = [_hex_no_alpha (x ) for x in color_key ]
908+ color_key : dict [str , str ] | None = None
909+ if color_by_categorical and col_for_color is not None :
910+ cat_series = pd .Categorical (transformed_element [col_for_color ])
911+ colors_arr = np .asarray (color_vector , dtype = object )
912+ color_key = {}
913+ for cat in cat_series .categories :
914+ if cat == "nan" :
915+ key_color = render_params .cmap_params .na_color .get_hex ()
916+ else :
917+ idx = np .flatnonzero (cat_series == cat )
918+ key_color = colors_arr [idx [0 ]] if idx .size else render_params .cmap_params .na_color .get_hex ()
919+ if isinstance (key_color , str ) and key_color .startswith ("#" ):
920+ key_color = _hex_no_alpha (key_color )
921+ color_key [str (cat )] = key_color
910922 if isinstance (color_vector [0 ], str ) and (color_vector is not None and color_vector [0 ][0 ] == "#" ):
911923 color_vector = np .asarray ([_hex_no_alpha (x ) for x in color_vector ])
912924
0 commit comments