Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.

Commit feb3cfd

Browse files
shared coloraxis/colorbar
1 parent c6c91fa commit feb3cfd

File tree

1 file changed

+17
-36
lines changed

1 file changed

+17
-36
lines changed

plotly_express/_core.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ def make_mapping(args, variable):
137137
)
138138

139139

140-
def make_trace_kwargs(
141-
args, trace_spec, g, mapping_labels, sizeref, color_range, show_colorbar
142-
):
140+
def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
143141

144142
if "line_close" in args and args["line_close"]:
145143
g = g.append(g.iloc[0])
@@ -235,11 +233,9 @@ def make_trace_kwargs(
235233
v_label_col = get_decorated_label(args, col, None)
236234
mapping_labels[v_label_col] = "%%{customdata[%d]}" % i
237235
elif k == "color":
238-
colorbar_container = None
239236
if trace_spec.constructor == go.Choropleth:
240237
result["z"] = g[v]
241-
colorbar_container = result
242-
color_letter = "z"
238+
result["z"]["coloraxis"] = "coloraxis1"
243239
mapping_labels[v_label] = "%{z}"
244240
else:
245241
colorable = "marker"
@@ -248,18 +244,8 @@ def make_trace_kwargs(
248244
if colorable not in result:
249245
result[colorable] = dict()
250246
result[colorable]["color"] = g[v]
251-
colorbar_container = result[colorable]
252-
color_letter = "c"
247+
result[colorable]["coloraxis"] = "coloraxis1"
253248
mapping_labels[v_label] = "%%{%s.color}" % colorable
254-
d = len(args["color_continuous_scale"]) - 1
255-
colorbar_container["colorscale"] = [
256-
[(1.0 * i) / (1.0 * d), x]
257-
for i, x in enumerate(args["color_continuous_scale"])
258-
]
259-
colorbar_container["showscale"] = show_colorbar
260-
colorbar_container[color_letter + "min"] = color_range[0]
261-
colorbar_container[color_letter + "max"] = color_range[1]
262-
colorbar_container["colorbar"] = dict(title=v_label)
263249
elif k == "animation_group":
264250
result["ids"] = g[v]
265251
elif k == "locations":
@@ -690,7 +676,6 @@ def infer_config(args, constructor, trace_patch):
690676
if "size" in args and args["size"]:
691677
sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2
692678

693-
color_range = None
694679
if "color" in args:
695680
if "color_continuous_scale" in args:
696681
if "color_discrete_sequence" not in args:
@@ -708,15 +693,7 @@ def infer_config(args, constructor, trace_patch):
708693
else:
709694
grouped_attrs.append("marker.color")
710695

711-
if "color" in attrs and args["color"]:
712-
cmin = args["data_frame"][args["color"]].min()
713-
cmax = args["data_frame"][args["color"]].max()
714-
if args["color_continuous_midpoint"] is not None:
715-
cmid = args["color_continuous_midpoint"]
716-
delta = max(cmax - cmid, cmid - cmin)
717-
color_range = [cmid - delta, cmid + delta]
718-
else:
719-
color_range = [cmin, cmax]
696+
show_colorbar = bool("color" in attrs and args["color"])
720697

721698
if "line_dash" in args:
722699
grouped_attrs.append("line.dash")
@@ -753,7 +730,7 @@ def infer_config(args, constructor, trace_patch):
753730

754731
grouped_mappings = [make_mapping(args, a) for a in grouped_attrs]
755732
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
756-
return trace_specs, grouped_mappings, sizeref, color_range
733+
return trace_specs, grouped_mappings, sizeref, show_colorbar
757734

758735

759736
def get_orderings(args, grouper, grouped):
@@ -790,7 +767,7 @@ def get_orderings(args, grouper, grouped):
790767

791768
def make_figure(args, constructor, trace_patch={}, layout_patch={}):
792769
apply_default_cascade(args)
793-
trace_specs, grouped_mappings, sizeref, color_range = infer_config(
770+
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
794771
args, constructor, trace_patch
795772
)
796773
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
@@ -877,13 +854,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
877854
trace.update(marker=dict(color=trace.line.color))
878855

879856
patch, fit_results = make_trace_kwargs(
880-
args,
881-
trace_spec,
882-
group,
883-
mapping_labels.copy(),
884-
sizeref,
885-
color_range=color_range,
886-
show_colorbar=(frame_name not in frames),
857+
args, trace_spec, group, mapping_labels.copy(), sizeref
887858
)
888859
trace.update(patch)
889860
if fit_results is not None:
@@ -898,6 +869,16 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
898869
frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"])
899870
)
900871
layout_patch = layout_patch.copy()
872+
if show_colorbar:
873+
d = len(args["color_continuous_scale"]) - 1
874+
layout_patch["coloraxis1"] = dict(
875+
colorbar=dict(title=get_decorated_label(args, args["color"], "color")),
876+
colorscale=[
877+
[(1.0 * i) / (1.0 * d), x]
878+
for i, x in enumerate(args["color_continuous_scale"])
879+
],
880+
cmid=args["color_continuous_midpoint"],
881+
)
901882
for v in ["title", "height", "width", "template"]:
902883
if args[v]:
903884
layout_patch[v] = args[v]

0 commit comments

Comments
 (0)