diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index ee043f351dbb..e50e6b576473 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -1,8 +1,3 @@ -""" -A collection of utility functions and classes. Originally, many -(but not all) were from the Python Cookbook -- hence the name cbook. -""" - import collections import collections.abc import contextlib @@ -12,7 +7,6 @@ import math import operator import os -from pathlib import Path import shlex import subprocess import sys @@ -20,18 +14,19 @@ import traceback import types import weakref +from pathlib import Path +import matplotlib import numpy as np +from codeflash.verification.codeflash_capture import codeflash_capture +from matplotlib import _api, _c_internal_utils +'\nA collection of utility functions and classes. Originally, many\n(but not all) were from the Python Cookbook -- hence the name cbook.\n' try: - from numpy.exceptions import VisibleDeprecationWarning # numpy >= 1.25 + from numpy.exceptions import VisibleDeprecationWarning except ImportError: from numpy import VisibleDeprecationWarning -import matplotlib -from matplotlib import _api, _c_internal_utils - - def _get_running_interactive_framework(): """ Return the interactive framework whose event loop is currently running, if @@ -43,52 +38,42 @@ def _get_running_interactive_framework(): One of the following values: "qt", "gtk3", "gtk4", "wx", "tk", "macosx", "headless", ``None``. """ - # Use ``sys.modules.get(name)`` rather than ``name in sys.modules`` as - # entries can also have been explicitly set to None. - QtWidgets = ( - sys.modules.get("PyQt6.QtWidgets") - or sys.modules.get("PySide6.QtWidgets") - or sys.modules.get("PyQt5.QtWidgets") - or sys.modules.get("PySide2.QtWidgets") - ) + QtWidgets = sys.modules.get('PyQt6.QtWidgets') or sys.modules.get('PySide6.QtWidgets') or sys.modules.get('PyQt5.QtWidgets') or sys.modules.get('PySide2.QtWidgets') if QtWidgets and QtWidgets.QApplication.instance(): - return "qt" - Gtk = sys.modules.get("gi.repository.Gtk") + return 'qt' + Gtk = sys.modules.get('gi.repository.Gtk') if Gtk: if Gtk.MAJOR_VERSION == 4: from gi.repository import GLib if GLib.main_depth(): - return "gtk4" + return 'gtk4' if Gtk.MAJOR_VERSION == 3 and Gtk.main_level(): - return "gtk3" - wx = sys.modules.get("wx") + return 'gtk3' + wx = sys.modules.get('wx') if wx and wx.GetApp(): - return "wx" - tkinter = sys.modules.get("tkinter") + return 'wx' + tkinter = sys.modules.get('tkinter') if tkinter: codes = {tkinter.mainloop.__code__, tkinter.Misc.mainloop.__code__} for frame in sys._current_frames().values(): while frame: if frame.f_code in codes: - return "tk" + return 'tk' frame = frame.f_back - # premetively break reference cycle between locals and the frame del frame - macosx = sys.modules.get("matplotlib.backends._macosx") + macosx = sys.modules.get('matplotlib.backends._macosx') if macosx and macosx.event_loop_is_running(): - return "macosx" + return 'macosx' if not _c_internal_utils.display_is_valid(): - return "headless" + return 'headless' return None - def _exception_printer(exc): - if _get_running_interactive_framework() in ["headless", None]: + if _get_running_interactive_framework() in ['headless', None]: raise exc else: traceback.print_exc() - class _StrongRef: """ Wrapper similar to a weakref, but keeping a strong reference to the object. @@ -106,7 +91,6 @@ def __eq__(self, other): def __hash__(self): return hash(self._obj) - def _weak_or_strong_ref(func, callback): """ Return a `WeakMethod` wrapping *func* if possible, else a `_StrongRef`. @@ -116,7 +100,6 @@ def _weak_or_strong_ref(func, callback): except TypeError: return _StrongRef(func) - class CallbackRegistry: """ Handle registering, processing, blocking, and disconnecting @@ -174,55 +157,39 @@ class CallbackRegistry: handled signals. """ - # We maintain two mappings: - # callbacks: signal -> {cid -> weakref-to-callback} - # _func_cid_map: signal -> {weakref-to-callback -> cid} - + @codeflash_capture(function_name='CallbackRegistry.__init__', tmp_dir_path='/tmp/codeflash_unmo7ca9/test_return_values', tests_root='/home/ubuntu/work/repo/lib/matplotlib/tests', is_fto=False) def __init__(self, exception_handler=_exception_printer, *, signals=None): - self._signals = None if signals is None else list(signals) # Copy it. + self._signals = None if signals is None else list(signals) self.exception_handler = exception_handler self.callbacks = {} self._cid_gen = itertools.count() self._func_cid_map = {} - # A hidden variable that marks cids that need to be pickled. self._pickled_cids = set() def __getstate__(self): - return { - **vars(self), - # In general, callbacks may not be pickled, so we just drop them, - # unless directed otherwise by self._pickled_cids. - "callbacks": {s: {cid: proxy() for cid, proxy in d.items() - if cid in self._pickled_cids} - for s, d in self.callbacks.items()}, - # It is simpler to reconstruct this from callbacks in __setstate__. - "_func_cid_map": None, - "_cid_gen": next(self._cid_gen) - } + return {**vars(self), 'callbacks': {s: {cid: proxy() for (cid, proxy) in d.items() if cid in self._pickled_cids} for (s, d) in self.callbacks.items()}, '_func_cid_map': None, '_cid_gen': next(self._cid_gen)} def __setstate__(self, state): cid_count = state.pop('_cid_gen') vars(self).update(state) - self.callbacks = { - s: {cid: _weak_or_strong_ref(func, self._remove_proxy) - for cid, func in d.items()} - for s, d in self.callbacks.items()} - self._func_cid_map = { - s: {proxy: cid for cid, proxy in d.items()} - for s, d in self.callbacks.items()} + self.callbacks = {s: {cid: _weak_or_strong_ref(func, self._remove_proxy) for (cid, func) in d.items()} for (s, d) in self.callbacks.items()} + self._func_cid_map = {s: {proxy: cid for (cid, proxy) in d.items()} for (s, d) in self.callbacks.items()} self._cid_gen = itertools.count(cid_count) def connect(self, signal, func): """Register *func* to be called when signal *signal* is generated.""" if self._signals is not None: _api.check_in_list(self._signals, signal=signal) - self._func_cid_map.setdefault(signal, {}) + if signal not in self._func_cid_map: + self._func_cid_map[signal] = {} proxy = _weak_or_strong_ref(func, self._remove_proxy) - if proxy in self._func_cid_map[signal]: - return self._func_cid_map[signal][proxy] + proxy_cid_map = self._func_cid_map[signal] + if proxy in proxy_cid_map: + return proxy_cid_map[proxy] cid = next(self._cid_gen) - self._func_cid_map[signal][proxy] = cid - self.callbacks.setdefault(signal, {}) + proxy_cid_map[proxy] = cid + if signal not in self.callbacks: + self.callbacks[signal] = {} self.callbacks[signal][cid] = proxy return cid @@ -236,22 +203,17 @@ def _connect_picklable(self, signal, func): self._pickled_cids.add(cid) return cid - # Keep a reference to sys.is_finalizing, as sys may have been cleared out - # at that point. def _remove_proxy(self, proxy, *, _is_finalizing=sys.is_finalizing): if _is_finalizing(): - # Weakrefs can't be properly torn down at that point anymore. return - for signal, proxy_to_cid in list(self._func_cid_map.items()): + for (signal, proxy_to_cid) in list(self._func_cid_map.items()): cid = proxy_to_cid.pop(proxy, None) if cid is not None: del self.callbacks[signal][cid] self._pickled_cids.discard(cid) break else: - # Not found return - # Clean up empty dicts if len(self.callbacks[signal]) == 0: del self.callbacks[signal] del self._func_cid_map[signal] @@ -263,21 +225,17 @@ def disconnect(self, cid): No error is raised if such a callback does not exist. """ self._pickled_cids.discard(cid) - # Clean up callbacks - for signal, cid_to_proxy in list(self.callbacks.items()): + for (signal, cid_to_proxy) in self.callbacks.items(): proxy = cid_to_proxy.pop(cid, None) if proxy is not None: break else: - # Not found return - proxy_to_cid = self._func_cid_map[signal] - for current_proxy, current_cid in list(proxy_to_cid.items()): - if current_cid == cid: - assert proxy is current_proxy - del proxy_to_cid[current_proxy] - # Clean up empty dicts + current_cid = proxy_to_cid.get(proxy) + if current_cid == cid: + assert proxy is proxy + del proxy_to_cid[proxy] if len(self.callbacks[signal]) == 0: del self.callbacks[signal] del self._func_cid_map[signal] @@ -296,8 +254,6 @@ def process(self, s, *args, **kwargs): if func is not None: try: func(*args, **kwargs) - # this does not capture KeyboardInterrupt, SystemExit, - # and GeneratorExit except Exception as exc: if self.exception_handler is not None: self.exception_handler(exc) @@ -320,16 +276,13 @@ def blocked(self, *, signal=None): orig = self.callbacks try: if signal is None: - # Empty out the callbacks self.callbacks = {} else: - # Only remove the specific signal self.callbacks = {k: orig[k] for k in orig if k != signal} yield finally: self.callbacks = orig - class silent_list(list): """ A list with a short ``repr()``. @@ -359,14 +312,11 @@ def __init__(self, type, seq=None): def __repr__(self): if self.type is not None or len(self) != 0: tp = self.type if self.type is not None else type(self[0]).__name__ - return f"" + return f'' else: - return "" - + return '' -def _local_over_kwdict( - local_var, kwargs, *keys, - warning_cls=_api.MatplotlibDeprecationWarning): +def _local_over_kwdict(local_var, kwargs, *keys, warning_cls=_api.MatplotlibDeprecationWarning): out = local_var for key in keys: kwarg_val = kwargs.pop(key, None) @@ -374,34 +324,21 @@ def _local_over_kwdict( if out is None: out = kwarg_val else: - _api.warn_external(f'"{key}" keyword argument will be ignored', - warning_cls) + _api.warn_external(f'"{key}" keyword argument will be ignored', warning_cls) return out - def strip_math(s): """ Remove latex formatting from mathtext. Only handles fully math and fully non-math strings. """ - if len(s) >= 2 and s[0] == s[-1] == "$": + if len(s) >= 2 and s[0] == s[-1] == '$': s = s[1:-1] - for tex, plain in [ - (r"\times", "x"), # Specifically for Formatter support. - (r"\mathdefault", ""), - (r"\rm", ""), - (r"\cal", ""), - (r"\tt", ""), - (r"\it", ""), - ("\\", ""), - ("{", ""), - ("}", ""), - ]: + for (tex, plain) in [('\\times', 'x'), ('\\mathdefault', ''), ('\\rm', ''), ('\\cal', ''), ('\\tt', ''), ('\\it', ''), ('\\', ''), ('{', ''), ('}', '')]: s = s.replace(tex, plain) return s - def _strip_comment(s): """Strip everything from the first unquoted #.""" pos = 0 @@ -416,17 +353,13 @@ def _strip_comment(s): else: closing_quote_pos = s.find('"', quote_pos + 1) if closing_quote_pos < 0: - raise ValueError( - f"Missing closing quote in: {s!r}. If you need a double-" - 'quote inside a string, use escaping: e.g. "the \" char"') - pos = closing_quote_pos + 1 # behind closing quote - + raise ValueError(f'Missing closing quote in: {s!r}. If you need a double-quote inside a string, use escaping: e.g. "the " char"') + pos = closing_quote_pos + 1 def is_writable_file_like(obj): """Return whether *obj* looks like a file object with a *write* method.""" return callable(getattr(obj, 'write', None)) - def file_requires_unicode(x): """ Return whether the given writable file-like object requires Unicode to be @@ -439,7 +372,6 @@ def file_requires_unicode(x): else: return False - def to_filehandle(fname, flag='r', return_opened=False, encoding=None): """ Convert a path to an open file handle or pass-through a file-like object. @@ -475,8 +407,6 @@ def to_filehandle(fname, flag='r', return_opened=False, encoding=None): if fname.endswith('.gz'): fh = gzip.open(fname, flag) elif fname.endswith('.bz2'): - # python may not be compiled with bz2 support, - # bury import until we need it import bz2 fh = bz2.BZ2File(fname, flag) else: @@ -488,23 +418,19 @@ def to_filehandle(fname, flag='r', return_opened=False, encoding=None): else: raise ValueError('fname must be a PathLike or file handle') if return_opened: - return fh, opened + return (fh, opened) return fh - -def open_file_cm(path_or_file, mode="r", encoding=None): - r"""Pass through file objects and context-manage path-likes.""" - fh, opened = to_filehandle(path_or_file, mode, True, encoding) +def open_file_cm(path_or_file, mode='r', encoding=None): + """Pass through file objects and context-manage path-likes.""" + (fh, opened) = to_filehandle(path_or_file, mode, True, encoding) return fh if opened else contextlib.nullcontext(fh) - def is_scalar_or_string(val): """Return whether the given object is a scalar or string like.""" return isinstance(val, str) or not np.iterable(val) - -@_api.delete_parameter( - "3.8", "np_load", alternative="open(get_sample_data(..., asfileobj=False))") +@_api.delete_parameter('3.8', 'np_load', alternative='open(get_sample_data(..., asfileobj=False))') def get_sample_data(fname, asfileobj=True, *, np_load=True): """ Return a sample data file. *fname* is a path relative to the @@ -535,7 +461,6 @@ def get_sample_data(fname, asfileobj=True, *, np_load=True): else: return str(path) - def _get_data_path(*args): """ Return the `pathlib.Path` to a resource file provided by Matplotlib. @@ -544,7 +469,6 @@ def _get_data_path(*args): """ return Path(matplotlib.get_data_path(), *args) - def flatten(seq, scalarp=is_scalar_or_string): """ Return a generator of flattened nested containers. @@ -566,8 +490,7 @@ def flatten(seq, scalarp=is_scalar_or_string): else: yield from flatten(item, scalarp) - -@_api.deprecated("3.8") +@_api.deprecated('3.8') class Stack: """ Stack of elements with a movable cursor. @@ -673,7 +596,6 @@ def remove(self, o): if elem != o: self.push(elem) - class _Stack: """ Stack of elements with a movable cursor. @@ -728,21 +650,16 @@ def home(self): """ return self.push(self._elements[0]) if self._elements else None - def safe_masked_invalid(x, copy=False): x = np.array(x, subok=True, copy=copy) if not x.dtype.isnative: - # If we have already made a copy, do the byteswap in place, else make a - # copy with the byte order swapped. - # Swap to native order. x = x.byteswap(inplace=copy).view(x.dtype.newbyteorder('N')) try: - xm = np.ma.masked_where(~(np.isfinite(x)), x, copy=False) + xm = np.ma.masked_where(~np.isfinite(x), x, copy=False) except TypeError: return x return xm - def print_cycles(objects, outstream=sys.stdout, show_progress=False): """ Print loops of cyclic references in the given *objects*. @@ -762,56 +679,42 @@ def print_cycles(objects, outstream=sys.stdout, show_progress=False): import gc def print_path(path): - for i, step in enumerate(path): - # next "wraps around" + for (i, step) in enumerate(path): next = path[(i + 1) % len(path)] - - outstream.write(" %s -- " % type(step)) + outstream.write(' %s -- ' % type(step)) if isinstance(step, dict): - for key, val in step.items(): + for (key, val) in step.items(): if val is next: - outstream.write(f"[{key!r}]") + outstream.write(f'[{key!r}]') break if key is next: - outstream.write(f"[key] = {val!r}") + outstream.write(f'[key] = {val!r}') break elif isinstance(step, list): - outstream.write("[%d]" % step.index(next)) + outstream.write('[%d]' % step.index(next)) elif isinstance(step, tuple): - outstream.write("( tuple )") + outstream.write('( tuple )') else: outstream.write(repr(step)) - outstream.write(" ->\n") - outstream.write("\n") + outstream.write(' ->\n') + outstream.write('\n') def recurse(obj, start, all, current_path): if show_progress: - outstream.write("%d\r" % len(all)) - + outstream.write('%d\r' % len(all)) all[id(obj)] = None - referents = gc.get_referents(obj) for referent in referents: - # If we've found our way back to the start, this is - # a cycle, so print it out if referent is start: print_path(current_path) - - # Don't go back through the original list of objects, or - # through temporary references to the object, since those - # are just an artifact of the cycle detector itself. elif referent is objects or isinstance(referent, types.FrameType): continue - - # We haven't seen this object before, so recurse elif id(referent) not in all: recurse(referent, start, all, current_path + [obj]) - for obj in objects: - outstream.write(f"Examining: {obj!r}\n") + outstream.write(f'Examining: {obj!r}\n') recurse(obj, obj, {}, []) - class Grouper: """ A disjoint-set data structure. @@ -847,26 +750,19 @@ class Grouper: """ def __init__(self, init=()): - self._mapping = weakref.WeakKeyDictionary( - {x: weakref.WeakSet([x]) for x in init}) + self._mapping = weakref.WeakKeyDictionary({x: weakref.WeakSet([x]) for x in init}) def __getstate__(self): - return { - **vars(self), - # Convert weak refs to strong ones. - "_mapping": {k: set(v) for k, v in self._mapping.items()}, - } + return {**vars(self), '_mapping': {k: set(v) for (k, v) in self._mapping.items()}} def __setstate__(self, state): vars(self).update(state) - # Convert strong refs to weak ones. - self._mapping = weakref.WeakKeyDictionary( - {k: weakref.WeakSet(v) for k, v in self._mapping.items()}) + self._mapping = weakref.WeakKeyDictionary({k: weakref.WeakSet(v) for (k, v) in self._mapping.items()}) def __contains__(self, item): return item in self._mapping - @_api.deprecated("3.8", alternative="none, you no longer need to clean a Grouper") + @_api.deprecated('3.8', alternative='none, you no longer need to clean a Grouper') def clean(self): """Clean dead weak references from the dictionary.""" @@ -876,19 +772,18 @@ def join(self, a, *args): """ mapping = self._mapping set_a = mapping.setdefault(a, weakref.WeakSet([a])) - for arg in args: set_b = mapping.get(arg, weakref.WeakSet([arg])) if set_b is not set_a: if len(set_b) > len(set_a): - set_a, set_b = set_b, set_a + (set_a, set_b) = (set_b, set_a) set_a.update(set_b) for elem in set_b: mapping[elem] = set_a def joined(self, a, b): """Return whether *a* and *b* are members of the same set.""" - return (self._mapping.get(a, object()) is self._mapping.get(b)) + return self._mapping.get(a, object()) is self._mapping.get(b) def remove(self, a): """Remove *a* from the grouper, doing nothing if it is not there.""" @@ -911,16 +806,23 @@ def get_siblings(self, a): siblings = self._mapping.get(a, [a]) return [x for x in siblings] - class GrouperView: """Immutable view over a `.Grouper`.""" - def __init__(self, grouper): self._grouper = grouper - def __contains__(self, item): return item in self._grouper - def __iter__(self): return iter(self._grouper) - def joined(self, a, b): return self._grouper.joined(a, b) - def get_siblings(self, a): return self._grouper.get_siblings(a) + def __init__(self, grouper): + self._grouper = grouper + def __contains__(self, item): + return item in self._grouper + + def __iter__(self): + return iter(self._grouper) + + def joined(self, a, b): + return self._grouper.joined(a, b) + + def get_siblings(self, a): + return self._grouper.get_siblings(a) def simple_linear_interpolation(a, steps): """ @@ -942,9 +844,7 @@ def simple_linear_interpolation(a, steps): fps = a.reshape((len(a), -1)) xp = np.arange(len(a)) * steps x = np.arange((len(a) - 1) * steps + 1) - return (np.column_stack([np.interp(x, xp, fp) for fp in fps.T]) - .reshape((len(x),) + a.shape[1:])) - + return np.column_stack([np.interp(x, xp, fp) for fp in fps.T]).reshape((len(x),) + a.shape[1:]) def delete_masked_points(*args): """ @@ -981,26 +881,26 @@ def delete_masked_points(*args): if not len(args): return () if is_scalar_or_string(args[0]): - raise ValueError("First argument must be a sequence") + raise ValueError('First argument must be a sequence') nrecs = len(args[0]) margs = [] seqlist = [False] * len(args) - for i, x in enumerate(args): - if not isinstance(x, str) and np.iterable(x) and len(x) == nrecs: + for (i, x) in enumerate(args): + if not isinstance(x, str) and np.iterable(x) and (len(x) == nrecs): seqlist[i] = True if isinstance(x, np.ma.MaskedArray): if x.ndim > 1: - raise ValueError("Masked arrays must be 1-D") + raise ValueError('Masked arrays must be 1-D') else: x = np.asarray(x) margs.append(x) - masks = [] # List of masks that are True where good. - for i, x in enumerate(margs): + masks = [] + for (i, x) in enumerate(margs): if seqlist[i]: if x.ndim > 1: - continue # Don't try to get nan locations unless 1-D. + continue if isinstance(x, np.ma.MaskedArray): - masks.append(~np.ma.getmaskarray(x)) # invert the mask + masks.append(~np.ma.getmaskarray(x)) xd = x.data else: xd = x @@ -1008,21 +908,20 @@ def delete_masked_points(*args): mask = np.isfinite(xd) if isinstance(mask, np.ndarray): masks.append(mask) - except Exception: # Fixme: put in tuple of possible exceptions? + except Exception: pass if len(masks): mask = np.logical_and.reduce(masks) igood = mask.nonzero()[0] if len(igood) < nrecs: - for i, x in enumerate(margs): + for (i, x) in enumerate(margs): if seqlist[i]: margs[i] = x[igood] - for i, x in enumerate(margs): + for (i, x) in enumerate(margs): if seqlist[i] and isinstance(x, np.ma.MaskedArray): margs[i] = x.filled() return margs - def _combine_masks(*args): """ Find all masked and/or non-finite points in a set of arguments, @@ -1057,37 +956,34 @@ def _combine_masks(*args): if not len(args): return () if is_scalar_or_string(args[0]): - raise ValueError("First argument must be a sequence") + raise ValueError('First argument must be a sequence') nrecs = len(args[0]) - margs = [] # Output args; some may be modified. - seqlist = [False] * len(args) # Flags: True if output will be masked. - masks = [] # List of masks. - for i, x in enumerate(args): + margs = [] + seqlist = [False] * len(args) + masks = [] + for (i, x) in enumerate(args): if is_scalar_or_string(x) or len(x) != nrecs: - margs.append(x) # Leave it unmodified. + margs.append(x) else: if isinstance(x, np.ma.MaskedArray) and x.ndim > 1: - raise ValueError("Masked arrays must be 1-D") + raise ValueError('Masked arrays must be 1-D') try: x = np.asanyarray(x) except (VisibleDeprecationWarning, ValueError): - # NumPy 1.19 raises a warning about ragged arrays, but we want - # to accept basically anything here. x = np.asanyarray(x, dtype=object) if x.ndim == 1: x = safe_masked_invalid(x) seqlist[i] = True if np.ma.is_masked(x): masks.append(np.ma.getmaskarray(x)) - margs.append(x) # Possibly modified. + margs.append(x) if len(masks): mask = np.logical_or.reduce(masks) - for i, x in enumerate(margs): + for (i, x) in enumerate(margs): if seqlist[i]: margs[i] = np.ma.array(x, mask=mask) return margs - def _broadcast_with_masks(*args, compress=False): """ Broadcast inputs, combining all masked arrays. @@ -1105,29 +1001,22 @@ def _broadcast_with_masks(*args, compress=False): list of array-like The broadcasted and masked inputs. """ - # extract the masks, if any masks = [k.mask for k in args if isinstance(k, np.ma.MaskedArray)] - # broadcast to match the shape bcast = np.broadcast_arrays(*args, *masks) inputs = bcast[:len(args)] masks = bcast[len(args):] if masks: - # combine the masks into one mask = np.logical_or.reduce(masks) - # put mask on and compress if compress: - inputs = [np.ma.array(k, mask=mask).compressed() - for k in inputs] + inputs = [np.ma.array(k, mask=mask).compressed() for k in inputs] else: - inputs = [np.ma.array(k, mask=mask, dtype=float).filled(np.nan).ravel() - for k in inputs] + inputs = [np.ma.array(k, mask=mask, dtype=float).filled(np.nan).ravel() for k in inputs] else: inputs = [np.ravel(k) for k in inputs] return inputs - def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None, autorange=False): - r""" + """ Return a list of dictionaries of statistics used to draw a series of box and whisker plots using `~.Axes.bxp`. @@ -1198,7 +1087,7 @@ def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None, autorange=False): .. math:: - \mathrm{med} \pm 1.57 \times \frac{\mathrm{iqr}}{\sqrt{N}} + \\mathrm{med} \\pm 1.57 \\times \\frac{\\mathrm{iqr}}{\\sqrt{N}} General approach from: McGill, R., Tukey, J.W., and Larsen, W.A. (1978) "Variations of @@ -1206,59 +1095,38 @@ def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None, autorange=False): """ def _bootstrap_median(data, N=5000): - # determine 95% confidence intervals of the median M = len(data) percentiles = [2.5, 97.5] - bs_index = np.random.randint(M, size=(N, M)) bsData = data[bs_index] estimate = np.median(bsData, axis=1, overwrite_input=True) - CI = np.percentile(estimate, percentiles) return CI def _compute_conf_interval(data, med, iqr, bootstrap): if bootstrap is not None: - # Do a bootstrap estimate of notch locations. - # get conf. intervals around median CI = _bootstrap_median(data, N=bootstrap) notch_min = CI[0] notch_max = CI[1] else: - N = len(data) notch_min = med - 1.57 * iqr / np.sqrt(N) notch_max = med + 1.57 * iqr / np.sqrt(N) - - return notch_min, notch_max - - # output is a list of dicts + return (notch_min, notch_max) bxpstats = [] - - # convert X to a list of lists - X = _reshape_2D(X, "X") - + X = _reshape_2D(X, 'X') ncols = len(X) if labels is None: labels = itertools.repeat(None) elif len(labels) != ncols: - raise ValueError("Dimensions of labels and X must be compatible") - + raise ValueError('Dimensions of labels and X must be compatible') input_whis = whis - for ii, (x, label) in enumerate(zip(X, labels)): - - # empty dict + for (ii, (x, label)) in enumerate(zip(X, labels)): stats = {} if label is not None: stats['label'] = label - - # restore whis to the input values in case it got changed in the loop whis = input_whis - - # note tricksiness, append up here and then mutate below bxpstats.append(stats) - - # if empty, bail if len(x) == 0: stats['fliers'] = np.array([]) stats['mean'] = np.nan @@ -1271,67 +1139,36 @@ def _compute_conf_interval(data, med, iqr, bootstrap): stats['whislo'] = np.nan stats['whishi'] = np.nan continue - - # up-convert to an array, just to be safe x = np.ma.asarray(x) x = x.data[~x.mask].ravel() - - # arithmetic mean stats['mean'] = np.mean(x) - - # medians and quartiles - q1, med, q3 = np.percentile(x, [25, 50, 75]) - - # interquartile range + (q1, med, q3) = np.percentile(x, [25, 50, 75]) stats['iqr'] = q3 - q1 if stats['iqr'] == 0 and autorange: whis = (0, 100) - - # conf. interval around median - stats['cilo'], stats['cihi'] = _compute_conf_interval( - x, med, stats['iqr'], bootstrap - ) - - # lowest/highest non-outliers - if np.iterable(whis) and not isinstance(whis, str): - loval, hival = np.percentile(x, whis) + (stats['cilo'], stats['cihi']) = _compute_conf_interval(x, med, stats['iqr'], bootstrap) + if np.iterable(whis) and (not isinstance(whis, str)): + (loval, hival) = np.percentile(x, whis) elif np.isreal(whis): loval = q1 - whis * stats['iqr'] hival = q3 + whis * stats['iqr'] else: raise ValueError('whis must be a float or list of percentiles') - - # get high extreme wiskhi = x[x <= hival] if len(wiskhi) == 0 or np.max(wiskhi) < q3: stats['whishi'] = q3 else: stats['whishi'] = np.max(wiskhi) - - # get low extreme wisklo = x[x >= loval] if len(wisklo) == 0 or np.min(wisklo) > q1: stats['whislo'] = q1 else: stats['whislo'] = np.min(wisklo) - - # compute a single array of outliers - stats['fliers'] = np.concatenate([ - x[x < stats['whislo']], - x[x > stats['whishi']], - ]) - - # add in the remaining stats - stats['q1'], stats['med'], stats['q3'] = q1, med, q3 - + stats['fliers'] = np.concatenate([x[x < stats['whislo']], x[x > stats['whishi']]]) + (stats['q1'], stats['med'], stats['q3']) = (q1, med, q3) return bxpstats - - -#: Maps short codes for line style to their full name used by backends. ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'} -#: Maps full names for line styles used by backends to their short codes. -ls_mapper_r = {v: k for k, v in ls_mapper.items()} - +ls_mapper_r = {v: k for (k, v) in ls_mapper.items()} def contiguous_regions(mask): """ @@ -1339,26 +1176,17 @@ def contiguous_regions(mask): True and we cover all such regions. """ mask = np.asarray(mask, dtype=bool) - if not mask.size: return [] - - # Find the indices of region changes, and correct offset - idx, = np.nonzero(mask[:-1] != mask[1:]) + (idx,) = np.nonzero(mask[:-1] != mask[1:]) idx += 1 - - # List operations are faster for moderately sized arrays idx = idx.tolist() - - # Add first and/or last index if needed if mask[0]: idx = [0] + idx if mask[-1]: idx.append(len(mask)) - return list(zip(idx[::2], idx[1::2])) - def is_math_text(s): """ Return whether the string *s* contains math expressions. @@ -1367,11 +1195,10 @@ def is_math_text(s): non-escaped dollar signs. """ s = str(s) - dollar_count = s.count(r'$') - s.count(r'\$') - even_dollars = (dollar_count > 0 and dollar_count % 2 == 0) + dollar_count = s.count('$') - s.count('\\$') + even_dollars = dollar_count > 0 and dollar_count % 2 == 0 return even_dollars - def _to_unmasked_float_array(x): """ Convert a sequence to a float array; if input was a masked array, masked @@ -1382,22 +1209,14 @@ def _to_unmasked_float_array(x): else: return np.asarray(x, float) - def _check_1d(x): """Convert scalars to 1D arrays; pass-through arrays as is.""" - # Unpack in case of e.g. Pandas or xarray object x = _unpack_to_numpy(x) - # plot requires `shape` and `ndim`. If passed an - # object that doesn't provide them, then force to numpy array. - # Note this will strip unit information. - if (not hasattr(x, 'shape') or - not hasattr(x, 'ndim') or - len(x.shape) < 1): + if not hasattr(x, 'shape') or not hasattr(x, 'ndim') or len(x.shape) < 1: return np.atleast_1d(x) else: return x - def _reshape_2D(X, name): """ Use Fortran ordering to convert ndarrays and lists of iterables to lists of @@ -1409,34 +1228,22 @@ def _reshape_2D(X, name): *name* is used to generate the error message for invalid inputs. """ - - # Unpack in case of e.g. Pandas or xarray object X = _unpack_to_numpy(X) - - # Iterate over columns for ndarrays. if isinstance(X, np.ndarray): X = X.T - if len(X) == 0: return [[]] elif X.ndim == 1 and np.ndim(X[0]) == 0: - # 1D array of scalars: directly return it. return [X] elif X.ndim in [1, 2]: - # 2D array, or 1D array of iterables: flatten them first. return [np.reshape(x, -1) for x in X] else: raise ValueError(f'{name} must have 2 or fewer dimensions') - - # Iterate over list of iterables. if len(X) == 0: return [[]] - result = [] is_1d = True for xi in X: - # check if this is iterable, except for strings which we - # treat as singletons. if not isinstance(xi, str): try: iter(xi) @@ -1449,15 +1256,11 @@ def _reshape_2D(X, name): if nd > 1: raise ValueError(f'{name} must have 2 or fewer dimensions') result.append(xi.reshape(-1)) - if is_1d: - # 1D array of scalars: directly return it. return [np.reshape(result, -1)] else: - # 2D array, or 1D array of iterables: use flattened version. return result - def violin_stats(X, method, points=100, quantiles=None): """ Return a list of dictionaries of data which can be used to draw a series @@ -1509,53 +1312,30 @@ def violin_stats(X, method, points=100, quantiles=None): - max: The maximum value for this column of data. - quantiles: The quantile values for this column of data. """ - - # List of dictionaries describing each of the violins. vpstats = [] - - # Want X to be a list of data sequences - X = _reshape_2D(X, "X") - - # Want quantiles to be as the same shape as data sequences + X = _reshape_2D(X, 'X') if quantiles is not None and len(quantiles) != 0: - quantiles = _reshape_2D(quantiles, "quantiles") - # Else, mock quantiles if it's none or empty + quantiles = _reshape_2D(quantiles, 'quantiles') else: quantiles = [[]] * len(X) - - # quantiles should have the same size as dataset if len(X) != len(quantiles): - raise ValueError("List of violinplot statistics and quantiles values" - " must have the same length") - - # Zip x and quantiles + raise ValueError('List of violinplot statistics and quantiles values must have the same length') for (x, q) in zip(X, quantiles): - # Dictionary of results for this distribution stats = {} - - # Calculate basic stats for the distribution min_val = np.min(x) max_val = np.max(x) quantile_val = np.percentile(x, 100 * q) - - # Evaluate the kernel density estimate coords = np.linspace(min_val, max_val, points) stats['vals'] = method(x, coords) stats['coords'] = coords - - # Store additional statistics for this distribution stats['mean'] = np.mean(x) stats['median'] = np.median(x) stats['min'] = min_val stats['max'] = max_val stats['quantiles'] = np.atleast_1d(quantile_val) - - # Append to output vpstats.append(stats) - return vpstats - def pts_to_prestep(x, *args): """ Convert continuous line to pre-steps. @@ -1585,15 +1365,12 @@ def pts_to_prestep(x, *args): >>> x_s, y1_s, y2_s = pts_to_prestep(x, y1, y2) """ steps = np.zeros((1 + len(args), max(2 * len(x) - 1, 0))) - # In all `pts_to_*step` functions, only assign once using *x* and *args*, - # as converting to an array may be expensive. steps[0, 0::2] = x steps[0, 1::2] = steps[0, 0:-2:2] steps[1:, 0::2] = args steps[1:, 1::2] = steps[1:, 2::2] return steps - def pts_to_poststep(x, *args): """ Convert continuous line to post-steps. @@ -1629,7 +1406,6 @@ def pts_to_poststep(x, *args): steps[1:, 1::2] = steps[1:, 0:-2:2] return steps - def pts_to_midstep(x, *args): """ Convert continuous line to mid-steps. @@ -1661,19 +1437,12 @@ def pts_to_midstep(x, *args): steps = np.zeros((1 + len(args), 2 * len(x))) x = np.asanyarray(x) steps[0, 1:-1:2] = steps[0, 2::2] = (x[:-1] + x[1:]) / 2 - steps[0, :1] = x[:1] # Also works for zero-sized input. + steps[0, :1] = x[:1] steps[0, -1:] = x[-1:] steps[1:, 0::2] = args steps[1:, 1::2] = steps[1:, 0::2] return steps - - -STEP_LOOKUP_MAP = {'default': lambda x, y: (x, y), - 'steps': pts_to_prestep, - 'steps-pre': pts_to_prestep, - 'steps-post': pts_to_poststep, - 'steps-mid': pts_to_midstep} - +STEP_LOOKUP_MAP = {'default': lambda x, y: (x, y), 'steps': pts_to_prestep, 'steps-pre': pts_to_prestep, 'steps-post': pts_to_poststep, 'steps-mid': pts_to_midstep} def index_of(y): """ @@ -1697,19 +1466,17 @@ def index_of(y): The x and y values to plot. """ try: - return y.index.to_numpy(), y.to_numpy() + return (y.index.to_numpy(), y.to_numpy()) except AttributeError: pass try: y = _check_1d(y) except (VisibleDeprecationWarning, ValueError): - # NumPy 1.19 will warn on ragged input, and we can't actually use it. pass else: - return np.arange(y.shape[0], dtype=float), y + return (np.arange(y.shape[0], dtype=float), y) raise ValueError('Input could not be cast to an at-least-1D NumPy array') - def safe_first_element(obj): """ Return the first element in *obj*. @@ -1718,18 +1485,13 @@ def safe_first_element(obj): supporting both index access and the iterator protocol. """ if isinstance(obj, collections.abc.Iterator): - # needed to accept `array.flat` as input. - # np.flatiter reports as an instance of collections.Iterator but can still be - # indexed via []. This has the side effect of re-setting the iterator, but - # that is acceptable. try: return obj[0] except TypeError: pass - raise RuntimeError("matplotlib does not support generators as input") + raise RuntimeError('matplotlib does not support generators as input') return next(iter(obj)) - def _safe_first_finite(obj): """ Return the first finite element in *obj* if one is available and skip_nonfinite is @@ -1740,42 +1502,33 @@ def _safe_first_finite(obj): This is a type-independent way of obtaining the first finite element, supporting both index access and the iterator protocol. """ + def safe_isfinite(val): if val is None: return False try: return math.isfinite(val) except (TypeError, ValueError): - # if the outer object is 2d, then val is a 1d array, and - # - math.isfinite(numpy.zeros(3)) raises TypeError - # - math.isfinite(torch.zeros(3)) raises ValueError pass try: return np.isfinite(val) if np.isscalar(val) else True except TypeError: - # This is something that NumPy cannot make heads or tails of, - # assume "finite" return True - if isinstance(obj, np.flatiter): - # TODO do the finite filtering on this return obj[0] elif isinstance(obj, collections.abc.Iterator): - raise RuntimeError("matplotlib does not support generators as input") + raise RuntimeError('matplotlib does not support generators as input') else: for val in obj: if safe_isfinite(val): return val return safe_first_element(obj) - def sanitize_sequence(data): """ Convert dictview objects to list. Other inputs are returned unchanged. """ - return (list(data) if isinstance(data, collections.abc.MappingView) - else data) - + return list(data) if isinstance(data, collections.abc.MappingView) else data def normalize_kwargs(kw, alias_mapping=None): """ @@ -1805,34 +1558,23 @@ def normalize_kwargs(kw, alias_mapping=None): passed to a callable. """ from matplotlib.artist import Artist - if kw is None: return {} - - # deal with default value of alias_mapping if alias_mapping is None: alias_mapping = {} - elif (isinstance(alias_mapping, type) and issubclass(alias_mapping, Artist) - or isinstance(alias_mapping, Artist)): - alias_mapping = getattr(alias_mapping, "_alias_map", {}) - - to_canonical = {alias: canonical - for canonical, alias_list in alias_mapping.items() - for alias in alias_list} + elif isinstance(alias_mapping, type) and issubclass(alias_mapping, Artist) or isinstance(alias_mapping, Artist): + alias_mapping = getattr(alias_mapping, '_alias_map', {}) + to_canonical = {alias: canonical for (canonical, alias_list) in alias_mapping.items() for alias in alias_list} canonical_to_seen = {} - ret = {} # output dictionary - - for k, v in kw.items(): + ret = {} + for (k, v) in kw.items(): canonical = to_canonical.get(k, k) if canonical in canonical_to_seen: - raise TypeError(f"Got both {canonical_to_seen[canonical]!r} and " - f"{k!r}, which are aliases of one another") + raise TypeError(f'Got both {canonical_to_seen[canonical]!r} and {k!r}, which are aliases of one another') canonical_to_seen[canonical] = k ret[canonical] = v - return ret - @contextlib.contextmanager def _lock_path(path): """ @@ -1850,31 +1592,23 @@ def _lock_path(path): directory, so that directory must exist and be writable. """ path = Path(path) - lock_path = path.with_name(path.name + ".matplotlib-lock") + lock_path = path.with_name(path.name + '.matplotlib-lock') retries = 50 sleeptime = 0.1 for _ in range(retries): try: - with lock_path.open("xb"): + with lock_path.open('xb'): break except FileExistsError: time.sleep(sleeptime) else: - raise TimeoutError("""\ -Lock error: Matplotlib failed to acquire the following lock file: - {} -This maybe due to another process holding this lock file. If you are sure no -other Matplotlib process is running, remove this file and try again.""".format( - lock_path)) + raise TimeoutError('Lock error: Matplotlib failed to acquire the following lock file:\n {}\nThis maybe due to another process holding this lock file. If you are sure no\nother Matplotlib process is running, remove this file and try again.'.format(lock_path)) try: yield finally: lock_path.unlink() - -def _topmost_artist( - artists, - _cached_max=functools.partial(max, key=operator.attrgetter("zorder"))): +def _topmost_artist(artists, _cached_max=functools.partial(max, key=operator.attrgetter('zorder'))): """ Get the topmost artist of a list. @@ -1884,7 +1618,6 @@ def _topmost_artist( """ return _cached_max(reversed(artists)) - def _str_equal(obj, s): """ Return whether *obj* is a string equal to string *s*. @@ -1895,7 +1628,6 @@ def _str_equal(obj, s): """ return isinstance(obj, str) and obj == s - def _str_lower_equal(obj, s): """ Return whether *obj* is a string equal, when lowercased, to string *s*. @@ -1906,7 +1638,6 @@ def _str_lower_equal(obj, s): """ return isinstance(obj, str) and obj.lower() == s - def _array_perimeter(arr): """ Get the elements on the perimeter of *arr*. @@ -1934,17 +1665,9 @@ def _array_perimeter(arr): >>> _array_perimeter(a) array([ 0, 1, 2, 3, 13, 23, 22, 21, 20, 10]) """ - # note we use Python's half-open ranges to avoid repeating - # the corners - forward = np.s_[0:-1] # [0 ... -1) - backward = np.s_[-1:0:-1] # [-1 ... 0) - return np.concatenate(( - arr[0, forward], - arr[forward, -1], - arr[-1, backward], - arr[backward, 0], - )) - + forward = np.s_[0:-1] + backward = np.s_[-1:0:-1] + return np.concatenate((arr[0, forward], arr[forward, -1], arr[-1, backward], arr[backward, 0])) def _unfold(arr, axis, size, step): """ @@ -1990,11 +1713,7 @@ def _unfold(arr, axis, size, step): new_strides = [*arr.strides, arr.strides[axis]] new_shape[axis] = (new_shape[axis] - size) // step + 1 new_strides[axis] = new_strides[axis] * step - return np.lib.stride_tricks.as_strided(arr, - shape=new_shape, - strides=new_strides, - writeable=False) - + return np.lib.stride_tricks.as_strided(arr, shape=new_shape, strides=new_strides, writeable=False) def _array_patch_perimeters(x, rstride, cstride): """ @@ -2020,30 +1739,11 @@ def _array_patch_perimeters(x, rstride, cstride): assert rstride > 0 and cstride > 0 assert (x.shape[0] - 1) % rstride == 0 assert (x.shape[1] - 1) % cstride == 0 - # We build up each perimeter from four half-open intervals. Here is an - # illustrated explanation for rstride == cstride == 3 - # - # T T T R - # L R - # L R - # L B B B - # - # where T means that this element will be in the top array, R for right, - # B for bottom and L for left. Each of the arrays below has a shape of: - # - # (number of perimeters that can be extracted vertically, - # number of perimeters that can be extracted horizontally, - # cstride for top and bottom and rstride for left and right) - # - # Note that _unfold doesn't incur any memory copies, so the only costly - # operation here is the np.concatenate. top = _unfold(x[:-1:rstride, :-1], 1, cstride, cstride) bottom = _unfold(x[rstride::rstride, 1:], 1, cstride, cstride)[..., ::-1] right = _unfold(x[:-1, cstride::cstride], 0, rstride, rstride) left = _unfold(x[1:, :-1:cstride], 0, rstride, rstride)[..., ::-1] - return (np.concatenate((top, right, bottom, left), axis=2) - .reshape(-1, 2 * (rstride + cstride))) - + return np.concatenate((top, right, bottom, left), axis=2).reshape(-1, 2 * (rstride + cstride)) @contextlib.contextmanager def _setattr_cm(obj, **kwargs): @@ -2055,41 +1755,26 @@ def _setattr_cm(obj, **kwargs): for attr in kwargs: orig = getattr(obj, attr, sentinel) if attr in obj.__dict__ or orig is sentinel: - # if we are pulling from the instance dict or the object - # does not have this attribute we can trust the above origs[attr] = orig else: - # if the attribute is not in the instance dict it must be - # from the class level cls_orig = getattr(type(obj), attr) - # if we are dealing with a property (but not a general descriptor) - # we want to set the original value back. if isinstance(cls_orig, property): origs[attr] = orig - # otherwise this is _something_ we are going to shadow at - # the instance dict level from higher up in the MRO. We - # are going to assume we can delattr(obj, attr) to clean - # up after ourselves. It is possible that this code will - # fail if used with a non-property custom descriptor which - # implements __set__ (and __delete__ does not act like a - # stack). However, this is an internal tool and we do not - # currently have any custom descriptors. else: origs[attr] = sentinel - try: - for attr, val in kwargs.items(): + for (attr, val) in kwargs.items(): setattr(obj, attr, val) yield finally: - for attr, orig in origs.items(): + for (attr, orig) in origs.items(): if orig is sentinel: delattr(obj, attr) else: setattr(obj, attr, orig) - class _OrderedSet(collections.abc.MutableSet): + def __init__(self): self._od = collections.OrderedDict() @@ -2109,34 +1794,23 @@ def add(self, key): def discard(self, key): self._od.pop(key, None) - -# Agg's buffers are unmultiplied RGBA8888, which neither PyQt<=5.1 nor cairo -# support; however, both do support premultiplied ARGB32. - - def _premultiplied_argb32_to_unmultiplied_rgba8888(buf): """ Convert a premultiplied ARGB32 buffer to an unmultiplied RGBA8888 buffer. """ - rgba = np.take( # .take() ensures C-contiguity of the result. - buf, - [2, 1, 0, 3] if sys.byteorder == "little" else [1, 2, 3, 0], axis=2) + rgba = np.take(buf, [2, 1, 0, 3] if sys.byteorder == 'little' else [1, 2, 3, 0], axis=2) rgb = rgba[..., :-1] alpha = rgba[..., -1] - # Un-premultiply alpha. The formula is the same as in cairo-png.c. mask = alpha != 0 for channel in np.rollaxis(rgb, -1): - channel[mask] = ( - (channel[mask].astype(int) * 255 + alpha[mask] // 2) - // alpha[mask]) + channel[mask] = (channel[mask].astype(int) * 255 + alpha[mask] // 2) // alpha[mask] return rgba - def _unmultiplied_rgba8888_to_premultiplied_argb32(rgba8888): """ Convert an unmultiplied RGBA8888 buffer to a premultiplied ARGB32 buffer. """ - if sys.byteorder == "little": + if sys.byteorder == 'little': argb32 = np.take(rgba8888, [2, 1, 0, 3], axis=2) rgb24 = argb32[..., :-1] alpha8 = argb32[..., -1:] @@ -2144,14 +1818,10 @@ def _unmultiplied_rgba8888_to_premultiplied_argb32(rgba8888): argb32 = np.take(rgba8888, [3, 0, 1, 2], axis=2) alpha8 = argb32[..., :1] rgb24 = argb32[..., 1:] - # Only bother premultiplying when the alpha channel is not fully opaque, - # as the cost is not negligible. The unsafe cast is needed to do the - # multiplication in-place in an integer buffer. - if alpha8.min() != 0xff: - np.multiply(rgb24, alpha8 / 0xff, out=rgb24, casting="unsafe") + if alpha8.min() != 255: + np.multiply(rgb24, alpha8 / 255, out=rgb24, casting='unsafe') return argb32 - def _get_nonzero_slices(buf): """ Return the bounds of the nonzero region of a 2D array as a pair of slices. @@ -2160,21 +1830,18 @@ def _get_nonzero_slices(buf): that encloses all non-zero entries in *buf*. If *buf* is fully zero, then ``(slice(0, 0), slice(0, 0))`` is returned. """ - x_nz, = buf.any(axis=0).nonzero() - y_nz, = buf.any(axis=1).nonzero() + (x_nz,) = buf.any(axis=0).nonzero() + (y_nz,) = buf.any(axis=1).nonzero() if len(x_nz) and len(y_nz): - l, r = x_nz[[0, -1]] - b, t = y_nz[[0, -1]] - return slice(b, t + 1), slice(l, r + 1) + (l, r) = x_nz[[0, -1]] + (b, t) = y_nz[[0, -1]] + return (slice(b, t + 1), slice(l, r + 1)) else: - return slice(0, 0), slice(0, 0) - + return (slice(0, 0), slice(0, 0)) def _pformat_subprocess(command): """Pretty-format a subprocess command for printing/logging purposes.""" - return (command if isinstance(command, str) - else " ".join(shlex.quote(os.fspath(arg)) for arg in command)) - + return command if isinstance(command, str) else ' '.join((shlex.quote(os.fspath(arg)) for arg in command)) def _check_and_log_subprocess(command, logger, **kwargs): """ @@ -2195,42 +1862,28 @@ def _check_and_log_subprocess(command, logger, **kwargs): stderr = proc.stderr if isinstance(stderr, bytes): stderr = stderr.decode() - raise RuntimeError( - f"The command\n" - f" {_pformat_subprocess(command)}\n" - f"failed and generated the following output:\n" - f"{stdout}\n" - f"and the following error:\n" - f"{stderr}") + raise RuntimeError(f'The command\n {_pformat_subprocess(command)}\nfailed and generated the following output:\n{stdout}\nand the following error:\n{stderr}') if proc.stdout: - logger.debug("stdout:\n%s", proc.stdout) + logger.debug('stdout:\n%s', proc.stdout) if proc.stderr: - logger.debug("stderr:\n%s", proc.stderr) + logger.debug('stderr:\n%s', proc.stderr) return proc.stdout - def _backend_module_name(name): """ Convert a backend name (either a standard backend -- "Agg", "TkAgg", ... -- or a custom backend -- "module://...") to the corresponding module name). """ - return (name[9:] if name.startswith("module://") - else f"matplotlib.backends.backend_{name.lower()}") - + return name[9:] if name.startswith('module://') else f'matplotlib.backends.backend_{name.lower()}' def _setup_new_guiapp(): """ Perform OS-dependent setup when Matplotlib creates a new GUI application. """ - # Windows: If not explicit app user model id has been set yet (so we're not - # already embedded), then set it to "matplotlib", so that taskbar icons are - # correct. try: _c_internal_utils.Win32_GetCurrentProcessExplicitAppUserModelID() except OSError: - _c_internal_utils.Win32_SetCurrentProcessExplicitAppUserModelID( - "matplotlib") - + _c_internal_utils.Win32_SetCurrentProcessExplicitAppUserModelID('matplotlib') def _format_approx(number, precision): """ @@ -2239,27 +1892,14 @@ def _format_approx(number, precision): """ return f'{number:.{precision}f}'.rstrip('0').rstrip('.') or '0' - def _g_sig_digits(value, delta): """ Return the number of significant digits to %g-format *value*, assuming that it is known with an error of *delta*. """ if delta == 0: - # delta = 0 may occur when trying to format values over a tiny range; - # in that case, replace it by the distance to the closest float. delta = abs(np.spacing(value)) - # If e.g. value = 45.67 and delta = 0.02, then we want to round to 2 digits - # after the decimal point (floor(log10(0.02)) = -2); 45.67 contributes 2 - # digits before the decimal point (floor(log10(45.67)) + 1 = 2): the total - # is 4 significant digits. A value of 0 contributes 1 "digit" before the - # decimal point. - # For inf or nan, the precision doesn't matter. - return max( - 0, - (math.floor(math.log10(abs(value))) + 1 if value else 1) - - math.floor(math.log10(delta))) if math.isfinite(value) else 0 - + return max(0, (math.floor(math.log10(abs(value))) + 1 if value else 1) - math.floor(math.log10(delta))) if math.isfinite(value) else 0 def _unikey_or_keysym_to_mplkey(unikey, keysym): """ @@ -2268,27 +1908,20 @@ def _unikey_or_keysym_to_mplkey(unikey, keysym): The Unicode key is checked first; this avoids having to list most printable keysyms such as ``EuroSign``. """ - # For non-printable characters, gtk3 passes "\0" whereas tk passes an "". if unikey and unikey.isprintable(): return unikey key = keysym.lower() - if key.startswith("kp_"): # keypad_x (including kp_enter). + if key.startswith('kp_'): key = key[3:] - if key.startswith("page_"): # page_{up,down} - key = key.replace("page_", "page") - if key.endswith(("_l", "_r")): # alt_l, ctrl_l, shift_l. + if key.startswith('page_'): + key = key.replace('page_', 'page') + if key.endswith(('_l', '_r')): key = key[:-2] - if sys.platform == "darwin" and key == "meta": - # meta should be reported as command on mac - key = "cmd" - key = { - "return": "enter", - "prior": "pageup", # Used by tk. - "next": "pagedown", # Used by tk. - }.get(key, key) + if sys.platform == 'darwin' and key == 'meta': + key = 'cmd' + key = {'return': 'enter', 'prior': 'pageup', 'next': 'pagedown'}.get(key, key) return key - @functools.cache def _make_class_factory(mixin_class, fmt, attr_name=None): """ @@ -2310,86 +1943,58 @@ def _make_class_factory(mixin_class, fmt, attr_name=None): @functools.cache def class_factory(axes_class): - # if we have already wrapped this class, declare victory! if issubclass(axes_class, mixin_class): return axes_class - - # The parameter is named "axes_class" for backcompat but is really just - # a base class; no axes semantics are used. base_class = axes_class class subcls(mixin_class, base_class): - # Better approximation than __module__ = "matplotlib.cbook". __module__ = mixin_class.__module__ def __reduce__(self): - return (_picklable_class_constructor, - (mixin_class, fmt, attr_name, base_class), - self.__getstate__()) - + return (_picklable_class_constructor, (mixin_class, fmt, attr_name, base_class), self.__getstate__()) subcls.__name__ = subcls.__qualname__ = fmt.format(base_class.__name__) if attr_name is not None: setattr(subcls, attr_name, base_class) return subcls - class_factory.__module__ = mixin_class.__module__ return class_factory - def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class): """Internal helper for _make_class_factory.""" factory = _make_class_factory(mixin_class, fmt, attr_name) cls = factory(base_class) return cls.__new__(cls) - def _is_torch_array(x): """Check if 'x' is a PyTorch Tensor.""" try: - # we're intentionally not attempting to import torch. If somebody - # has created a torch array, torch should already be in sys.modules return isinstance(x, sys.modules['torch'].Tensor) - except Exception: # TypeError, KeyError, AttributeError, maybe others? - # we're attempting to access attributes on imported modules which - # may have arbitrary user code, so we deliberately catch all exceptions + except Exception: return False - def _is_jax_array(x): """Check if 'x' is a JAX Array.""" try: - # we're intentionally not attempting to import jax. If somebody - # has created a jax array, jax should already be in sys.modules return isinstance(x, sys.modules['jax'].Array) - except Exception: # TypeError, KeyError, AttributeError, maybe others? - # we're attempting to access attributes on imported modules which - # may have arbitrary user code, so we deliberately catch all exceptions + except Exception: return False - def _unpack_to_numpy(x): """Internal helper to extract data from e.g. pandas and xarray objects.""" if isinstance(x, np.ndarray): - # If numpy, return directly return x if hasattr(x, 'to_numpy'): - # Assume that any to_numpy() method actually returns a numpy array return x.to_numpy() if hasattr(x, 'values'): xtmp = x.values - # For example a dict has a 'values' attribute, but it is not a property - # so in this case we do not want to return a function if isinstance(xtmp, np.ndarray): return xtmp if _is_torch_array(x) or _is_jax_array(x): xtmp = x.__array__() - - # In case __array__() method does not return a numpy array in future if isinstance(xtmp, np.ndarray): return xtmp return x - def _auto_format_str(fmt, value): """ Apply *value* to the format string *fmt*.