diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ac9fa43b..b205e5de 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: with: # Pin ruff version to make sure we do not break our builds at the # worst times - version: "0.14.0" + version: "0.14.7" test: # name: Test (${{ matrix.python-version }}, ${{ matrix.os }}) @@ -34,14 +34,14 @@ jobs: shell: bash -l {0} steps: - - uses: actions/checkout@v5 - - name: Set up Python ${{ matrix.python-version }} - uses: conda-incubator/setup-miniconda@v3 - with: - auto-update-conda: true - environment-file: environment.yml - python-version: ${{ matrix.python-version }} - - run: conda info - - run: conda list - - run: conda config --show - - run: pytest + - uses: actions/checkout@v5 + - name: Set up Python ${{ matrix.python-version }} + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + environment-file: environment.yml + python-version: ${{ matrix.python-version }} + - run: conda info + - run: conda list + - run: conda config --show + - run: pytest diff --git a/condarecipe/larray-editor/meta.yaml b/condarecipe/larray-editor/meta.yaml index 256f38f2..9380696b 100644 --- a/condarecipe/larray-editor/meta.yaml +++ b/condarecipe/larray-editor/meta.yaml @@ -18,13 +18,18 @@ requirements: host: - python >=3.9 - pip + - setuptools run: - python >=3.9 - # requires larray >= 0.32 because of the LArray -> Array rename + # Technically, we should require larray >=0.35 because we need align_arrays + # for compare(), but to make larray-editor releasable, we cannot depend on + # larray X.Y when releasing larray-editor X.Y (see utils.py for more + # details) + # TODO: require 0.35 for next larray-editor version - larray >=0.32 # it is indirectly pulled from larray, but let us be explicit about this - - numpy + - numpy >=1.22 - matplotlib - pyqt >=5 - qtpy >=2 # for Qt6 support diff --git a/doc/source/changes/version_0_35.rst.inc b/doc/source/changes/version_0_35.rst.inc index f4ade10e..7aeac275 100644 --- a/doc/source/changes/version_0_35.rst.inc +++ b/doc/source/changes/version_0_35.rst.inc @@ -1,25 +1,96 @@ .. py:currentmodule:: larray_editor -Syntax changes -^^^^^^^^^^^^^^ - -* renamed ``MappingEditor.old_method_name()`` to :py:obj:`MappingEditor.new_method_name()` (closes :editor_issue:`1`). - -* renamed ``old_argument_name`` argument of :py:obj:`MappingEditor.method_name()` to ``new_argument_name``. - - -Backward incompatible changes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -* other backward incompatible changes - - New features ^^^^^^^^^^^^ +* allow displaying *many* more different kinds of objects, and not only arrays + from larray. One specific goal when developing this new feature was speed. + Most of these viewers should be fast (when at all possible), even on (very) + large datasets. We only support displaying (not editing) all the new types. + + The following types are supported so far (but adding more is relatively easy): + + * Python builtin objects: + - tuple (including named tuple), list (sequences), dict (mappings), + dict views, memoryview and array + - text and binary files + * Python stdlib objects: + - pathlib.Path + * if the path points to a directory, it will display the content of the + directory + * if the path points to a file, it will try to display it, if we + implemented support for that file type (see below for the list + of supported types). + - sqlite3.Connection (and their tables) + - pstats.Stats (results of Python's profiler) + - zipfile.ZipFile and zipfile.Path + * new objects from LArray: Axis, Excel Workbook (what you get from + larray.open_excel()), Sheets and Range + * IODE "collections" objects: Comments, Equations, Identities, Lists, Tables, + Scalars and Variables, as well as Table objects + * Pandas: DataFrame, Series and DataFrameGroupBy + * Polars: DataFrame and LazyFrame + * Numpy: ndarray + * PyArrow: Array, Table, RecordBatchFileReader (reader object for feather + files) and ParquetFile + * Narwhals: DataFrame and LazyFrame + * PyTables: File, Group (with special support for Pandas DataFrames written + in HDF files), Array and Table + * IBIS: Table + * DuckDB: DuckDBPyConnection and DuckDBPyRelation (what you receive from any + query) + + File types (extensions) currently supported: + - Iode files: .ac, .ae, .ai, .al, .as, .at, .av, .cmt, .eqs, .idt, .lst, + .scl, .tbl, .var + - Text files: .bat, .c, .cfg, .cpp, .h, .htm, .html, .ini, .log, .md, + .py, .pyx, .pxd, .rep, .rst, .sh, .sql, .toml, .txt, .wsgi, + .yaml, .yml + - HDF5 files: .h5, .hdf + - Parquet files: .parquet + - Stata files: .dta + - Feather files: .feather + - SAS files: .sas7bdat + It is limited to the first few thousand rows (the exact number depends on + the number of columns), because reading later rows get increasingly slow, + to the point of being unusable. + - CSV files: .csv + - Gzipped CSV files: .csv.gz + - Excel files: .xls, .xlsx + - Zip files: .zip + - DuckDB files: .ddb, .duckdb + +* the editor now features a new "File Explorer" (accessible from the "File" + menu) so that one can more easily make use of all the above file viewers. + +* added a new SQL Console (next to the iPython console) for querying Polars + structures (DataFrame, LazyFrame and Series) as SQL tables. The console + features auto-completion for SQL keywords, table names and column names + and stores the last 1000 queries (even across sessions). Recalling a query + from history is done with the up and down arrows and like in the iPython + console, it searches through history with the current command as prefix. + This console will only be present if the polars module is installed. + +* allow sorting some objects by column by pressing on an horizontal label. + This is currently implemented for the following objects: + - python built-in sequences (e.g. tuples and lists) + - python pathlib.Path objects representing directories + - LArray (only for 2D arrays) + - Pandas DataFrame + - Polars DataFrame and LazyFrame + - Narwhals LazyFrame + - SQLite tables + - DuckDB relations + +* allow filtering some objects by pressing on an horizontal label. + This is currently implemented for the following objects: + - Pandas DataFrame + - Polars DataFrame and LazyFrame + - DuckDB relations + * allow comparing arrays/sessions with different axes in :py:obj:`compare()`. - The function gained ``align`` and ``fill_value`` arguments and the interface has a new - combobox to change the alignment method during the comparison: + The function gained ``align`` and ``fill_value`` arguments and the interface + has a new combobox to change the alignment method during the comparison: - outer: will use a label if it is in any array (ordered like the first array). This is the default as it results in no information loss. - inner: will use a label if it is in all arrays (ordered like the first array). @@ -28,28 +99,30 @@ New features - exact: raise an error when axes are not equal. Closes :editor_issue:`214` and :editor_issue:`251`. -* double-clicking on an array name in the list will open it in a new window +* double-clicking on a name in the variable list will open it in a new window (closes :editor_issue:`143`). - .. note:: - - - It works for foo bar ! - - It does not work for foo baz ! - Miscellaneous improvements ^^^^^^^^^^^^^^^^^^^^^^^^^^ -* made the editor interruptible by an outside program (i.e. made PyCharm stop & restart buttons work directly - instead of only when the editor receives the focus again). Closes :editor_issue:`257`. +* made the editor interruptible by an outside program (i.e. made PyCharm stop & + restart buttons work directly instead of only when the editor receives the + focus again). Closes :editor_issue:`257`. + +* resize axes and vertical label columns automatically -* when comparing sessions via :py:obj:`compare()`, the color of arrays in the list is now updated depending - on the tolerance. To reflect that the tolerance widget moved to the top of the interface. - Closes :editor_issue:`201`. +* string values are left aligned instead of right aligned -* typing the name of a variable holding a matplotlib figure (or axes) in the console shows it - (previously, only expressions were displayed and *not* simple variables). - For example: :: +* when comparing sessions via :py:obj:`compare()`, the color of arrays in the + list is now updated depending on the tolerance. To reflect that the tolerance + widget moved to the top of the interface. Closes :editor_issue:`201`. + +* :py:obj:`compare()` max difference is colored red when the difference is not 0 + +* typing the name of a variable holding a matplotlib figure (or axes) in the + console shows it (previously, only expressions were displayed and *not* + simple variables). For example: :: >>> arr.plot() @@ -66,6 +139,10 @@ Miscellaneous improvements Fixes ^^^^^ -* fixed :py:obj:`compare()` colors when the only difference is nans on either side. +* fixed :py:obj:`compare()` colors when the only difference is nans on either + side. + +* fixed :py:obj:`compare()` colors and max difference label when the only + differences are for rows where the value is 0 in the first array. -* fixed something (closes :editor_issue:`1`). +* fixed single column plot in viewer when ticks are not strings \ No newline at end of file diff --git a/larray_editor/api.py b/larray_editor/api.py index d61d2b49..9b10798c 100644 --- a/larray_editor/api.py +++ b/larray_editor/api.py @@ -159,7 +159,7 @@ def _get_title(obj, depth=0, maxnames=3): def create_edit_dialog(parent, obj=None, title='', minvalue=None, maxvalue=None, readonly=False, depth=0, - display_caller_info=True, add_larray_functions=None): + display_caller_info=True, add_larray_functions=None, **kwargs): """ Open a new editor window. @@ -198,7 +198,7 @@ def create_edit_dialog(parent, obj=None, title='', minvalue=None, maxvalue=None, return MappingEditorWindow(obj, title=title, readonly=readonly, caller_info=caller_info, add_larray_functions=add_larray_functions, - parent=parent) + parent=parent, **kwargs) else: return ArrayEditorWindow(obj, title=title, readonly=readonly, caller_info=caller_info, @@ -212,11 +212,10 @@ def create_debug_dialog(parent, stack_summary, title='Debugger', stack_pos=None) def create_compare_dialog(parent, *args, title='', names=None, depth=0, display_caller_info=True, **kwargs): - caller_frame = sys._getframe(depth + 1) - if display_caller_info: - caller_info = getframeinfo(caller_frame) - else: - caller_info = None + if len(args) == 1 and isinstance(args[0], dict): + if names is None: + names = list(args[0].keys()) + args = list(args[0].values()) compare_sessions = any(isinstance(a, (la.Session, str, Path)) for a in args) default_name = 'session' if compare_sessions else 'array' @@ -232,7 +231,8 @@ def get_name(i, obj, depth=0): # list comprehension used to create their own frame but are now # inlined in Python 3.12+ extra_frame_for_comprehension = 0 if PY312 else 1 - names = [get_name(i, a, depth=depth + 1 + extra_frame_for_comprehension) for i, a in enumerate(args)] + names = [get_name(i, a, depth=depth + 1 + extra_frame_for_comprehension) + for i, a in enumerate(args)] else: assert isinstance(names, list) and len(names) == len(args) @@ -240,6 +240,12 @@ def get_name(i, obj, depth=0): args = [la.Session(a) if not isinstance(a, la.Session) else a for a in args] + if display_caller_info: + caller_frame = sys._getframe(depth + 1) + caller_info = getframeinfo(caller_frame) + else: + caller_info = None + if compare_sessions: return SessionComparatorWindow(args, names=names, title=title, caller_info=caller_info, parent=parent, @@ -381,7 +387,7 @@ def excepthook(type_, value, tback): return excepthook -def edit(obj=None, title='', minvalue=None, maxvalue=None, readonly=False, depth=0): +def edit(obj=None, title='', minvalue=None, maxvalue=None, readonly=False, depth=0, **kwargs): r""" Open a new editor window. @@ -415,7 +421,7 @@ def edit(obj=None, title='', minvalue=None, maxvalue=None, readonly=False, depth >>> edit(a1) # doctest: +SKIP """ _show_dialog("Viewer", create_edit_dialog, obj=obj, title=title, minvalue=minvalue, maxvalue=maxvalue, - readonly=readonly, depth=depth + 1) + readonly=readonly, depth=depth + 1, **kwargs) def view(obj=None, title='', depth=0): @@ -470,8 +476,11 @@ def compare(*args, depth=0, **kwargs): Parameters ---------- - *args : Arrays, Sessions, str or Path. - Arrays or sessions to compare. Strings or Path will be loaded as Sessions from the corresponding files. + *args : Arrays, Sessions, str or Path, or dict of them. + Arrays or sessions to compare. Strings or Path will be loaded as + Sessions from the corresponding files. If arrays or sessions are given + as a single dict, the keys of the dict will be used as names for the + arrays or sessions. title : str, optional Title for the window. Defaults to ''. names : list of str, optional diff --git a/larray_editor/arrayadapter.py b/larray_editor/arrayadapter.py index d1650e28..ccd2bda4 100644 --- a/larray_editor/arrayadapter.py +++ b/larray_editor/arrayadapter.py @@ -1,129 +1,521 @@ +# FIXME: * drag and drop axes uses set_data while changing filters do not + +# TODO: +# redesign (again) the adapter <> arraymodel boundary: +# - the adapter may return buffers of any size (the most efficient size +# which includes the requested area). It must include the requested area if +# it exists. The buffer must be reasonably small (must fit in RAM +# comfortably). In that case, the adapter must also return actual hstart +# and vstart. +# >>> on second thoughts, I am unsure this is a good idea. It might be +# better to store the entire buffer on the adapter and have a +# BufferedAdapter base class (or maybe do this in AbstractAdapter +# directly -- but doing this for in-memory containers is wasteful). +# - the buffers MUST be 2D +# - what about type? numpy or any sequence? +# * we should always have 2 buffers worth in memory +# - asking for a new buffer/chunk should be done in a Thread +# - when there are less than X lines unseen, ask for more. X should depend on +# size of buffer and time to fetch a new buffer +# TODO (long term): add support for streaming data source. In that case, +# the behavior should be mostly what we had before (when we scroll, it +# requests more data and the total length of the scrollbar is updated) +# +# TODO (even longer term): add support for streaming data source with a limit +# (ie keep last N entries) +# +# TODO: add support for "progressive" data sources, e.g. pandas SAS reader +# (from pandas.io.sas.sas7bdat import SAS7BDATReader), which can read by +# chunks but cannot read a particular offset. It would be crazy to +# re-read the whole thing up to the requested data each time, but caching +# the whole file in memory probably isn't desirable/feasible either, so +# I guess the best we can do is to cache as many chunks as we can without +# filling up the memory (the first chunk + the last few we just read +# are probably the most likely to be re-visited) and read from the file +# if the user requests some data outside of those chunks +import collections.abc +import logging +import sys +import os +import math +import importlib +import itertools +import time +# import types +from datetime import datetime +from typing import Optional +from pathlib import Path + import numpy as np import larray as la +from larray.util.misc import Product + +from larray_editor.utils import (get_sample, scale_to_01range, + is_number_value_vectorized, logger, + timed) +from larray_editor.commands import CellValueChange + +MAX_FILTER_OPTIONS = 1001 -from larray_editor.utils import Product, _LazyDimLabels, Axis, get_sample -from larray_editor.commands import ArrayValueChange + +def indirect_sort(seq, ascending): + return sorted(range(len(seq)), key=seq.__getitem__, reverse=not ascending) REGISTERED_ADAPTERS = {} +REGISTERED_ADAPTERS_USING_STRINGS = {} +REGISTERED_ADAPTER_TYPES = None +KB = 2 ** 10 +MB = 2 ** 20 + + +def register_adapter_using_string(target_type: str, adapter_creator): + """Register an adapter to display a type + + Parameters + ---------- + target_type : str + Type for which the adapter should be used, given as a string. + adapter_creator : callable + Callable which will return an Adapter instance + """ + assert '.' in target_type + top_module_name, type_name = target_type.split('.', maxsplit=1) + module_adapters = REGISTERED_ADAPTERS_USING_STRINGS.setdefault(top_module_name, {}) + if type_name in module_adapters: + logger.warning(f"Replacing adapter for {target_type}") + module_adapters[type_name] = adapter_creator + # container = REGISTERED_ADAPTERS_USING_STRINGS + # parts = target_type.split('.') + # for i, p in enumerate(parts): + # if i == len(parts) - 1: + # if p in container: + # print(f"Warning: replacing adapter for {target_type}") + # container[p] = adapter_creator + # else: + # container = container.setdefault(p, {}) + +# TODO: sadly we cannot use functools.singledispatch because it does not support string types, +# but the MRO stuff is a lot better than my own code so I could inspire myself with that. +# Ideally, I could add support for string types in singledispatch and propose the addition +# to Python +def register_adapter(target_type, adapter_creator): + """Register an adapter to display a type + + Parameters + ---------- + target_type : str | type + Type for which the adapter should be used. + adapter_creator : callable + Callable which will return an Adapter instance + """ + if isinstance(target_type, str): + register_adapter_using_string(target_type, adapter_creator) + return + if target_type in REGISTERED_ADAPTERS: + logger.warning(f"Warning: replacing adapter for {target_type}") -def register_adapter(target_type): - """Class decorator to register new adapter + REGISTERED_ADAPTERS[target_type] = adapter_creator + + # normally, the list is created only once when a first adapter is + # asked for, but if an adapter is registered after that point we need + # to update the list + if REGISTERED_ADAPTER_TYPES is not None: + update_registered_adapter_types() + + +def adapter_for(target_type): + """Class decorator to register new adapters Parameters ---------- - target_type : type + target_type : str | type Type handled by adapter class. """ - def decorate_class(adapter_cls): - if target_type not in REGISTERED_ADAPTERS: - REGISTERED_ADAPTERS[target_type] = adapter_cls - return adapter_cls - return decorate_class + def decorate_callable(adapter_creator): + register_adapter(target_type, adapter_creator) + return adapter_creator + return decorate_callable + + +PATH_SUFFIX_ADAPTERS = {} + + +def register_path_adapter(suffixes, adapter_creator, required_module=None): + """Register an adapter to display a file type (extension) + + Parameters + ---------- + suffixes : str | list[str] + File extension(s) for which the adapter should be used. + adapter_creator : callable + Callable which will return an Adapter instance. + required_module : str + Name of module required to handle this file type. + """ + if isinstance(suffixes, str): + suffixes = [suffixes] + for suffix in suffixes: + if suffix in PATH_SUFFIX_ADAPTERS: + logger.warning(f"Replacing path adapter for {suffix}") + PATH_SUFFIX_ADAPTERS[suffix] = (adapter_creator, required_module) + + +def path_adapter_for(suffixes, required_module=None): + """Class/function decorator to register new file-type adapters + + Parameters + ---------- + suffixes : str | list[str] + File extension(s) associated with adapter class. + required_module : str, optional + Name of module required to handle this file type. + """ + def decorate_callable(adapter_creator): + register_path_adapter(suffixes, adapter_creator, required_module) + return adapter_creator + return decorate_callable + + +def get_adapter_creator_for_type(data_type): # -> AbstractAdapter | func | None: + # first check precise type + if data_type in REGISTERED_ADAPTERS: + return REGISTERED_ADAPTERS[data_type] + + data_type_full_module_name = data_type.__module__ + if '.' in data_type_full_module_name: + data_type_top_module_name, _ = data_type_full_module_name.split('.', maxsplit=1) + else: + data_type_top_module_name = data_type_full_module_name + + # handle string types + if data_type_top_module_name in REGISTERED_ADAPTERS_USING_STRINGS: + assert data_type_top_module_name in sys.modules + module = sys.modules[data_type_top_module_name] + module_adapters = REGISTERED_ADAPTERS_USING_STRINGS[data_type_top_module_name] + # register all adapters for that module using concrete types (instead + # of string types) + for str_adapter_type, adapter in list(module_adapters.items()): + # submodule + type_name = str_adapter_type + while '.' in type_name and module is not None: + submodule_name, type_name = type_name.split('.', maxsplit=1) + module = getattr(module, submodule_name, None) + # submodule not found (probably not loaded yet) + if module is None: + continue + + adapter_type = getattr(module, type_name, None) + if adapter_type is None: + continue + + # cache real adapter type if we have (more) objects of that kind to + # display later + REGISTERED_ADAPTERS[adapter_type] = adapter + + update_registered_adapter_types() + + # remove string form from adapters mapping + del module_adapters[str_adapter_type] + if not module_adapters: + del REGISTERED_ADAPTERS_USING_STRINGS[data_type_top_module_name] + + # then check subclasses + if REGISTERED_ADAPTER_TYPES is None: + update_registered_adapter_types() + + for adapter_type in REGISTERED_ADAPTER_TYPES: + if issubclass(data_type, adapter_type): + return REGISTERED_ADAPTERS[adapter_type] + return None + + +def get_adapter_creator(data): # -> AbstractAdapter | str: + obj_type = type(data) + creator = get_adapter_creator_for_type(obj_type) + # 3 options: + # - the type is not handled + if creator is None: + return f"Cannot display objects of type {obj_type.__name__}" + # - all instances of the type are handled by the same adapter + elif isinstance(creator, type) and issubclass(creator, AbstractAdapter): + return creator + # - different adapters handle that type and/or not all instance are handled + else: + return creator(data) + + +def update_registered_adapter_types(): + global REGISTERED_ADAPTER_TYPES + + REGISTERED_ADAPTER_TYPES = list(REGISTERED_ADAPTERS.keys()) + # sort classes with longer MRO first, so that subclasses come before + # their parent class + def class_mro_length(cls): + return len(cls.mro()) + REGISTERED_ADAPTER_TYPES.sort(key=class_mro_length, reverse=True) -def get_adapter(data, bg_value): + +def get_adapter(data, attributes=None): if data is None: return None - data_type = type(data) - if data_type not in REGISTERED_ADAPTERS: - raise TypeError(f"No Adapter implemented for data with type {data_type}") - adapter_cls = REGISTERED_ADAPTERS[data_type] - return adapter_cls(data, bg_value) + adapter_creator = get_adapter_creator(data) + assert adapter_creator is not None + if isinstance(adapter_creator, str): + raise TypeError(adapter_creator) + resource_handle = adapter_creator.open(data) + return adapter_creator(resource_handle, attributes) -class AbstractAdapter: - def __init__(self, data, bg_value): - self.data = data - self.bg_value = bg_value - self.current_filter = {} - self.update_filtered_data() - self.ndim = None - self.size = None - self.dtype = None +def nd_shape_to_2d(shape, num_h_axes=1): + """ - # ===================== # - # PROPERTIES # - # ===================== # + Parameters + ---------- + shape : tuple + num_h_axes : int, optional + Defaults to 1. - @property - def data(self): - return self._original_data + Examples + -------- + >>> nd_shape_to_2d(()) + (1, 1) + >>> nd_shape_to_2d((2,)) + (1, 2) + >>> nd_shape_to_2d((0,)) + (1, 0) + >>> nd_shape_to_2d((2, 3)) + (2, 3) + >>> nd_shape_to_2d((2, 0)) + (2, 0) + >>> nd_shape_to_2d((2, 3, 4)) + (6, 4) + >>> nd_shape_to_2d((2, 3, 0)) + (6, 0) + >>> nd_shape_to_2d((2, 0, 4)) + (0, 4) + >>> nd_shape_to_2d((), num_h_axes=2) + (1, 1) + >>> nd_shape_to_2d((2,), num_h_axes=2) + (1, 2) + >>> nd_shape_to_2d((2, 3), num_h_axes=2) + (1, 6) + >>> nd_shape_to_2d((2, 3, 4), num_h_axes=2) + (2, 12) + >>> nd_shape_to_2d((), num_h_axes=0) + (1, 1) + >>> nd_shape_to_2d((2,), num_h_axes=0) + (2, 1) + >>> nd_shape_to_2d((2, 3), num_h_axes=0) + (6, 1) + >>> nd_shape_to_2d((2, 3, 4), num_h_axes=0) + (24, 1) - @data.setter - def data(self, original_data): - assert original_data is not None, f"{self.__class__} does not accept None as input data" - self._original_data = self.prepare_data(original_data) + Returns + ------- + shape: tuple of integers + 2d shape + """ + shape_v = shape[:-num_h_axes] if num_h_axes else shape + shape_h = shape[-num_h_axes:] if num_h_axes else () + return np.prod(shape_v, dtype=int), np.prod(shape_h, dtype=int) - @property - def bg_value(self): - return self._bg_value - @bg_value.setter - def bg_value(self, bg_value): - self._bg_value = self.prepare_bg_value(bg_value) +# CHECK: maybe implement decorator to mark any method as a context menu action. But what we need is not a method +# which does the action, but a method which adds a command part to the current command. - # ===================== # - # METHODS TO OVERRIDE # - # ===================== # +# @context_menu('Transpose') +# def transpose(self): +# pass - def prepare_data(self, data): - """Must be overridden if data passed to set_data need some checks and/or transformations""" - return data +class AbstractAdapter: + # TODO: we should have a way to provide other attributes: format, readonly, font (problematic for colwidth), + # align, tooltips, flags?, min_value, max_value (for the delegate), ... + # I guess data itself will need to be a dict: {'values': ...} + def __init__(self, data, attributes=None): + self.data = data + self.attributes = attributes + # CHECK: filters will probably not make it as-is after quickbar is implemented: they will need to move + # to the axes area + # AND possibly h/vlabels and + # must update the current command + self.current_filter = {} + self._current_sort = [] - def prepare_bg_value(self, bg_value): - """Must be overridden if bg_value passed to set_data need some checks and/or transformations""" - return bg_value + # FIXME: this is an ugly/quick&dirty workaround + # AFAICT, this is only used in ArrayDelegate + self.dtype = np.dtype(object) + # self.dtype = None + self.vmin = None + self.vmax = None + self._number_format = "%s" + self.sort_key = None # (kind='axis'|'column'|'row', idx_of_kind, direction (1, -1)) + # caching support + self._cached_fragment = None + self._cached_fragment_v_start = None + self._cached_fragment_h_start = None - def filter_data(self, data, filter): - """Return filtered data""" - raise NotImplementedError() + # ================================ # + # methods which MUST be overridden # + # ================================ # + # def get_values(self, h_start, v_start, h_stop, v_stop): + # raise NotImplementedError() - def get_axes(self, data): - """Return list of :py:class:`Axis` or an empty list in case of a scalar or an empty array. + # TODO: split this into: + # - extract_chunk_from_data (result is in native/cheapest + # format to produce) + # and + # - native_chunk_to_2D_sequence + # the goal is to cache chunks + def get_chunk_from_data(self, data, h_start, v_start, h_stop, v_stop): + """ + Extract a subset of a data object of the type the adapter handles. + Must return a 2D sequence, preferably a numpy array. """ raise NotImplementedError() - def _get_raw_data(self, data): - """Return internal data as a ND Numpy array""" + def shape2d(self): raise NotImplementedError() - def _get_bg_value(self, bg_value): - """Return bg_value as ND Numpy array or None. - It must have the same shape as data if not None. + # =============================== # + # methods which CAN be overridden # + # =============================== # + + def cell_activated(self, row_idx, column_idx): """ - raise NotImplementedError() + If this method returns a (not None) value, it will be used as the new + value for the array_editor_widget. Later this should add an operand on + the quickbar but we are not there yet. + """ + return None - def _from_selection(self, raw_data, axes_names, vlabels, hlabels): - """Create and return an object of type managed by the adapter subclass. + def get_values(self, h_start, v_start, h_stop, v_stop): + return self.get_chunk_from_data(self.data, h_start, v_start, h_stop, v_stop) - Parameters - ---------- - raw_data : Numpy.ndarray - Array of selected data. - axes_names : list of string - List of axis names - vlabels : nested list - Selected vertical labels - hlabels: list - Selected horizontal labels + @classmethod + def open(cls, data): + """Open the ressources used by the adapter - Returns - ------- - Object of the type managed by the adapter subclass. - """ - raise NotImplementedError() + The result of this method will be stored in the .data + attribute and passed as argument to the adapter class""" + return data - def move_axis(self, data, bg_value, old_index, new_index): - """Move an axis of the data array and associated bg value. + def close(self): + """Close the ressources used by the adapter""" + pass + + def _is_chunk_cached(self, h_start, v_start, h_stop, v_stop): + cached_fragment = self._cached_fragment + if cached_fragment is None: + return False + cached_h_start = self._cached_fragment_h_start + cached_v_start = self._cached_fragment_v_start + cached_width = cached_fragment.shape[1] + cached_height = cached_fragment.shape[0] + return (h_start >= cached_h_start and + h_stop <= cached_h_start + cached_width and + v_start >= cached_v_start and + v_stop <= cached_v_start + cached_height) + + def _get_fragment_via_cache(self, h_start, v_start, h_stop, v_stop): + clsname = self.__class__.__name__ + logger.debug(f"{clsname}._get_fragment_via_cache({h_start, v_start, h_stop, v_stop})") + if self._is_chunk_cached(h_start, v_start, h_stop, v_stop): + fragment = self._cached_fragment + fragment_h_start = self._cached_fragment_h_start + fragment_v_start = self._cached_fragment_v_start + logger.debug(" -> cache hit ! " + f"({fragment_h_start=} {fragment_v_start=})") + else: + fragment, fragment_h_start, fragment_v_start = ( + self._get_fragment_from_source(h_start, v_start, + h_stop, v_stop)) + logger.debug(" -> cache miss ! " + f"({fragment_h_start=} {fragment_v_start=})") + self._cached_fragment = fragment + self._cached_fragment_h_start = fragment_h_start + self._cached_fragment_v_start = fragment_v_start + return fragment, fragment_h_start, fragment_v_start + + # TODO: factorize with LArrayArrayAdapter (so that we get the attributes + # handling of LArrayArrayAdapter for all types and the larray adapter + # can benefit from the generic code here + @timed(logger) + def get_data_values_and_attributes(self, h_start, v_start, h_stop, v_stop): + """h_stop and v_stop should *not* be included""" + # TODO: implement region caching + logger.debug( + f"{self.__class__.__name__}.get_data_values_and_attributes(" + f"{h_start=}, {v_start=}, {h_stop=}, {v_stop=})" + ) + height, width = self.shape2d() + assert v_start >= 0, f"v_start ({v_start}) is out of bounds (should be >= 0)" + assert h_start >= 0, f"h_start ({h_start}) is out of bounds (should be >= 0)" + assert v_stop >= 0, f"v_stop ({v_stop}) is out of bounds (should be >= 0)" + assert h_stop >= 0, f"h_stop ({h_stop}) is out of bounds (should be >= 0)" + if height > 0: + assert v_start < height, f"v_start ({v_start}) is out of bounds (should be < {height})" + if width > 0: + assert h_start < width, f"h_start ({h_start}) is out of bounds (should be < {width})" + assert v_stop <= height, f"v_stop ({v_stop}) is out of bounds (should be <= {height})" + assert h_stop <= width, f"h_stop ({h_stop}) is out of bounds (should be <= {width})" + chunk_values = self.get_values(h_start, v_start, h_stop, v_stop) + if isinstance(chunk_values, np.ndarray): + assert chunk_values.ndim == 2 + logger.debug(f" {chunk_values.shape=}") + elif isinstance(chunk_values, list) and len(chunk_values) == 0: + chunk_values = [[]] + + # Without specifying dtype=object, asarray converts sequences + # containing both strings and numbers to all strings which then + # fail in get_color_value, but we do not want to convert + # existing numpy arrays to object dtype. This is a bit silly and + # inefficient for numeric-only sequences, but I do not see + # a better way. + if not isinstance(chunk_values, np.ndarray): + chunk_values = np.asarray(chunk_values, dtype=object) + finite_values = get_finite_numeric_values(chunk_values) + vmin, vmax = self.update_finite_min_max_values(finite_values, + h_start, v_start, + h_stop, v_stop) + color_value = scale_to_01range(finite_values, vmin, vmax) + chunk_format = self.get_format(chunk_values, h_start, v_start, h_stop, v_stop) + return {'data_format': chunk_format, + 'values': chunk_values, + 'bg_value': color_value} + + def get_format(self, chunk_values, h_start, v_start, h_stop, v_stop): + return [[self._number_format]] + + def set_format(self, fmt): + """Change display format""" + # print(f"setting adapter format: {fmt}") + self._number_format = fmt + + def from_clipboard_data_to_model_data(self, list_data): + return list_data + + def get_axes_labels_and_data_values(self, row_min, row_max, col_min, col_max): + axes_names = self.get_axes_area() + axes_names = axes_names['values'] if isinstance(axes_names, dict) else axes_names + hlabels = self.get_hlabels_values(col_min, col_max) + vlabels = self.get_vlabels_values(row_min, row_max) + raw_data = self.get_values(col_min, row_min, col_max, row_max) + if isinstance(raw_data, list) and len(raw_data) == 0: + raw_data = [[]] + return axes_names, vlabels, hlabels, raw_data + + def move_axis(self, data, attributes, old_index, new_index): + """Move an axis of the data array and associated attribute arrays. Parameters ---------- data : array Array to transpose - bg_value : array or None - Associated bg_value array. + attributes : dict or None + Dict of associated arrays. old_index: int Current index of axis to move. new_index: int @@ -133,420 +525,3619 @@ def move_axis(self, data, bg_value, old_index, new_index): ------- data : array Transposed input array - bg_value: array - Transposed associated bg_value + attributes: dict + Transposed associated arrays """ raise NotImplementedError() - def _map_global_to_filtered(self, data, filtered_data, filter, key): - """ - map global (unfiltered) ND key to local (filtered) 2D key + def can_filter_axis(self, axis_idx) -> bool: + return False + + def get_filter_names(self): + """return [combo_label, ...]""" + return [] + + # TODO: change to get_filter_options(filter_idx, start, stop) + # ... in the end, filters will move to axes names + # AND possibly h/vlabels and + # must update the current command + def get_filter_options(self, filter_idx) -> Optional[list]: + """return [combo_values]""" + return None + + def update_filter(self, filter_idx, indices): + """Update current filter for a given axis if labels selection from the array widget has changed Parameters ---------- - data : array - Input array. - filtered_data : array - Filtered data. - filter : dict - Current filter. - key: tuple - Labels associated with the modified element of the non-filtered array. - - Returns - ------- - tuple - Positional index (row, column) of the modified data cell. + filter_idx: int + Index of filter for which selection has changed. + indices: list of int + Indices of selected labels. """ raise NotImplementedError() - def _map_filtered_to_global(self, filtered_data, data, filter, key): + def get_current_filter_indices(self, filter_idx): + """Returns indices currently selected for a given filter. + + Must return None if that filter is not applied. + + Parameters + ---------- + filter_idx : int + Index of filter. + """ + return self.current_filter.get(filter_idx) + + def map_filtered_to_global(self, filtered_shape, filter, local2dkey): """ map local (filtered data) 2D key to global (unfiltered) ND key. Parameters ---------- - filtered_data : array - Filtered data. - data : array - Input array. + filtered_shape : tuple + Shape of filtered data. filter : dict - Current filter. - key: tuple + Current filter: {axis_idx: index_or_indices} + local2dkey: tuple Positional index (row, column) of the modified data cell. Returns ------- tuple - Labels associated with the modified element of the non-filtered array. + ND indices associated with the modified element of the non-filtered array. """ raise NotImplementedError() - def _to_excel(self, data): - """Export data to an Excel Sheet + def translate_changes(self, data_model_changes): + to_global = self.map_filtered_to_global + # FIXME: filtered_data is specific to LArray. Either make it part of the API, or do not pass it as argument + # and get it in the implementation of map_filtered_to_global + global_changes = [ + CellValueChange(to_global(self.filtered_data.shape, self.current_filter, key), + old_value, new_value) + for key, (old_value, new_value) in data_model_changes.items() + ] + return global_changes + + def get_sample(self): + """Return a sample of the internal data""" + # TODO: use default_buffer sizes instead, or, better yet, a new get_preferred_buffer_size() method + height, width = self.shape2d() + # TODO: use this instead (but it currently does not work because get_values does not always + # return a numpy array while the current code does + # return self.get_values(0, 0, min(width, 20), min(height, 20)) + return self.get_data_values_and_attributes(0, 0, min(width, 20), min(height, 20))['values'] + + @timed(logger) + def update_finite_min_max_values(self, finite_values: np.ndarray, + h_start: int, v_start: int, + h_stop: int, v_stop: int): + """can return either two floats or two arrays""" + + # we need initial to support empty arrays + vmin = np.nanmin(finite_values, initial=np.nan) + vmax = np.nanmax(finite_values, initial=np.nan) + + self.vmin = ( + np.nanmin([self.vmin, vmin]) if self.vmin is not None else vmin) + self.vmax = ( + np.nanmax([self.vmax, vmax]) if self.vmax is not None else vmax) + return self.vmin, self.vmax + + def get_axes_area(self): + # axes = self.filtered_data.axes + # test axes.size == 0 is required in case an instance built as Array([]) is passed + # test len(axes) == 0 is required when a user filters until getting a scalar (because in that case size is 1) + # TODO: store this in the adapter + # if axes.size == 0 or len(axes) == 0: + # return [[]] + # else: + shape = self.shape2d() + row_idx_names = [name if name is not None else '' + for name in self.get_vnames()] + num_v_axes = len(row_idx_names) + col_idx_names = [name if name is not None else '' + for name in self.get_hnames()] + num_h_axes = len(col_idx_names) + if (not len(row_idx_names) and not len(col_idx_names)) or any(d == 0 for d in shape): + return [[]] + names = np.full((max(num_h_axes, 1), max(num_v_axes, 1)), '', dtype=object) + if len(row_idx_names) > 1: + names[-1, :-1] = row_idx_names[:-1] + if len(col_idx_names) > 1: + names[:-1, -1] = col_idx_names[:-1] + part1 = row_idx_names[-1] if row_idx_names else '' + part2 = col_idx_names[-1] if col_idx_names else '' + sep = '\\' if part1 and part2 else '' + names[-1, -1] = f'{part1}{sep}{part2}' + + current_sort = self.get_current_sort() + sorted_axes = {axis_idx: ascending for axis_idx, label_idx, ascending in current_sort + if label_idx == -1} + decoration = np.full_like(names, '', dtype=object) + ascending_to_decoration = { + True: 'arrow_up', + False: 'arrow_down', + } + decoration[-1, :-1] = [ascending_to_decoration[sorted_axes[i]] if i in sorted_axes else '' + for i in range(len(row_idx_names) - 1)] + return {'values': names.tolist(), 'decoration': decoration.tolist()} + + def get_current_sort(self) -> list[tuple]: + """Return current sort + + Must be a list of tuples of the form + (axis_idx, label_idx, ascending) where + * axis_idx: is the index of the axis (of the label) + being sorted + * label_idx: is the index of the label being sorted, + or -1 if the sort is by the axis labels themselves + * ascending: bool + + Note that unsorted axes are not mentioned. + """ + return self._current_sort + + # Adapter classes *may* implement this if can_sort_axis returns True for any axis + # (otherwise they will rely on this default implementation) + def axis_sort_direction(self, axis_idx): + """must return 'ascending', 'descending' or 'unsorted'""" + for cur_sort_axis_idx, label_idx, ascending in self._current_sort: + if cur_sort_axis_idx == axis_idx: + return 'ascending' if ascending else 'descending' + return 'unsorted' + + def hlabel_sort_direction(self, row_idx, col_idx): + """must return 'ascending', 'descending' or 'unsorted'""" + cell_axis_idx = self.hlabel_row_to_axis_num(row_idx) + for axis_idx, label_idx, ascending in self._current_sort: + if axis_idx == cell_axis_idx and label_idx == col_idx: + return 'ascending' if ascending else 'descending' + return 'unsorted' + + def can_filter_hlabel(self, row_idx, col_idx) -> bool: + return False + + def can_sort_axis_labels(self, axis_idx) -> bool: + return False + + def sort_axis_labels(self, axis_idx, ascending): + pass + + # TODO: unsure a different result per label is useful. Per axis would probably be enough + def can_sort_hlabel(self, row_idx, col_idx) -> bool: + return False + + def sort_hlabel(self, row_idx, col_idx, ascending): + pass + + @timed(logger) + def get_vlabels(self, start, stop) -> dict: + chunk_values = self.get_vlabels_values(start, stop) + if isinstance(chunk_values, list) and len(chunk_values) == 0: + chunk_values = [[]] + return {'values': chunk_values} + + def get_vlabels_values(self, start, stop): + # Note that using some kind of lazy object here is pointless given that + # we will use most of it (the buffer should not be much larger than the + # viewport). It would make sense to define one big lazy object as + # self._vlabels = Product([range(len(data))]) and use return + # self._vlabels[start:stop] here but I am unsure it is worth it because + # that would be slower than what we have now. + return [[i] for i in range(start, stop)] + + @timed(logger) + def get_hlabels(self, start, stop): + values = self.get_hlabels_values(start, stop) + return {'values': values, 'decoration': self.get_hlabels_decorations(start, stop, values)} + + def get_hlabels_values(self, start, stop): + return [list(range(start, stop))] + + def hlabel_row_to_axis_num(self, row_idx): + return row_idx + self.num_v_axes() + + def num_v_axes(self): + return 1 + + def get_hlabels_decorations(self, start, stop, labels): + current_sort = self.get_current_sort() + sorted_labels_by_axis = {} + for axis_idx, label_idx, ascending in current_sort: + sorted_labels_by_axis.setdefault(axis_idx, {})[label_idx] = ascending + ascending_to_decoration = { + True: 'arrow_up', + False: 'arrow_down', + } + decorations = [] + for row_idx in range(len(labels)): + row_axis_idx = self.hlabel_row_to_axis_num(row_idx) + axis_sorted_labels = sorted_labels_by_axis.get(row_axis_idx, {}) + decoration_row = [ + ascending_to_decoration[axis_sorted_labels[col_idx]] if col_idx in axis_sorted_labels else '' + for col_idx in range(start, stop) + ] + decorations.append(decoration_row) + return decorations + + def get_vnames(self): + return [''] + + def get_hnames(self): + return [''] + + def get_vname(self): + return ' '.join(str(name) for name in self.get_vnames()) + + def get_hname(self): + return ' '.join(str(name) for name in self.get_hnames()) + + def combine_labels_and_data(self, raw_data, axes_names, vlabels, hlabels): + """Return list Parameters ---------- - data : array - data to export. + raw_data : sequence of sequence of built-in scalar types + Array of selected data. Supports numpy arrays, tuple, list etc. + axes_names : list of string + List of axis names + vlabels : nested list + Selected vertical labels + hlabels: nested list + Selected horizontal labels + + Returns + ------- + list of list of built-in Python scalars (None, bool, int, float, str) """ - raise NotImplementedError() + # we use itertools.chain so that we can combine any iterables, not just lists + chain = itertools.chain + topheaders = [list(chain(axis_row, hlabels_row)) + for axis_row, hlabels_row in zip(axes_names, hlabels)] + datarows = [list(chain(row_labels, row_data)) + for row_labels, row_data in zip(vlabels, raw_data)] + return topheaders + datarows - def _plot(self, data): - """Return a matplotlib.Figure object using input data. + # FIXME (unsure this is still the case): this function does not support None axes_names, vlabels and hlabels + # which _selection_data() produces in some cases (notably when working + # on a scalar array). Unsure if we should fix _selection_data or this + # method though. + def get_combined_values(self, row_min, row_max, col_min, col_max): + """Return ... Parameters ---------- - data : array - Data to plot. Returns ------- - A matplotlib.Figure object. + list of list of built-in Python scalars (None, bool, int, float, str) """ - raise NotImplementedError + axes_names, vlabels, hlabels, raw_data = ( + self.get_axes_labels_and_data_values(row_min, row_max, + col_min, col_max) + ) + return self.combine_labels_and_data(raw_data, axes_names, + vlabels, hlabels) + + def to_string(self, row_min, row_max, col_min, col_max, sep='\t'): + """Copy selection as tab-separated (clipboard) text - # =========================== # - # OTHER METHODS # - # =========================== # + Returns + ------- + str + """ + data = self.get_combined_values(row_min, row_max, col_min, col_max) - def get_axes_filtered_data(self): - return self.get_axes(self.filtered_data) + # np.savetxt make things more complicated, especially on py3 + # We do not use repr for everything to avoid having extra quotes for strings. + # XXX: but is it really a problem? Wouldn't it allow us to copy-paste values with sep (tabs) in them? + # I need to test what Excel does for strings + def vrepr(v): + if isinstance(v, float): + return repr(v) + else: + return str(v) - def get_finite_sample(self): - """Return a sample of the internal data""" - data = self._get_raw_data(self.filtered_data) - # this will yield a data sample of max 200 - sample = get_sample(data, 200) - if np.issubdtype(sample.dtype, np.number): - return sample[np.isfinite(sample)] - else: - return sample - - def get_axes_names(self, fold_last_axis=False): - axes_names = [axis.name for axis in self.get_axes_filtered_data()] - if fold_last_axis and len(axes_names) >= 2: - axes_names = axes_names[:-2] + [axes_names[-2] + '\\' + axes_names[-1]] - return axes_names - - def get_vlabels(self): - axes = self.get_axes(self.filtered_data) - if len(axes) == 0: - vlabels = [[]] - elif len(axes) == 1: - vlabels = [['']] - else: - vlabels = [axis.labels for axis in axes[:-1]] - prod = Product(vlabels) - vlabels = [_LazyDimLabels(prod, i) for i in range(len(vlabels))] - return vlabels - - def get_hlabels(self): - axes = self.get_axes(self.filtered_data) - if len(axes) == 0: - hlabels = [[]] - else: - hlabels = axes[-1].labels - hlabels = Product([hlabels]) - return hlabels - - def _get_shape_2D(self, np_data): - shape, ndim = np_data.shape, np_data.ndim - if ndim == 0: - shape_2D = (1, 1) - elif ndim == 1: - shape_2D = (1,) + shape - elif ndim == 2: - shape_2D = shape - else: - shape_2D = (np.prod(shape[:-1]), shape[-1]) - return shape_2D - - def get_raw_data(self): - # get filtered data as Numpy ND array - np_data = self._get_raw_data(self.filtered_data) - assert isinstance(np_data, np.ndarray) - # compute equivalent 2D shape - shape_2D = self._get_shape_2D(np_data) - assert shape_2D[0] * shape_2D[1] == np_data.size - # return data reshaped as 2D array - return np_data.reshape(shape_2D) - - def get_bg_value(self): - # get filtered bg value as Numpy ND array or None - if self.bg_value is None: - return self.bg_value - np_bg_value = self._get_bg_value(self.filter_data(self.bg_value, self.current_filter)) - # compute equivalent 2D shape - shape_2D = self._get_shape_2D(np_bg_value) - assert shape_2D[0] * shape_2D[1] == np_bg_value.size - # return bg_value reshaped as 2D array if not None - return np_bg_value.reshape(shape_2D) - - def update_filtered_data(self): - self.filtered_data = self.filter_data(self.data, self.current_filter) - - def change_filter(self, data, filter, axis, indices): - """Update current filter for a given axis if labels selection from the array widget has changed + return '\n'.join(sep.join(vrepr(v) for v in line) for line in data) - Parameters - ---------- - data : array - Input array. - filter: dict - Dictionary {axis_id: labels} representing the current selection. - axis: Axis - Axis for which selection has changed. - indices: list of int - Indices of selected labels. + def to_excel(self, row_min, row_max, col_min, col_max): + """Export data to an Excel Sheet """ - axis_id = axis.id - if not indices or len(indices) == len(axis): - if axis_id in filter: - del filter[axis_id] - else: - if len(indices) == 1: - filter[axis_id] = axis.labels[indices[0]] - else: - filter[axis_id] = axis.labels[indices] + import xlwings as xw - def update_filter(self, axis, indices): - self.change_filter(self.data, self.current_filter, axis, indices) - self.update_filtered_data() + data = self.get_combined_values(row_min, row_max, col_min, col_max) + # convert (row) generators to lists then array + # TODO: the conversion to array is currently necessary even though xlwings will translate it back to a list + # anyway. The problem is that our lists contains numpy types and especially np.str_ crashes xlwings. + # unsure how we should fix this properly: in xlwings, or change get_combined_data() to return only + # standard Python types. + array = np.array([list(r) for r in data]) - def translate_changes(self, data_model_changes): - def to_global(key): - return self._map_filtered_to_global(self.filtered_data, self.data, self.current_filter, key) + # Create a new Excel instance. We cannot simply use xw.view(array) + # because it reuses the active Excel instance if any, and if that one + # is hidden, the user will not see anything + app = xw.App(visible=True) - global_changes = [ArrayValueChange(to_global(key), old_value, new_value) - for key, (old_value, new_value) in data_model_changes.items()] - return global_changes + # Activate XLA(M) addins. By default, they are not activated when an + # Excel Workbook is opened via COM + xl_app = app.api + for i in range(1, xl_app.AddIns.Count + 1): + addin = xl_app.AddIns(i) + addin_path = addin.FullName + if addin.Installed and '.xll' not in addin_path.lower(): + xl_app.Workbooks.Open(addin_path) - def selection_to_chain(self, raw_data, axes_names, vlabels, hlabels): - """Return an itertools.chain object. + # Dump array to first sheet + book = app.books[0] + sheet = book.sheets[0] + with app.properties(screen_updating=False): + sheet["A1"].value = array + # Unsure whether we should do this or not + # sheet.tables.add(sheet["A1"].expand()) + sheet.autofit() - Parameters - ---------- - raw_data : Numpy.ndarray - Array of selected data. - axes_names : list of string - List of axis names - vlabels : nested list - Selected vertical labels - hlabels: list - Selected horizontal labels + # Move Excel Window at the front. Without steal_focus it does not seem + # to do anything + app.activate(steal_focus=True) + + def plot(self, row_min, row_max, col_min, col_max): + """Return a matplotlib.Figure object for selected subset. Returns ------- - itertools.chain + A matplotlib.Figure object. """ - # FIXME: this function does not support None axes_names, vlabels and hlabels - # which _selection_data() produces in some cases (notably when working - # on a scalar array). Unsure if we should fix _selection_data or this - # method though. - from itertools import chain - topheaders = [axes_names + hlabels] - if self.ndim == 1: - return chain(topheaders, [chain([''], row) for row in raw_data]) - else: - assert self.ndim > 1 - return chain(topheaders, - [chain([vlabels[j][r] for j in range(len(vlabels))], row) - for r, row in enumerate(raw_data)]) - - def to_excel(self, raw_data, axes_names, vlabels, hlabels): - try: - data = self._from_selection(raw_data, axes_names, vlabels, hlabels) - if data is None: - return - self._to_excel(data) - except NotImplementedError: - data = self.selection_to_chain(raw_data, axes_names, vlabels, hlabels) - if data is None: - return - # convert (row) generators to lists then array - # TODO: the conversion to array is currently necessary even though xlwings will translate it back to a list - # anyway. The problem is that our lists contains numpy types and especially np.str_ crashes xlwings. - # unsure how we should fix this properly: in xlwings, or change _selection_data to return only - # standard Python types. - array = np.array([list(r) for r in data]) - wb = la.open_excel() - wb[0]['A1'] = array - - def plot(self, raw_data, axes_names, vlabels, hlabels): from matplotlib.figure import Figure - try: - data = self._from_selection(raw_data, axes_names, vlabels, hlabels) - if data is None: - return - return self._plot(data) - except NotImplementedError: - if raw_data is None: - return - - axes_names = self.get_axes_names() - # if there is only one dimension, ylabels is empty - if not vlabels: - ylabels = [[]] - else: - # transpose ylabels - ylabels = [[str(vlabels[i][j]) for i in range(len(vlabels))] for j in range(len(vlabels[0]))] - assert raw_data.ndim == 2 + # we do not use the axes_names part because the position of axes names is up to the adapter + _, vlabels, hlabels, raw_data = self.get_axes_labels_and_data_values(row_min, row_max, col_min, col_max) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"AbstractAdapter.plot {vlabels=} {hlabels=}") + logger.debug(f"{raw_data=}") + if not isinstance(raw_data, np.ndarray): + # Without dtype=object, in the presence of a string, raw_data will + # be entirely converted to strings which is not what we want. + raw_data = np.asarray(raw_data, dtype=object) + raw_data = raw_data.reshape((raw_data.shape[0], -1)) + assert isinstance(raw_data, np.ndarray), f"got data of type {type(raw_data)}" + assert raw_data.ndim == 2, f"ndim is {raw_data.ndim}" + finite_values = get_finite_numeric_values(raw_data) + figure = Figure() - figure = Figure() + # create an axis + ax = figure.add_subplot() - # create an axis - ax = figure.add_subplot(111) + # we have a list of rows but we want a list of columns + xlabels = list(zip(*hlabels)) + ylabels = vlabels - if raw_data.shape[1] == 1: - # plot one column - xlabel = ','.join(axes_names[:-1]) - xticklabels = ['\n'.join(row) for row in ylabels] - xdata = np.arange(raw_data.shape[0]) - ax.plot(xdata, raw_data[:, 0]) - ax.set_ylabel(hlabels[0]) - else: - # plot each row as a line - xlabel = axes_names[-1] - xticklabels = [str(label) for label in hlabels] - xdata = np.arange(raw_data.shape[1]) - for row in range(len(raw_data)): - ax.plot(xdata, raw_data[row], label=' '.join(ylabels[row])) - - # set x axis + xlabel = self.get_hname() + ylabel = self.get_vname() + + height, width = finite_values.shape + if width == 1: + # plot one column + xlabels, ylabels = ylabels, xlabels + xlabel, ylabel = ylabel, xlabel + height, width = width, height + finite_values = finite_values.T + + # plot each row as a line + xticklabels = ['\n'.join(str(label) for label in label_col) + for label_col in xlabels] + xdata = np.arange(width) + for data_row, ylabels_row in zip(finite_values, ylabels): + row_label = ' '.join(str(label) for label in ylabels_row) + ax.plot(xdata, data_row, label=row_label) + + # set x axis + if xlabel: ax.set_xlabel(xlabel) - ax.set_xlim((xdata[0], xdata[-1])) - # we need to do that because matplotlib is smart enough to - # not show all ticks but a selection. However, that selection - # may include ticks outside the range of x axis - xticks = [t for t in ax.get_xticks().astype(int) if t <= len(xticklabels) - 1] - xticklabels = [xticklabels[t] for t in xticks] - ax.set_xticks(xticks) - ax.set_xticklabels(xticklabels) - - if raw_data.shape[1] != 1 and ylabels != [[]]: - # set legend - # box = ax.get_position() - # ax.set_position([box.x0, box.y0, box.width * 0.85, box.height]) - # ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) - ax.legend() - - return figure - - -@register_adapter(np.ndarray) -@register_adapter(la.Array) -class ArrayDataAdapter(AbstractAdapter): - def __init__(self, data, bg_value): - AbstractAdapter.__init__(self, data=data, bg_value=bg_value) - self.ndim = data.ndim - self.size = data.size - self.dtype = data.dtype + ax.set_xlim((0, width - 1)) + # we need to do that because matplotlib is smart enough to + # not show all ticks but a selection. However, that selection + # may include ticks outside the range of x axis + xticks = [t for t in ax.get_xticks().astype(int) if t < len(xticklabels)] + xticklabels = [xticklabels[t] for t in xticks] + ax.set_xticks(xticks) + ax.set_xticklabels(xticklabels) - def prepare_data(self, data): - return la.asarray(data) + # add legend + all_empty_labels = all(not label for yrow in ylabels for label in yrow) + if width != 1 and not all_empty_labels: + kwargs = {'title': ylabel} if ylabel else {} + ax.legend(**kwargs) - def prepare_bg_value(self, bg_value): - return la.asarray(bg_value) if bg_value is not None else None + return figure - def filter_data(self, data, filter): - if data is None: - return data - assert isinstance(data, la.Array) - if filter is None: - return data - else: - assert isinstance(filter, dict) - data = data[filter] - return la.asarray(data) if np.isscalar(data) else data - def get_axes(self, data): - assert isinstance(data, la.Array) - axes = data.axes - # test data.size == 0 is required in case an instance built as Array([]) is passed - # test len(axes) == 0 is required when a user filters until to get a scalar - if data.size == 0 or len(axes) == 0: - return [] - else: - return [Axis(axes.axis_id(axis), name, axis.labels) for axis, name in zip(axes, axes.display_names)] +class AbstractColumnarAdapter(AbstractAdapter): + """For adapters where color is per column""" - def _get_raw_data(self, data): - assert isinstance(data, la.Array) - return data.data - - def _get_bg_value(self, bg_value): - if bg_value is not None: - assert isinstance(bg_value, la.Array) - return bg_value.data - else: - return bg_value - - # TODO: update this method the day Array objects will also handle MultiIndex-like axes. - def _from_selection(self, raw_data, axes_names, vlabels, hlabels): - if '\\' in axes_names[-1]: - axes_names = axes_names[:-1] + axes_names[-1].split('\\') - if len(axes_names) == 2: - axes = [la.Axis(vlabels[0], axes_names[0])] - elif len(axes_names) > 2: - # combine the N-1 first axes - combined_axes_names = '_'.join(axes_names[:-1]) - combined_labels = ['_'.join([str(vlabels[i][j]) for i in range(len(vlabels))]) - for j in range(len(vlabels[0]))] - axes = [la.Axis(combined_labels, combined_axes_names)] - else: - # assuming selection represents a 1D array - axes = [] - raw_data = raw_data[0] - # last axis - axes += [la.Axis(hlabels, axes_names[-1])] - return la.Array(raw_data, axes) - - def move_axis(self, data, bg_value, old_index, new_index): - assert isinstance(data, la.Array) - new_axes = data.axes.copy() - new_axes.insert(new_index, new_axes.pop(new_axes[old_index])) - data = data.transpose(new_axes) - if bg_value is not None: - assert isinstance(bg_value, la.Array) - bg_value = bg_value.transpose(new_axes) - return data, bg_value - - def _map_filtered_to_global(self, filtered_data, data, filter, key): - # transform local positional index key to (axis_ids: label) dictionary key. - # Contains only displayed axes - row, col = key - labels = [filtered_data.axes[-1].labels[col]] - for axis in reversed(filtered_data.axes[:-1]): - row, position = divmod(row, len(axis)) - labels = [axis.labels[position]] + labels - axes_ids = list(filtered_data.axes.ids) - dkey = dict(zip(axes_ids, labels)) - # add the "scalar" parts of the filter to it (ie the parts of the - # filter which removed dimensions) - dkey.update({k: v for k, v in filter.items() if np.isscalar(v)}) - # re-transform it to tuple (to make it hashable/to store it in .changes) - return tuple(dkey[axis_id] for axis_id in data.axes.ids) - - def _map_global_to_filtered(self, data, filtered_data, filter, key): - assert isinstance(key, tuple) and len(key) == data.ndim - dkey = {axis_id: axis_key for axis_key, axis_id in zip(key, data.axes.ids)} - # transform global dictionary key to "local" (filtered) key by removing - # the parts of the key which are redundant with the filter - for axis_id, axis_filter in filter.items(): - axis_key = dkey[axis_id] - if np.isscalar(axis_filter) and axis_key == axis_filter: - del dkey[axis_id] - elif not np.isscalar(axis_filter) and axis_key in axis_filter: - pass + def __init__(self, data, attributes=None): + super().__init__(data, attributes) + self.vmin = {} + self.vmax = {} + + @timed(logger) + def update_finite_min_max_values(self, finite_values: np.ndarray, + h_start: int, v_start: int, + h_stop: int, v_stop: int): + + assert isinstance(self.vmin, dict) and isinstance(self.vmax, dict) + assert h_stop >= h_start + + # per column => axis=0 + local_vmin = np.nanmin(finite_values, axis=0, initial=np.nan) + local_vmax = np.nanmax(finite_values, axis=0, initial=np.nan) + num_cols = h_stop - h_start + assert local_vmin.shape == (num_cols,), \ + (f"unexpected shape: {local_vmin.shape} ({finite_values.shape=}) vs " + f"{(num_cols,)} ({h_start=} {h_stop=})") + # vmin or self.vmin can both be nan (if the whole section data + # is/was nan) + global_vmin = self.vmin + global_vmax = self.vmax + vmin_slice = np.empty(num_cols, dtype=np.float64) + vmax_slice = np.empty(num_cols, dtype=np.float64) + for global_col_idx in range(h_start, h_stop): + local_col_idx = global_col_idx - h_start + + col_min = np.nanmin([global_vmin.get(global_col_idx, np.nan), + local_vmin[local_col_idx]]) + # update the global vmin dict inplace + global_vmin[global_col_idx] = col_min + vmin_slice[local_col_idx] = col_min + + col_max = np.nanmax([global_vmax.get(global_col_idx, np.nan), + local_vmax[local_col_idx]]) + # update the global vmax dict inplace + global_vmax[global_col_idx] = col_max + vmax_slice[local_col_idx] = col_max + return vmin_slice, vmax_slice + + +class DirectoryPathAdapter(AbstractColumnarAdapter): + _COL_NAMES = ['Name', 'Type', 'Date Modified', 'Size'] + _SORT_FUNCS = { + # Technically, we should use p.stem, but p.name has a better behavior + # when there are several files with same stem but different suffixes + # (e.g. test.txt and test.csv) + 0: lambda p: (not p.is_dir(), p.name), + 1: lambda p: (not p.is_dir(), p.suffix.lower()), + 2: lambda p: (not p.is_dir(), p.stat().st_mtime), + 3: lambda p: (not p.is_dir(), p.stat().st_size), + } + + def __init__(self, data, attributes): + # taking absolute() allows going outside of the initial directory + # via double click. This is both good and bad. + data = data.absolute() + super().__init__(data=data, attributes=attributes) + + # sort by ascending name by default + self._current_sort = [(1, 0, True)] + self._update_sorted_path_objs() + + def shape2d(self): + return len(self._sorted_path_objs), len(self._COL_NAMES) + + def get_hlabels_values(self, start, stop): + return [self._COL_NAMES[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + parent_dir = self.data.parent + + def get_file_info(p: Path) -> tuple[str, str, str, str|int]: + + is_dir = p.is_dir() + if is_dir: + # do not strip suffixes for directories + file_name = p.name if p != parent_dir else '..' else: - # that key is invalid for/outside the current filter - return None - # transform (axis:label) dict key to positional ND key - try: - index_key = filtered_data._translated_key(dkey) - except ValueError: + file_name = p.stem + + file_type = '' if is_dir else p.suffix.lstrip('.') + try: + file_stat = p.stat() + try: + mt_time = datetime.fromtimestamp(file_stat.st_mtime) + file_mtime = mt_time.strftime('%d/%m/%Y %H:%M') + except Exception: + file_mtime = '' + file_size = file_stat.st_size if not is_dir else '' + except Exception: + file_mtime = '' + file_size = '' + return file_name, file_type, file_mtime, file_size + + return [get_file_info(p)[h_start:h_stop] + for p in self._sorted_path_objs[v_start:v_stop]] + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + assert row_idx == 0 + assert col_idx in {0, 1, 2, 3} + self._current_sort = [(1, col_idx, ascending)] + self._update_sorted_path_objs() + + def _update_sorted_path_objs(self): + path_objs = list(self.data.iterdir()) + + assert len(self._current_sort) == 1 + _, col_idx, ascending = self._current_sort[0] + assert col_idx in {0, 1, 2, 3} + key_func = self._SORT_FUNCS[col_idx] + path_objs.sort(key=key_func, reverse=not ascending) + + # add ".." if needed + parent_dir = self.data.parent + if parent_dir != self.data: + path_objs.insert(0, parent_dir) + self._sorted_path_objs = path_objs + + def cell_activated(self, row_idx, column_idx): + return self._sorted_path_objs[row_idx].absolute() + + +@adapter_for('pathlib.Path') +def get_path_suffix_adapter(fpath): + logger.debug(f"get_path_suffix_adapter('{fpath}')") + suffix = fpath.suffix.lower() + if suffix in PATH_SUFFIX_ADAPTERS: + path_adapter_cls, required_module = PATH_SUFFIX_ADAPTERS[suffix] + if required_module is not None: + if required_module not in sys.modules: + try: + importlib.import_module(required_module) + except ImportError: + return (f"Cannot handle {fpath.suffix} files because the " + f"'{required_module}' module is not available ") + # 2 options: + # - either there is a single adapter for that suffix + if (isinstance(path_adapter_cls, type) and + issubclass(path_adapter_cls, AbstractAdapter)): + return path_adapter_cls + # - different adapters handle that suffix and/or not all instances can + # be handled + else: + return path_adapter_cls(fpath) + elif fpath.is_dir(): + return DirectoryPathAdapter + else: + return f"Cannot display {fpath.suffix} files" + + +class SequenceAdapter(AbstractAdapter): + def __init__(self, data, attributes): + super().__init__(data, attributes) + self.sorted_data = data + self.sorted_indices = range(len(data)) + + def shape2d(self): + return len(self.data), 1 + + def get_vnames(self): + return ['index'] + + def get_hlabels_values(self, start, stop): + return [['value']] + + def get_vlabels_values(self, start, stop): + # Note that using some kind of lazy object here is pointless given that + # we will use most of it (the buffer should not be much larger than the + # viewport). It would make sense to define one big lazy object as + # self._vlabels = Product([range(len(data))]) and use return + # self._vlabels[start:stop] here but I am unsure it is worth it because + # that would be slower than what we have now. + return [[i] for i in self.sorted_indices[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + chunk_values = self.sorted_data[v_start:v_stop] + # use a numpy array to avoid confusing the model if some + # elements are sequence themselves + array_values = np.empty((len(chunk_values), 1), dtype=object) + array_values[:, 0] = chunk_values + return array_values + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + self._current_sort = [(self.num_v_axes() + row_idx, col_idx, ascending)] + self.sorted_indices = indirect_sort(self.data, ascending) + self.sorted_data = sorted(self.data, reverse=not ascending) + + +class NamedTupleAdapter(AbstractAdapter): + def shape2d(self): + return len(self.data), 1 + + def get_vnames(self): + return ['attribute'] + + def get_hlabels_values(self, start, stop): + return [['value']] + + def get_vlabels_values(self, start, stop): + return [[k] for k in self.data._fields[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [[v] for v in self.data[v_start:v_stop]] + + +@adapter_for(collections.abc.Sequence) +def get_sequence_adapter(data): + namedtuple_attrs = ['_asdict', '_field_defaults', '_fields', '_make', '_replace'] + # We do not want to display strings and bytes as sequences + if isinstance(data, (bytes, str)): + obj_type = type(data) + return f"Cannot display objects of type {obj_type.__name__}" + # Named tuples have no special parent class, so we cannot dispatch using the type + # of data and need to check the presence of NamedTuple specific attributes instead + elif all(hasattr(data, attr) for attr in namedtuple_attrs): + return NamedTupleAdapter + else: + return SequenceAdapter + + +@adapter_for(collections.abc.Mapping) +class MappingAdapter(AbstractAdapter): + def __init__(self, data, attributes): + super().__init__(data, attributes=attributes) + self.sorted_data = data + + def shape2d(self): + return len(self.data), 1 + + def get_vnames(self): + return ['key'] + + def get_hlabels_values(self, start, stop): + return [['value']] + + def get_vlabels_values(self, start, stop): + # using islice instead of caching list(data.keys()) and list(data.values()) in __init__ + # make things *much* faster to display the first elements of very large dicts at + # the expense of making the display of the last elements about twice as slow. + # It seems a desirable tradeoff, especially given the lower memory usage and + # the absence of stale cache problem. Performance-wise, we could cache keys() and + # values() here (instead of in __init__) if start or stop is above some threshold + # but I am unsure it is worth the added complexity. + return [[k] for k in itertools.islice(self.sorted_data.keys(), start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + values_chunk = itertools.islice(self.sorted_data.values(), v_start, v_stop) + return [[v] for v in values_chunk] + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + assert row_idx == 0 + assert col_idx == 0 + try: + # first try the values themselves as sort key... + self.sorted_data = dict(sorted(self.data.items(), + key=lambda items: items[1], + reverse=not ascending)) + except TypeError: + # ... but that will fail for unsortable types or mixed type mappings + # for those cases, we fall back to sorting by str. + # We should keep this (str vs repr) in sync with + # AbstractArrayModel._format_value, so that the sorting order + # matches the displayed values. + + # By using the following commented code, we could also sort all + # numbers after strings and that would allow sorting them more + # naturally (9 before 10), but might be unexpected for users. + # Unsure what is best. + # def get_key(items): + # value = items[1] + # if isinstance(value, (int, float)): + # return True, value + # else: + # return False, str(value) + self.sorted_data = dict(sorted(self.data.items(), + key=lambda items: str(items[1]), + reverse=not ascending)) + self._current_sort = [(1, 0, ascending)] + + +# @adapter_for(object) +# class ObjectAdapter(AbstractAdapter): +# def __init__(self, data, attributes): +# super().__init__(data=data, attributes=attributes) +# self._fields = [k for k in dir(data) if not k.startswith('_') and type(getattr(data, k)) not in +# {types.FunctionType, types.BuiltinFunctionType, types.BuiltinMethodType}] +# +# def shape2d(self): +# return len(self._fields), 1 +# +# def get_vnames(self): +# return ['key'] +# +# def get_hlabels(self, start, stop): +# return [['value']] +# +# def get_vlabels(self, start, stop): +# return [[f] for f in self._fields[start:stop]] +# +# def get_values(self, h_start, v_start, h_stop, v_stop): +# return [[getattr(self.data, k)] for k in self._fields[v_start:v_stop]] + + +@adapter_for(collections.abc.Collection) +class CollectionAdapter(AbstractAdapter): + def shape2d(self): + return len(self.data), 1 + + def get_hlabels_values(self, start, stop): + return [['value']] + + def get_vlabels_values(self, start, stop): + return [[''] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [[v] for v in itertools.islice(self.data, v_start, v_stop)] + + +# Specific adapter just to change the label +@adapter_for(collections.abc.KeysView) +class KeysViewAdapter(CollectionAdapter): + def get_hlabels_values(self, start, stop): + return [['key']] + + +@adapter_for(collections.abc.ItemsView) +class ItemsViewAdapter(CollectionAdapter): + def shape2d(self): + return len(self.data), 2 + + def get_hlabels_values(self, start, stop): + return [['key', 'value']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + # slicing self.data already returns "tuple" rows + return list(itertools.islice(self.data, v_start, v_stop)) + + +def get_finite_numeric_values(array: np.ndarray) -> np.ndarray: + """return a copy of array with non numeric, -inf or inf values set to nan""" + dtype = array.dtype + finite_value = array + # TODO: there are more complex dtypes than this. Is there a way to get them all in one shot? + if dtype in (np.complex64, np.complex128): + # for complex numbers, shading will be based on absolute value + # FIXME: this is fine for coloring purposes but not for determining + # format (or plotting?) + finite_value = np.abs(finite_value) + elif dtype.type is np.object_: + # change non numeric to nan + finite_value = np.where(is_number_value_vectorized(finite_value), + finite_value, + np.nan) + finite_value = finite_value.astype(np.float64) + elif np.issubdtype(dtype, np.bool_): + finite_value = finite_value.astype(np.int8) + elif not np.issubdtype(dtype, np.number): + # if the whole array is known to be non numeric, we do not need + # to compute anything + return np.full(array.shape, np.nan, dtype=np.float64) + + assert np.issubdtype(finite_value.dtype, np.number) + + # change inf and -inf to nan (setting them to 0 or to very large numbers is + # not an option because it would "dampen" normal values) + return np.where(np.isfinite(finite_value), finite_value, np.nan) + + +# only used in LArray adapter. it should use the same code path as the rest +# though +def get_color_value(array, global_vmin, global_vmax, axis=None): + assert isinstance(array, np.ndarray) + try: + finite_value = get_finite_numeric_values(array) + + vmin = np.nanmin(finite_value, axis=axis) + if global_vmin is not None: + # vmin or global_vmin can both be nan (if the whole section data is/was nan) + global_vmin = np.nanmin([global_vmin, vmin], axis=axis) + else: + global_vmin = vmin + vmax = np.nanmax(finite_value, axis=axis) + if global_vmax is not None: + # vmax or global_vmax can both be nan (if the whole section data is/was nan) + global_vmax = np.nanmax([global_vmax, vmax], axis=axis) + else: + global_vmax = vmax + color_value = scale_to_01range(finite_value, global_vmin, global_vmax) + except (ValueError, TypeError): + global_vmin = None + global_vmax = None + color_value = None + return color_value, global_vmin, global_vmax + + +class NumpyHomogeneousArrayAdapter(AbstractAdapter): + def shape2d(self): + return nd_shape_to_2d(self.data.shape, num_h_axes=1) + + def get_vnames(self): + return ['' for axis_len in self.data.shape[:-1]] + + def get_hlabels_values(self, start, stop): + if self.data.ndim > 0: + return [list(range(start, stop))] + else: + return [['']] + + def get_vlabels_values(self, start, stop): + if self.data.ndim > 0: + vlabels = Product([range(axis_len) for axis_len in self.data.shape[:-1]]) + return vlabels[start:stop] + else: + return [['']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + data2d = self.data.reshape(nd_shape_to_2d(self.data.shape)) + return data2d[v_start:v_stop, h_start:h_stop] + + +class NumpyStructuredArrayAdapter(AbstractColumnarAdapter): + def shape2d(self): + shape = self.data.shape + (len(self.data.dtype.names),) + return nd_shape_to_2d(shape, num_h_axes=1) + + def get_vnames(self): + return ['' for axis_len in self.data.shape] + + def get_hlabels_values(self, start, stop): + return [list(self.data.dtype.names[start:stop])] + + def get_vlabels_values(self, start, stop): + vlabels = Product([range(axis_len) for axis_len in self.data.shape]) + return vlabels[start:stop] + + def get_values(self, h_start, v_start, h_stop, v_stop): + # TODO: this works nicely but isn't any better for users because number of decimals + # is not auto-detected and cannot be changed. I think I could implement + # auto-detection *relatively* easily but at this point I don't know + # how to implement changing it. + # One option would be that the ndigits box would set the number of digits for + # all *numeric* columns (or even cells?) instead of trying to set it for all columns. + # Another option would be that the ndigits box would not be the + # number of digits for each column but rather the "bonus" number compared to + # the autodetected value. + # Yet another option would be to keep track of the number of digits per column + # (or cell) and change it only for currently selected cells. + # Selecting the entire column would then set it "globally" for the column. + data1d = self.data.reshape(-1) + # Each field of a "row" can be accessed via either its name (row['age']) or its position + # (row[1]) but rows *cannot* be sliced, hence going via tuple(row_data) + return [tuple(row_data)[h_start:h_stop] for row_data in data1d[v_start:v_stop]] + + +@adapter_for(np.ndarray) +def get_np_array_adapter(data): + if data.dtype.names is not None: + return NumpyStructuredArrayAdapter + else: + return NumpyHomogeneousArrayAdapter + + +class MemoryViewAdapter(NumpyHomogeneousArrayAdapter): + def __init__(self, data, attributes): + # no data copy is necessary for converting memoryview <-> numpy array + # and a memoryview >1D cannot be sliced (only indexed with a single + # element) so it is much easier *and* efficient to convert to a numpy + # array and display that + super().__init__(np.asarray(data), attributes) + + +@adapter_for(memoryview) +def get_memoryview_adapter(data): + if len(data.format) > 1: + # ... because they cannot be indexed + return "memoryview with 'structured' formats are not supported" + else: + return MemoryViewAdapter + + +@adapter_for(la.Array) +class LArrayArrayAdapter(AbstractAdapter): + num_axes_to_display_horizontally = 1 + + def __init__(self, data, attributes): + # self.num_axes_to_display_horizontally = min(data.ndim, 2) + data = la.asarray(data) + if attributes is not None: + attributes = {k: la.asarray(v) for k, v in attributes.items()} + super().__init__(data, attributes) + # TODO: should not be needed (this is only used in ArrayDelegate) + self.dtype = data.dtype + + self.filtered_data = self.data + self.filtered_attributes = self.attributes + self._number_format = "%.3f" + + def from_clipboard_data_to_model_data(self, list_data): + try: + # index of first cell which contains '\' + pos_last = next(i for i, v in enumerate(list_data[0]) if '\\' in v) + except StopIteration: + # if there isn't any, assume 1d array + pos_last = 0 + + if pos_last or '\\' in list_data[0][0]: + # ndim > 1 + # strip horizontal and vertical labels (to keep only the raw data) + # It is not a problem if the target labels (which we do not know + # here) do not match with the source labels that we are stripping + # because pasting values to the exact same cells we copied them + # from has little interest + list_data = [line[pos_last + 1:] for line in list_data[1:]] + elif len(list_data) == 2 and list_data[1][0] == '': + # ndim == 1, horizontal + # strip horizontal labels (first line) and empty cell (due to axis name) + list_data = [list_data[1][1:]] + else: + # assume raw data + pass + return list_data + + def shape2d(self): + return nd_shape_to_2d(self.filtered_data.shape, + num_h_axes=self.num_axes_to_display_horizontally) + + def can_filter_axis(self, axis_idx) -> bool: + return True + + def get_filter_names(self): + return self.data.axes.display_names + + def get_filter_options(self, filter_idx): + return self.data.axes[filter_idx].labels + + def _filter_data(self, data, full_indices_filter): + if data is None: + return data + assert isinstance(data, la.Array) + data = data.i[full_indices_filter] + return la.asarray(data) if np.isscalar(data) else data + + def update_filter(self, filter_idx, indices): + """Update current filter for a given axis if labels selection from the array widget has changed + + Parameters + ---------- + filter_idx : int + Index of filter (axis) for which selection has changed. + indices: list of int + Indices of selected labels. + """ + cur_filter = self.current_filter + axis = self.data.axes[filter_idx] + if not indices or len(indices) == len(axis): + if filter_idx in cur_filter: + del cur_filter[filter_idx] + else: + if len(indices) == 1: + cur_filter[filter_idx] = indices[0] + else: + cur_filter[filter_idx] = indices + + # cur_filter is a {axis_idx: axis_indices} dict + # full_indices_filter is a tuple + full_indices_filter = tuple( + cur_filter[axis_idx] if axis_idx in cur_filter else slice(None) + for axis_idx in range(len(self.data.axes)) + ) + self.filtered_data = self._filter_data(self.data, full_indices_filter) + if self.attributes is not None: + self.filtered_attributes = {k: self._filter_data(v, full_indices_filter) + for k, v in self.attributes.items()} + + def can_sort_hlabel(self, row_idx, col_idx): + return self.filtered_data.ndim == 2 + + def sort_hlabel(self, row_idx, col_idx, ascending): + self._current_sort = [(self.num_v_axes() + row_idx, col_idx, ascending)] + arr = self.filtered_data + assert arr.ndim == 2 + row_axis = arr.axes[0] + col_axis = arr.axes[-1] + key = col_axis.i[col_idx] + sort_indices = arr[key].indicesofsorted(ascending=ascending) + assert sort_indices.ndim == 1 + indexer = row_axis.i[sort_indices.data] + self.filtered_data = arr[indexer] + if self.attributes is not None: + self.filtered_attributes = {k: v[indexer] + for k, v in self.filtered_attributes.items()} + + def can_sort_axis_labels(self, axis_idx) -> bool: + return True + + def sort_axis_labels(self, axis_idx, ascending): + self._current_sort = [(axis_idx, -1, ascending)] + assert isinstance(self.filtered_data, la.Array) + self.filtered_data = self.filtered_data.sort_labels(axis_idx, ascending=ascending) + if self.attributes is not None: + self.filtered_attributes = {k: v.sort_labels(axis_idx, ascending=ascending) + for k, v in self.filtered_attributes.items()} + + def get_data_values_and_attributes(self, h_start, v_start, h_stop, v_stop): + # data + # ==== + chunk_values = self.get_values(h_start, v_start, h_stop, v_stop) + chunk_data = { + 'editable': [[True]], + 'data_format': [[self._number_format]], + 'values': chunk_values + } + + # user-defined attributes (e.g. user-provided bg_value) + # ======================= + if self.filtered_attributes is not None: + chunk_data.update({k: self.get_chunk_from_data(v, h_start, v_start, h_stop, v_stop) + for k, v in self.filtered_attributes.items()}) + # we are not doing this above like for editable and data_format for performance reasons + # when bg_value is a user-provided value (and we do not need a computed one) + if 'bg_value' not in chunk_data: + # "default" bg_value computed on the subset asked by the model + bg_value, self.vmin, self.vmax = get_color_value(chunk_values, self.vmin, self.vmax) + chunk_data['bg_value'] = bg_value + return chunk_data + + def get_values(self, h_start, v_start, h_stop, v_stop): + return self.get_chunk_from_data(self.filtered_data, + h_start, v_start, + h_stop, v_stop) + + def get_chunk_from_data(self, data, h_start, v_start, h_stop, v_stop): + # get filtered data as Numpy 2D array + assert isinstance(data, la.Array) + np_data = data.data + assert isinstance(np_data, np.ndarray) + shape2d = nd_shape_to_2d(np_data.shape, self.num_axes_to_display_horizontally) + raw_data = np_data.reshape(shape2d) + return raw_data[v_start:v_stop, h_start:h_stop] + + def get_vnames(self): + axes = self.filtered_data.axes + num_v_axes = max(len(axes) - self.num_axes_to_display_horizontally, 0) + return axes.display_names[:num_v_axes] + + def get_hnames(self): + axes = self.filtered_data.axes + num_v_axes = max(len(axes) - self.num_axes_to_display_horizontally, 0) + return axes.display_names[num_v_axes:] + + def get_vlabels_values(self, start, stop): + axes = self.filtered_data.axes + # test data.size == 0 is required in case an instance built as Array([]) is passed + # test len(axes) == 0 is required when a user filters until getting a scalar (because in that case size is 1) + # TODO: store this in the adapter + if axes.size == 0 or len(axes) == 0: + return [[]] + elif len(axes) <= self.num_axes_to_display_horizontally: + # all axes are horizontal => a single empty vlabel + return [['']] + else: + # we must not convert the *whole* axes to raw python objects here (e.g. using tolist) because this would be + # too slow for huge axes + v_axes = axes[:-self.num_axes_to_display_horizontally] \ + if self.num_axes_to_display_horizontally else axes + # CHECK: store self._vlabels in adapter? + vlabels = Product([axis.labels for axis in v_axes]) + return vlabels[start:stop] + + def get_hlabels_values(self, start, stop): + axes = self.filtered_data.axes + # test data.size == 0 is required in case an instance built as Array([]) is passed + # test len(axes) == 0 is required when a user filters until to get a scalar + # TODO: store this in the adapter + if axes.size == 0 or len(axes) == 0: + return [[]] + elif not self.num_axes_to_display_horizontally: + # all axes are vertical => a single empty hlabel + return [['']] + else: + haxes = axes[-self.num_axes_to_display_horizontally:] + hlabels = Product([axis.labels for axis in haxes]) + section_labels = hlabels[start:stop] + # we have a list of columns but we need a list of rows + return [[label_col[row_num] for label_col in section_labels] + for row_num in range(self.num_axes_to_display_horizontally)] + + def get_sample(self): + """Return a sample of the internal data""" + np_data = self.filtered_data.data + # this will yield a data sample of max 200 + return get_sample(np_data, 200) + + def move_axis(self, data, attributes, old_index, new_index): + assert isinstance(data, la.Array) + new_axes = data.axes.copy() + new_axes.insert(new_index, new_axes.pop(new_axes[old_index])) + data = data.transpose(new_axes) + if attributes is not None: + assert isinstance(attributes, dict) + attributes = {k: v.transpose(new_axes) for k, v in attributes.items()} + return data, attributes + + # TODO: move this to a DenseArrayAdapter superclass + def map_filtered_to_global(self, filtered_shape, filter, local2dkey): + """ + transform local (filtered) 2D (row_idx, col_idx) key to global (unfiltered) ND key + (axis0_pos, axis1_pos, ..., axisN_pos). This is positional only (no labels). + """ + row, col = local2dkey + + localndkey = list(np.unravel_index(row, filtered_shape[:-1])) + [col] + + # add the "scalar" parts of the filter to it (ie the parts of the filter which removed dimensions) + scalar_filter_keys = [axis_idx for axis_idx, axis_filter in filter.items() + if np.isscalar(axis_filter)] + for axis_idx in sorted(scalar_filter_keys): + localndkey.insert(axis_idx, filter[axis_idx]) + + # translate local to global for filtered dimensions which are still present (non scalar) + return tuple( + axis_pos if axis_idx not in filter or np.isscalar(filter[axis_idx]) else filter[axis_idx][axis_pos] + for axis_idx, axis_pos in enumerate(localndkey) + ) + + +# cannot let the default Sequence adapter be used because axis[slice] is an +# LGroup +@adapter_for(la.Axis) +class LArrayAxisAdapter(NumpyHomogeneousArrayAdapter): + def __init__(self, data, attributes): + super().__init__(data.labels, attributes) + + +@adapter_for('array.array') +class ArrayArrayAdapter(AbstractAdapter): + def shape2d(self): + return len(self.data), 1 + + def get_hlabels_values(self, start, stop): + return [['']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [[v] for v in self.data[v_start:v_stop]] + + +def excel_colname(col): + """col is a *zero* based column number + + >>> excel_colname(0) + 'A' + >>> excel_colname(25) + 'Z' + >>> excel_colname(26) + 'AA' + >>> excel_colname(51) + 'AZ' + """ + letters = [] + value_a = ord("A") + while col >= 0: + letters.append(chr(value_a + col % 26)) + col = (col // 26) - 1 + return "".join(reversed(letters)) + + +@adapter_for('larray.inout.xw_excel.Workbook') +class WorkbookAdapter(SequenceAdapter): + def __init__(self, data, attributes): + super().__init__(data, attributes) + self._sheet_names = data.sheet_names() + + def get_hlabels_values(self, start, stop): + return [['sheet name']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [[sheet_name] + for sheet_name in self._sheet_names[v_start:v_stop]] + + def cell_activated(self, row_idx, column_idx): + return self.data[row_idx] + + +@path_adapter_for(('.xlsx', '.xls'), 'xlwings') +class XlsxPathAdapter(WorkbookAdapter): + @classmethod + def open(cls, fpath): + return la.open_excel(fpath) + + def close(self): + self.data.close() + + +none_to_empty_string = np.vectorize(lambda v: v if v is not None else '') + + +@adapter_for('larray.inout.xw_excel.Sheet') +class SheetAdapter(AbstractAdapter): + def shape2d(self): + return self.data.shape + + def get_hlabels_values(self, start, stop): + return [[excel_colname(i) for i in range(start, stop)]] + + def get_vlabels_values(self, start, stop): + # +1 because excel rows are 1 based + return [[i] for i in range(start + 1, stop + 1)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + range = self.data[v_start:v_stop, h_start:h_stop] + np_data = range.__array__() + # TODO: I wonder if I shouldn't change larray.Sheet.__array__ instead + # to make it always return 2D arrays (even for single column/row + # ranges) + if np_data.ndim < 2: + np_data = np_data.reshape((v_stop - v_start, h_stop - h_start)) + return none_to_empty_string(np_data) + + +@adapter_for('larray.inout.xw_excel.Range') +class RangeAdapter(AbstractAdapter): + def shape2d(self): + return nd_shape_to_2d(self.data.shape) + + def get_hlabels_values(self, start, stop): + # - 1 because data.column is 1-based (Excel) while excel_colname is 0-based + offset = self.data.column - 1 + return [[excel_colname(i) for i in range(offset + start, offset + stop)]] + + def get_vlabels_values(self, start, stop): + offset = self.data.row + return [[i] for i in range(offset + start, offset + stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + sub_range = self.data[v_start:v_stop, h_start:h_stop] + np_data = sub_range.__array__() + # TODO: I wonder if I shouldn't change larray.Sheet.__array__ instead + # to make it always return 2D arrays (even for single column/row + # ranges) + if np_data.ndim < 2: + np_data = np_data.reshape((v_stop - v_start, h_stop - h_start)) + return np_data + + +@adapter_for('pandas.DataFrame') +class PandasDataFrameAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + pd = sys.modules['pandas'] + assert isinstance(data, pd.DataFrame) + super().__init__(data, attributes=attributes) + self.sorted_data = data + self.filtered_data = data + self._unq_values_per_column = {} + + def shape2d(self): + return self.data.shape + + def num_v_axes(self): + return self.data.index.nlevels + + def get_hnames(self): + return self.data.columns.names + + def get_vnames(self): + return self.data.index.names + + def get_vlabels_values(self, start, stop): + pd = sys.modules['pandas'] + index = self.sorted_data.index[start:stop] + if isinstance(index, pd.MultiIndex): + return [list(row) for row in index.values] + else: + return index.values[:, np.newaxis] + + def get_hlabels_values(self, start, stop): + pd = sys.modules['pandas'] + index = self.sorted_data.columns[start:stop] + if isinstance(index, pd.MultiIndex): + return [index.get_level_values(i).values + for i in range(index.nlevels)] + else: + return [index.values] + + def get_values(self, h_start, v_start, h_stop, v_stop): + # Sadly, as of Pandas 2.2.3, the previous version of this code: + # df.iloc[v_start:v_stop, h_start:h_stop].values + # first copies all mentioned columns in their entirety, then take the + # subset of the rows (then converts to a numpy array) + # As a workaround, we first take each *single* column in its entirety + # which, in most case, is a view, then take the row slice + # (then recombine using numpy stack) + df = self.sorted_data + columns = [df.iloc[:, i].values for i in range(h_start, h_stop)] + chunks = [col[v_start:v_stop] for col in columns] + try: + return np.stack(chunks, axis=1) + except np.exceptions.DTypePromotionError: + return np.stack(chunks, axis=1, dtype=object) + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + self._current_sort = [(1, col_idx, ascending)] + self.sorted_data = self._sort_data(self.filtered_data) + + def _sort_data(self, data): + columns = data.columns + colnames = [] + ascendings = [] + for axis_idx, col_idx, ascending in self._current_sort: + assert axis_idx == 1 + colnames.append(columns[col_idx]) + ascendings.append(ascending) + return data.sort_values(colnames, ascending=ascendings) + + def can_filter_hlabel(self, row_idx, col_idx) -> bool: + return True + + def get_filter_options(self, filter_idx): + if filter_idx in self._unq_values_per_column: + return self._unq_values_per_column[filter_idx] + else: + pd = sys.modules['pandas'] + df = self.data + assert isinstance(df, pd.DataFrame) + + col_values = df.iloc[:, filter_idx] + unique_values = col_values.unique() + unique_values.sort() + unq_values = unique_values[:MAX_FILTER_OPTIONS] + self._unq_values_per_column[filter_idx] = unq_values + return unq_values + + def update_filter(self, filter_idx, indices): + if not indices: + indices = list(range(len(self.get_filter_options(filter_idx)))) + self.current_filter = cur_filter = {filter_idx: indices} + self.filtered_data = self._filter_data(self.data, cur_filter) + self.sorted_data = self._sort_data(self.filtered_data) + if self.attributes is not None: + # FIXME: need to sort attributes too + self.filtered_attributes = {k: self._filter_data(v, cur_filter) + for k, v in self.attributes.items()} + + def _filter_data(self, data, cur_filter): + df = data + columns = df.columns + for col_idx, filtered_indices in cur_filter.items(): + col_name = columns[col_idx] + col_unq_values = self._unq_values_per_column[col_idx] + filtered_values = col_unq_values[filtered_indices] + df = df[df[col_name].isin(filtered_values)] + return df + + +@adapter_for('pandas.Series') +class PandasSeriesAdapter(AbstractAdapter): + def __init__(self, data, attributes): + pd = sys.modules['pandas'] + assert isinstance(data, pd.Series) + super().__init__(data=data, attributes=attributes) + + def shape2d(self): + return len(self.data), 1 + + def get_vnames(self): + return self.data.index.names + + def get_vlabels_values(self, start, stop): + pd = sys.modules['pandas'] + index = self.data.index[start:stop] + if isinstance(index, pd.MultiIndex): + # returns a 1D array of tuples + return index.values + else: + return index.values[:, np.newaxis] + + def get_hlabels_values(self, start, stop): + return [['']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + assert h_start == 0 + return self.data.iloc[v_start:v_stop].values.reshape(-1, 1) + + +@adapter_for('pandas.core.groupby.generic.DataFrameGroupBy') +class PandasDataFrameGroupByAdapter(PandasDataFrameAdapter): + def __init__(self, data, attributes): + original_df = data.obj + gb_keys = data.keys + if not isinstance(gb_keys, list): + gb_keys = [gb_keys] + numeric_columns = original_df.select_dtypes(['number', 'bool']).columns + agg_df = data[numeric_columns.difference(gb_keys)].sum() + super().__init__(agg_df, attributes={}) + + def get_hlabels_values(self, start, stop): + label_rows = super().get_hlabels_values(start, stop) + last_row = [f'{label}\n(sum)' for label in label_rows[-1]] + return label_rows[:-1] + [last_row] + + +@adapter_for('pyarrow.Array') +class PyArrowArrayAdapter(AbstractAdapter): + def shape2d(self): + return len(self.data), 1 + + def get_hlabels_values(self, start, stop): + return [['value']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return self.data[v_start:v_stop].to_numpy(zero_copy_only=False).reshape(-1, 1) + + +# Contrary to other Path adapters, this one is both a File *and* Path adapter +# because it is more efficient to NOT keep the file open (because the pyarrow +# API only allows limiting which columns are read when opening the file) +@path_adapter_for(('.feather', '.ipc', '.arrow'), 'pyarrow.ipc') +@adapter_for('pyarrow.RecordBatchFileReader') +class FeatherFileAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + if isinstance(data, str): + data = Path(data) + # assert isinstance(data, (Path, pyarrow.RecordBatchFileReader)) + super().__init__(data=data, attributes=attributes) + + # TODO: take pandas metadata index columns into account: + # - display those columns as labels + # - remove those columns from shape + # - do not read those columns in get_values + with self._open_file() as f: + self._colnames = f.schema.names + self._num_columns = len(f.schema) + self._num_record_batches = f.num_record_batches + + self._batch_nrows = np.full(self._num_record_batches, -1, dtype=np.int64) + maxint = np.iinfo(np.int64).max + self._batch_ends = np.full(self._num_record_batches, maxint, dtype=np.int64) + # TODO: get this from somewhere else. We could use + # AbstractArrayModel.default_buffer_rows, but that would create + # a dependency on arraymodel which I would rather avoid. + # I guess I should move that to a shared constants module + DEFAULT_BUFFER_ROWS = 40 + self._num_batches_indexed = 0 + self._num_rows = None + self._index_rows_up_to(DEFAULT_BUFFER_ROWS) + + def _open_file(self, col_indices=None): + """col_indices is only taken into account if self.data is a Path""" + import pyarrow.ipc as ipc + if isinstance(self.data, Path): + if col_indices is not None: + options = ipc.IpcReadOptions(included_fields=col_indices) + else: + options = None + return ipc.open_file(self.data, options=options) + else: + assert isinstance(self.data, ipc.RecordBatchFileReader) + return self.data + + def _get_batches(self, start_batch, stop_batch, col_indices: list[int]) -> list: + """stop_batch is not included""" + logger.debug(f"FeatherFileAdapter._get_batches({start_batch}, " + f"{stop_batch}, {col_indices})") + batch_indices = range(start_batch, stop_batch) + if isinstance(self.data, Path): + with self._open_file(col_indices=col_indices) as f: + return [f.get_batch(i) for i in batch_indices] + else: + return [self.data.get_batch(i).select(col_indices) + for i in batch_indices] + + def shape2d(self): + nrows = self._num_rows if self._num_rows is not None else self._estimated_num_rows + return nrows, self._num_columns + + def get_hlabels_values(self, start, stop): + return [self._colnames[start:stop]] + + def _index_rows_up_to(self, num_rows): + if self._num_batches_indexed == 0: + last_indexed_batch_end = 0 + else: + last_indexed_batch_end = self._batch_ends[self._num_batches_indexed - 1] + + if num_rows <= last_indexed_batch_end: + return + + with self._open_file(col_indices=[0]) as f: + while (num_rows > last_indexed_batch_end and + self._num_batches_indexed < self._num_record_batches): + batch_num = self._num_batches_indexed + batch_rows = self._get_batch_nrows(f, batch_num) + last_indexed_batch_end += batch_rows + self._batch_ends[batch_num] = last_indexed_batch_end + self._num_batches_indexed += 1 + + if self._num_batches_indexed == self._num_record_batches: + # we are fully indexed + self._num_rows = last_indexed_batch_end + self._estimated_num_rows = None + else: + # we are not fully indexed + self._num_rows = None + last_batch_nrows = ( + self._get_batch_nrows(f, self._num_record_batches - 1)) + + # since we are not fully indexed, we are guaranteed to not + # count the last batch which usually has a different length + estimated_rows_per_batch = np.mean(self._batch_nrows[:self._num_batches_indexed]) + self._estimated_num_rows = int(estimated_rows_per_batch * + (self._num_record_batches - 1) + + last_batch_nrows) + + def _get_batch_nrows(self, f, batch_num): + if self._batch_nrows[batch_num] == -1: + batch_rows = f.get_batch(batch_num).num_rows + assert batch_rows >= 0 + self._batch_nrows[batch_num] = batch_rows + else: + batch_rows = self._batch_nrows[batch_num] + return batch_rows + + def get_values(self, h_start, v_start, h_stop, v_stop): + pyarrow = sys.modules['pyarrow'] + self._index_rows_up_to(v_stop) + # - 1 because the last row is not included + start_batch, stop_batch = np.searchsorted(self._batch_ends, + v=[v_start, v_stop - 1], + side='right') + # stop_batch is not included + stop_batch += 1 + chunk_start = self._batch_ends[start_batch - 1] if start_batch > 0 else 0 + col_indices = list(range(h_start, h_stop)) + batches = self._get_batches(start_batch, stop_batch, col_indices) + if len(batches) > 1: + combined = pyarrow.concat_batches(batches) + else: + combined = batches[0] + + chunk = combined[v_start - chunk_start:v_stop - chunk_start] + + # not going via to_pandas() because it "eats" index columns + columns = chunk.columns + np_columns = [c.to_numpy(zero_copy_only=False) for c in columns] + try: + return np.stack(np_columns, axis=1) + except np.exceptions.DTypePromotionError: + return np.stack(np_columns, axis=1, dtype=object) + + +@adapter_for('pyarrow.parquet.ParquetFile') +class PyArrowParquetFileAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + import json + super().__init__(data=data, attributes=attributes) + self._schema = data.schema + meta = data.metadata + self._num_cols = meta.num_columns + self._num_rows = meta.num_rows + meta_meta = data.metadata.metadata + col_names = self._schema.names + self._col_names = col_names + self._pandas_idx_cols = [] + if b'pandas' in meta_meta: + pd_meta = json.loads(meta_meta[b'pandas']) + + idx_col_names = pd_meta['index_columns'] + if all(isinstance(col_name, str) for col_name in idx_col_names): + idx_col_indices = [col_names.index(col_name) + for col_name in idx_col_names] + # We only support the case where index columns are at the end + # and are sorted. It is the case in all files I have seen so + # far but I don't know whether it is always the case + expected_first_idx_col = self._num_cols - len(idx_col_indices) + idx_cols_at_the_end = all(idx >= expected_first_idx_col + for idx in idx_col_indices) + idx_cols_sorted = sorted(idx_col_indices) == idx_col_indices + if idx_cols_at_the_end and idx_cols_sorted: + self._pandas_idx_cols = idx_col_indices + self._num_cols -= len(idx_col_indices) + + meta = data.metadata + num_rows_per_group = np.array([meta.row_group(i).num_rows + for i in range(data.num_row_groups)]) + self._group_ends = num_rows_per_group.cumsum() + assert self._group_ends[-1] == meta.num_rows + + def shape2d(self): + return self._num_rows, self._num_cols + + def get_hlabels_values(self, start, stop): + return [self._col_names[start:stop]] + + def get_vnames(self): + if self._pandas_idx_cols: + return [self._col_names[i] for i in self._pandas_idx_cols] + else: + return [''] + + def get_vlabels_values(self, start, stop): + if self._pandas_idx_cols: + # This assumes that index columns are contiguous (which is + # implicitly tested in __init__ via the "all index at the end" test) + # and sorted (tested in __init__) + h_start = self._pandas_idx_cols[0] + h_stop = self._pandas_idx_cols[-1] + 1 + return self.get_values(h_start, start, h_stop, stop) + else: + return [[i] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + # fragment is a pyarrow.Table + fragment, fragment_h_start, fragment_v_start = ( + self._get_fragment_via_cache(h_start, v_start, h_stop, v_stop)) + + # chunk is a list of pyarrow.ChunkedArray + chunk = self._fragment_to_chunk(fragment, + fragment_h_start, fragment_v_start, + h_start, v_start, h_stop, v_stop) + return self._chunk_to_numpy(chunk) + + def _get_fragment_from_source(self, h_start, v_start, h_stop, v_stop): + start_row_group, stop_row_group = ( + # - 1 because the last row is not included + np.searchsorted(self._group_ends, [v_start, v_stop - 1], + side='right')) + # - 1 because _group_ends stores row group ends and we want the start + table_h_start = h_start + table_v_start = ( + self._group_ends[start_row_group - 1] if start_row_group > 0 else 0) + row_groups = range(start_row_group, stop_row_group + 1) + column_names = self._schema.names[h_start:h_stop] + f = self.data + table = f.read_row_groups(row_groups, columns=column_names) + return table, table_h_start, table_v_start + + # fragment is a native object representing the smallest buffer which can + # hold the requested rows and columns (not sliced in memory) + # chunk is the actual requested slice from that fragment, still in + # whatever format is most convenient for the adapter + def _fragment_to_chunk(self, fragment, fragment_h_start, fragment_v_start, + h_start, v_start, h_stop, v_stop): + + chunk = fragment[v_start - fragment_v_start:v_stop - fragment_v_start] + h_start_in_chunk = h_start - fragment_h_start + h_stop_in_chunk = h_stop - fragment_h_start + # not going via to_pandas() because it "eats" index columns + return chunk.columns[h_start_in_chunk:h_stop_in_chunk] + + def _chunk_to_numpy(self, chunk): + # chunk is a list of pyarrow.ChunkedArray + np_columns = [c.to_numpy() for c in chunk] + try: + return np.stack(np_columns, axis=1) + except np.exceptions.DTypePromotionError: + return np.stack(np_columns, axis=1, dtype=object) + + +@adapter_for('pyarrow.Table') +class PyArrowTableAdapter(AbstractColumnarAdapter): + def shape2d(self): + # TODO: take pandas metadata index columns into account: + # - display those columns as labels + # - remove those columns from shape + # - do not read those columns in get_values + # self.data.schema.pandas_metadata + return self.data.shape + + def get_hlabels_values(self, start, stop): + return [self.data.column_names[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + chunk = self.data[v_start:v_stop].select(range(h_start, h_stop)) + # not going via to_pandas() because it "eats" index columns + np_columns = [c.to_numpy() for c in chunk.columns] + try: + return np.stack(np_columns, axis=1) + except np.exceptions.DTypePromotionError: + return np.stack(np_columns, axis=1, dtype=object) + + +@adapter_for('polars.DataFrame') +@adapter_for('narwhals.DataFrame') +class PolarsDataFrameAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + super().__init__(data, attributes=attributes) + self.filtered_data = data + self.sorted_data = data + self._unq_values_per_column = {} + + def shape2d(self): + return self.sorted_data.shape + + def get_hlabels_values(self, start, stop): + return [self.sorted_data.columns[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + # Going via Pandas instead of directly using to_numpy() because this + # has a better behavior for datetime columns (e.g. pl_df3). + # Otherwise, Polars converts datetimes to floats instead using a numpy + # object array + return self.sorted_data[v_start:v_stop, h_start:h_stop].to_pandas().values + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + self._current_sort = [(1, col_idx, ascending)] + self.sorted_data = self._sort_data(self.filtered_data) + + def _sort_data(self, data): + for axis_idx, col_idx, ascending in self._current_sort: + assert axis_idx == 1 + col_name = data.columns[col_idx] + data = data.sort(col_name, descending=not ascending) + return data + + def can_filter_hlabel(self, row_idx, col_idx) -> bool: + return True + + def get_filter_options(self, filter_idx): + if filter_idx in self._unq_values_per_column: + return self._unq_values_per_column[filter_idx] + else: + df = self.data + column = df.get_column(df.columns[filter_idx]) + unq_values = column.unique().sort().limit(MAX_FILTER_OPTIONS).to_numpy() + self._unq_values_per_column[filter_idx] = unq_values + return unq_values + + def update_filter(self, filter_idx, indices): + """Update current filter for a given axis if labels selection from the array widget has changed + + Parameters + ---------- + filter_idx : int + Index of filter (axis) for which selection has changed. + indices: list of int + Indices of selected labels. + """ + # only allow filtering a single columns for now (by not keeping previous + # filters) + if not indices: + indices = list(range(len(self.get_filter_options(filter_idx)))) + self.current_filter = cur_filter = {filter_idx: indices} + self.filtered_data = self._filter_data(self.data, cur_filter) + self.sorted_data = self._sort_data(self.filtered_data) + if self.attributes is not None: + # FIXME: need to sort attributes too + self.filtered_attributes = {k: self._filter_data(v, cur_filter) + for k, v in self.attributes.items()} + + def _filter_data(self, data, cur_filter): + import polars as pl + df = data + columns = df.columns + for col_idx, filtered_indices in cur_filter.items(): + col_name = columns[col_idx] + col_unq_values = self._unq_values_per_column[col_idx] + filtered_values = col_unq_values[filtered_indices] + df = df.filter(pl.col(col_name).is_in(filtered_values)) + return df + + +@adapter_for('polars.LazyFrame') +class PolarsLazyFrameAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + import polars as pl + assert isinstance(data, pl.LazyFrame) + + super().__init__(data=data, attributes=attributes) + self._schema = data.collect_schema() + self._columns = self._schema.names() + # TODO: this is often slower than computing the "first window" data + # so we could try to use a temporary value and + # fill the real height as we go like for CSV files + + self.filtered_data = data + len_query = self.filtered_data.select(pl.len()) + self._height = len_query.collect(engine='streaming').item() + self.sorted_data = data + self._unq_values_per_column = {} + + def shape2d(self): + return self._height, len(self._schema) + + def get_hlabels_values(self, start, stop): + return [self._columns[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + lf = self.sorted_data + row_subset = lf[v_start:v_stop] + subset = row_subset.select(self._columns[h_start:h_stop]) + df = subset.collect(engine='streaming') + # Going via Pandas instead of directly using to_numpy() because this + # has a better behavior for datetime columns (e.g. pl_df3). + # Otherwise, Polars converts datetimes to floats instead using a numpy + # object array + return df.to_pandas().values + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + self._current_sort = [(1, col_idx, ascending)] + self.sorted_data = self._sort_data(self.filtered_data) + + def _sort_data(self, data): + for axis_idx, col_idx, ascending in self._current_sort: + assert axis_idx == 1 + col_name = self._columns[col_idx] + data = data.sort(col_name, descending=not ascending) + return data + + def can_filter_hlabel(self, row_idx, col_idx) -> bool: + return True + + def get_filter_options(self, filter_idx): + if filter_idx in self._unq_values_per_column: + return self._unq_values_per_column[filter_idx] + else: + lf = self.data + colname = self._columns[filter_idx] + query = (lf.select(colname) + .unique(colname) + .sort(colname) + .limit(MAX_FILTER_OPTIONS)) + unq_values_df = query.collect(engine='streaming') + unq_values = unq_values_df[:, 0].to_numpy() + self._unq_values_per_column[filter_idx] = unq_values + return unq_values + + def update_filter(self, filter_idx, indices): + """Update current filter for a given axis if labels selection from the array widget has changed + + Parameters + ---------- + filter_idx : int + Index of filter (axis) for which selection has changed. + indices: list of int + Indices of selected labels. + """ + # cur_filter = self.current_filter + # cur_filter[filter_idx] = indices + # only allow filtering a single columns for now (by not keeping previous + # filters) + import polars as pl + if not indices: + indices = list(range(len(self.get_filter_options(filter_idx)))) + self.current_filter = cur_filter = {filter_idx: indices} + self.filtered_data = self._filter_data(self.data, cur_filter) + self.sorted_data = self._sort_data(self.filtered_data) + len_query = self.filtered_data.select(pl.len()) + self._height = len_query.collect(engine='streaming').item() + if self.attributes is not None: + # FIXME: need to sort attributes too + self.filtered_attributes = {k: self._filter_data(v, cur_filter) + for k, v in self.attributes.items()} + + def _filter_data(self, data, cur_filter): + import polars as pl + df = data + columns = df.columns + for col_idx, filtered_indices in cur_filter.items(): + col_name = columns[col_idx] + col_unq_values = self._unq_values_per_column[col_idx] + filtered_values = col_unq_values[filtered_indices] + df = df.filter(pl.col(col_name).is_in(filtered_values)) + return df + + +# we need an explicit adapter (instead of reusing PolarsLazyFrameAdapter as-is) +# because Narwhals lazy frames are not indexable, so we need the row_index +# trick +@adapter_for('narwhals.LazyFrame') +class NarwhalsLazyFrameAdapter(PolarsLazyFrameAdapter): + def __init__(self, data, attributes): + import narwhals as nw + assert isinstance(data, nw.LazyFrame) + + # do not use super().__init__ to avoid the isinstance check from + # the parent class + AbstractColumnarAdapter.__init__(self, data=data, attributes=attributes) + self._schema = data.collect_schema() + self._columns = self._schema.names() + # TODO: this is often slower than computing the "first window" data + # so we could try to use a temporary value and + # fill the real height as we go like for CSV files + # TODO: engine='streaming' is not part of the narwhals API (it + # is forwarded to the underlying engine) so it will work with + # Polars but probably not other engines) + self._height = data.select(nw.len()).collect(engine='streaming').item() + self.data = data + self.filtered_data = data + # this is almost a noop initially (no sort) but does add a row index + # column, which is necessary to slice the lazyframe (see below) + self.sorted_data = self._sort_data(self.filtered_data) + self._unq_values_per_column = {} + + def get_values(self, h_start, v_start, h_stop, v_stop): + # if not self._current_sort: + # FIXME: this breaks column width detection code + # return [['narwhals lazyframes must be sorted to display them'] + # + [''] * (h_stop - h_start - 1)] + + nw = sys.modules['narwhals'] + # narwhals LazyFrame does not support slicing, so we have to + # resort to this awful workaround which is MUCH slower than native + # polars slicing. I suspect it must evaluate the whole dataset. + filter_ = (nw.col('_index') >= v_start) & (nw.col('_index') < v_stop) + row_subset = self.sorted_data.filter(filter_) + # .select also implicitly drops _index + subset = row_subset.select(self._columns[h_start:h_stop]) + df = subset.collect(engine='streaming') + # Going via Pandas instead of directly using to_numpy() because this + # has a better behavior for datetime columns (e.g. pl_df3). + # Otherwise, Polars converts datetimes to floats instead using a numpy + # object array + return df.to_pandas().values + + def _sort_data(self, data): + col_names = [] + for axis_idx, col_idx, ascending in self._current_sort: + assert axis_idx == 1 + col_name = self._columns[col_idx] + col_names.append(col_name) + data = data.sort(col_name, descending=not ascending) + # We need to add a row index to be able to slice the lazyframe (see + # below). This needs to be done *after* filtering and sorting. + # FIXME: using order_by=None because narwhals API requires the order_by + # argument for lazyframes but its implementation (currently) + # allows None for polars-backed lazyframes. + # We use it instead of order_by=col_names, even though this + # probably breaks in non Polars backends because narwhals does + # not seem to support specifying descending order for the row + # index. + return data.with_row_index('_index', order_by=None) + + def get_filter_options(self, filter_idx): + if filter_idx in self._unq_values_per_column: + return self._unq_values_per_column[filter_idx] + else: + lf = self.data + colname = self._columns[filter_idx] + query = (lf.select(colname) + .unique(colname) + .sort(colname) + # narwhals does not support .limit() on LazyFrame + .head(MAX_FILTER_OPTIONS)) + unq_values_df = query.collect(engine='streaming') + unq_values = unq_values_df[:, 0].to_numpy() + self._unq_values_per_column[filter_idx] = unq_values + return unq_values + + # overridden just to use nw.len() instead of pl.len() + def update_filter(self, filter_idx, indices): + nw = sys.modules['narwhals'] + if not indices: + indices = list(range(len(self.get_filter_options(filter_idx)))) + self.current_filter = cur_filter = {filter_idx: indices} + self.filtered_data = self._filter_data(self.data, cur_filter) + self.sorted_data = self._sort_data(self.filtered_data) + len_query = self.filtered_data.select(nw.len()) + self._height = len_query.collect(engine='streaming').item() + if self.attributes is not None: + # FIXME: need to sort attributes too + self.filtered_attributes = {k: self._filter_data(v, cur_filter) + for k, v in self.attributes.items()} + + # overridden just to use nw.col() instead of pl.col() + def _filter_data(self, data, cur_filter): + nw = sys.modules['narwhals'] + df = data + columns = df.columns + for col_idx, filtered_indices in cur_filter.items(): + col_name = columns[col_idx] + col_unq_values = self._unq_values_per_column[col_idx] + filtered_values = col_unq_values[filtered_indices] + df = df.filter(nw.col(col_name).is_in(filtered_values)) + return df + + +@adapter_for('iode.Variables') +class IodeVariablesAdapter(AbstractAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + self._periods = data.periods + + def shape2d(self): + return len(self.data), len(self._periods) + + def get_hlabels_values(self, start, stop): + return [[str(p) for p in self._periods[start:stop]]] + + def get_vlabels_values(self, start, stop): + get_name = self.data.get_name + return [[get_name(i)] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + get_name = self.data.get_name + names = [get_name(i) for i in range(v_start, v_stop)] + first_period = self._periods[h_start] + # - 1 because h_stop itself is exlusive while iode stop period is inclusive + last_period = self._periods[h_stop - 1] + return self.data[names, first_period:last_period].to_numpy() + + +@adapter_for('iode.Table') +class IodeTableAdapter(AbstractAdapter): + def shape2d(self): + # TODO: ideally, width should be self.data.nb_columns, but we need + # to handle column spans in the model for that + # see: https://runebook.dev/en/articles/qt/qtableview/columnSpan + return len(self.data), 1 + + def get_hlabels_values(self, start, stop): + return [['']] + + def get_vlabels_values(self, start, stop): + return [[str(self.data[i].line_type)] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [[str(self.data[i])] for i in range(v_start, v_stop)] + + +class AbstractIodeSimpleListAdapter(AbstractAdapter): + def shape2d(self): + return len(self.data), 1 + + def get_hlabels_values(self, start, stop): + return [[self._COLUMN_NAME]] + + def get_vlabels_values(self, start, stop): + get_name = self.data.get_name + return [[get_name(i)] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + indices_getter = self.data.i + return [[str(indices_getter[i])] + for i in range(v_start, v_stop)] + + +class AbstractIodeObjectListAdapter(AbstractAdapter): + _ATTRIBUTES = [] + + def shape2d(self): + return len(self.data), len(self._ATTRIBUTES) + + def get_hlabels_values(self, start, stop): + return [[attr.capitalize() for attr in self._ATTRIBUTES[start:stop]]] + + def get_vlabels_values(self, start, stop): + get_name = self.data.get_name + return [[get_name(i)] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + attrs = self._ATTRIBUTES[h_start:h_stop] + indices_getter = self.data.i + objects = [indices_getter[i] for i in range(v_start, v_stop)] + return [[getattr(obj, attr) for attr in attrs] + for obj in objects] + + +@adapter_for('iode.Comments') +class IodeCommentsAdapter(AbstractIodeSimpleListAdapter): + _COLUMN_NAME = 'Comment' + + +@adapter_for('iode.Identities') +class IodeIdentitiesAdapter(AbstractIodeSimpleListAdapter): + _COLUMN_NAME = 'Identity' + + +@adapter_for('iode.Lists') +class IodeListsAdapter(AbstractIodeSimpleListAdapter): + _COLUMN_NAME = 'List' + + +@adapter_for('iode.Tables') +class IodeTablesAdapter(AbstractIodeObjectListAdapter): + _ATTRIBUTES = ['title', 'language'] + + def cell_activated(self, row_idx, column_idx): + return self.data.i[row_idx] + + +@adapter_for('iode.Scalars') +class IodeScalarsAdapter(AbstractIodeObjectListAdapter): + _ATTRIBUTES = ['value', 'std', 'relax'] + + +@adapter_for('iode.Equations') +class IodeEquationsAdapter(AbstractAdapter): + _COLNAMES = ['lec', 'method', 'sample', 'block', + 'fstat', 'r2adj', 'dw', 'loglik', + 'date'] + _SIMPLE_ATTRS = {'block', 'date', 'lec', 'method', 'sample'} + + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + + def shape2d(self): + return len(self.data), len(self._COLNAMES) + + def get_hlabels_values(self, start, stop): + return [self._COLNAMES[start:stop]] + + def get_vlabels_values(self, start, stop): + get_name = self.data.get_name + return [[get_name(i)] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + """*_stop are exclusive""" + colnames = self._COLNAMES[h_start:h_stop] + indices_getter = self.data.i + simple_attrs = self._SIMPLE_ATTRS + res = [] + for i in range(v_start, v_stop): + try: + eq = indices_getter[i] + tests = eq.tests + row = [str(getattr(eq, colname)).replace('\n', '') + if colname in simple_attrs else tests[colname] + for colname in colnames] + except Exception: + row = ['' + if colname in simple_attrs else np.nan + for colname in colnames] + res.append(row) + return np.array(res, dtype=object) + + +@path_adapter_for(('.av', '.var'), 'iode') +class IodeVariablesPathAdapter(IodeVariablesAdapter): + @classmethod + def open(cls, fpath): + iode = sys.modules['iode'] + iode.variables.load(str(fpath)) + return iode.variables + + +@path_adapter_for(('.as', 'scl'), 'iode') +class IodeScalarsPathAdapter(IodeScalarsAdapter): + @classmethod + def open(cls, fpath): + iode = sys.modules['iode'] + iode.scalars.load(str(fpath)) + return iode.scalars + + +@path_adapter_for(('.ac', '.cmt'), 'iode') +class IodeCommentsPathAdapter(IodeCommentsAdapter): + @classmethod + def open(cls, fpath): + iode = sys.modules['iode'] + iode.comments.load(str(fpath)) + return iode.comments + + +@path_adapter_for(('.at', '.tbl'), 'iode') +class IodeTablesPathAdapter(IodeTablesAdapter): + @classmethod + def open(cls, fpath): + iode = sys.modules['iode'] + iode.tables.load(str(fpath)) + return iode.tables + + +@path_adapter_for(('.ae', '.eqs'), 'iode') +class IodeEquationsPathAdapter(IodeEquationsAdapter): + @classmethod + def open(cls, fpath): + iode = sys.modules['iode'] + iode.equations.load(str(fpath)) + return iode.equations + + +@path_adapter_for(('.ai', '.idt'), 'iode') +class IodeIdentitiesPathAdapter(IodeIdentitiesAdapter): + @classmethod + def open(cls, fpath): + iode = sys.modules['iode'] + iode.identities.load(str(fpath)) + return iode.identities + + +@path_adapter_for(('.al', '.lst'), 'iode') +class IodeListsPathAdapter(IodeListsAdapter): + @classmethod + def open(cls, fpath): + iode = sys.modules['iode'] + iode.lists.load(str(fpath)) + return iode.lists + + +@adapter_for('ibis.Table') +class IbisTableAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + self._columns = data.columns + self._height = data.count().execute() + + def shape2d(self): + return self._height, len(self._columns) + + def get_hlabels_values(self, start, stop): + return [self._columns[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + lazy_sub_df = self.data[v_start:v_stop].select(self._columns[h_start:h_stop]) + return lazy_sub_df.to_pandas().values + + +# TODO: reuse NumpyStructuredArrayAdapter +@adapter_for('tables.File') +class PyTablesFileAdapter(AbstractColumnarAdapter): + _COLNAMES = ['Name'] + + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + + def shape2d(self): + return self.data.root._v_nchildren, 1 + + def get_hlabels_values(self, start, stop): + return [self._COLNAMES[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + subnodes = self.data.list_nodes('/') + return [[group._v_name][h_start:h_stop] + for group in subnodes[v_start:v_stop]] + + def cell_activated(self, row_idx, column_idx): + groups = self.data.list_nodes('/') + return groups[row_idx] + + +class PyTablesGroupAdapter(AbstractColumnarAdapter): + _COLNAMES = ['Name'] + + def shape2d(self): + return self.data._v_nchildren, 1 + + def get_hlabels_values(self, start, stop): + return [self._COLNAMES[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + subnodes = self.data._f_list_nodes() + return [[group._v_name][h_start:h_stop] + for group in subnodes[v_start:v_stop]] + + def cell_activated(self, row_idx, column_idx): + subnodes = self.data._f_list_nodes() + return subnodes[row_idx] + + +class PyTablesPandasFrameAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + attrs = data._v_attrs + assert hasattr(attrs, 'nblocks') + assert hasattr(attrs, 'axis0_variety') and attrs.axis0_variety in {'regular', 'multi'} + assert hasattr(attrs, 'axis1_variety') and attrs.axis1_variety in {'regular', 'multi'} + self._axis0_variety = attrs.axis0_variety + self._axis1_variety = attrs.axis1_variety + self._encoding = getattr(attrs, 'encoding', None) + nblocks = attrs.nblocks + self._block_values_nodes = [data._f_get_child(f'block{i}_values') + for i in range(nblocks)] + assert not (nblocks > 1 and attrs.axis0_variety == 'multi'), \ + ("loading mixed type DataFrames with a multi-index in columns " + "from HDF5 is not implemented yet") + + # data.block0_values.shape[0] is not always correct (if multiblocks) + if attrs.axis1_variety == 'multi': + self._num_rows = data.axis1_label0.shape[0] + else: + self._num_rows = data.axis1.shape[0] + + if nblocks > 1: + import tables + + axis_node = data._f_get_child('axis0') + col_names = axis_node.read().tolist() + # {col_idx: (block_idx, idx_in_block)} + column_source = {} + cached_string_blocks = {} + for block_idx in range(nblocks): + block_values_node = data._f_get_child(f'block{block_idx}_values') + block_items = data._f_get_child(f'block{block_idx}_items').read() + + if isinstance(block_values_node, tables.VLArray): + # This is very unfortunate but we cannot slice those blocks + # on disk because they are stored as a single blob + # We load the full block and kept it cached in + # memory so that we do not reload the whole block on each + # scroll + block_values = block_values_node.read()[0] + cached_string_blocks[block_idx] = block_values + for idx_in_block, col_name in enumerate(block_items): + col_idx = col_names.index(col_name) + column_source[col_idx] = (block_idx, idx_in_block) + self._cached_string_blocks = cached_string_blocks + self._column_source = column_source + self._num_columns = len(column_source) + else: + self._cached_string_blocks = None + self._column_source = None + self._num_columns = data.block0_values.shape[1] + + + def shape2d(self): + return self._num_rows, self._num_columns + + def _get_axis_names(self, axis_num: int) -> list[str]: + group = self.data + attrs = group._v_attrs + if getattr(attrs, f'axis{axis_num}_variety') == 'regular': + axis_node = group._f_get_child(f'axis{axis_num}') + return [axis_node._v_attrs.name] + else: + nlevels = getattr(attrs, f'axis{axis_num}_nlevels') + return [ + group._f_get_child(f'axis{axis_num}_level{i}')._v_attrs.name + for i in range(nlevels) + ] + + def get_hnames(self): + return self._get_axis_names(0) + + def get_vnames(self): + return self._get_axis_names(1) + + def _get_axis_labels(self, axis_num: int, start: int, stop: int) -> np.ndarray: + group = self.data + attrs = group._v_attrs + if getattr(attrs, f'axis{axis_num}_variety') == 'regular': + axis_node = group._f_get_child(f'axis{axis_num}') + labels = axis_node[start:stop].reshape(1, -1) + kind = axis_node._v_attrs.kind + if kind == 'string' and self._encoding is not None: + labels = np.char.decode(labels, encoding=self._encoding) + else: + chunks = [] + has_strings = False + has_non_strings = False + nlevels = getattr(attrs, f'axis{axis_num}_nlevels') + for i in range(nlevels): + label_node = group._f_get_child(f'axis{axis_num}_label{i}') + chunk_label_x = label_node[start:stop] + max_label_x = chunk_label_x.max() + level_node = group._f_get_child(f'axis{axis_num}_level{i}') + axis_level_x = level_node[:max_label_x + 1] + chunk_level_x = axis_level_x[chunk_label_x] + kind = level_node._v_attrs.kind + if kind == 'string': + has_strings = True + if self._encoding is not None: + chunk_level_x = np.char.decode(chunk_level_x, encoding=self._encoding) + else: + has_non_strings = True + chunks.append(chunk_level_x) + if has_strings and has_non_strings: + labels = np.stack(chunks, axis=0, dtype=object) + else: + labels = np.stack(chunks, axis=0) + return labels + + def get_hlabels_values(self, start, stop): + return self._get_axis_labels(0, start, stop) + + def get_vlabels_values(self, start, stop): + return self._get_axis_labels(1, start, stop).transpose() + + def get_values(self, h_start, v_start, h_stop, v_stop): + data = self.data + attrs = data._v_attrs + if attrs.nblocks == 1: + return data.block0_values[v_start:v_stop, h_start:h_stop] + else: + import tables + block_nodes = self._block_values_nodes + # TODO: for performance, we should probably read all columns from + # the same block at once + np_columns = [] + for col_idx in range(h_start, h_stop): + block_idx, idx_in_block = self._column_source[col_idx] + block_node = block_nodes[block_idx] + if isinstance(block_node, tables.VLArray): + block_node = self._cached_string_blocks[block_idx] + chunk = block_node[v_start:v_stop, idx_in_block] + np_columns.append(chunk) + + return np.stack(np_columns, axis=1, dtype=object) + + +@adapter_for('tables.Group') +def dispatch_pytables_group_to_adapter(data): + # distinguish between "normal" pytables Group and Pandas frames + attrs = data._v_attrs + if hasattr(attrs, 'pandas_type') and attrs.pandas_type == 'frame': + return PyTablesPandasFrameAdapter + else: + return PyTablesGroupAdapter + + +@adapter_for('tables.Table') +class PyTablesTableAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + + def shape2d(self): + return len(self.data), len(self.data.dtype.names) + + def get_hlabels_values(self, start, stop): + return [self._get_col_names(start, stop)] + + def _get_col_names(self, start, stop): + return list(self.data.dtype.names[start:stop]) + + def get_values(self, h_start, v_start, h_stop, v_stop): + # TODO: when we scroll horizontally, we fetch the data over + # and over while we could only fetch it once + # given that pytables fetches entire rows anyway. + # Several solutions: + # * cache "current" rows in the adapter + # * have a way for the arraymodel to ask the adapter for the minimum buffer size + # * allow the adapter to return more data than what the model asked for and have the model actually + # use/take that extra data into account. This would require the adapter to return + # real_h_start, real_v_start (stop values can be deduced) in addition to actual values + array = self.data[v_start:v_stop] + return [tuple(row_data)[h_start:h_stop] for row_data in array] + + +@adapter_for('tables.Array') +class PyTablesArrayAdapter(NumpyHomogeneousArrayAdapter): + def shape2d(self): + if self.data.ndim == 1: + return self.data.shape + (1,) + else: + return nd_shape_to_2d(self.data.shape) + + def get_vlabels_values(self, start, stop): + shape = self.data.shape + ndim = self.data.ndim + if ndim == 1: + shape += (1,) + if ndim > 0: + vlabels = Product([range(axis_len) for axis_len in shape[:-1]]) + return vlabels[start:stop] + else: + return [['']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + data = self.data + if data.ndim == 1: + return data[v_start:v_stop].reshape(-1, 1) + elif data.ndim == 2: + return data[v_start:v_stop, h_start:h_stop] + else: + raise NotImplementedError('>2d not implemented yet') + + +@path_adapter_for(('.h5', '.hdf'), 'tables') +class H5PathAdapter(PyTablesFileAdapter): + @classmethod + def open(cls, fpath): + tables = sys.modules['tables'] + return tables.open_file(fpath) + + def close(self): + self.data.close() + + +# TODO: options to display as hex or decimal +# >>> s = f.read(10) +# >>> s +# b'\x00\x00\xc2\xea\x81\xb3\x14\x11\xcf\xbd +@adapter_for('_io.BufferedReader') +class BinaryFileAdapter(AbstractAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + self._nbytes = os.path.getsize(data.name) + self._width = 16 + + def shape2d(self): + return math.ceil(self._nbytes / self._width), self._width + + def get_vlabels_values(self, start, stop): + start, stop, step = slice(start, stop).indices(self.shape2d()[0]) + return [[i * self._width] for i in range(start, stop)] + + def get_values(self, h_start, v_start, h_stop, v_stop): + f = self.data + width = self._width + + backup_pos = f.tell() + + # read data (ignoring horizontal bounds at this point) + start_pos = v_start * width + stop_pos = v_stop * width + f.seek(start_pos) + s = f.read(stop_pos - start_pos) + + # restore file position + f.seek(backup_pos) + + # load the string as an array of unsigned bytes + buffer1d = np.frombuffer(s, dtype='u1') + + # enlarge the array so that it is divisible by width (so that we can reshape it) + buffer_size = len(buffer1d) + size_remainder = buffer_size % width + if size_remainder != 0: + filler_size = width - size_remainder + rounded_size = buffer_size + filler_size + try: + # first try inplace resize + buffer1d.resize(rounded_size, refcheck=False) + except Exception: + buffer1d = np.append(buffer1d, np.zeros(filler_size, dtype='u1')) + + # change undisplayable characters to '.' + buffer1d = np.where((buffer1d < 32) | (buffer1d >= 128), + ord('.'), + buffer1d).view('S1') + + # reshape to 2d + buffer2d = buffer1d.reshape((-1, width)) + + # take what we were asked for + return buffer2d[:, h_start:h_stop] + + +def index_line_ends(s, index=None, offset=0, c='\n'): + r"""returns a list of line end positions + + It does NOT add an implicit line end at the end of the string. + + >>> index_line_ends("0\n234\n6\n8") + [1, 5, 7] + >>> chunks = ["0\n234\n6", "", "\n", "8"] + >>> pos = 0 + >>> idx = [] + >>> for chunk in chunks: + ... _ = index_line_ends(chunk, idx, pos) + ... pos += len(chunk) + >>> idx + [1, 5, 7] + """ + if index is None: + index = [] + if not len(s): + return index + line_start = 0 + find = s.find + append = index.append + while True: + line_end = find(c, line_start) + if line_end == -1: + break + append(line_end + offset) + line_start = line_end + len(c) + return index + + +def detect_encoding(chunk: bytes): + try: + import charset_normalizer + chartset_match = charset_normalizer.from_bytes(chunk).best() + if chartset_match is None: + return None + else: + return chartset_match.encoding + except ImportError: + logger.debug("could not import 'charset_normalizer' " + "=> using basic encoding detection") + for encoding in ('utf8', 'cp1252', 'ascii'): + try: + chunk.decode(encoding) + return encoding + except UnicodeDecodeError: + pass + # failed to detect an encoding + return None + + +@adapter_for('_io.TextIOWrapper') +class TextFileAdapter(AbstractAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + # TODO: we should check at regular interval, this hasn't changed + self._nbytes = os.path.getsize(self.data.name) + self._lines_end_index = [] + self._fully_indexed = False + self._encoding = None + self._lines_end_char = None + + # sniff a small chunk so that we can compute an approximate number of lines + # TODO: instead of opening and closing the file over and over, we + # should implement a mechanism to keep the file open while it is + # displayed and close it if another variable is selected. + # That might prevent the file from being deleted (by an external tool), + # which could be both annoying and practical. + with self._binary_file as f: + self._index_up_to(f, 1, chunk_size=64 * KB, max_time=0.05) + + @property + def _binary_file(self): + return open(self.data.name, 'rb') + + @property + def _avg_bytes_per_line(self): + lines_end_index = self._lines_end_index + if lines_end_index: + return lines_end_index[-1] / len(lines_end_index) + elif self._nbytes: + return self._nbytes + else: + return 1 + + @property + def _num_lines(self): + """returns estimated number of lines""" + if self._nbytes == 0: + return 0 + if self._fully_indexed: + return len(self._lines_end_index) + else: + return math.ceil(self._nbytes / self._avg_bytes_per_line) + + def shape2d(self): + return self._num_lines, 1 + + def _index_up_to(self, f, approx_v_stop, chunk_size=4 * MB, max_time=0.5): + # If the size of the index ever becomes a problem, we could store only + # one line on X but we are not there yet. + # We also need to limit line length (to something like 256Kb?). Beyond that it is + # probably not a line-based file. + if len(self._lines_end_index): + lines_to_index = max(approx_v_stop - len(self._lines_end_index), 0) + data_to_index = lines_to_index * self._avg_bytes_per_line + must_index = 0 < data_to_index < 512 * MB + else: + # we have not indexed anything yet + must_index = True + + if must_index: + logger.debug(f"trying to index up to {approx_v_stop}") + start_time = time.perf_counter() + chunk_start = self._lines_end_index[-1] if self._lines_end_index else 0 + f.seek(chunk_start) + # TODO: check for off by one error with v_stop + while (time.perf_counter() - start_time < max_time) and (len(self._lines_end_index) < approx_v_stop) and \ + not self._fully_indexed: + + # TODO: if we are beyond v_start, we should store the chunks to avoid reading them twice from disk + # (once for indexing then again for getting the data) + chunk = f.read(chunk_size) + if chunk_start == 0: + self._analyze_first_chunk(chunk) + self._analyze_chunk(chunk, chunk_start) + length_read = len(chunk) + # FIXME: this test is buggy. + # * if there was exactly chunk_size left to read, the file might never + # be marked as fully indexed + # * I think there are other (rare) reasons why a read can return + # less bytes than asked for + if length_read < chunk_size: + self._fully_indexed = True + # add implicit line end at the end of the file if there isn't an explicit one + file_length = chunk_start + length_read + file_last_char_pos = file_length - len(self._lines_end_char) + if not self._lines_end_index or self._lines_end_index[-1] != file_last_char_pos: + self._lines_end_index.append(file_length) + chunk_start += length_read + + def _analyze_first_chunk(self, chunk): + """first chunk-specific analyses""" + if self._encoding is None: + self._detect_encoding(chunk) + + if self._lines_end_char is None: + self._detect_lines_end_char(chunk) + + def _analyze_chunk(self, chunk, chunk_start): + """analyzes a chunk (including first chunk)""" + index_line_ends(chunk, self._lines_end_index, offset=chunk_start, + c=self._lines_end_char) + + def _detect_lines_end_char(self, chunk): + # Try to detect between: + # * CRLF - \r\n - 0D0A (Windows) + # * LF - \n - 0A (Linux/Unix) + # * CR - \r - 0D (Classic macOS before transitioning to LF) + for line_end_char in (b'\r\n', b'\n', b'\r'): + if line_end_char in chunk: + logger.debug(f'detected line endings as {line_end_char!r}') + break + else: + # if the loop did not break, we fallback to \n + line_end_char = b'\n' + logger.debug('failed to detect line endings, falling back to ' + f'{line_end_char!r}') + self._lines_end_char = line_end_char + + def _detect_encoding(self, chunk: bytes): + encoding = detect_encoding(chunk) + if encoding is None: + logger.debug("could not detect encoding from chunk, using ascii") + encoding = 'ascii' + else: + logger.debug(f"encoding detected as {encoding}") + self._encoding = encoding + + def get_vlabels_values(self, start, stop): + # we need to trigger indexing too (because get_vlabels happens before get_data) so that lines_indexed is correct + # FIXME: get_data should not trigger indexing too if start/stop are the same + with self._binary_file as f: + self._index_up_to(f, stop) + + start, stop, step = slice(start, stop).indices(self._num_lines) + lines_indexed = len(self._lines_end_index) + return [[str(i) if i < lines_indexed else '~' + str(i)] + for i in range(start, stop)] + + def _get_lines(self, start_line, stop_line): + """stop is exclusive""" + assert start_line >= 0 and stop_line >= 0 + with self._binary_file as f: + self._index_up_to(f, stop_line) + num_indexed_lines = len(self._lines_end_index) + if self._fully_indexed and stop_line > num_indexed_lines: + stop_line = num_indexed_lines + + # if we are entirely in indexed lines, we can use exact pos + if stop_line <= num_indexed_lines: + # position of start_line is one byte after the end of the line + # preceding it (if any) + if start_line >= 1: + start_pos = (self._lines_end_index[start_line - 1] + + len(self._lines_end_char)) + else: + start_pos = 0 + # stop_line should be excluded (=> -1) + stop_pos = self._lines_end_index[stop_line - 1] + f.seek(start_pos) + chunk = f.read(stop_pos - start_pos) + num_lines = stop_line - start_line + lines = self._decode_chunks_to_lines([chunk], num_lines) + # lines = chunk.split(b'\n') + # assert len(lines) == num_required_lines + return lines + else: + pos_last_end = self._lines_end_index[-1] + # start_line is indexed + if start_line - 1 < num_indexed_lines: + approx_start = False + start_pos = (self._lines_end_index[start_line - 1] + + len(self._lines_end_char) + if start_line >= 1 else 0) + else: + approx_start = True + # use approximate pos for start + non_indexed_lines_before_start = start_line - num_indexed_lines + estim_non_indexed_bytes_before_start = ( + int(non_indexed_lines_before_start * + self._avg_bytes_per_line)) + start_pos = pos_last_end + 1 + estim_non_indexed_bytes_before_start + # read one more line before expected start_pos to have more + # chance of getting the line entirely + start_pos = max(start_pos - int(self._avg_bytes_per_line), 0) + + num_lines = 0 + num_lines_required = stop_line - start_line + + f.seek(start_pos) + # use approximate pos for stop + chunks = [] + CHUNK_SIZE = 1 * MB + non_indexed_lines_before_stop = stop_line - num_indexed_lines + estim_non_indexed_bytes_before_stop = ( + math.ceil(non_indexed_lines_before_stop * + self._avg_bytes_per_line)) + stop_pos = pos_last_end + estim_non_indexed_bytes_before_stop + # read maximum 4Mb and do not read beyond file end + max_stop_pos = min(stop_pos + 4 * MB, self._nbytes) + # first chunk size is what we *think* is necessary to get + # num_lines_required + chunk_size = stop_pos - start_pos + # but then, if the number of lines we actually got (num_lines) + # is not enough we will ask for more + while num_lines < num_lines_required and stop_pos < max_stop_pos: + chunk = f.read(chunk_size) + chunks.append(chunk) + num_lines += chunk.count(self._lines_end_char) + stop_pos += len(chunk) + chunk_size = CHUNK_SIZE + + if approx_start: + # +1 and [1:] to remove first line so that we are sure the first line is complete + n_req_lines = num_lines_required + 1 + lines = self._decode_chunks_to_lines(chunks, n_req_lines)[1:] + else: + lines = self._decode_chunks_to_lines(chunks, num_lines_required) + return lines + + def _decode_chunk(self, chunk: bytes): + try: + return chunk.decode(self._encoding) + except UnicodeDecodeError: + old_encoding = self._encoding + # try to find another encoding + self._detect_encoding(chunk) + logger.debug(f"Could not decode chunk using {old_encoding}") + logger.debug(f"Trying again using {self._encoding} and ignoring " + f"errors") + return chunk.decode(self._encoding, errors='replace') + + def _decode_chunks_to_lines(self, chunks: list, num_required_lines: int): + r""" + Parameters + ---------- + chunks : list + List of bytes. + """ + if not chunks: + return [] + + lines = [] + chunk_idx = 0 + last_line = '' + while len(lines) < num_required_lines and chunk_idx < len(chunks): + chunk = chunks[chunk_idx] + decoded_chunk = self._decode_chunk(chunk) + line_ending = self._lines_end_char.decode(self._encoding) + chunk_lines = decoded_chunk.split(line_ending) + first_line = chunk_lines[0] + if last_line: + first_line = last_line + first_line + lines.append(first_line) + last_line = chunk_lines[-1] + lines.extend(chunk_lines[1:-1]) + chunk_idx += 1 + if len(lines) < num_required_lines: + lines.append(last_line) + return lines + + def get_values(self, h_start, v_start, h_stop, v_stop): + """*_stop are exclusive""" + return [[line] for line in self._get_lines(v_start, v_stop)] + + +TEXT_FILE_SUFFIXES = ( + '.bat', + '.c', + '.cfg', + '.cpp', + '.h', + '.htm', # web + '.html', # web + '.ini', + '.log', + '.md', + '.py', + '.pyx', # cython + '.pxd', # cython + '.rep', + '.rst', + '.sh', + '.sql', + '.toml', + '.txt', + '.wsgi', + '.yaml', + '.yml', +) + + +@path_adapter_for(TEXT_FILE_SUFFIXES) +class TextPathAdapter(TextFileAdapter): + @classmethod + def open(cls, fpath): + return open(fpath, 'rt') + + +class CsvFileAdapter(TextFileAdapter): + DEFAULT_DELIMITER = ',' + + def __init__(self, data, attributes): + self._dialect = None + TextFileAdapter.__init__(self, data=data, attributes=attributes) + if self._nbytes > 0: + first_line = self._get_lines(0, 1) + assert len(first_line) == 1, f"{len(first_line)}" + reader = self._get_reader([first_line[0]]) + self._colnames = next(reader) + else: + self._colnames = [] + + def _analyze_first_chunk(self, chunk): + import csv + # detects encoding and line endings + super()._analyze_first_chunk(chunk) + # try to detect dialect + str_chunk = chunk.decode(self._encoding) + sniffer = csv.Sniffer() + # make sure the default delimiter is tried first + sniffer.preferred.insert(0, self.DEFAULT_DELIMITER) + try: + dialect = sniffer.sniff(str_chunk) + logger.debug("CSV dialect detected: " + f"delimiter={dialect.delimiter!r}, " + f"quotechar={dialect.quotechar!r}") + except csv.Error as e: + dialect = None + logger.debug(f"Could not detect CSV dialect: {e}, " + f"using default delimiter: {self.DEFAULT_DELIMITER}") + self._dialect = dialect + + # for large files, this is approximate + def shape2d(self): + # - 1 for header row + return self._num_lines - 1, len(self._colnames) + + def get_hlabels_values(self, start, stop): + return [self._colnames[start:stop]] + + def get_vlabels_values(self, start, stop): + # + 1 for header row + return super().get_vlabels_values(start + 1, stop + 1) + + def get_values(self, h_start, v_start, h_stop, v_stop): + """*_stop are exclusive""" + # + 1 because the header row is not part of the data but _get_lines works + # on the actual file lines + lines = self._get_lines(v_start + 1, v_stop + 1) + if not lines: + return [] + reader = self._get_reader(lines) + return [line[h_start:h_stop] for line in reader] + + def _get_reader(self, lines): + import csv + # Note that csv reader actually needs a line-based input + if self._dialect is not None: + reader = csv.reader(lines, dialect=self._dialect) + else: + reader = csv.reader(lines, delimiter=self.DEFAULT_DELIMITER) + return reader + + +@path_adapter_for('.csv', 'csv') +class CsvPathAdapter(CsvFileAdapter): + @classmethod + def open(cls, fpath): + return open(fpath, 'rt') + + +class TsvFileAdapter(CsvFileAdapter): + DEFAULT_DELIMITER = '\t' + + +@path_adapter_for('.tsv', 'csv') +class TsvPathAdapter(TsvFileAdapter): + @classmethod + def open(cls, fpath): + return open(fpath, 'rt') + + + +class PolarsParquetPathAdapter(PolarsLazyFrameAdapter): + @classmethod + def open(cls, fpath): + import polars as pl + return pl.scan_parquet(fpath) + + +class PyArrowParquetPathAdapter(PyArrowParquetFileAdapter): + @classmethod + def open(cls, fpath): + import pyarrow.parquet as pq + return pq.ParquetFile(fpath) + + +@path_adapter_for('.parquet') +def dispatch_parquet_path_adapter(fpath): + # the polars adapter is first as it has more features + return dispatch_file_suffix_by_available_module('parquet',{ + 'polars': PolarsParquetPathAdapter, + 'pyarrow.parquet': PyArrowParquetPathAdapter + }) + + +# modules are tried in the order they are defined +def dispatch_file_suffix_by_available_module(suffix, module_dict: dict): + for module_name, adapter_cls in module_dict.items(): + # We need this special case because find_spec can only safely check the + # presence of top level modules. For submodules, it actually imports + # the parent module and only then checks for the submodule, which + # breaks if the parent module is not available. + if '.' in module_name: + top_module = module_name.split('.')[0] + if importlib.util.find_spec(top_module) is None: + continue + if importlib.util.find_spec(module_name) is not None: + return adapter_cls + module_names = ', '.join(module_dict.keys()) + return (f'Cannot handle {suffix} file because none of the required modules ' + f'are available. Please install at least one of: {module_names}.') + + +# This is a Path adapter (it handles Path objects) because pyreadstat has no +# object representing open files. It does provide an uniform interface across +# formats, hence the abstract base class +class AbstractPyReadStatPathAdapter(AbstractColumnarAdapter): + READ_FUNC_NAME = None + + # data must be a Path object + def __init__(self, data, attributes=None): + assert isinstance(data, Path) + # we know the module is loaded but it is not in the current namespace + pyreadstat = sys.modules['pyreadstat'] + super().__init__(data, attributes=attributes) + self._read_func = getattr(pyreadstat, self.READ_FUNC_NAME) + empty_df, meta = self._read_func(data, metadataonly=True) + self._colnames = meta.column_names + self._numrows = meta.number_rows + + def shape2d(self): + return self._numrows, len(self._colnames) + + def get_hlabels_values(self, start, stop): + return [self._colnames[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + used_cols = self._colnames[h_start:h_stop] + df, meta = self._read_func(self.data, row_offset=v_start, + row_limit=v_stop - v_start, + usecols=used_cols) + return df.values + + +class PyReadstatSas7BdatPathAdapter(AbstractPyReadStatPathAdapter): + READ_FUNC_NAME = 'read_sas7bdat' + + +class PyReadstatDtaPathAdapter(AbstractPyReadStatPathAdapter): + READ_FUNC_NAME = 'read_dta' + + +@adapter_for('pandas.io.sas.sas7bdat.SAS7BDATReader') +class PandasSAS7BDATReaderAdapter(AbstractColumnarAdapter): + MAX_BYTES_TO_SKIP = 10_000_000 + + def __init__(self, data, attributes=None): + super().__init__(data, attributes=attributes) + reader = data + index_cols = set(reader.index) if reader.index is not None else set() + default_chunksize = max(1_000_000 // reader.row_length, 1) + chunksize = reader.chunksize or default_chunksize + logger.debug(f'{chunksize=}') + self._chunk_size = chunksize + self._colnames = [col for col in reader.column_names + if col not in index_cols] + + def shape2d(self): + return self.data.row_count, len(self._colnames) + + def get_hlabels_values(self, start, stop): + return [self._colnames[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + reader = self.data + current_row = reader._current_row_in_file_index + if current_row > v_start: + logger.debug("must reset Pandas SAS7BDATReader") + # reset reader to the beginning of the file by closing it and + # re-initializing it + fpath = reader.handles.handle.name + kwargs = dict( + index=reader.index, + convert_dates=reader.convert_dates, + blank_missing=reader.blank_missing, + chunksize=reader.chunksize, + encoding=reader.encoding, + convert_text=reader.convert_text, + convert_header_text=reader.convert_header_text, + # cannot be easily retrieved + # compression=reader.compression, + ) + reader.close() + reader.__init__(fpath, **kwargs) + current_row = reader._current_row_in_file_index + expected_num_rows = v_stop - v_start + expected_num_cols = h_stop - h_start + + # skip to v_start + num_rows_to_skip = v_start - self._chunk_size - current_row + if num_rows_to_skip > 0: + bytes_to_skip = num_rows_to_skip * reader.row_length + logger.debug(f"must skip {num_rows_to_skip} rows " + f"(~{bytes_to_skip:_} bytes)") + if bytes_to_skip > self.MAX_BYTES_TO_SKIP: + # An exception would be eaten by the adapter so the user + # would never see it + msg = 'File is too large to display non top rows' + first_row = [msg] + [''] * (expected_num_cols - 1) + second_row = [''] * expected_num_cols + all_rows = [first_row, second_row] * (expected_num_rows // 2) + if expected_num_rows % 2 == 1: + all_rows.append(first_row) + return all_rows + + while current_row < v_start - self._chunk_size: + reader.read(self._chunk_size) + current_row = reader._current_row_in_file_index + # read up to v_stop + num_rows_to_read = v_stop - current_row + df = reader.read(num_rows_to_read) + logger.debug(f'{len(df)} rows read') + assert v_start >= current_row, f"{v_start} < {current_row}" + chunk = df.iloc[v_start - current_row:] + assert len(chunk) == expected_num_rows, \ + f"{len(chunk)=} != {expected_num_rows=}" + + chunk_columns = [chunk.iloc[:, i].values + for i in range(h_start, h_stop)] + try: + return np.stack(chunk_columns, axis=1) + except np.exceptions.DTypePromotionError: + return np.stack(chunk_columns, axis=1, dtype=object) + + +class PandasSAS7BDATPathAdapter(PandasSAS7BDATReaderAdapter): + @classmethod + def open(cls, fpath): + import pandas as pd + # * iterator=True so that Pandas returns a SAS7BDATReader instead of a + # DataFrame + # * encoding='infer' to avoid having string columns returned + # as raw bytes + return pd.read_sas(fpath, iterator=True, encoding='infer') + + +@adapter_for('pandas.io.stata.StataReader') +class PandasStataReaderAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes=None): + super().__init__(data, attributes=attributes) + reader = data + reader._ensure_open() + + # monkey-patch Pandas StataReader to fix column selection (only + # the first column selection of a reader works in the original version) + def _do_select_columns(self, data, columns): + if not hasattr(self, '_full_dtyplist'): + self._full_dtyplist = self._dtyplist + self._full_typlist = self._typlist + self._full_fmtlist = self._fmtlist + self._full_lbllist = self._lbllist + + column_set = set(columns) + if len(column_set) != len(columns): + raise ValueError("columns contains duplicate entries") + unmatched = column_set.difference(data.columns) + if unmatched: + joined = ", ".join(list(unmatched)) + raise ValueError( + "The following columns were not " + f"found in the Stata data set: {joined}" + ) + # Copy information for retained columns for later processing + get_loc = data.columns.get_loc + col_indices = [get_loc(col) for col in columns] + self._dtyplist = [self._full_dtyplist[i] for i in col_indices] + self._typlist = [self._full_typlist[i] for i in col_indices] + self._fmtlist = [self._full_fmtlist[i] for i in col_indices] + self._lbllist = [self._full_lbllist[i] for i in col_indices] + self._column_selector_set = True + return data[columns] + + reader.__class__._do_select_columns = _do_select_columns + + def shape2d(self): + reader = self.data + return reader._nobs, reader._nvar + + def get_hlabels_values(self, start, stop): + return [self.data._varlist[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + reader = self.data + columns = reader._varlist[h_start:h_stop] + + reader._lines_read = v_start + chunk = reader.read(v_stop - v_start, columns=columns) + + chunk_columns = [chunk.iloc[:, i].values + for i in range(h_stop - h_start)] + try: + return np.stack(chunk_columns, axis=1) + except np.exceptions.DTypePromotionError: + return np.stack(chunk_columns, axis=1, dtype=object) + + +class PandasDTAPathAdapter(PandasStataReaderAdapter): + @classmethod + def open(cls, fpath): + import pandas as pd + # iterator=True so that Pandas returns a StataReader instead of a + # DataFrame + return pd.read_stata(fpath, iterator=True) + + +@path_adapter_for('.sas7bdat') +def dispatch_sas7bdat_path_adapter(fpath): + # the pandas adapter is first as it (much) faster for reading the first + # lines of large files. In practice, Pandas is always available + # because it is currently a hard dependency of larray-editor + return dispatch_file_suffix_by_available_module('sas7bat',{ + 'pandas': PandasSAS7BDATPathAdapter, + 'pyreadstat': PyReadstatSas7BdatPathAdapter + }) + + +@path_adapter_for('.dta') +def dispatch_dta_path_adapter(fpath): + # the pandas adapter is first as it (much) faster for reading large files. + # In practice, Pandas is always available because it is currently a hard + # dependency of larray-editor + return dispatch_file_suffix_by_available_module('dta',{ + 'pandas': PandasDTAPathAdapter, + 'pyreadstat': PyReadstatDtaPathAdapter + }) + + +@adapter_for('pstats.Stats') +class ProfilingStatsAdapter(AbstractColumnarAdapter): + # we display everything except callers + _COLNAMES = ['filepath', 'line num', 'func. name', + 'ncalls (non rec)', 'ncalls (total)', + 'tottime', 'cumtime'] + + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + self._keys = list(data.stats.keys()) + + def shape2d(self): + return len(self._keys), len(self._COLNAMES) + + def get_hlabels_values(self, start, stop): + return [self._COLNAMES[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + """*_stop are exclusive""" + + func_calls = self._keys[v_start:v_stop] + stats = self.data.stats + call_details = [stats[k] for k in func_calls] + # we display everything except callers + return [(filepath, line_num, func_name, ncalls_primitive, ncalls_tot, tottime, cumtime)[h_start:h_stop] + for ((filepath, line_num, func_name), (ncalls_primitive, ncalls_tot, tottime, cumtime, callers)) + in zip(func_calls, call_details)] + + +SQLITE_LIST_TABLES_QUERY = ("SELECT name " + "FROM sqlite_schema " + "WHERE type='table' AND name NOT LIKE 'sqlite_%'") + + +class SQLiteTable: + def __init__(self, con, name): + self.con = con + self.name = name + + def __repr__(self): + return f"" + + +@adapter_for(SQLiteTable) +class SQLiteTableAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + assert isinstance(data, SQLiteTable) + super().__init__(data=data, attributes=attributes) + table_name = self.data.name + cur = self.data.con.cursor() + cur.execute(f"SELECT count(*) FROM {table_name}") + self._numrows = cur.fetchone()[0] + cur.execute(f"SELECT * FROM {table_name} LIMIT 1") + self._columns = [col_descr[0] for col_descr in cur.description] + cur.close() + + def shape2d(self): + return self._numrows, len(self._columns) + + def get_hlabels_values(self, start, stop): + return [self._columns[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + cur = self.data.con.cursor() + cols = self._columns[h_start:h_stop] + if self._current_sort: + col_names = [ + f"{self._columns[col_idx]}{'' if ascending else ' DESC'}" + for axis, col_idx, ascending in self._current_sort + ] + order_by = f" ORDER BY {', '.join(col_names)}" + else: + order_by = "" + query = f"""\ +SELECT {', '.join(cols)} FROM {self.data.name}{order_by} +LIMIT {v_stop - v_start} OFFSET {v_start}""" + cur.execute(query) + rows = cur.fetchall() + cur.close() + return rows + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + assert row_idx == 0 + self._current_sort = [(1, col_idx, ascending)] + + +@adapter_for('sqlite3.Connection') +class SQLiteConnectionAdapter(AbstractAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + # as of python3.12, sqlite3.Cursor is not context manager friendly + cur = data.cursor() + cur.execute(SQLITE_LIST_TABLES_QUERY) + self._table_names = [row[0] for row in cur.fetchall()] + cur.close() + + def shape2d(self): + return len(self._table_names), 1 + + def get_hlabels_values(self, start, stop): + return [['Name']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [[name] for name in self._table_names[v_start:v_stop]] + + def cell_activated(self, row_idx, column_idx): + table_name = self._table_names[row_idx] + return SQLiteTable(self.data, table_name) + + +DUCKDB_LIST_TABLES_QUERY = "SHOW TABLES" + +@adapter_for('duckdb.DuckDBPyRelation') +class DuckDBRelationAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + self._numrows = len(data) + self._columns = data.columns + self._unq_values_per_column = {} + + def shape2d(self): + return self._numrows, len(self._columns) + + def get_hlabels_values(self, start, stop): + return [self._columns[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + cols = self._columns[h_start:h_stop] + num_rows = v_stop - v_start + query = self.data + query = self._add_filters(query) + for axis, col_idx, ascending in self._current_sort: + assert axis == 1 + colname = self._columns[col_idx] + desc = "" if ascending else " DESC" + query = query.order(f"{colname}{desc}") + rows = query.limit(num_rows, offset=v_start) + subset = rows.select(*cols) + return subset.fetchall() + + def _add_filters(self, query): + for filter_idx, filter_indices in self.current_filter.items(): + colname = self._columns[filter_idx] + col_unq_values = self.get_filter_options(filter_idx) + filter_values = col_unq_values[filter_indices] + + # Sadly, duckdb relation API does not support named parameters/ + # prepared statements so we have to inline the values + if isinstance(filter_values[0], str): + # this is hopefully enough to avoid creating an SQL injection + # vulnerability + assert not any(isinstance(v, str) and "'" in v + for v in filter_values) + filter_values = [f"'{v}'" for v in filter_values] + else: + filter_values = [str(v) for v in filter_values] + assert not any("'" in v for v in filter_values) + query = query.filter(f"{colname} IN ({', '.join(filter_values)})") + return query + + def can_sort_hlabel(self, row_idx, col_idx): + return True + + def sort_hlabel(self, row_idx, col_idx, ascending): + assert row_idx == 0 + self._current_sort = [(1, col_idx, ascending)] + + def can_filter_hlabel(self, row_idx, col_idx) -> bool: + return True + + def get_filter_options(self, filter_idx): + if filter_idx in self._unq_values_per_column: + return self._unq_values_per_column[filter_idx] + else: + colname = self._columns[filter_idx] + query = (self.data.select(colname) + .distinct() + .order(colname) + .limit(MAX_FILTER_OPTIONS)) + unq_values = query.fetchnumpy()[colname] + self._unq_values_per_column[filter_idx] = unq_values + return unq_values + + def update_filter(self, filter_idx, indices): + """Update current filter for a given axis if labels selection from the array widget has changed + + Parameters + ---------- + filter_idx : int + Index of filter (axis) for which selection has changed. + indices: list of int + Indices of selected labels. + """ + # only allow filtering a single columns for now (by not keeping previous + # filters) + if not indices: + indices = list(range(len(self.get_filter_options(filter_idx)))) + self.current_filter = {filter_idx: indices} + self._numrows = len(self._add_filters(self.data)) + + +@adapter_for('duckdb.DuckDBPyConnection') +class DuckDBConnectionAdapter(AbstractAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + self._table_names = [ + row[0] for row in data.sql(DUCKDB_LIST_TABLES_QUERY).fetchall() + ] + + def shape2d(self): + return len(self._table_names), 1 + + def get_hlabels_values(self, start, stop): + return [['Table Name']] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [[name] for name in self._table_names[v_start:v_stop]] + + def cell_activated(self, row_idx, column_idx): + table_name = self._table_names[row_idx] + return self.data.table(table_name) + + +@path_adapter_for('.ddb', 'duckdb') +@path_adapter_for('.duckdb', 'duckdb') +class DuckDBPathAdapter(DuckDBConnectionAdapter): + @classmethod + def open(cls, fpath): + duckdb = sys.modules['duckdb'] + return duckdb.connect(fpath) + + +class CSVGZPathAdapater(CsvFileAdapter): + @classmethod + def open(cls, fpath): + import gzip + # not specifying an encoding is not an option because in that case + # we would get bytes and not str, which makes csv reader unhappy + return gzip.open(fpath, mode='rt', encoding='utf-8') + + @property + def _binary_file(self): + import gzip + return gzip.open(self.data.name, mode='rb') + + +@path_adapter_for('.gz', 'gzip') +def dispatch_gzip_path_adapter(gz_path): + # strip .gz extension and dispatch to appropriate adapter + fpath = gz_path.with_name(gz_path.stem) + suffix = fpath.suffix.lower() + if suffix == '.csv': + return CSVGZPathAdapater + else: + return None + + +@adapter_for('zipfile.ZipFile') +class ZipFileAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + + infolist = data.infolist() + infolist.sort(key=lambda info: (not info.is_dir(), info.filename)) + self._infolist = infolist + self._list = [(info.filename, + datetime(*info.date_time).strftime('%d/%m/%Y %H:%M'), + '' if info.is_dir() else info.file_size) + for info in infolist] + self._colnames = ['Name', 'Time Modified', 'Size'] + + def shape2d(self): + return len(self._list), len(self._colnames) + + def get_hlabels_values(self, start, stop): + return [self._colnames[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [row[h_start:h_stop] + for row in self._list[v_start:v_stop]] + + def cell_activated(self, row_idx, column_idx): + import zipfile + info = self._infolist[row_idx] + if info.is_dir(): + return zipfile.Path(self.data, info.filename) + else: + # do nothing for now + return None + # TODO: this returns a zipfile.ZipExtFile which is a file-like + # object but it does not inherit from io.BufferedReader so no + # adapter corresponds. We should add an adapter for + # zipfile.ZipExtFile + # return self.data.open(info.filename) + + +@path_adapter_for('.zip', 'zipfile') +class ZipPathAdapter(ZipFileAdapter): + @classmethod + def open(cls, fpath): + zipfile = sys.modules['zipfile'] + return zipfile.ZipFile(fpath) + + +@adapter_for('zipfile.Path') +class ZipfilePathAdapter(AbstractColumnarAdapter): + def __init__(self, data, attributes): + super().__init__(data=data, attributes=attributes) + zpath_objs = list(data.iterdir()) + zpath_objs.sort(key=lambda p: (not p.is_dir(), p.name)) + self._zpath_objs = zpath_objs + self._list = [(p.name, '' if p.is_dir() else '') + for p in zpath_objs] + self._colnames = ['Name', 'Type'] + + def shape2d(self): + return len(self._list), len(self._colnames) + + def get_hlabels_values(self, start, stop): + return [self._colnames[start:stop]] + + def get_values(self, h_start, v_start, h_stop, v_stop): + return [row[h_start:h_stop] + for row in self._list[v_start:v_stop]] + + def cell_activated(self, row_idx, column_idx): + child_path = self._zpath_objs[row_idx] + if child_path.is_dir(): + return child_path + else: + # for now, do nothing return None - # transform positional ND key to positional 2D key - strides = np.append(1, np.cumprod(filtered_data.shape[1:-1][::-1], dtype=int))[::-1] - return (index_key[:-1] * strides).sum(), index_key[-1] diff --git a/larray_editor/arraymodel.py b/larray_editor/arraymodel.py index 9610076a..c7ba214e 100644 --- a/larray_editor/arraymodel.py +++ b/larray_editor/arraymodel.py @@ -1,17 +1,170 @@ -from os.path import basename -import logging -from inspect import stack -import numpy as np -from larray_editor.utils import (get_default_font, - is_float, is_number, LinearGradient, SUPPORTED_FORMATS, scale_to_01range, - Product, is_number_value, get_sample_indices, logger) +import itertools + from qtpy.QtCore import Qt, QModelIndex, QAbstractTableModel, Signal from qtpy.QtGui import QColor -from qtpy.QtWidgets import QMessageBox +from qtpy.QtWidgets import QMessageBox, QStyle + +import numpy as np -LARGE_SIZE = 5e5 -LARGE_NROWS = 1e5 -LARGE_COLS = 60 +from larray_editor.utils import (get_default_font, + is_number_value, is_float_dtype, is_number_dtype, + LinearGradient, logger, broadcast_get, + format_exception, log_caller) + +# TODO before using the widget in other projects: +# * cleanup adapter filters API: +# - get_filter_options works for either larray-style filters and +# hlabel filters but cannot work for both at the same time (what the given +# index means). +# - How does axis filters enter the picture? +# - I think we should go towards get_hlabel_actions(row_idx, col_idx) -> +# some kind of Form definition. On any update of the form values, +# the method corresponding to the part of the form that changed is called. + +# The simplest form would be: +# {'filters': (self.change_filters, [label_values]), +# 'sort': (self.change_sort, ..., +# 'group_by: ...}. +# this would also help support custom actions: +# {'Label in menu': callable_to_perform_action | dict for submenu | } +# apply auto gui on it: +# def add_op_from_label(self, sort: ascending|descending|unsorted, +# filter: list_of_labels) -> QuickBarOp: +# ... +# +# * move ndigits/format detection to adapter +# but the trick is to avoid using current column width and just +# target a "reasonable number" of digits +# * update format on offset change (colwidth is updated but not format) +# * support thousand separators by default (see PROMES) +# * editing values on filtered arrays does not always work: when a filter +# contains at least one dimension with *several* labels selected (but not +# all), after editing a cell, the new value is visible only after the filter +# is changed again. The cause of this is very ugly. There are actually +# two distinct bugs at work here, usually hidden by the fact that filters +# on larray usually return views all the way to the model. But when several +# labels are selected, self.filtered_data in the LArray adapter is a *copy* +# of the original data and not a view, so even if the model.reset() +# (in EditObjecCommand) caused the model to re-ask the data from the +# adapter (which it probably should but does NOT currently do -- it only +# re-processes the data it already has -- which seems to work when the +# data is a view), the adapter would still return the wrong data. +# * allow adapters to send more data than requested (unsure about this, maybe +# implementing buffering in adapters is better) +# * implement the generic edit mechanism for the quickbar (to be able to edit +# the original array while viewing a filtered array or, even, editing after +# any other custom function, not just filtering) +# - each adapter may provide a method: +# can_edit_through_operation(op_name, args, kwargs) +# unsure whether to use *args and **kwargs instead of args and kwargs +# - if the adapter returns True for the above method (AbstractAdapter +# must return False whatever the operation is), it must also implement a +# transform_changes_through_inverse_of_operation(op_name, ..., changes) +# method (whatever its final name). op could be '__getitem__' but also +# 'growth_rate', or whatever. +# - the UI must clearly display whether an array is editable. Adding +# "readonly" in the window title is a good start but unsure it is enough. +# - I wonder if filtering could be done generically (if adapters implement +# an helper function, we get filtering *with* editing passthrough) +# - in a bluesky world, the inverse op could ask for options (e.g. editing a +# summed cell can be dispersed on the individual unsummed values in several +# ways: proportional to original value, constant fraction, ...) +# * adapters should provide a method to get position with axes area of each axis +# and a method for the opposite (given an xy position, get the axis number). +# The abstract adapter should provide a default implementation and possibly +# the choice between several common options (larray style, pandas style, ...) +# larray style is annoying for filters though, so I am unsure I want to +# support it, even though it is the most compact and visually appealing. +# Maybe using it when there is no need for a \ (no two *named* axes in the +# same cell) is a good option??? +# * do not convert to a single numpy array in Adapter.get_data_values_and_attributes +# because it converts mixed string/number datasets to all strings, which +# in turn breaks a lot of functionalities (alignment of numeric cells +# is wrong, plots on numeric cells are wrong -- though this one suffers from +# its own explicit conversion to numpy array --, etc.) +# Ideally, we should support per-cell dtype, but if we support +# dense arrays, per column dtype and per row dtype that would be already much +# better than what we have now. +# * changing from an array to another is sometimes broken (new array not +# displayed, old array axes still present) +# - I *think* this only happens when an Exception is raised in the adapter +# or arraymodel. These exceptions should not happen in the first place, but +# the widget should handle them gracefully. +# * allow fixing frac_digits or scientific from the API (see PROMES, I think) +# * better readonly behavior from PROMES +# * include all other widgets from PROMES +# * allow having no filters nor gradient chooser (see PROMES) +# * massive cleanup (many methods would probably be better in either their +# superclass or one of their subclasses) + +# TODO post release (move to issues): +# * mouse selection on "edges" should move the buffer +# (it scrolls the internal viewport but does not change the offset) + +h_align_map = { + 'left': Qt.AlignLeft, + 'center': Qt.AlignHCenter, + 'right': Qt.AlignRight, +} +v_align_map = { + 'top': Qt.AlignTop, + 'center': Qt.AlignVCenter, + 'bottom': Qt.AlignBottom, +} + +# h_offset +# -------- +# | | +# v v +# |----------------------| <-| +# | total data | | v_offset +# | |------------------| | <-| +# | | data in model | | +# | | |--------------| | | +# | | | visible area | | | +# | | |--------------| | | +# | |------------------| | +# |----------------------| + +def homogenous_shape(seq) -> tuple: + """ + Returns the shape (size of each dimension) of nested sequences. + Checks that nested sequences are homogeneous; if not, + treats children as scalars. + """ + # we cannot rely on .shape for object arrays because they could + # contain sequences themselves + if isinstance(seq, np.ndarray) and not seq.dtype.kind == 'O': + return seq.shape + elif isinstance(seq, (list, tuple, np.ndarray)): + parent_length = len(seq) + if parent_length == 0: + return (0,) + elif parent_length == 1: + return (parent_length,) + homogenous_shape(seq[0]) + res = [parent_length] + child_shapes = [homogenous_shape(child) + for child in seq] + # zip length will be determined by the shortest shape, which is + # exactly what we need + for depth_lengths in zip(*child_shapes[1:]): + first_child_length = depth_lengths[0] + if all(length == first_child_length + for length in depth_lengths[1:]): + res.append(first_child_length) + return tuple(res) + else: + return () + + +def homogenous_ndim(seq) -> int: + return len(homogenous_shape(seq)) + + +def assert_at_least_2d_or_empty(seq): + seq_shape = homogenous_shape(seq) + assert len(seq_shape) >= 2 or 0 in seq_shape, ( + f"sequence:\n{seq}\nshould be >=2D or empty but has shape {seq_shape}") class AbstractArrayModel(QAbstractTableModel): @@ -26,93 +179,424 @@ class AbstractArrayModel(QAbstractTableModel): font : QFont, optional Font. Default is `Calibri` with size 11. """ - ROWS_TO_LOAD = 500 - COLS_TO_LOAD = 40 + default_buffer_rows = 40 + default_buffer_cols = 40 - def __init__(self, parent=None, readonly=False, font=None): + def __init__(self, parent=None, adapter=None): QAbstractTableModel.__init__(self) self.dialog = parent - self.readonly = readonly - - if font is None: - font = get_default_font() - self.font = font - - self._data = None - self.rows_loaded = 0 - self.cols_loaded = 0 - self.total_rows = 0 - self.total_cols = 0 + self.adapter = adapter + + self.h_offset = 0 + self.v_offset = 0 + self.nrows = 0 + self.ncols = 0 + + self.raw_data = {} + self.processed_data = {} + self.role_defaults = {} + # FIXME: unused + self.flags_defaults = Qt.NoItemFlags + self.processed_flags_data = None #Qt.NoItemFlags + self.bg_gradient = None + self.default_v_align = [[Qt.AlignVCenter]] + + def set_adapter(self, adapter): + self.adapter = adapter + self.h_offset = 0 + self.v_offset = 0 + self.nrows = 0 + self.ncols = 0 + self._get_current_data() + self.reset() + + def set_h_offset(self, offset): + # TODO: when moving in one direction only, we should make sure to only request data we do not have already + # (if there is overlap between the old "window" and the new one). + self.set_offset(self.v_offset, offset) + + def set_v_offset(self, offset): + self.set_offset(offset, self.h_offset) + + def set_offset(self, v_offset, h_offset): + # TODO: the implementation of this method should use set_bounds instead + logger.debug(f"{self.__class__.__name__}.set_offset({v_offset=}, {h_offset=})") + assert v_offset is not None and h_offset is not None + assert v_offset >= 0 and h_offset >= 0 + self.v_offset = v_offset + self.h_offset = h_offset + old_shape = self.nrows, self.ncols + self._get_current_data() + self._process_data() + new_shape = self.nrows, self.ncols + if new_shape != old_shape: + self.reset() + else: + top_left = self.index(0, 0) + # -1 because Qt index end bounds are inclusive + bottom_right = self.index(self.nrows - 1, self.ncols - 1) + self.dataChanged.emit(top_left, bottom_right) - def _set_data(self, data): + def _begin_insert_remove(self, action, target, parent, start, stop): + if start >= stop: + return False + funcs = { + ('remove', 'rows'): self.beginRemoveRows, + ('insert', 'rows'): self.beginInsertRows, + ('remove', 'columns'): self.beginRemoveColumns, + ('insert', 'columns'): self.beginInsertColumns, + } + funcs[action, target](parent, start, stop - 1) + return True + + def _end_insert_remove(self, action, target): + funcs = { + ('remove', 'rows'): self.endRemoveRows, + ('insert', 'rows'): self.endInsertRows, + ('remove', 'columns'): self.endRemoveColumns, + ('insert', 'columns'): self.endInsertColumns, + } + funcs[action, target]() + + # FIXME: review all API methods and be consistent in argument order: h, v or v, h. + # I think Qt always uses v, h but we sometime use h, v + + # TODO: make this a private method (it is not called anymore but SHOULD be + # called (or inlined) in set_offset) + def set_bounds(self, v_start=None, h_start=None, v_stop=None, h_stop=None): + """stop bounds are *exclusive* + any None is replaced by its previous value""" + + oldvstart, oldhstart = self.v_offset, self.h_offset + oldvstop, oldhstop = oldvstart + self.nrows, oldhstart + self.ncols + newvstart, newhstart, newvstop, newhstop = v_start, h_start, v_stop, h_stop + print("set bounds", v_start, h_start, v_stop, h_stop) + if newvstart is None: + newvstart = oldvstart + if newhstart is None: + newhstart = oldhstart + if newvstop is None: + newvstop = oldvstop + if newhstop is None: + newhstop = oldhstop + + nrows = newvstop - newvstart + ncols = newhstop - newhstart + # new_shape = nrows, ncols + + # if new_shape != old_shape: + # self.reset() + # else: + + # we could generalize this to allow moving the "viewport" and shrinking/enlarging it at the same time + # but this is a BAD idea as we should very rarely shrink/enlarge the buffer + + # assert we_are_enlarging or shrinking in one direction only + # we have 9 cases total: same shape or 4 cases for each direction: enlarging|shrinking * moving start|stop + # ensure we have some overlap between old and new + + # ENLARGE + # ------- + + # start stop start stop start stop start stop + # |---old---| |---old---| |---old---| |---old---| + # v v v v v v v v + # ----------------- OR ---------------- OR --------------------------- OR -------------------------- + # ^ ^ ^ ^ ^ ^ ^ ^ + # |----new----| |----new----| |----new----| |----new----| + # start stop start stop start stop start stop + + # SHRINK + # ------ + + # start stop start stop start stop start stop + # |----old----| |----old----| |----old----| |----old----| + # v v v v v v v v + # --------------- OR ---------------- OR -------------------------- OR -------------------------- + # ^ ^ ^ ^ ^ ^ ^ ^ + # |---new---| |---new---| |---new---| |---new---| + # start stop start stop start stop start stop + + parent = QModelIndex() + + end_todo = {} + for action in ('remove', 'insert'): + for target in ('rows', 'columns'): + end_todo[action, target] = False + + target = 'rows' + oldstart, oldstop, newstart, newstop = oldvstart, oldvstop, newvstart, newvstop + + # remove oldstart:newstart + end_todo['remove', target] |= self._begin_insert_remove('remove', target, parent, + oldstart, min(newstart, oldstop)) + # remove newstop:oldstop + end_todo['remove', target] |= self._begin_insert_remove('remove', target, parent, + max(newstop, oldstart), oldstop) + # insert newstart:oldstart + end_todo['insert', target] |= self._begin_insert_remove('insert', target, parent, + newstart, min(oldstart, newvstop)) + # insert oldstop:newstop + end_todo['insert', target] |= self._begin_insert_remove('insert', target, parent, + max(oldstop, newstart), newstop) + + target = 'columns' + oldstart, oldstop, newstart, newstop = oldhstart, oldhstop, newhstart, newhstop + + # remove oldstart:newstart + end_todo['remove', target] |= self._begin_insert_remove('remove', target, parent, + oldstart, min(newstart, oldstop)) + # remove newstop:oldstop + end_todo['remove', target] |= self._begin_insert_remove('remove', target, parent, + max(newstop, oldstart), oldstop) + # insert newstart:oldstart + end_todo['insert', target] |= self._begin_insert_remove('insert', target, parent, + newstart, min(oldstart, newvstop)) + # insert oldstop:newstop + end_todo['insert', target] |= self._begin_insert_remove('insert', target, parent, + max(oldstop, newstart), newstop) + + assert newvstart is not None and newhstart is not None + self.v_offset, self.h_offset = newvstart, newhstart + self.nrows, self.ncols = nrows, ncols + + # TODO: we should only get data we do not have yet + self._get_current_data() + # TODO: we should only process data we do not have yet + self._process_data() + + for action in ('remove', 'insert'): + for target in ('rows', 'columns'): + if end_todo[action, target]: + self._end_insert_remove(action, target) + + # removed_rows_start, remove_rows_stop = ... # correspond to old_rows - nrows + # changed_rows_start, changed_rows_stop = ... # other rows + # if v_stop > oldvstop and v_start < oldvstop: # 10 < 11 => 10.. & ..10 => intersection = 10 + # insertRows + # elif v_start < oldvstart and v_stop > oldvstart: # 11 > 10 => ..10 & 10.. => intersection = 10 + # insertRows + # elif v_stop < oldvstop and v_stop > oldvstart: # 11 > 10 => ..10 & 10.. => intersection = 10 + # removeRows + # elif v_start < oldvstart and v_stop > oldvstart: # 11 > 10 => ..10 & 10.. => intersection = 10 + # removeRows + # if new_shape == old_shape: + # # FIXME: this method will never be called for this case, unless I suppress set_offset + # top_left = self.index(0, 0) + # # -1 because Qt index end bounds are inclusive + # bottom_right = self.index(nrows - 1, ncols - 1) + # self.dataChanged.emit(top_left, bottom_right) + # elif nrows == old_rows and ncols > old_cols: + # if ...: + # pass + # else: + # pass + # elif nrows == old_rows and ncols < old_cols: + # if ...: + # pass + # else: + # ... + # elif nrows == old_rows and ncols < old_cols: + # self.beginRemoveRows(parent, first, last) + # self.beginInsertRows(QModelIndex(), self.nrows, self.nrows + items_to_fetch - 1) + # self.nrows += items_to_fetch + # self._get_data() + # self._process_data() + # self.endInsertRows() + # top_left = self.index(0, 0) + # # -1 because Qt index end bounds are inclusive + # bottom_right = self.index(nrows - 1, ncols - 1) + # self.dataChanged.emit(top_left, bottom_right) + + def _format_value(self, args): + fmt, value = args + # using str(value) instead of '%s' % value makes it work for + # tuple value + # try: + return fmt % value if is_number_value(value) else str(value) + # except Exception as e: + # print("YAAAAAAAAAAAAAAA") + # return '' + + def _value_to_h_align(self, elem_value): + if is_number_value(elem_value): + return Qt.AlignRight + else: + return Qt.AlignLeft + + def _process_data(self): + # None format => %user_format if number else %s + + # format per cell (in data) => decimal select will not work and that's fine + # we could make decimal select change adapter-provided format per cell + # which means the adapter can decide to ignore it or not + # default adapter implementation should affect only numeric cells and use %s for non numeric + raw_data = self.raw_data + values = raw_data.get('values', [['']]) + assert_at_least_2d_or_empty(values) + + data_format = raw_data.get('data_format', [['%s']]) + assert_at_least_2d_or_empty(data_format) + + format_and_values = seq_zip_broadcast(data_format, values, ndim=2) + formatted_values = map_nested_sequence(self._format_value, + format_and_values, + ndim=2) + + editable = raw_data.get('editable', [[False]]) + + assert_at_least_2d_or_empty(editable) + + def editable_to_flags(elem_editable): + editable_flag = Qt.ItemIsEditable if elem_editable else Qt.NoItemFlags + return Qt.ItemIsEnabled | Qt.ItemIsSelectable | editable_flag + + # FIXME: use self.flags_defaults + # self.processed_flags_data = self.flags_defaults + self.processed_flags_data = map_nested_sequence(editable_to_flags, editable, 2) + self.processed_data = { + Qt.DisplayRole: formatted_values, + # XXX: maybe split this into font_name, font_size, font_flags, or make that data item a dict itself + # Qt.FontRole: None, + # Qt.ToolTipRole: None, + } + + if 'h_align' in raw_data: + h_align = map_nested_sequence(h_align_map.__getitem__, raw_data['h_align'], 2) + else: + h_align = map_nested_sequence(self._value_to_h_align, values, 2) + if 'v_align' in raw_data: + v_align = map_nested_sequence(v_align_map.__getitem__, raw_data['v_align'], 2) + else: + v_align = self.default_v_align + self.processed_data[Qt.TextAlignmentRole] = [ + [int(ha | va) for ha, va in seq_zip_broadcast(ha_row, va_row)] for + ha_row, va_row in seq_zip_broadcast(h_align, v_align) + ] + + if 'bg_value' in raw_data and self.bg_gradient is not None: + bg_value = raw_data['bg_value'] + # TODO: implement a way to specify bg_gradient per row or per column + bg_color = self.bg_gradient[bg_value] if bg_value is not None else None + self.processed_data[Qt.BackgroundColorRole] = bg_color + if 'decoration' in raw_data: + standardIcon = self.dialog.style().standardIcon + def make_icon(decoration): + return standardIcon(DECORATION_MAPPING[decoration]) if decoration else None + decoration_data = map_nested_sequence(make_icon, raw_data['decoration'], 2) + self.processed_data[Qt.DecorationRole] = decoration_data + + def _fetch_data(self, h_start, v_start, h_stop, v_stop): raise NotImplementedError() - def set_data(self, data, reset=True): - self._set_data(data) - if reset: - self.reset() + def _get_current_data(self): + max_row, max_col = self.adapter.shape2d() + + # TODO: I don't think we should ever *ask* for more rows or columns + # than the default, but we should support *receiving* more, + # including before the current v_offset/h_offset. + # The only requirement should be that the asked for region + # is included in what we receive. + rows_to_ask = max(self.nrows, self.default_buffer_rows) + cols_to_ask = max(self.ncols, self.default_buffer_cols) + h_stop = min(self.h_offset + cols_to_ask, max_col) + v_stop = min(self.v_offset + rows_to_ask, max_row) + # print(f"asking {rows_to_ask} rows / {cols_to_ask} columns of data ({self.__class__.__name__})") + try: + raw_data = self._fetch_data(self.h_offset, self.v_offset, + h_stop, v_stop) + except Exception as e: + try: + logger.error(f"could not fetch data from adapter:\n" + f"{''.join(format_exception(e))}") + except Exception: + # sometimes the exception message itself contains unprintable + # data (e.g. unicode string with characters which cannot be + # encoded to the stderr encoding) + logger.error("could not fetch data from adapter " + "(and cannot log exception)") + raw_data = np.array([[]]) + if not isinstance(raw_data, dict): + raw_data = {'values': raw_data} + self.raw_data = raw_data + + # XXX: currently this can be a view on the original data + values = self.raw_data['values'] + assert_at_least_2d_or_empty(values) + self.nrows = len(values) + # FIXME: this is problematic for list of sequences + first_row = values[0] if len(values) else [] + self.ncols = len(first_row) if isinstance(first_row, (tuple, list, np.ndarray)) else 1 + # print(f" > received {self.nrows} rows / {self.ncols} cols") + if self.nrows > max_row: + print(f"WARNING: received too many rows ({self.nrows} > {max_row})!") + if self.ncols > max_col: + print(f"WARNING: received too many columns ({self.ncols} > {max_col})!") + + def set_bg_gradient(self, bg_gradient): + if bg_gradient is not None and not isinstance(bg_gradient, LinearGradient): + raise ValueError("Expected None or LinearGradient instance for `bg_gradient` argument") + self.bg_gradient = bg_gradient + self.reset() def rowCount(self, parent=QModelIndex()): - return self.rows_loaded + return self.nrows def columnCount(self, parent=QModelIndex()): - return self.cols_loaded - - def fetch_more_rows(self): - if self.total_rows > self.rows_loaded: - remainder = self.total_rows - self.rows_loaded - items_to_fetch = min(remainder, self.ROWS_TO_LOAD) - self.beginInsertRows(QModelIndex(), self.rows_loaded, - self.rows_loaded + items_to_fetch - 1) - self.rows_loaded += items_to_fetch - self.endInsertRows() - - def fetch_more_columns(self): - if self.total_cols > self.cols_loaded: - remainder = self.total_cols - self.cols_loaded - items_to_fetch = min(remainder, self.COLS_TO_LOAD) - self.beginInsertColumns(QModelIndex(), self.cols_loaded, - self.cols_loaded + items_to_fetch - 1) - self.cols_loaded += items_to_fetch - self.endInsertColumns() + return self.ncols + # AFAICT, this is only used in the ArrayDelegate def get_value(self, index): - raise NotImplementedError() - - def _compute_rows_cols_loaded(self): - # Use paging when the total size, number of rows or number of - # columns is too large - size = self.total_rows * self.total_cols - if size > LARGE_SIZE: - self.rows_loaded = min(self.ROWS_TO_LOAD, self.total_rows) - self.cols_loaded = min(self.COLS_TO_LOAD, self.total_cols) - else: - if self.total_rows > LARGE_NROWS: - self.rows_loaded = self.ROWS_TO_LOAD - else: - self.rows_loaded = self.total_rows - if self.total_cols > LARGE_COLS: - self.cols_loaded = self.COLS_TO_LOAD - else: - self.cols_loaded = self.total_cols + return broadcast_get(self.raw_data['values'], index.row(), index.column()) def flags(self, index): - raise NotImplementedError() + if not index.isValid(): + return Qt.NoItemFlags + + row = index.row() + column = index.column() + if row >= self.nrows or column >= self.ncols: + assert False, "should not have happened" + return QAbstractTableModel.flags(self, index) + return broadcast_get(self.processed_flags_data, row, column) def headerData(self, section, orientation, role=Qt.DisplayRole): return None def data(self, index, role=Qt.DisplayRole): - raise NotImplementedError() + if not index.isValid(): + return None + + role_map = self.processed_data + if role in role_map: + role_data = role_map[role] + else: + role_data = self.role_defaults.get(role) + + row = index.row() + column = index.column() + if row >= self.nrows or column >= self.ncols: + return None + return broadcast_get(role_data, row, column) + # res = broadcast_get_index(role_data, index) + # if role == Qt.DisplayRole: + # print("data", index.row(), index.column(), "=>", res) + # return res def reset(self): self.beginResetModel() + self._process_data() self.endResetModel() - if logger.isEnabledFor(logging.DEBUG): - caller = stack()[1] - logger.debug(f"model {self.__class__} has been reset after call of {caller.function} from module " - f"{basename(caller.filename)} at line {caller.lineno}") + log_caller() + + +DECORATION_MAPPING = { + # QStyle.SP_TitleBarUnshadeButton + # QStyle.SP_TitleBarShadeButton + 'arrow_down': QStyle.SP_ArrowDown, + 'arrow_up': QStyle.SP_ArrowUp, +} class AxesArrayModel(AbstractArrayModel): @@ -127,52 +611,35 @@ class AxesArrayModel(AbstractArrayModel): font : QFont, optional Font. Default is `Calibri` with size 11. """ - def __init__(self, parent=None, readonly=False, font=None): - AbstractArrayModel.__init__(self, parent, readonly, font) - self.font.setBold(True) - - def _set_data(self, data): - # TODO: use sequence instead - if not isinstance(data, (list, tuple)): - QMessageBox.critical(self.dialog, "Error", "Expected list or tuple") - data = [] - self._data = data - self.total_rows = 1 - self.total_cols = len(data) - self._compute_rows_cols_loaded() - - def flags(self, index): - """Set editable flag""" - return Qt.ItemIsEnabled - - def get_value(self, index): - i = index.column() - return str(self._data[i]) - - def get_values(self, left=0, right=None): - if right is None: - right = self.total_cols - values = self._data[left:right] - return values - - def data(self, index, role=Qt.DisplayRole): - if not index.isValid(): - return None - - if role == Qt.TextAlignmentRole: - return int(Qt.AlignCenter | Qt.AlignVCenter) - elif role == Qt.FontRole: - return self.font - elif role == Qt.BackgroundColorRole: - color = QColor(Qt.lightGray) - color.setAlphaF(.4) - return color - elif role == Qt.DisplayRole: - return self.get_value(index) - # elif role == Qt.ToolTipRole: - # return None - else: - return None + def __init__(self, parent=None, adapter=None): + AbstractArrayModel.__init__(self, parent, adapter) + # TODO: move defaults to class attributes, not instances' + default_font = get_default_font() + default_font.setBold(True) + default_background = QColor(Qt.lightGray) + default_background.setAlphaF(.4) + self.role_defaults = { + Qt.TextAlignmentRole: int(Qt.AlignCenter | Qt.AlignVCenter), + Qt.FontRole: default_font, + Qt.BackgroundColorRole: default_background, + # Qt.DisplayRole: '', + # Qt.ToolTipRole: + } + self.flags_defaults = Qt.ItemIsEnabled + + def _fetch_data(self, h_start, v_start, h_stop, v_stop): + axes_area = self.adapter.get_axes_area() + # print(f"{axes_area=}") + return axes_area + + def _value_to_h_align(self, elem_value): + return Qt.AlignCenter + + # def get_values(self, left=0, right=None): + # if right is None: + # right = self.total_cols + # values = self._data[left:right] + # return values class LabelsArrayModel(AbstractArrayModel): @@ -187,57 +654,149 @@ class LabelsArrayModel(AbstractArrayModel): font : QFont, optional Font. Default is `Calibri` with size 11. """ - def __init__(self, parent=None, readonly=False, font=None): - AbstractArrayModel.__init__(self, parent, readonly, font) - self.font.setBold(True) - - def _set_data(self, data): - # TODO: use sequence instead - if not isinstance(data, (list, tuple, Product)): - QMessageBox.critical(self.dialog, "Error", "Expected list, tuple or Product") - data = [[]] - self._data = data - self.total_rows = len(data[0]) - self.total_cols = len(data) if self.total_rows > 0 else 0 - self._compute_rows_cols_loaded() + def __init__(self, parent=None, adapter=None): + AbstractArrayModel.__init__(self, parent, adapter) + default_font = get_default_font() + default_font.setBold(True) + default_background = QColor(Qt.lightGray) + default_background.setAlphaF(.4) + self.role_defaults = { + Qt.TextAlignmentRole: int(Qt.AlignCenter | Qt.AlignVCenter), + Qt.FontRole: default_font, + Qt.BackgroundColorRole: default_background, + # Qt.ToolTipRole: + } + self.flags_defaults = Qt.ItemIsEnabled + + def _value_to_h_align(self, elem_value): + return Qt.AlignCenter - def flags(self, index): - """Set editable flag""" - return Qt.ItemIsEnabled + # XXX: I wonder if we shouldn't return a 2D Numpy array of strings? + # def get_values(self, left=0, top=0, right=None, bottom=None): + # if right is None: + # right = self.total_rows + # if bottom is None: + # bottom = self.total_cols + # values = [list(line[left:right]) for line in self._data[top:bottom]] + # return values - def get_value(self, index): - i = index.row() - j = index.column() - # we need to inverse column and row because of the way vlabels are generated - return str(self._data[j][i]) - # XXX: I wonder if we shouldn't return a 2D Numpy array of strings? - def get_values(self, left=0, top=0, right=None, bottom=None): - if right is None: - right = self.total_rows - if bottom is None: - bottom = self.total_cols - values = [list(line[left:right]) for line in self._data[top:bottom]] - return values +class VLabelsArrayModel(LabelsArrayModel): + def _fetch_data(self, h_start, v_start, h_stop, v_stop): + return self.adapter.get_vlabels(v_start, v_stop) - def data(self, index, role=Qt.DisplayRole): - if not index.isValid(): - return None - if role == Qt.TextAlignmentRole: - return int(Qt.AlignCenter | Qt.AlignVCenter) - elif role == Qt.FontRole: - return self.font - elif role == Qt.BackgroundColorRole: - color = QColor(Qt.lightGray) - color.setAlphaF(.4) - return color - elif role == Qt.DisplayRole: - return self.get_value(index) - # elif role == Qt.ToolTipRole: - # return None - else: - return None +class HLabelsArrayModel(LabelsArrayModel): + def _fetch_data(self, h_start, v_start, h_stop, v_stop): + return self.adapter.get_hlabels(h_start, h_stop) + + +def seq_broadcast(*seqs): + """ + Examples + -------- + >>> seq_broadcast(["a"], ["b1", "b2"]) + (['a', 'a'], ['b1', 'b2']) + >>> seq_broadcast(["a1", "a2"], ["b"]) + (['a1', 'a2'], ['b', 'b']) + >>> seq_broadcast(["a1", "a2"], ["b1", "b2"]) + (['a1', 'a2'], ['b1', 'b2']) + >>> seq_broadcast(["a1", "a2"], ["b1", "b2", "b3"]) + Traceback (most recent call last): + ... + ValueError: all sequences lengths must be 1 or the same + """ + seqs = [seq if isinstance(seq, (tuple, list)) else [seq] + for seq in seqs] + assert all(hasattr(seq, '__getitem__') for seq in seqs) + length = max(len(seq) for seq in seqs) + if not all(len(seq) == 1 or len(seq) == length for seq in seqs): + raise ValueError("all sequences lengths must be 1 or the same") + return tuple(seq * length if len(seq) == 1 else seq + for seq in seqs) + + +def seq_zip_broadcast(*seqs, ndim=1): + """ + Zip sequences but broadcasting (repeating) s_len 1 sequences to the s_len + of the longest sequence. + + Examples + -------- + >>> list(seq_zip_broadcast(["a"], ["b1", "b2"])) + [('a', 'b1'), ('a', 'b2')] + >>> list(seq_zip_broadcast(["a1", "a2"], ["b"])) + [('a1', 'b'), ('a2', 'b')] + >>> list(seq_zip_broadcast(["a1", "a2"], ["b1", "b2"])) + [('a1', 'b1'), ('a2', 'b2')] + >>> list(seq_zip_broadcast(["a1", "a2"], ["b1", "b2", "b3"])) + Traceback (most recent call last): + ... + ValueError: all sequences lengths must be 1 or the same: + (['a1', 'a2'], ['b1', 'b2', 'b3']) + >>> list(seq_zip_broadcast([[1], [2]], [[1, 2]], ndim=2)) + [[(1, 1), (1, 2)], [(2, 1), (2, 2)]] + """ + if ndim == 1: + assert len(seqs) > 0 + + # "if s_len != 1" and "default=1" are necessary to support combining + # an empty sequence with a length 1 sequence + seq_lengths = [len(seq) for seq in seqs] + max_length = max([s_len for s_len in seq_lengths if s_len != 1], + default=1) + if not all(s_len in {1, max_length} for s_len in seq_lengths): + raise ValueError(f"all sequences lengths must be 1 or the same:\n" + f"{seqs}") + return zip(*(itertools.repeat(seq[0], max_length) if len(seq) == 1 else seq + for seq in seqs)) + else: + assert ndim > 1 + broadcasted = seq_zip_broadcast(*seqs, ndim=1) + return [list(seq_zip_broadcast(*seq, ndim=ndim - 1)) + for seq in broadcasted] + + +def map_nested_sequence(func, seq, ndim): + """ + Apply a function to elements of a (nested) sequence at a specified depth. + + Parameters + ---------- + func : callable + Function to apply to elements at the target dimension level. + Should accept a single argument and return a transformed value. + seq : sequence + The (potentially nested) sequence to process. + ndim : int + Target dimension depth. Must be >= 1. + When ndim=1, applies func to each element of seq. + When ndim>1, recursively processes nested sub-sequences. + + Returns + ------- + list + Returns a list with the same depth as the input with func applied to + each element at the target depth. + + Examples + -------- + >>> # 1D sequence + >>> map_nested_sequence(lambda x: x * 2, [1, 2, 3], 1) + [2, 4, 6] + >>> # 2D sequence + >>> map_nested_sequence(lambda x: x * 10, [[1, 2], [3, 4]], 2) + [[10, 20], [30, 40]] + >>> # Apply function to 2D sequence at depth 1 + >>> map_nested_sequence(str, [[1, 2], [3, 4]], 1) + ['[1, 2]', '[3, 4]'] + """ + assert ndim >= 1 + if ndim == 1: + return [func(elem) for elem in seq] + else: + return [map_nested_sequence(func, elem, ndim - 1) + for elem in seq] class DataArrayModel(AbstractArrayModel): @@ -247,6 +806,7 @@ class DataArrayModel(AbstractArrayModel): ---------- parent : QWidget, optional Parent Widget. + FIXME: update this readonly : bool, optional If True, data cannot be changed. False by default. format : str, optional @@ -254,203 +814,35 @@ class DataArrayModel(AbstractArrayModel): By default, they are represented as floats with 3 decimal points. font : QFont, optional Font. Default is `Calibri` with size 11. - bg_gradient : LinearGradient, optional - Background color gradient - bg_value : Numpy ndarray, optional - Background color value. Must have the shape as data - minvalue : scalar, optional - Minimum value allowed. - maxvalue : scalar, optional - Maximum value allowed. """ - ROWS_TO_LOAD = 500 - COLS_TO_LOAD = 40 newChanges = Signal(dict) - def __init__(self, parent=None, readonly=False, format="%.3f", font=None, minvalue=None, maxvalue=None): - AbstractArrayModel.__init__(self, parent, readonly, font) - self._format = format - - self.minvalue = minvalue - self.maxvalue = maxvalue - - self.color_func = None - - self.vmin = None - self.vmax = None - self.bgcolor_possible = False - - self.bg_value = None - self.bg_gradient = None - - def get_format(self): - """Return current format""" - # Avoid accessing the private attribute _format from outside - return self._format - - def get_data(self): - """Return data""" - return self._data - - def _set_data(self, data): - # TODO: check that data respects minvalue/maxvalue - assert isinstance(data, np.ndarray) and data.ndim == 2 - self._data = data - - dtype = data.dtype - if dtype.names is None: - dtn = dtype.name - if dtn not in SUPPORTED_FORMATS and not dtn.startswith('str') \ - and not dtn.startswith('unicode'): - QMessageBox.critical(self.dialog, "Error", f"{dtn} arrays are currently not supported") - return - # for complex numbers, shading will be based on absolute value - # but for all other types it will be the real part - # TODO: there are a lot more complex dtypes than this. Is there a way to get them all in one shot? - if dtype in (np.complex64, np.complex128): - self.color_func = np.abs - else: - self.color_func = None - # -------------------------------------- - self.total_rows, self.total_cols = self._data.shape - self._compute_rows_cols_loaded() - - def reset_minmax(self): - try: - data = self.get_values(sample=True) - color_value = self.color_func(data) if self.color_func is not None else data - if color_value.dtype.type == np.object_: - color_value = color_value[is_number_value(color_value)] - # this is probably broken if we have complex numbers stored as objects but I don't foresee - # this case happening anytime soon. - color_value = color_value.astype(float) - # ignore nan, -inf, inf (setting them to 0 or to very large numbers is not an option) - color_value = color_value[np.isfinite(color_value)] - if color_value.size: - self.vmin = float(np.min(color_value)) - self.vmax = float(np.max(color_value)) - else: - self.vmin = np.nan - self.vmax = np.nan - - self.bgcolor_possible = True - # ValueError for empty arrays, TypeError for object/string arrays - except (TypeError, ValueError): - self.vmin = None - self.vmax = None - self.bgcolor_possible = False - - def set_format(self, format, reset=True): - """Change display format""" - self._format = format - if reset: - self.reset() - - def set_bg_gradient(self, bg_gradient, reset=True): - if bg_gradient is not None and not isinstance(bg_gradient, LinearGradient): - raise ValueError("Expected None or LinearGradient instance for `bg_gradient` argument") - self.bg_gradient = bg_gradient - if reset: - self.reset() - - def set_bg_value(self, bg_value, reset=True): - if bg_value is not None and not (isinstance(bg_value, np.ndarray) and bg_value.shape == self._data.shape): - raise ValueError(f"Expected None or 2D Numpy ndarray with shape {self._data.shape} for `bg_value` argument") - self.bg_value = bg_value - if reset: - self.reset() - - def get_value(self, index): - i, j = index.row(), index.column() - return self._data[i, j] - - def flags(self, index): - """Set editable flag""" - if not index.isValid(): - return Qt.ItemIsEnabled - flags = QAbstractTableModel.flags(self, index) - if not self.readonly: - flags |= Qt.ItemIsEditable - return flags - - def data(self, index, role=Qt.DisplayRole): - """Cell content""" - if not index.isValid(): - return None - # if role == Qt.DecorationRole: - # return ima.icon('editcopy') - # if role == Qt.DisplayRole: - # return "" - - if role == Qt.TextAlignmentRole: - return int(Qt.AlignRight | Qt.AlignVCenter) - elif role == Qt.FontRole: - return self.font - - value = self.get_value(index) - if role == Qt.DisplayRole: - if value is np.ma.masked: - return '' - # for headers - elif isinstance(value, str) and not isinstance(value, np.str_): - return value - else: - return self._format % value - elif role == Qt.BackgroundColorRole: - if self.bgcolor_possible and self.bg_gradient is not None and value is not np.ma.masked: - if self.bg_value is None: - try: - v = self.color_func(value) if self.color_func is not None else value - if np.isnan(v): - v = np.nan - else: - do_reset = False - if np.isnan(self.vmin) or -np.inf < v < self.vmin: - # TODO: this is suboptimal, as it can reset many times (though in practice, it is - # usually ok). When we get buffering, we will need to compute vmin/vmax on the - # whole buffer at once, eliminating this problem (and we could even compute final - # colors directly all at once) - self.vmin = v - do_reset = True - if np.isnan(self.vmax) or self.vmax < v < np.inf: - self.vmax = v - do_reset = True - - if do_reset: - self.reset() - v = scale_to_01range(v, self.vmin, self.vmax) - except TypeError: - v = np.nan - else: - i, j = index.row(), index.column() - v = self.bg_value[i, j] - return self.bg_gradient[v] - # elif role == Qt.ToolTipRole: - # return f"{repr(value)}\n{self.get_labels(index)}" - return None - - def get_values(self, left=0, top=0, right=None, bottom=None, sample=False): - width, height = self.total_rows, self.total_cols - if right is None: - right = width - if bottom is None: - bottom = height - values = self._data[left:right, top:bottom] - if sample: - sample_indices = get_sample_indices(values, 500) - # we need to keep the dtype, otherwise numpy might convert mixed object arrays to strings - return np.array([values[i, j] for i, j in zip(*sample_indices)], dtype=values.dtype) - else: - return values - + def __init__(self, parent=None, adapter=None): + # readonly=False, format="%.3f", font=None): + AbstractArrayModel.__init__(self, parent, adapter) + default_font = get_default_font() + self.role_defaults = { + Qt.TextAlignmentRole: int(Qt.AlignRight | Qt.AlignVCenter), + Qt.FontRole: default_font, + # Qt.ToolTipRole: + } + + def _fetch_data(self, h_start, v_start, h_stop, v_stop): + return self.adapter.get_data_values_and_attributes(h_start, v_start, + h_stop, v_stop) + + # TODO: use ast.literal_eval instead of convert_value? + # TODO: do this in the adapter def convert_value(self, value): """ Parameters ---------- value : str """ - dtype = self._data.dtype + # TODO: this assumes the adapter sends us a numpy array. Is it + # in the contract? I thought other sequence were accepted too? + dtype = self.raw_data['values'].dtype if dtype.name == "bool": try: return bool(float(value)) @@ -458,20 +850,32 @@ def convert_value(self, value): return value.lower() == "true" elif dtype.name.startswith("string") or dtype.name.startswith("unicode"): return str(value) - elif is_float(dtype): + elif is_float_dtype(dtype): return float(value) - elif is_number(dtype): + elif is_number_dtype(dtype): return int(value) else: return complex(value) def convert_values(self, values): + # TODO: do not do this, as it might change the dtype along the way. For example: + # >>> print(np.asarray([1, 3.0, "toto"])) + # ['1' '3.0' 'toto'] values = np.asarray(values) - res = np.empty_like(values, dtype=self._data.dtype) + # FIXME: for some adapters, we cannot rely on having a single dtype + # the dtype could be per-column, per-row, per-cell, or even, for some adapters + # (e.g. list), not fixed/changeable dynamically + # => we need to ask the adapter for the dtype + # => we need to know *here* which cells are impacted + # TODO: maybe ask the adapter to convert_values instead (there should be some base + # functionality in the parent class though) + dtype = self.raw_data['values'].dtype + res = np.empty_like(values, dtype=dtype) try: # TODO: use array/vectorized conversion functions (but watch out # for bool) # new_data = str_array.astype(data.dtype) + # TODO: do this in two steps. Get convertion_func for the dtype then call it for i, v in enumerate(values.flat): res.flat[i] = self.convert_value(v) except ValueError as e: @@ -482,32 +886,33 @@ def convert_values(self, values): return None return res - # TODO: I wonder if set_values should not actually change the data. In that case, ArrayEdtiorWidget.paste - # and DataArrayModel.setData should call another method "queueValueChange" or something like that. In any case - # it must be absolutely clear from either the method name, an argument (eg. update_data=False) or from the - # class name that the data is not changed directly. - # I am also unsure how this all thing will interect with the big adapter/model refactor in the buffer branch. - def set_values(self, left, top, right, bottom, values): + # TODO: I wonder if set_values should not actually change the data (but QUndoCommand would make this weird). + # If we do this, we might need ArrayEdtiorWidget.paste and DataArrayModel.setData to call another method + # "queueValueChange" or something like that. + # In any case it must be absolutely clear from either the method name, an argument (eg. update_data=False) + # or from the class name that the data is not changed directly. + def set_values(self, top, left, bottom, right, values): """ - This does NOT actually change any data directly. It will emit a signal that the data was changed, - which is intercepted by the undo-redo system which creates a command to change the values, execute it and - call .reset() on this model, which fetches and displays the new data. It is apparently NOT possible to add a - QUndoCommand onto the QUndoStack without executing it. + This does NOT actually change any data directly. It will emit a signal + that the data was changed, which is intercepted by the undo-redo system + which creates a command to change the values, execute it and + call .reset() on this model, which fetches and displays the new data. - To add to the strangeness, this method updates self.vmin and self.vmax immediately, which leads to very odd - results (the color is updated but not the value) if one forgets to connect the newChanges signal to the - undo-redo system. + It is apparently NOT possible to add a QUndoCommand onto the QUndoStack + without executing it. Parameters ---------- - left : int top : int - right : int - exclusive + in global filtered coordinates + (already includes v_offset but is not filter-aware) + left : int bottom : int exclusive + right : int + exclusive values : ndarray - must not be of the correct type + may be of incorrect type Returns ------- @@ -519,60 +924,57 @@ def set_values(self, left, top, right, bottom, values): if values is None: return values = np.atleast_2d(values) - vshape = values.shape - vwidth, vheight = vshape - width, height = right - left, bottom - top - assert vwidth == 1 or vwidth == width - assert vheight == 1 or vheight == height - - # Add change to self.changes - # requires numpy 1.10 + values_height, values_width = values.shape + selection_height, selection_width = bottom - top, right - left + # paste should make sure this is the case + assert values_height == 1 or values_height == selection_height + assert values_width == 1 or values_width == selection_width + + # convert to local coordinates + local_top = top - self.v_offset + local_left = left - self.h_offset + local_bottom = bottom - self.v_offset + local_right = right - self.h_offset + assert (local_top >= 0 and local_bottom >= 0 and + local_left >= 0 and local_right >= 0) + + # compute changes dict changes = {} - newvalues = np.broadcast_to(values, (width, height)) - oldvalues = np.empty_like(newvalues) - for i in range(width): - for j in range(height): - pos = left + i, top + j - old_value = self._data[pos] - oldvalues[i, j] = old_value - new_value = newvalues[i, j] + # requires numpy 1.10 + new_values = np.broadcast_to(values, (selection_height, selection_width)) + old_values = self.raw_data['values'] + for j in range(selection_height): + for i in range(selection_width): + old_value = old_values[local_top + j, local_left + i] + new_value = new_values[j, i] if new_value != old_value: - changes[pos] = (old_value, new_value) - - # Update vmin/vmax if necessary - if self.vmin is not None and self.vmax is not None: - # FIXME: -inf/+inf and non-number values should be ignored here too - colorval = self.color_func(values) if self.color_func is not None else values - old_colorval = self.color_func(oldvalues) if self.color_func is not None else oldvalues - # we need to lower vmax or increase vmin - if np.any(((old_colorval == self.vmax) & (colorval < self.vmax)) | - ((old_colorval == self.vmin) & (colorval > self.vmin))): - self.reset_minmax() - self.reset() - # this is faster, when the condition is False (which should be most of the cases) than computing - # subset_max and checking if subset_max > self.vmax - if np.any(colorval > self.vmax): - self.vmax = float(np.nanmax(colorval)) - self.reset() - if np.any(colorval < self.vmin): - self.vmin = float(np.nanmin(colorval)) - self.reset() - - # DataArrayModel should have a reference to an adapter? + changes[top + j, left + i] = (old_value, new_value) + if len(changes) > 0: + # changes take into account the viewport/offsets but not the filter + # the array widget will use the adapter to translate those changes + # to global changes then push them to the undo/redo stack, which + # will execute them and that will actually modify the array self.newChanges.emit(changes) - # XXX: I wonder if emitting dataChanged makes any sense since data has not actually changed! - top_left = self.index(left, top) + top_left = self.index(local_top, local_left) # -1 because Qt index end bounds are inclusive - bottom_right = self.index(right - 1, bottom - 1) + bottom_right = self.index(local_bottom - 1, local_right - 1) + + # emitting dataChanged only makes sense because a signal .emit call only returns when all its + # slots have executed, so the newChanges signal emitted above has already triggered the whole + # chain of code which effectively changes the data self.dataChanged.emit(top_left, bottom_right) return top_left, bottom_right def setData(self, index, value, role=Qt.EditRole): - """Cell content change""" - if not index.isValid() or self.readonly: + """Cell content change + index is in local 2D coordinates + """ + if not index.isValid(): return False - i, j = index.row(), index.column() - result = self.set_values(i, j, i + 1, j + 1, value) + row, col = index.row(), index.column() + row += self.v_offset + col += self.h_offset + result = self.set_values(row, col, row + 1, col + 1, value) return result is not None diff --git a/larray_editor/arraywidget.py b/larray_editor/arraywidget.py index 0376b7e1..d87d70ae 100644 --- a/larray_editor/arraywidget.py +++ b/larray_editor/arraywidget.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # # Copyright © 2009-2012 Pierre Raybaut -# Copyright © 2015-2016 Gaëtan de Menten +# Copyright © 2015-2025 Gaëtan de Menten # Licensed under the terms of the MIT License # based on @@ -19,23 +19,16 @@ # Note that the canonical way to implement filters in a TableView would # be to use a QSortFilterProxyModel. In this case, we would need to reimplement # its filterAcceptsColumn and filterAcceptsRow methods. The problem is that -# it does seem to be really designed for very large arrays and it would +# it does not seem to be really designed for very large arrays and it would # probably be too slow on those (I have read quite a few people complaining # about speed issues with those) possibly because it suppose you have the whole # array in your model. It would also probably not play well with the # partial/progressive load we have currently implemented. # TODO: -# * drag & drop to reorder axes -# http://zetcode.com/gui/pyqt4/dragdrop/ -# http://stackoverflow.com/questions/10264040/ -# how-to-drag-and-drop-into-a-qtablewidget-pyqt -# http://stackoverflow.com/questions/3458542/multiple-drag-and-drop-in-pyqt4 +# * make it more obvious one can drag & drop axes names to reorder axes # http://ux.stackexchange.com/questions/34158/ # how-to-make-it-obvious-that-you-can-drag-things-that-you-normally-cant -# * keep header columns & rows visible ("frozen") -# http://doc.qt.io/qt-5/qtwidgets-itemviews-frozencolumn-example.html -# * document default icons situation (limitations) # * document paint speed experiments # * filter on headers. In fact this is not a good idea, because that prevents # selecting whole columns, which is handy. So a separate row for headers, @@ -53,9 +46,6 @@ # => different format per column, which is problematic UI-wise # * keyboard shortcut for filter each dim # * tab in a filter combo, brings up next filter combo -# * view/edit DataFrames too -# * view/edit LArray over Pandas (ie sparse) -# * resubmit editor back for inclusion in Spyder # ? custom delegates for each type (spinner for int, checkbox for bool, ...) # ? "light" headers (do not repeat the same header several times (on the screen) # it would be nicer but I am not sure it is a good idea because with many @@ -69,44 +59,339 @@ # worth it for a while. import math -import logging +from pathlib import Path import numpy as np -from qtpy.QtCore import Qt, QPoint, QItemSelection, QItemSelectionModel, Signal, QSize -from qtpy.QtGui import QDoubleValidator, QIntValidator, QKeySequence, QFontMetrics, QCursor, QPixmap, QPainter, QIcon +from qtpy import QtCore +from qtpy.QtCore import (Qt, QPoint, QItemSelection, QItemSelectionModel, + Signal, QSize, QModelIndex, QTimer) +from qtpy.QtGui import (QDoubleValidator, QIntValidator, QKeySequence, QFontMetrics, QCursor, QPixmap, QPainter, QIcon, + QWheelEvent, QMouseEvent) from qtpy.QtWidgets import (QApplication, QTableView, QItemDelegate, QLineEdit, QCheckBox, QMessageBox, QMenu, QLabel, QSpinBox, QWidget, QToolTip, QShortcut, QScrollBar, - QHBoxLayout, QVBoxLayout, QGridLayout, QSizePolicy, QFrame, QComboBox) + QHBoxLayout, QVBoxLayout, QGridLayout, QSizePolicy, QFrame, QComboBox, + QStyleOptionViewItem, QPushButton) -from larray_editor.utils import (keybinding, create_action, clear_layout, get_default_font, is_number, is_float, _, - ima, LinearGradient, logger, cached_property) -from larray_editor.arrayadapter import get_adapter -from larray_editor.arraymodel import LabelsArrayModel, AxesArrayModel, DataArrayModel -from larray_editor.combo import FilterComboBox +from larray_editor.utils import (keybinding, create_action, clear_layout, get_default_font, + is_number_dtype, is_float_dtype, _, + LinearGradient, logger, cached_property, data_frac_digits, + num_int_digits) +from larray_editor.arrayadapter import (get_adapter, get_adapter_creator, + AbstractAdapter, MAX_FILTER_OPTIONS) +from larray_editor.arraymodel import (HLabelsArrayModel, VLabelsArrayModel, LabelsArrayModel, + AxesArrayModel, DataArrayModel) +from larray_editor.combo import FilterComboBox, CombinedSortFilterMenu + +MORE_OPTIONS_NOT_SHOWN = "" + +# mime-type we use when drag and dropping axes (x- prefix is for unregistered +# types) +LARRAY_AXIS_INDEX_DRAG_AND_DROP_MIMETYPE = "application/x-larray-axis-index" + + +def display_selection(selection: QtCore.QItemSelection): + return ', '.join(f"<{idx.row()}, {idx.column()}>" for idx in selection.indexes()) + + +def clip(value, minimum, maximum): + if value < minimum: + return minimum + elif value > maximum: + return maximum + else: + return value # XXX: define Enum instead ? TOP, BOTTOM = 0, 1 LEFT, RIGHT = 0, 1 +MIN_COLUMN_WIDTH = 30 +MAX_COLUMN_WIDTH = 800 DEFAULT_COLUMN_WIDTH = 64 DEFAULT_ROW_HEIGHT = 20 + +class FilterBar(QWidget): + def __init__(self, array_widget): + super().__init__() + # we are not passing array_widget as parent for QHBoxLayout because + # we could have the filterbar outside the widget + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + self.array_widget = array_widget + + # See https://www.pythonguis.com/faq/pyqt-drag-drop-widgets/ + # and https://zetcode.com/pyqt6/dragdrop/ + self.setAcceptDrops(True) + self.drag_label = None + self.drag_start_pos = None + + def reset_to_defaults(self): + layout = self.layout() + clear_layout(layout) + data_adapter = self.array_widget.data_adapter + if data_adapter is None: + return + assert isinstance(data_adapter, AbstractAdapter), \ + f"unexpected data_adapter type: {type(data_adapter)}" + filter_names = data_adapter.get_filter_names() + # size > 0 to avoid arrays with length 0 axes and len(axes) > 0 to avoid scalars (scalar.size == 1) + if filter_names: #self.data_adapter.size > 0 and len(filters) > 0: + layout.addWidget(QLabel(_("Filters"))) + for filter_idx, filter_name in enumerate(filter_names): + layout.addWidget(QLabel(filter_name)) + filter_labels = data_adapter.get_filter_options(filter_idx) + # FIXME: on very large axes, this is getting too slow. Ideally the combobox should use a model which + # only fetch labels when they are needed to be displayed + # this needs a whole new widget though + if len(filter_labels) < 10000: + layout.addWidget(self.create_filter_combo(filter_idx, filter_labels)) + else: + layout.addWidget(QLabel("too big to be filtered")) + layout.addStretch() + + def create_filter_combo(self, filter_idx, filter_labels): + def filter_changed(checked_items): + self.change_filter(filter_idx, checked_items) + + combo = FilterComboBox(self) + combo.addItems([str(label) for label in filter_labels]) + combo.checked_items_changed.connect(filter_changed) + return combo + + def change_filter(self, filter_idx, indices): + logger.debug(f"FilterBar.change_filter({filter_idx}, {indices})") + # FIXME: the method can be called from the outside, and in that case + # the combos checked items need be synchronized too + array_widget = self.array_widget + data_adapter = array_widget.data_adapter + vscrollbar: ScrollBar = array_widget.vscrollbar + hscrollbar: ScrollBar = array_widget.hscrollbar + old_v_pos = vscrollbar.value() + old_h_pos = hscrollbar.value() + old_nrows, old_ncols = data_adapter.shape2d() + data_adapter.update_filter(filter_idx, indices) + data_adapter._current_sort = [] + # TODO: this does too much work (it sets the adapters even + # if those do not change and sets v_offset/h_offset to 0 when we + # do not *always* want to do so) and maybe too little + # (update_range should probably be done elsewhere) + # this also reset() each model. + # For DataArrayModel it causes an extra (compared to the one + # below) update_range (via the modelReset signal) + array_widget._set_models_adapter() + new_nrows, new_ncols = data_adapter.shape2d() + hscrollbar.update_range() + vscrollbar.update_range() + array_widget.update_cell_sizes_from_content() + if old_v_pos == 0 and old_h_pos == 0: + # if the old values were already 0, visible_v/hscroll_changed will + # not be triggered and update_*_column_widths has no chance to run + # unless we call them explicitly + assert isinstance(array_widget, ArrayEditorWidget) + array_widget.update_cell_sizes_from_content() + else: + # TODO: would be nice to implement some clever positioning algorithm + # here when new_X != old_X so that the visible rows stay visible. + # Currently, this does not change the scrollbar value at all if + # the old value fits in the new range. When changing from one + # specific label to another of an larray, this does not change + # the shape of the result and is thus what we want but there + # are cases where we could do better. + # TODO: the setValue(0) should not be necessary in the case of + # new_nrows == old_nrows but it is currently because + # v/h_offset is set to 0 by the call to _set_models_adapter + # above but the scrollbar values do not change, so + # setValue(old_v_pos) does not trigger a valueChanged signal, + # and thus the v/h_offset is not set back to its old value + # if we don't first change the scrollbar values + vscrollbar.setValue(0) + hscrollbar.setValue(0) + # if the old value was already at 0, we do not need to set it again + if new_nrows == old_nrows and old_v_pos != 0: + vscrollbar.setValue(old_v_pos) + if new_ncols == old_ncols and old_h_pos != 0: + hscrollbar.setValue(old_h_pos) + + # Check for left button mouse press events on axis labels + def mousePressEvent(self, event): + if event.button() != Qt.LeftButton: + return + click_pos = event.pos() + child = self.childAt(click_pos) + assert self.drag_label is None + if isinstance(child, QLabel) and child.text() != "Filters": + self.drag_label = child + self.drag_start_pos = click_pos + + # If we release the left button before we moved the mouse enough to + # trigger the "real" dragging sequence (see mouveMoveEvent), we need to + # forget the drag_label and drag_start_pos + def mouseReleaseEvent(self, event): + if event.button() != Qt.LeftButton: + return + self.drag_label = None + self.drag_start_pos = None + + # Mouse move events will occur only when a mouse button is pressed down, + # unless mouse tracking has been enabled with QWidget.setMouseTracking() + def mouseMoveEvent(self, event): + # We did not click on an axis label yet + drag_label = self.drag_label + if drag_label is None: + return + + # We do not check the event button. The left button should still be + # pressed but event.button() will always be NoButton: "If the event type + # is MouseMove, the appropriate button for this event is Qt::NoButton" + + # We are too close to where we initially clicked + drag_delta = event.pos() - self.drag_start_pos + if drag_delta.manhattanLength() < QApplication.startDragDistance(): + return + + from qtpy.QtCore import QMimeData, QByteArray + from qtpy.QtGui import QDrag + + axis_index = self.layout().indexOf(drag_label) // 2 + + mimeData = QMimeData() + mimeData.setData(LARRAY_AXIS_INDEX_DRAG_AND_DROP_MIMETYPE, + QByteArray.number(axis_index)) + pixmap = QPixmap(drag_label.size()) + drag_label.render(pixmap) + + # We will initiate a real dragging sequence, we don't need these anymore + self.drag_label = None + self.drag_start_pos = None + + drag = QDrag(self) + drag.setMimeData(mimeData) + drag.setPixmap(pixmap) + drag.setHotSpot(drag_delta) + drag.exec_(Qt.MoveAction) + + # Tell whether the filter bar is an acceptable target for a particular + # dragging event (which could come from another app) + def dragEnterEvent(self, event): + if event.mimeData().hasFormat(LARRAY_AXIS_INDEX_DRAG_AND_DROP_MIMETYPE): + event.setDropAction(Qt.MoveAction) + event.accept() + else: + event.ignore() + + # Inside the filter bar, inform Qt whether some particular position + # is a good final target or not + def dragMoveEvent(self, event): + if event.mimeData().hasFormat(LARRAY_AXIS_INDEX_DRAG_AND_DROP_MIMETYPE): + child = self.childAt(event.pos()) + if isinstance(child, QLabel) and child.text() != "Filters": + event.setDropAction(Qt.MoveAction) + event.accept() + else: + event.ignore() + else: + event.ignore() + + # If the user dropped on a valid target, we need to handle the event + def dropEvent(self, event): + mime_data = event.mimeData() + if mime_data.hasFormat(LARRAY_AXIS_INDEX_DRAG_AND_DROP_MIMETYPE): + old_index_byte_array = mime_data.data(LARRAY_AXIS_INDEX_DRAG_AND_DROP_MIMETYPE) + old_index, success = old_index_byte_array.toInt() + child = self.childAt(event.pos()) + new_index = self.layout().indexOf(child) // 2 + data_adapter = self.array_widget.data_adapter + data, attributes = data_adapter.move_axis(data_adapter.data, + data_adapter.attributes, + old_index, + new_index) + self.array_widget.set_data(data, attributes) + event.setDropAction(Qt.MoveAction) + event.accept() + else: + event.ignore() + + +class BackButtonBar(QWidget): + def __init__(self, array_widget): + super().__init__() + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + button = QPushButton('Back') + button.clicked.connect(self.on_clicked) + layout.addWidget(button) + self.array_widget = array_widget + self._back_data = [] + self._back_data_adapters = [] + layout.addStretch() + self.hide() + + def add_back(self, data, data_adapter): + self._back_data.append(data) + # We need to keep the data_adapter around because some resource + # are created in the adapter (e.g. duckdb connection when viewing a + # .ddb file) and if the adapter is garbage collected, the resource + # is deleted (e.g. the duckdb connection dies - contrary to other libs, + # a duckdb table object does not keep the connection alive) + self._back_data_adapters.append(data_adapter) + if not self.isVisible(): + self.show() + + def clear(self): + for adapter in self._back_data_adapters[::-1]: + self._close_adapter(adapter) + + self._back_data_adapters = [] + self._back_data = [] + self.hide() + + @staticmethod + def _close_adapter(adapter): + clsname = type(adapter).__name__ + logger.debug(f"closing data adapter ({clsname})") + adapter.close() + + def on_clicked(self): + if not len(self._back_data): + logger.warn("Back button has no target to go to") + return + target_data = self._back_data.pop() + data_adapter = self._back_data_adapters.pop() + if not len(self._back_data): + self.hide() + array_widget: ArrayEditorWidget = self.array_widget + # We are not using array_widget.set_data(target_data) so that we can + # reuse the same data_adapter instead of recreating a new one + array_widget.data = target_data + array_widget.set_data_adapter(data_adapter, frac_digits=None) + + class AbstractView(QTableView): """Abstract view class""" def __init__(self, parent, model, hpos, vpos): + assert isinstance(parent, ArrayEditorWidget) QTableView.__init__(self, parent) # set model self.setModel(model) # set position - if not (hpos == LEFT or hpos == RIGHT): + if hpos not in {LEFT, RIGHT}: raise TypeError(f"Value of hpos must be {LEFT} or {RIGHT}") self.hpos = hpos - if not (vpos == TOP or vpos == BOTTOM): + if vpos not in {TOP, BOTTOM}: raise TypeError(f"Value of vpos must be {TOP} or {BOTTOM}") self.vpos = vpos + self.first_selection_corner = None + # handling a second selection corner is necessary to implement the + # "select entire row/column" functionality because in that case the + # second corner is not necessarily in the viewport, but it is a real + # cell (i.e. the coordinates are inclusive) + self.second_selection_corner = None # set selection mode self.setSelectionMode(QTableView.ContiguousSelection) @@ -121,9 +406,15 @@ def __init__(self, parent, model, hpos, vpos): if vpos == BOTTOM: self.horizontalHeader().hide() - # to fetch more rows/columns when required - self.horizontalScrollBar().valueChanged.connect(self.on_horizontal_scroll_changed) - self.verticalScrollBar().valueChanged.connect(self.on_vertical_scroll_changed) + # XXX: this might help if we want the widget to be focusable using "tab" + # self.setFocusPolicy(Qt.StrongFocus) + + # These 4 lines are only useful for debugging + # hscrollbar = self.horizontalScrollBar() + # hscrollbar.valueChanged.connect(self.on_horizontal_scroll_changed) + # vscrollbar = self.verticalScrollBar() + # vscrollbar.valueChanged.connect(self.on_vertical_scroll_changed) + # Hide scrollbars self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) @@ -134,50 +425,269 @@ def __init__(self, parent, model, hpos, vpos): self.horizontalHeader().sectionResized.connect(self.updateGeometry) self.verticalHeader().sectionResized.connect(self.updateGeometry) + # def on_vertical_scroll_changed(self, value): + # log_caller() + # print(f"hidden vscroll on {self.__class__.__name__} changed to {value}") + + # def on_horizontal_scroll_changed(self, value): + # log_caller() + # print(f"hidden hscroll on {self.__class__.__name__} changed to {value}") + + # def selectionChanged(self, selected: QtCore.QItemSelection, deselected: QtCore.QItemSelection) -> None: + # super().selectionChanged(selected, deselected) + # print(f"selectionChanged:\n" + # f" -> selected({display_selection(selected)}),\n" + # f" -> deselected({display_selection(deselected)})") + + def reset_to_defaults(self): + """ + reset widget to initial state (when the ArrayEditorWidget is switching + from one object to another) + """ + self.set_default_size() + self.first_selection_corner = None + self.second_selection_corner = None + def set_default_size(self): + # logger.debug(f"{self.__class__.__name__}.set_default_size()") + # make the grid a bit more compact - self.horizontalHeader().setDefaultSectionSize(DEFAULT_COLUMN_WIDTH) + horizontal_header = self.horizontalHeader() + horizontal_header.blockSignals(True) + horizontal_header.setDefaultSectionSize(DEFAULT_COLUMN_WIDTH) + + if horizontal_header.sectionSize(0) != DEFAULT_COLUMN_WIDTH: + # Explicitly set all columns to the default width to override any + # custom sizes + for col in range(self.model().columnCount()): + self.setColumnWidth(col, DEFAULT_COLUMN_WIDTH) + horizontal_header.blockSignals(False) + self.verticalHeader().setDefaultSectionSize(DEFAULT_ROW_HEIGHT) if self.vpos == TOP: - self.horizontalHeader().setFixedHeight(10) + horizontal_header.setFixedHeight(10) if self.hpos == LEFT: self.verticalHeader().setFixedWidth(10) - def on_vertical_scroll_changed(self, value): - if value == self.verticalScrollBar().maximum(): - self.model().fetch_more_rows() + # We need to have this here (in AbstractView) and not only on DataView, so that we + # catch them for vlabels too. For axes and hlabels, it is a bit of a weird + # behavior since they are not affected themselves but that is really a nitpick + # Also, overriding the general event() method for this does not work as it is + # handled behind us (by the ScrollArea I assume) and we do not even see the event + # unless we are at the buffer boundary. + def wheelEvent(self, event: QWheelEvent): + """Catch wheel events and send them to the corresponding visible scrollbar""" + delta = event.angleDelta() + logger.debug(f"wheelEvent on {self.__class__.__name__} ({delta})") + editor_widget = self.parent().parent() + if delta.x() != 0: + editor_widget.hscrollbar.wheelEvent(event) + if delta.y() != 0: + editor_widget.vscrollbar.wheelEvent(event) + event.accept() - def on_horizontal_scroll_changed(self, value): - if value == self.horizontalScrollBar().maximum(): - self.model().fetch_more_columns() + def keyPressEvent(self, event): + key = event.key() + if key in {Qt.Key_Home, Qt.Key_End, Qt.Key_Up, Qt.Key_Down, Qt.Key_Left, Qt.Key_Right, + Qt.Key_PageUp, Qt.Key_PageDown}: + event.accept() + self.navigate_key_event(event) + else: + QTableView.keyPressEvent(self, event) - def updateSectionHeight(self, logicalIndex, oldSize, newSize): - self.setRowHeight(logicalIndex, newSize) + def navigate_key_event(self, event): + logger.debug("") + logger.debug("navigate") + logger.debug("========") + model = self.model() + widget = self.parent().parent() + assert isinstance(widget, ArrayEditorWidget) + + event_modifiers = event.modifiers() + event_key = event.key() + if event_modifiers & Qt.ShiftModifier: + # remove shift from modifiers so the Ctrl+Key combos are still detected + event_modifiers ^= Qt.ShiftModifier + shift = True + else: + shift = False - def updateSectionWidth(self, logicalIndex, oldSize, newSize): - self.setColumnWidth(logicalIndex, newSize) + try: + # qt6 + modifiers_value = event_modifiers.value + except AttributeError: + # qt5 + modifiers_value = event_modifiers + keyseq = QKeySequence(modifiers_value | event_key) + page_step = self.verticalScrollBar().pageStep() + cursor_global_pos = self.get_cursor_global_pos() + if cursor_global_pos is None: + cursor_global_v_pos, cursor_global_h_pos = 0, 0 + logger.debug("No previous cursor position: using 0, 0") + else: + cursor_global_v_pos, cursor_global_h_pos = cursor_global_pos + logger.debug(f"old global cursor {cursor_global_v_pos} {cursor_global_h_pos}") + + # TODO: for some adapter shape2 is not reliable (it is a best guess), we should make sure we gracefully handle + # wrong info + total_v_size, total_h_size = model.adapter.shape2d() + key2delta = { + Qt.Key_Home: (0, -cursor_global_h_pos), + Qt.Key_End: (0, total_h_size - cursor_global_h_pos - 1), + Qt.Key_Up: (-1, 0), + Qt.Key_Down: (1, 0), + Qt.Key_Left: (0, -1), + Qt.Key_Right: (0, 1), + Qt.Key_PageUp: (-page_step, 0), + Qt.Key_PageDown: (page_step, 0), + } + + # Ctrl+arrow does not mean anything by default, so dispatching does not help + # TODO: use another dict for this. dict[keyseq] does not work even if keyseq == key works. + # Using a different dict and checking the modifier explicitly should work. + # Or maybe getting the string representation of the keyseq is possible too. + # TODO: it might be simpler to set the cursor_global_pos values directly rather than using delta + if keyseq == "Ctrl+Home": + v_delta, h_delta = (-cursor_global_v_pos, -cursor_global_h_pos) + elif keyseq == "Ctrl+End": + v_delta, h_delta = (total_v_size - cursor_global_v_pos - 1, total_h_size - cursor_global_h_pos - 1) + elif keyseq == "Ctrl+Left": + v_delta, h_delta = (0, -cursor_global_h_pos) + elif keyseq == "Ctrl+Right": + v_delta, h_delta = (0, total_h_size - cursor_global_h_pos - 1) + elif keyseq == "Ctrl+Up": + v_delta, h_delta = (-cursor_global_v_pos, 0) + elif keyseq == "Ctrl+Down": + v_delta, h_delta = (total_v_size - cursor_global_v_pos - 1, 0) + else: + v_delta, h_delta = key2delta[event_key] + + # TODO: internal scroll => change value of visible scrollbar (or avoid internal scroll) + cursor_new_global_v_pos = clip(cursor_global_v_pos + v_delta, 0, total_v_size - 1) + cursor_new_global_h_pos = clip(cursor_global_h_pos + h_delta, 0, total_h_size - 1) + logger.debug(f"new global cursor {cursor_new_global_v_pos} {cursor_new_global_h_pos}") + + self.scroll_to_global_pos(cursor_new_global_v_pos, cursor_new_global_h_pos) + + new_v_posinbuffer = cursor_new_global_v_pos - model.v_offset + new_h_posinbuffer = cursor_new_global_h_pos - model.h_offset + + local_cursor_index = model.index(new_v_posinbuffer, new_h_posinbuffer) + if shift: + if self.first_selection_corner is None: + # This can happen when using navigation keys before + # selecting any cell using the mouse (but after getting focus + # on the widget which can be done at least by clicking inside + # the widget area but outside "valid" cells) + self.first_selection_corner = (cursor_global_v_pos, cursor_global_h_pos) + self.second_selection_corner = cursor_new_global_v_pos, cursor_new_global_h_pos + selection_v_pos1, selection_h_pos1 = self.first_selection_corner + selection_v_pos2, selection_h_pos2 = self.second_selection_corner + row_min = min(selection_v_pos1, selection_v_pos2) + row_max = max(selection_v_pos1, selection_v_pos2) + col_min = min(selection_h_pos1, selection_h_pos2) + col_max = max(selection_h_pos1, selection_h_pos2) + + selection_model = self.selectionModel() + selection_model.setCurrentIndex(local_cursor_index, QItemSelectionModel.Current) + # we need to clip local coordinates in case the selection corners are outside the viewport + local_top = max(row_min - model.v_offset, 0) + local_left = max(col_min - model.h_offset, 0) + local_bottom = min(row_max - model.v_offset, model.nrows - 1) + local_right = min(col_max - model.h_offset, model.ncols - 1) + selection = QItemSelection(model.index(local_top, local_left), + model.index(local_bottom, local_right)) + selection_model.select(selection, QItemSelectionModel.ClearAndSelect) + else: + self.first_selection_corner = cursor_new_global_v_pos, cursor_new_global_h_pos + self.second_selection_corner = cursor_new_global_v_pos, cursor_new_global_h_pos + self.setCurrentIndex(local_cursor_index) + logger.debug(f"after navigate_key_event: {self.first_selection_corner=} " + f"{self.second_selection_corner=}") + + # after we drop support for Python < 3.10, we should use: + # def get_cursor_global_pos(self) -> tuple[int, int] | None: + def get_cursor_global_pos(self): + model = self.model() + current_index = self.currentIndex() + if not current_index.isValid(): + return None + v_posinbuffer = current_index.row() + h_posinbuffer = current_index.column() + assert v_posinbuffer >= 0 + assert h_posinbuffer >= 0 + cursor_global_v_pos = model.v_offset + v_posinbuffer + cursor_global_h_pos = model.h_offset + h_posinbuffer + return cursor_global_v_pos, cursor_global_h_pos + + def scroll_to_global_pos(self, global_v_pos, global_h_pos): + """ + Change visible scrollbars value so that vpos/hpos is visible + without changing the cursor position + """ + model = self.model() + widget = self.parent().parent() + assert isinstance(widget, ArrayEditorWidget) + visible_cols = widget.visible_cols() + visible_rows = widget.visible_rows() + + hidden_v_offset = self.verticalScrollBar().value() + hidden_h_offset = self.horizontalScrollBar().value() + total_v_offset = model.v_offset + hidden_v_offset + total_h_offset = model.h_offset + hidden_h_offset + + if global_v_pos < total_v_offset: + new_total_v_offset = global_v_pos + # TODO: document where those +2 come from + elif global_v_pos > total_v_offset + visible_rows - 2: + new_total_v_offset = global_v_pos - visible_rows + 2 + else: + # do not change offset + new_total_v_offset = total_v_offset + + if global_h_pos < total_h_offset: + new_total_h_offset = global_h_pos + elif global_h_pos > total_h_offset + visible_cols - 2: + new_total_h_offset = global_h_pos - visible_cols + 2 + else: + # do not change offset + new_total_h_offset = total_h_offset + + # change visible scrollbars value + widget.vscrollbar.setValue(new_total_v_offset) + widget.hscrollbar.setValue(new_total_h_offset) def autofit_columns(self): """Resize cells to contents""" + # print(f"{self.__class__.__name__}.autofit_columns()") QApplication.setOverrideCursor(QCursor(Qt.WaitCursor)) - # Spyder loads more columns before resizing, but since it does not - # load all columns anyway, I do not see the point - # self.model().fetch_more_columns() + # for column in range(self.model_axes.columnCount()): + # self.resize_axes_column_to_contents(column) + self.resizeColumnsToContents() + # If the resized columns would make the whole view smaller or larger, + # the view size itself (not its columns) is changed. This allows, + # for example, other views (e.g. hlabels) to be moved accordingly. + self.updateGeometry() QApplication.restoreOverrideCursor() def updateGeometry(self): - # Set maximum height + # vpos = "TOP" if self.vpos == TOP else "BOTTOM" + # hpos = "LEFT" if self.hpos == LEFT else "RIGHT" + # print(f"{self.__class__.__name__}.updateGeometry() ({vpos=}, {hpos=})") + # Set total height (for the whole view, not a particular row) if self.vpos == TOP: - maximum_height = self.horizontalHeader().height() + \ + total_height = self.horizontalHeader().height() + \ sum(self.rowHeight(r) for r in range(self.model().rowCount())) - self.setFixedHeight(maximum_height) - # Set maximum width + # print(f" TOP => {total_height=}") + self.setFixedHeight(total_height) + # Set total width (for the whole view, not a particular column) if self.hpos == LEFT: - maximum_width = self.verticalHeader().width() + \ + total_width = self.verticalHeader().width() + \ sum(self.columnWidth(c) for c in range(self.model().columnCount())) - self.setFixedWidth(maximum_width) + # print(f" LEFT => {total_width=}") + self.setFixedWidth(total_width) # update geometry super().updateGeometry() @@ -194,6 +704,76 @@ def __init__(self, parent, model): f"Received {type(model).__name__} instead") AbstractView.__init__(self, parent, model, LEFT, TOP) + # FIXME: only have this if the adapter supports any extra action on axes + # self.clicked.connect(self.on_clicked) + + def on_clicked(self, index: QModelIndex): + row_idx = index.row() + column_idx = index.column() + + # FIXME: column_idx works fine for the unfiltered/initial array but on an already filtered + # array it breaks because column_idx is the idx of the *filtered* array which can + # contain less axes while change_filter (via create_filter_menu) want the index + # of the *unfiltered* array + try: + adapter = self.model().adapter + filtrable = adapter.can_filter_axis(column_idx) + sortable = adapter.can_sort_axis_labels(column_idx) + if sortable: + sort_direction = adapter.axis_sort_direction(column_idx) + else: + sort_direction = 'unsorted' + filter_labels = adapter.get_filter_options(column_idx) + except IndexError: + filtrable = False + filter_labels = [] + sortable = False + sort_direction = 'unsorted' + if filtrable or sortable: + menu = self.create_filter_menu(column_idx, + filtrable, + filter_labels, + sortable, + sort_direction) + x = (self.columnViewportPosition(column_idx) + + self.verticalHeader().width()) + y = (self.rowViewportPosition(row_idx) + self.rowHeight(row_idx) + + self.horizontalHeader().height()) + menu.exec_(self.mapToGlobal(QPoint(x, y))) + + def create_filter_menu(self, + axis_idx, + filtrable, + filter_labels, + sortable=False, + sort_direction='unsorted'): + def filter_changed(checked_items): + # print("filter_changed", axis_idx, checked_items) + array_widget = self.parent().parent() + array_widget.filter_bar.change_filter(axis_idx, checked_items) + + def sort_changed(ascending): + array_widget = self.parent().parent() + array_widget.sort_axis_labels(axis_idx, ascending) + + menu = CombinedSortFilterMenu(self, + filtrable=filtrable, + sortable=sortable, + sort_direction=sort_direction) + if filtrable: + menu.addItems([str(label) for label in filter_labels]) + menu.checked_items_changed.connect(filter_changed) + if sortable: + menu.sort_signal.connect(sort_changed) + return menu + + # override viewOptions so that cell decorations (ie axes names arrows) are + # drawn to the right of cells instead of to the left + def viewOptions(self): + option = QTableView.viewOptions(self) + option.decorationPosition = QStyleOptionViewItem.Right + return option + def selectAll(self): self.allSelected.emit() @@ -208,46 +788,161 @@ def __init__(self, parent, model, hpos, vpos): f"Received {type(model).__name__} instead") AbstractView.__init__(self, parent, model, hpos, vpos) + # FIXME: only have this if the adapter supports any extra action on axes + if self.vpos == TOP: + self.clicked.connect(self.on_clicked) + + def on_clicked(self, index: QModelIndex): + if not index.isValid(): + return + + row_idx = index.row() + local_col_idx = index.column() + model: LabelsArrayModel = self.model() + global_col_idx = model.h_offset + local_col_idx + + assert self.vpos == TOP + + # FIXME: global_col_idx works fine for the unfiltered/initial array but on + # an already filtered array it breaks because global_col_idx is the + # idx of the *filtered* array which can contain less axes while + # change_filter (via create_filter_menu) want the index of the + # *unfiltered* array + adapter = model.adapter + filtrable = adapter.can_filter_hlabel(1, global_col_idx) + sortable = adapter.can_sort_hlabel(row_idx, global_col_idx) + if sortable: + sort_direction = adapter.hlabel_sort_direction(row_idx, global_col_idx) + def sort_changed(ascending): + # TODO: the chain for this is kinda convoluted: + # local signal handler + # -> ArrayWidget method + # -> adapter method+model reset + array_widget = self.parent().parent() + array_widget.sort_hlabel(row_idx, global_col_idx, ascending) + else: + sort_direction = 'unsorted' + sort_changed = None + + if filtrable: + filter_labels = adapter.get_filter_options(global_col_idx) + if len(filter_labels) == MAX_FILTER_OPTIONS: + filter_labels = filter_labels.tolist() + filter_labels[-1] = MORE_OPTIONS_NOT_SHOWN + filter_indices = adapter.get_current_filter_indices(global_col_idx) + def filter_changed(checked_items): + # TODO: the chain for this is kinda convoluted: + # local signal handler (this function) + # -> ArrayWidget method + # -> adapter method+model reset + array_widget = self.parent().parent() + assert isinstance(array_widget, ArrayEditorWidget) + array_widget.filter_bar.change_filter(global_col_idx, checked_items) + else: + filter_labels = [] + filter_changed = None + filter_indices = None + + if filtrable or sortable: + # because of the local vs global idx, we cannot cache/reuse the + # filter menu widget (we would need to remove the items and readd + # the correct ones) so it is easier to just recreate the whole + # widget. We need to take the already ticked indices into account + # though. + menu = self.create_filter_menu(global_col_idx, + filter_labels, + filter_indices, + filter_changed, + sort_changed, + sort_direction) + x = (self.columnViewportPosition(local_col_idx) + + self.verticalHeader().width()) + y = (self.rowViewportPosition(row_idx) + self.rowHeight(row_idx) + + self.horizontalHeader().height()) + menu.exec_(self.mapToGlobal(QPoint(x, y))) + + def create_filter_menu(self, + filter_idx, + filter_labels, + filter_indices, + filter_changed, + sort_changed, + sort_direction): + filtrable = filter_changed is not None + sortable = sort_changed is not None + menu = CombinedSortFilterMenu(self, + filtrable=filtrable, + sortable=sortable, + sort_direction=sort_direction) + if filtrable: + menu.addItems([str(label) for label in filter_labels], + filter_indices) + # disable last item if there are too many options + if len(filter_labels) == MAX_FILTER_OPTIONS: + # this is correct (MAX - 1 to get the last item, + 1 because + # of the "Select all" item at the beginning) + last_item = menu._model[MAX_FILTER_OPTIONS - 1 + 1] + last_item.setFlags(QtCore.Qt.NoItemFlags) + + menu.checked_items_changed.connect(filter_changed) + if sortable: + menu.sort_signal.connect(sort_changed) + return menu + + # override viewOptions so that cell decorations (ie axes names arrows) are + # drawn to the right of cells instead of to the left + def viewOptions(self): + option = QTableView.viewOptions(self) + option.decorationPosition = QStyleOptionViewItem.Right + return option + class ArrayDelegate(QItemDelegate): """Array Editor Item Delegate""" - def __init__(self, dtype, parent=None, font=None, minvalue=None, maxvalue=None): + def __init__(self, parent=None, font=None, minvalue=None, maxvalue=None): + # parent is the DataView instance QItemDelegate.__init__(self, parent) - self.dtype = dtype if font is None: font = get_default_font() self.font = font self.minvalue = minvalue self.maxvalue = maxvalue - # We must keep a count instead of the "current" one, because when - # switching from one cell to the next, the new editor is created - # before the old one is destroyed, which means it would be set to None - # when the old one is destroyed. + # keep track of whether there is already at least one editor already open (to properly + # open a new editor when pressing Enter in DataView only if one is not already open) + + # We must keep a count instead of keeping a reference to the "current" one, because when switching + # from one cell to the next, the new editor is created before the old one is destroyed, which means + # it would be set to None when the old one is destroyed, instead of to the new current editor. self.editor_count = 0 def createEditor(self, parent, option, index): """Create editor widget""" model = index.model() - # TODO: dtype should be taken from the adapter instead. Only the adapter knows whether the dtype is per cell + # TODO: dtype should be asked per cell. Only the adapter knows whether the dtype is per cell # (e.g. list), per column (e.g. Dataframe) or homogenous for the whole table (e.g. la.Array) # dtype = model.adapter.get_dtype(hpos, vpos) - dtype = self.dtype + + dtype = model.adapter.dtype value = model.get_value(index) + # this will return a string ! + # value = model.data(index, Qt.DisplayRole) if dtype.name == "bool": - # toggle value - value = not value - model.setData(index, value) - return + # directly toggle value and do not actually create an editor + model.setData(index, not value) + return None elif value is not np.ma.masked: # Not using a QSpinBox for integer inputs because I could not find # a way to prevent the spinbox/editor from closing if the value is # invalid. Using the builtin minimum/maximum of the spinbox works # but that provides no message so it is less clear. editor = QLineEdit(parent) - if is_number(dtype): + if is_number_dtype(dtype): + # FIXME: get minvalue & maxvalue from somewhere... the adapter? + # or the model? another specific adapter for minvalue, + # one for maxvalue, one for bg_value, etc.? minvalue, maxvalue = self.minvalue, self.maxvalue - validator = QDoubleValidator(editor) if is_float(dtype) else QIntValidator(editor) + validator = QDoubleValidator(editor) if is_float_dtype(dtype) else QIntValidator(editor) if minvalue is not None: validator.setBottom(minvalue) if maxvalue is not None: @@ -287,6 +982,24 @@ def setEditorData(self, editor, index): text = index.model().data(index, Qt.DisplayRole) editor.setText(text) + def setModelData(self, editor, model, index): + parent = self.parent() + assert isinstance(parent, DataView) + # We store and recover scrollbar positions because the + # model_data.reset() we do in EditObjectCommand, set the hidden + # scrollbars to 0. + hscrollbar = parent.horizontalScrollBar() + vscrollbar = parent.verticalScrollBar() + h_pos_before = hscrollbar.value() + v_pos_before = vscrollbar.value() + + # This is the only thing we should be doing + model.setData(index, editor.text()) + + # recover original scrollbar positions + hscrollbar.setValue(h_pos_before) + vscrollbar.setValue(v_pos_before) + class DataView(AbstractView): """Data array view class""" @@ -303,68 +1016,109 @@ def __init__(self, parent, model): f"Received {type(model).__name__} instead") AbstractView.__init__(self, parent, model, RIGHT, BOTTOM) + # adapter = model.adapter + # available_actions = adapter.get_available_actions() self.context_menu = self.setup_context_menu() - # TODO: find a cleaner way to do this - # For some reason the shortcuts in the context menu are not available if the widget does not have the focus, - # EVEN when using action.setShortcutContext(Qt.ApplicationShortcut) (or Qt.WindowShortcut) so we redefine them - # here. I was also unable to get the function an action.triggered is connected to, so I couldn't do this via - # a loop on self.context_menu.actions. - shortcuts = [ - (keybinding('Copy'), self.parent().copy), - (QKeySequence("Ctrl+E"), self.parent().to_excel), - (keybinding('Paste'), self.parent().paste), - (keybinding('Print'), self.parent().plot) - ] - for key_seq, target in shortcuts: - shortcut = QShortcut(key_seq, self) - shortcut.activated.connect(target) - - def set_dtype(self, dtype): - model = self.model() - delegate = ArrayDelegate(dtype, self, minvalue=model.minvalue, maxvalue=model.maxvalue) + delegate = ArrayDelegate(self) self.setItemDelegate(delegate) - - def selectNewRow(self, row_index): + self.doubleClicked.connect(self.activate_cell) + + def selectRow(self, buffer_v_pos: int): + assert isinstance(buffer_v_pos, int) + super().selectRow(buffer_v_pos) + model: DataArrayModel = self.model() + total_v_size, total_h_size = model.adapter.shape2d() + global_v_pos = model.v_offset + buffer_v_pos + self.first_selection_corner = (global_v_pos, 0) + self.second_selection_corner = (global_v_pos, total_h_size - 1) + + def selectNewRow(self, buffer_v_pos: int): + assert isinstance(buffer_v_pos, int) # if not MultiSelection mode activated, selectRow will unselect previously # selected rows (unless SHIFT or CTRL key is pressed) - # this produces a selection with multiple QItemSelectionRange. We could merge them here, but it is - # easier to handle in _selection_bounds + # this produces a selection with multiple QItemSelectionRange. + # We could merge them here, but it is easier to handle in selection_bounds self.setSelectionMode(QTableView.MultiSelection) - self.selectRow(row_index) + # do not call self.selectRow to avoid updating first_selection_corner + super().selectRow(buffer_v_pos) self.setSelectionMode(QTableView.ContiguousSelection) - def selectNewColumn(self, column_index): + model = self.model() + total_v_size, total_h_size = model.adapter.shape2d() + global_v_pos = model.v_offset + buffer_v_pos + self.second_selection_corner = (global_v_pos, total_h_size - 1) + + def selectColumn(self, buffer_h_pos: int): + assert isinstance(buffer_h_pos, int) + super().selectColumn(buffer_h_pos) + model = self.model() + total_v_size, total_h_size = model.adapter.shape2d() + global_h_pos = model.h_offset + buffer_h_pos + self.first_selection_corner = (0, global_h_pos) + self.second_selection_corner = (total_v_size - 1, global_h_pos) + + def selectNewColumn(self, buffer_h_pos: int): + assert isinstance(buffer_h_pos, int) + # if not MultiSelection mode activated, selectColumn will unselect previously # selected columns (unless SHIFT or CTRL key is pressed) - # this produces a selection with multiple QItemSelectionRange. We could merge them here, but it is - # easier to handle in _selection_bounds + # easier to handle in selection_bounds self.setSelectionMode(QTableView.MultiSelection) - self.selectColumn(column_index) + # do not call self.selectColumn to avoid updating first_selection_corner + super().selectColumn(buffer_h_pos) self.setSelectionMode(QTableView.ContiguousSelection) + model = self.model() + total_v_size, total_h_size = model.adapter.shape2d() + global_h_pos = model.h_offset + buffer_h_pos + self.second_selection_corner = (total_v_size - 1, global_h_pos) + + def selectAll(self): + super().selectAll() + total_v_size, total_h_size = self.model().adapter.shape2d() + self.first_selection_corner = (0, 0) + self.second_selection_corner = (total_v_size - 1, total_h_size - 1) + def setup_context_menu(self): """Setup context menu""" - self.copy_action = create_action(self, _('Copy'), - shortcut=keybinding('Copy'), - icon=ima.icon('edit-copy'), - triggered=lambda: self.signal_copy.emit()) - self.excel_action = create_action(self, _('Copy to Excel'), - shortcut="Ctrl+E", - # icon=ima.icon('edit-copy'), - triggered=lambda: self.signal_excel.emit()) - self.paste_action = create_action(self, _('Paste'), - shortcut=keybinding('Paste'), - icon=ima.icon('edit-paste'), - triggered=lambda: self.signal_paste.emit()) - self.plot_action = create_action(self, _('Plot'), - shortcut=keybinding('Print'), - # icon=ima.icon('editcopy'), - triggered=lambda: self.signal_plot.emit()) + actions_def = [ + (_('Copy'), keybinding('Copy'), 'edit-copy', + lambda: self.signal_copy.emit()), + (_('Copy to Excel'), "Ctrl+E", None, + lambda: self.signal_excel.emit()), + (_('Plot'), keybinding('Print'), None, + lambda: self.signal_plot.emit()), + (_('Paste'), keybinding('Paste'), 'edit-paste', + lambda: self.signal_paste.emit()), + ] + actions = [ + create_action(self, label, shortcut=shortcut, icon=icon, + triggered=function) + for label, shortcut, icon, function in actions_def + ] menu = QMenu(self) - menu.addActions([self.copy_action, self.excel_action, self.plot_action, self.paste_action]) + menu.addActions(actions) + + # TODO: For some reason, when I wrote the context_menu code, the + # shortcuts from the actions in the context menu only worked + # if the widget had focus, EVEN when using + # action.setShortcutContext(Qt.ApplicationShortcut) + # (or Qt.WindowShortcut) so I had to redefine them here. + # I should revisit this code to see if that is still the case + # and even if so, I should do this in a cleaner way (probably by + # reusing the actions_def list above) + shortcuts = [ + (keybinding('Copy'), self.parent().copy), + (QKeySequence("Ctrl+E"), self.parent().to_excel), + (keybinding('Paste'), self.parent().paste), + (keybinding('Print'), self.parent().plot) + ] + for key_seq, target in shortcuts: + shortcut = QShortcut(key_seq, self) + shortcut.activated.connect(target) return menu def contextMenuEvent(self, event): @@ -372,98 +1126,173 @@ def contextMenuEvent(self, event): self.context_menu.popup(event.globalPos()) event.accept() + def mousePressEvent(self, event: QMouseEvent) -> None: + """Reimplement Qt method""" + super().mousePressEvent(event) + if event.button() == Qt.LeftButton: + cursor_global_pos = self.get_cursor_global_pos() + if cursor_global_pos is not None: + self.first_selection_corner = cursor_global_pos + + def mouseReleaseEvent(self, event: QMouseEvent) -> None: + """Reimplement Qt method""" + super().mouseReleaseEvent(event) + if event.button() == Qt.LeftButton: + cursor_global_pos = self.get_cursor_global_pos() + if cursor_global_pos is not None: + if self.first_selection_corner is not None: + # this is the normal case where we just finished a selection + self.second_selection_corner = cursor_global_pos + else: + # this can happen when the array_widget is reset between + # a mouse button press and its release, e.g. when + # double-clicking in the explorer to open a dataset but + # keeping the button pressed during the second click, + # moving the mouse a bit to select the cell, then releasing + # the mouse button + self.first_selection_corner = None + def keyPressEvent(self, event): """Reimplement Qt method""" # allow to start editing cells by pressing Enter - if event.key() == Qt.Key_Return and not self.model().readonly: + if event.key() == Qt.Key_Return: index = self.currentIndex() - try: - # qt6 - delegate = self.itemDelegateForIndex(index) - except AttributeError: - # qt5 - delegate = self.itemDelegate(index) - if delegate.editor_count == 0: - self.edit(index) + # TODO: we should check whether the object is readonly + # before trying to activate. If an object is both + # editable and activatable, it will be a problem + if not self.activate_cell(index): + try: + # qt6 + delegate = self.itemDelegateForIndex(index) + except AttributeError: + # qt5 + delegate = self.itemDelegate(index) + if delegate.editor_count == 0: + self.edit(index) else: - QTableView.keyPressEvent(self, event) + AbstractView.keyPressEvent(self, event) - def _selection_bounds(self, none_selects_all=True): + def selection_bounds(self): """ - Parameters - ---------- - none_selects_all : bool, optional - If True (default) and selection is empty, returns all data. - Returns ------- - tuple - selection bounds. end bound is exclusive + selection bounds (row_min, row_max, col_min, col_max -- end bounds are *exclusive*) + If selection is empty, returns all data. """ model = self.model() - selection_model = self.selectionModel() - assert isinstance(selection_model, QItemSelectionModel) - selection = selection_model.selection() - assert isinstance(selection, QItemSelection) - if not selection: - if none_selects_all: - return 0, model.total_rows, 0, model.total_cols - else: - return None - # merge potentially multiple selections into one big rect - row_min = min(srange.top() for srange in selection) - row_max = max(srange.bottom() for srange in selection) - col_min = min(srange.left() for srange in selection) - col_max = max(srange.right() for srange in selection) - - # if not all rows/columns have been loaded - if row_min == 0 and row_max == self.model().rows_loaded - 1: - row_max = self.model().total_rows - 1 - if col_min == 0 and col_max == self.model().cols_loaded - 1: - col_max = self.model().total_cols - 1 - return row_min, row_max + 1, col_min, col_max + 1 + # We do not check/use the local "Qt" selection model, which can even + # be empty even when something is selected if the view was scrolled + # enough (via v/h_offset) that it went out of the buffer area + if self.first_selection_corner is None: + assert self.second_selection_corner is None + total_rows, total_cols = model.adapter.shape2d() + return 0, total_rows, 0, total_cols + + assert self.first_selection_corner is not None + assert self.second_selection_corner is not None + selection_v_pos1, selection_h_pos1 = self.first_selection_corner + selection_v_pos2, selection_h_pos2 = self.second_selection_corner + + row_min = min(selection_v_pos1, selection_v_pos2) + row_max = max(selection_v_pos1, selection_v_pos2) + col_min = min(selection_h_pos1, selection_h_pos2) + col_max = max(selection_h_pos1, selection_h_pos2) + return row_min, row_max + 1, col_min, col_max + 1 -MAX_INT_DIGITS = 308 - -def num_int_digits(value): - """ - Number of integer digits. Completely ignores the fractional part. Does not take sign into account. - - >>> num_int_digits(1) - 1 - >>> num_int_digits(99) - 2 - >>> num_int_digits(-99.1) - 2 - """ - value = abs(value) - log10 = math.log10(value) if value > 0 else 0 - if log10 == np.inf: - return MAX_INT_DIGITS - else: - # max(1, ...) because there is at least one integer digit. - # explicit conversion to int for Python2.x - return max(1, int(math.floor(log10)) + 1) + def activate_cell(self, index: QModelIndex): + model = self.model() + global_v_pos = model.v_offset + index.row() + global_h_pos = model.h_offset + index.column() + new_data = model.adapter.cell_activated(global_v_pos, global_h_pos) + if new_data is not None: + # the adapter wants us to open a sub-element + array_widget = self.parent().parent() + assert isinstance(array_widget, ArrayEditorWidget) + adapter_creator = get_adapter_creator(new_data) + assert adapter_creator is not None + if isinstance(adapter_creator, str): + QMessageBox.information(self, "Cannot display object", + adapter_creator) + return True + + from larray_editor.editor import AbstractEditorWindow, MappingEditorWindow + widget = self + while (widget is not None and + not isinstance(widget, AbstractEditorWindow) and + callable(widget.parent)): + widget = widget.parent() + if isinstance(widget, MappingEditorWindow): + kernel = widget.ipython_kernel + if kernel is not None: + # make the current object available in the console + kernel.shell.push({ + '__current__': new_data + }) + if not (isinstance(new_data, Path) and new_data.is_dir()): + # TODO: we should add an operand on the future quickbar instead + array_widget.back_button_bar.add_back(array_widget.data, + array_widget.data_adapter) + # TODO: we should open a new window instead (see above) + array_widget.set_data(new_data) + return True + return False class ScrollBar(QScrollBar): """ A specialised scrollbar. """ - def __init__(self, parent, data_scrollbar): - super().__init__(data_scrollbar.orientation(), parent) - self.setMinimum(data_scrollbar.minimum()) - self.setMaximum(data_scrollbar.maximum()) - self.setSingleStep(data_scrollbar.singleStep()) - self.setPageStep(data_scrollbar.pageStep()) - - data_scrollbar.valueChanged.connect(self.setValue) - self.valueChanged.connect(data_scrollbar.setValue) - - data_scrollbar.rangeChanged.connect(self.setRange) - self.rangeChanged.connect(data_scrollbar.setRange) + def __init__(self, parent, orientation, data_model, widget): + super().__init__(orientation, parent) + assert isinstance(data_model, DataArrayModel) + assert isinstance(widget, ArrayEditorWidget) + + self.model = data_model + self.widget = widget + + # We need to update_range when the *total* number of rows/columns + # change, not when the loaded rows change so connecting to the + # rowsInserted and columnsInserted signals is useless here + data_model.modelReset.connect(self.update_range) + + def update_range(self): + adapter = self.model.adapter + if adapter is None: + return + # TODO: for some adapters shape2d is not reliable (it is a best guess), + # we should make sure we handle that + total_rows, total_cols = adapter.shape2d() + view_data = self.widget.view_data + + if self.orientation() == Qt.Horizontal: + buffer_ncols = self.model.ncols + hidden_hscroll_max = view_data.horizontalScrollBar().maximum() + max_value = total_cols - buffer_ncols + hidden_hscroll_max + logger.debug(f"update_range horizontal {total_cols=} {buffer_ncols=} {hidden_hscroll_max=} => {max_value=}") + if total_cols == 0 and max_value != 0: + logger.warn(f"empty data but {max_value=}. We let it pass for " + f"now (set it to 0).") + max_value = 0 + else: + buffer_nrows = self.model.nrows + hidden_vscroll_max = view_data.verticalScrollBar().maximum() + max_value = total_rows - buffer_nrows + hidden_vscroll_max + logger.debug(f"update_range vertical {total_rows=} {buffer_nrows=} {hidden_vscroll_max=} => {max_value=}") + if total_rows == 0 and max_value != 0: + logger.warn(f"empty data but {max_value=}. We let it pass for " + f"now (set it to 0).") + max_value = 0 + assert max_value >= 0, "max_value should not be negative" + value_before = self.value() + min_before = self.minimum() + max_before = self.maximum() + self.setMinimum(0) + self.setMaximum(max_value) + logger.debug(f" min: {min_before} -> 0 / " + f"max: {max_before} -> {max_value} / " + f"value: {value_before} -> {self.value()}") available_gradients = [ @@ -490,18 +1319,24 @@ def __init__(self, parent, data_scrollbar): class FontMetrics: def __init__(self, data_model): self.data_model = data_model - self._used_font = data_model.font + self._cached_font = self.model_font + + @property + def model_font(self): + return self.data_model.role_defaults[Qt.FontRole] def font_changed(self): - model_font = self.data_model.font - changed = model_font is not self._used_font and model_font != self._used_font - if changed: - self._used_font = model_font + model_font = self.model_font + if model_font is self._cached_font: + return False + changed = model_font == self._cached_font + # update cached font even if not changed so that the "is" check is enough next time + self._cached_font = model_font return changed @cached_property(font_changed) def str_width(self): - return QFontMetrics(self._used_font).horizontalAdvance + return QFontMetrics(self._cached_font).horizontalAdvance @cached_property(font_changed) def digit_width(self): @@ -537,8 +1372,10 @@ def get_numbers_width(self, int_digits, frac_digits=0, need_sign=False, scientif class ArrayEditorWidget(QWidget): dataChanged = Signal(list) + # milliseconds between a scroll event and updating cell sizes + UPDATE_SIZES_FROM_CONTENT_DELAY = 100 - def __init__(self, parent, data=None, readonly=False, bg_value=None, bg_gradient='blue-red', + def __init__(self, parent, data=None, readonly=False, attributes=None, bg_gradient='blue-red', minvalue=None, maxvalue=None, digits=None): QWidget.__init__(self, parent) assert bg_gradient in gradient_map @@ -547,16 +1384,16 @@ def __init__(self, parent, data=None, readonly=False, bg_value=None, bg_gradient self.readonly = readonly # prepare internal views and models - self.model_axes = AxesArrayModel(parent=self, readonly=readonly) + self.model_axes = AxesArrayModel(parent=self) #, readonly=readonly) self.view_axes = AxesView(parent=self, model=self.model_axes) - self.model_hlabels = LabelsArrayModel(parent=self, readonly=readonly) + self.model_hlabels = HLabelsArrayModel(parent=self) #, readonly=readonly) self.view_hlabels = LabelsView(parent=self, model=self.model_hlabels, hpos=RIGHT, vpos=TOP) - self.model_vlabels = LabelsArrayModel(parent=self, readonly=readonly) + self.model_vlabels = VLabelsArrayModel(parent=self) #, readonly=readonly) self.view_vlabels = LabelsView(parent=self, model=self.model_vlabels, hpos=LEFT, vpos=BOTTOM) - self.model_data = DataArrayModel(parent=self, readonly=readonly, minvalue=minvalue, maxvalue=maxvalue) + self.model_data = DataArrayModel(parent=self) #, readonly=readonly, minvalue=minvalue, maxvalue=maxvalue) self.view_data = DataView(parent=self, model=self.model_data) self.font_metrics = FontMetrics(self.model_data) @@ -564,20 +1401,35 @@ def __init__(self, parent, data=None, readonly=False, bg_value=None, bg_gradient # in case data is None self.data_adapter = None - # Create vertical and horizontal scrollbars - self.vscrollbar = ScrollBar(self, self.view_data.verticalScrollBar()) - self.hscrollbar = ScrollBar(self, self.view_data.horizontalScrollBar()) - - # Synchronize resizing - self.view_axes.horizontalHeader().sectionResized.connect(self.view_vlabels.updateSectionWidth) - self.view_axes.verticalHeader().sectionResized.connect(self.view_hlabels.updateSectionHeight) - self.view_hlabels.horizontalHeader().sectionResized.connect(self.view_data.updateSectionWidth) - self.view_vlabels.verticalHeader().sectionResized.connect(self.view_data.updateSectionHeight) - # Synchronize auto-resizing - self.view_axes.horizontalHeader().sectionHandleDoubleClicked.connect(self.resize_axes_column_to_contents) - self.view_hlabels.horizontalHeader().sectionHandleDoubleClicked.connect(self.resize_hlabels_column_to_contents) - self.view_axes.verticalHeader().sectionHandleDoubleClicked.connect(self.resize_axes_row_to_contents) - self.view_vlabels.verticalHeader().sectionHandleDoubleClicked.connect(self.resize_vlabels_row_to_contents) + # Create visible vertical and horizontal scrollbars + # TODO: when models "total" shape change (this is NOT model.nrows/ncols), we should update the range of + # vscrollbar/hscrollbar. this is already partially done in ScrollBar (it listens to modelReset signal) but + # this is not enough + self.vscrollbar = ScrollBar(self, Qt.Vertical, self.model_data, self) + self.vscrollbar.valueChanged.connect(self.visible_vscroll_changed) + self.hscrollbar = ScrollBar(self, Qt.Horizontal, self.model_data, self) + self.hscrollbar.valueChanged.connect(self.visible_hscroll_changed) + + axes_h_header = self.view_axes.horizontalHeader() + axes_v_header = self.view_axes.verticalHeader() + hlabels_h_header = self.view_hlabels.horizontalHeader() + vlabels_v_header = self.view_vlabels.verticalHeader() + + # Propagate section resizing (left -> right and top -> bottom) + axes_h_header.sectionResized.connect(self.on_axes_column_resized) + axes_v_header.sectionResized.connect(self.on_axes_row_resized) + hlabels_h_header.sectionResized.connect(self.on_hlabels_column_resized) + vlabels_v_header.sectionResized.connect(self.on_vlabels_row_resized) + + # only useful for debugging + # data_h_header = self.view_data.horizontalHeader() + # data_h_header.sectionResized.connect(self.on_data_column_resized) + + # Propagate auto-resizing requests + axes_h_header.sectionHandleDoubleClicked.connect(self.resize_axes_column_to_contents) + hlabels_h_header.sectionHandleDoubleClicked.connect(self.resize_hlabels_column_to_contents) + axes_v_header.sectionHandleDoubleClicked.connect(self.resize_axes_row_to_contents) + vlabels_v_header.sectionHandleDoubleClicked.connect(self.resize_vlabels_row_to_contents) # synchronize specific methods self.view_axes.allSelected.connect(self.view_data.selectAll) @@ -589,19 +1441,45 @@ def __init__(self, parent, data=None, readonly=False, bg_value=None, bg_gradient # propagate changes (add new items in the QUndoStack attribute of MappingEditor) self.model_data.newChanges.connect(self.data_changed) - # Synchronize scrolling + # Synchronize scrolling of the different hidden scrollbars # data <--> hlabels - self.view_data.horizontalScrollBar().valueChanged.connect(self.view_hlabels.horizontalScrollBar().setValue) - self.view_hlabels.horizontalScrollBar().valueChanged.connect(self.view_data.horizontalScrollBar().setValue) + hidden_data_hscrollbar = self.view_data.horizontalScrollBar() + hidden_hlabels_hscrollbar = self.view_hlabels.horizontalScrollBar() + hidden_data_hscrollbar.valueChanged.connect( + hidden_hlabels_hscrollbar.setValue + ) + hidden_hlabels_hscrollbar.valueChanged.connect( + hidden_data_hscrollbar.setValue + ) + # data <--> vlabels - self.view_data.verticalScrollBar().valueChanged.connect(self.view_vlabels.verticalScrollBar().setValue) - self.view_vlabels.verticalScrollBar().valueChanged.connect(self.view_data.verticalScrollBar().setValue) + hidden_data_vscrollbar = self.view_data.verticalScrollBar() + hidden_vlabels_vscrollbar = self.view_vlabels.verticalScrollBar() + hidden_data_vscrollbar.valueChanged.connect( + hidden_vlabels_vscrollbar.setValue + ) + hidden_vlabels_vscrollbar.valueChanged.connect( + hidden_data_vscrollbar.setValue + ) + + # Propagate range updates from hidden scrollbars to visible scrollbars + # The ranges are updated when we resize the window or some columns/rows + # and some of them quit or re-enter the viewport. + # We do NOT need to propagate the hidden scrollbar value changes + # because we scroll by entire columns/rows, so resizing columns/rows + # does not change the scrollbar values + def hidden_hscroll_range_changed(min_value: int, max_value: int): + self.hscrollbar.update_range() + hidden_data_hscrollbar.rangeChanged.connect(hidden_hscroll_range_changed) + def hidden_vscroll_range_changed(min_value: int, max_value: int): + self.vscrollbar.update_range() + hidden_data_vscrollbar.rangeChanged.connect(hidden_vscroll_range_changed) # Synchronize selecting columns(rows) via hor.(vert.) header of x(y)labels view - self.view_hlabels.horizontalHeader().sectionPressed.connect(self.view_data.selectColumn) - self.view_hlabels.horizontalHeader().sectionEntered.connect(self.view_data.selectNewColumn) - self.view_vlabels.verticalHeader().sectionPressed.connect(self.view_data.selectRow) - self.view_vlabels.verticalHeader().sectionEntered.connect(self.view_data.selectNewRow) + hlabels_h_header.sectionPressed.connect(self.view_data.selectColumn) + hlabels_h_header.sectionEntered.connect(self.view_data.selectNewColumn) + vlabels_v_header.sectionPressed.connect(self.view_data.selectRow) + vlabels_v_header.sectionEntered.connect(self.view_data.selectNewRow) # following lines are required to keep usual selection color # when selecting rows/columns via headers of label views. @@ -636,18 +1514,22 @@ def __init__(self, parent, data=None, readonly=False, bg_value=None, bg_gradient array_frame.setLayout(array_layout) # Set filters and buttons layout - self.filters_layout = QHBoxLayout() + self.back_button_bar = BackButtonBar(self) + self.filter_bar = FilterBar(self) self.btn_layout = QHBoxLayout() self.btn_layout.setAlignment(Qt.AlignLeft) # sometimes also called "Fractional digits" or "scale" label = QLabel("Decimal Places") self.btn_layout.addWidget(label) + # default range is 0-99 spin = QSpinBox(self) spin.valueChanged.connect(self.frac_digits_changed) + # spin.setRange(-1, 99) + # this is used when the widget has its minimum value + # spin.setSpecialValueText("auto") self.digits_spinbox = spin self.btn_layout.addWidget(spin) - self.frac_digits = 0 scientific = QCheckBox(_('Scientific')) scientific.stateChanged.connect(self.scientific_changed) @@ -691,165 +1573,387 @@ def __init__(self, parent, data=None, readonly=False, bg_value=None, bg_gradient # Set widget layout layout = QVBoxLayout() - layout.addLayout(self.filters_layout) + layout.addWidget(self.back_button_bar) + layout.addWidget(self.filter_bar) layout.addWidget(array_frame) layout.addLayout(self.btn_layout) + # left, top, right, bottom layout.setContentsMargins(0, 0, 0, 0) self.setLayout(layout) # set gradient self.model_data.set_bg_gradient(gradient_map[bg_gradient]) + # TODO: store detected_column_widths too so that it does not vary so + # much on scroll. Viewing test_api_larray.py is a good test for + # this. + self.user_defined_hlabels_column_widths = {} + self.user_defined_axes_column_widths = {} + self.user_defined_vlabels_row_heights = {} + self.user_defined_axes_row_heights = {} + self.detected_hlabels_column_widths = {} + self.detected_axes_column_widths = {} + # TODO: find some more efficient structure to store them. 99.9% + # of rows will use the default height + self.detected_vlabels_row_heights = {} + self.detected_axes_row_heights = {} + self._updating_hlabels_column_widths = False + self._updating_axes_column_widths = False + self._updating_vlabels_row_heights = False + self._updating_axes_row_heights = False + + update_timer = QTimer(self) + update_timer.setSingleShot(True) + update_timer.setInterval(self.UPDATE_SIZES_FROM_CONTENT_DELAY) + update_timer.timeout.connect(self.update_cell_sizes_from_content) + self.update_cell_sizes_timer = update_timer + # set data if data is not None: - self.set_data(data, bg_value=bg_value, frac_digits=digits) - - # See http://doc.qt.io/qt-4.8/qt-draganddrop-fridgemagnets-dragwidget-cpp.html for an example - self.setAcceptDrops(True) - - def gradient_changed(self, index): - gradient = self.gradient_chooser.itemData(index) if index > 0 else None - self.model_data.set_bg_gradient(gradient) - - def data_changed(self, data_model_changes): - changes = self.data_adapter.translate_changes(data_model_changes) - self.dataChanged.emit(changes) - - def mousePressEvent(self, event): - self.dragLabel = self.childAt(event.pos()) if event.button() == Qt.LeftButton else None - self.dragStartPosition = event.pos() - - def mouseMoveEvent(self, event): - from qtpy.QtCore import QMimeData, QByteArray - from qtpy.QtGui import QPixmap, QDrag - - if not (event.button() != Qt.LeftButton and isinstance(self.dragLabel, QLabel)): + self.set_data(data, attributes=attributes, frac_digits=digits) + + def visible_cols(self, include_partial=True): + """number of visible columns *including* partially visible ones""" + + view_data = self.view_data + hidden_h_offset = view_data.horizontalScrollBar().value() + view_width = view_data.width() + last_visible_col_idx = view_data.columnAt(view_width - 1) + + # +1 because last_visible_col_idx is a 0-based index + # if last_visible_col_idx == -1 it means the visible area is larger than + # the array + num_cols = last_visible_col_idx + 1 if last_visible_col_idx != -1 else self.model_data.ncols + # clsname = self.__class__.__name__ + # logger.debug(f"{clsname}.visible_cols({include_partial=})") + # logger.debug(f" {hidden_h_offset=} {view_width=} " + # f"{last_visible_col_idx=} => {num_cols=}") + if not include_partial and last_visible_col_idx != -1: + last_visible_col_width = view_data.columnWidth(last_visible_col_idx) + # no - 1 necessary here + next_to_last_col_idx = view_data.columnAt(view_width - + last_visible_col_width) + has_partial = next_to_last_col_idx < last_visible_col_idx + if has_partial: + num_cols -= 1 + return num_cols - hidden_h_offset + + def visible_rows(self, include_partial=True): + """number of visible rows *including* partially visible ones""" + + view_data = self.view_data + hidden_v_offset = view_data.verticalScrollBar().value() + view_height = view_data.height() + last_visible_row_idx = view_data.rowAt(view_height - 1) + + # +1 because last_visible_row_idx is a 0-based index + # if last_visible_row_idx == -1 it means the visible area is larger than + # the array + num_rows = last_visible_row_idx + 1 if last_visible_row_idx != -1 else self.model_data.nrows + # clsname = self.__class__.__name__ + # logger.debug(f"{clsname}.visible_rows({include_partial=})") + # logger.debug(f" {hidden_v_offset=} {view_height=} " + # f"{last_visible_row_idx=} => {num_rows=}") + if not include_partial and last_visible_row_idx != -1: + last_visible_row_height = view_data.rowHeight(last_visible_row_idx) + # no - 1 necessary here + next_to_last_row_idx = view_data.rowAt(view_height - + last_visible_row_height) + has_partial = next_to_last_row_idx < last_visible_row_idx + if has_partial: + num_rows -= 1 + # logger.debug(f" {has_partial=} => {num_rows=}") + visible_rows = num_rows - hidden_v_offset + # logger.debug(f" {hidden_v_offset=} => {visible_rows=}") + return visible_rows + + # Update the local/Qt selection, if needed + def _update_selection(self, new_h_offset, new_v_offset): + view_data = self.view_data + + selection_model = view_data.selectionModel() + assert isinstance(selection_model, QItemSelectionModel) + local_selection = selection_model.selection() + assert isinstance(local_selection, QItemSelection) + # if there is a local selection, we always need to move it (and + # sometimes shrink it); if there is a global selection and no local + # selection we may need to create a local selection when the global + # selection intersects with the viewport + global_selection_set = view_data.first_selection_corner is not None + if local_selection or global_selection_set: + model_data = self.model_data + row_min, row_max, col_min, col_max = view_data.selection_bounds() + + # we need to clip local coordinates in case the selection + # corners are outside the viewport + local_top = max(row_min - new_v_offset, 0) + local_left = max(col_min - new_h_offset, 0) + # -1 because selection_bounds are exclusive while Qt use + # inclusive bounds + local_bottom = min(row_max - 1 - new_v_offset, model_data.nrows - 1) + local_right = min(col_max - 1 - new_h_offset, model_data.ncols - 1) + local_selection = QItemSelection( + model_data.index(local_top, local_left), + model_data.index(local_bottom, local_right) + ) + selection_model.select(local_selection, + QItemSelectionModel.ClearAndSelect) + + def visible_vscroll_changed(self, value): + # 'value' will be the first visible row + assert value >= 0, f"value must be >= 0 but is {value!r}" + model_data = self.model_data + hidden_vscroll = self.view_data.verticalScrollBar() + # hidden_vscroll_max is the margin we got before we must move the buffer + hidden_vscroll_max = hidden_vscroll.maximum() + v_offset = model_data.v_offset + extra_move = hidden_vscroll_max // 2 + logger.debug(f"visible vscroll changed({value=}, {v_offset=}, " + f"hidden_max={hidden_vscroll_max}, {extra_move=})") + + # buffer is beyond what is asked to display, we need to move it back + if value < v_offset: + # we could simply set it to value but we want to move more to avoid + # fetching data for each row + new_v_offset = max(value - extra_move, 0) + msg = " value < v_offset (min)" + + # we don't need to move the buffer (we can absorb the scroll change + # entirely with the hidden scroll) + elif value <= v_offset + hidden_vscroll_max: + new_v_offset = v_offset + msg = " min <= value <= max => change hidden only" + + # buffer is before what is asked to display, we need to move it further + # <-visible_rows-> + # <------nrows----------> + # | |------buffer---------| | | | + # ^ ^ ^ ^ ^ + # 0 v_offset value max_value total_rows + else: + # we could simply set it to "value - hidden_vscroll_max" to move as + # little as possible (this would place the visible rows at the end + # of the buffer) but we want to move more to avoid fetching data + # each time we move a single row + new_v_offset = value - hidden_vscroll_max + extra_move + # make sure we always have an entire buffer + total_rows, total_cols = self.data_adapter.shape2d() + new_v_offset = min(new_v_offset, total_rows - model_data.nrows) + msg = " value > v_offset + invis (max)" + + assert new_v_offset >= 0 + assert new_v_offset <= value <= new_v_offset + hidden_vscroll_max + + new_hidden_offset = value - new_v_offset + logger.debug(f"{msg} => {new_hidden_offset=}, {new_v_offset=}") + if new_v_offset != v_offset: + model_data.set_v_offset(new_v_offset) + self.model_vlabels.set_v_offset(new_v_offset) + self._update_selection(model_data.h_offset, new_v_offset) + self.update_cell_sizes_timer.start() + + hidden_vscroll.setValue(new_hidden_offset) + + def update_cell_sizes_from_content(self): + logger.debug("ArrayEditorWidget.update_cell_sizes_from_content()") + # TODO: having this in a timer alleviates the scrolling speed issue + # but we could also try to make this faster: + # * Would computing the sizeHint ourselves help? + # * For many in-memory (especially numerical) containers, + # it would be cheaper to compute that once for the + # whole array instead of after each scroll + self._update_hlabels_column_widths_from_content() + # we do not need to update axes cell size on scroll but vlabels + # width can change on scroll (and they are linked to axes widths) + self._update_vlabels_row_heights_from_content() + self._update_axes_column_widths_from_content() + self._update_axes_row_heights_from_content() + + def visible_hscroll_changed(self, value): + # 'value' will be the first visible column + assert value >= 0, f"value must be >= 0 but is {value!r}" + model_data = self.model_data + hidden_hscroll = self.view_data.horizontalScrollBar() + # hidden_hscroll_max is the margin we got before we must move the buffer + hidden_hscroll_max = hidden_hscroll.maximum() + extra_move = hidden_hscroll_max // 2 + h_offset = model_data.h_offset + logger.debug(f"visible hscroll changed ({value=}, {h_offset=}, " + f"hidden_max={hidden_hscroll_max}, {extra_move=})") + + # buffer is beyond what is asked to display, we need to move it back + if value < h_offset: + # we could simply set it to value but we want to move more to avoid + # fetching data for each row + new_h_offset = max(value - extra_move, 0) + msg = "value < h_offset (min)" + + # we don't need to move the buffer (we can absorb the scroll change + # entirely with the hidden scroll) + elif value <= h_offset + hidden_hscroll_max: + new_h_offset = h_offset + msg = "min <= value <= max (hidden only)" + + # buffer is before what is asked to display, we need to move it further + # <-visible_cols-> + # <------ncols----------> + # | |------buffer---------| | | | + # ^ ^ ^ ^ ^ + # 0 h_offset value max_value total_cols + else: + # we could simply set it to "value - hidden_hscroll_max" to move as + # little as possible (this would place the visible cols at the end + # of the buffer) but we want to move more to avoid fetching data + # each time we move a single col + new_h_offset = value - hidden_hscroll_max + extra_move + # make sure we always have an entire buffer + total_rows, total_cols = self.data_adapter.shape2d() + new_h_offset = min(new_h_offset, total_cols - model_data.ncols) + msg = "value > h_offset + invis (max)" + + assert new_h_offset >= 0 + assert new_h_offset <= value <= new_h_offset + hidden_hscroll_max + + new_hidden_offset = value - new_h_offset + logger.debug(f"{msg} => {new_hidden_offset=}, {new_h_offset=}") + if new_h_offset != h_offset: + model_data.set_h_offset(new_h_offset) + self.model_hlabels.set_h_offset(new_h_offset) + self._update_selection(new_h_offset, model_data.v_offset) + self.update_cell_sizes_timer.start() + hidden_hscroll.setValue(new_hidden_offset) + + def on_axes_column_resized(self, logical_index, old_size, new_size): + # synchronize with linked view + # equivalent (AFAICT) to: + # view_vlabels.horizontalHeader().resizeSection(logical_index, new_size) + self.view_vlabels.setColumnWidth(logical_index, new_size) + if self._updating_axes_column_widths: return + self.user_defined_axes_column_widths[logical_index] = new_size - if (event.pos() - self.dragStartPosition).manhattanLength() < QApplication.startDragDistance(): + def on_axes_row_resized(self, logical_index, old_size, new_size): + # synchronize with linked view + self.view_hlabels.setRowHeight(logical_index, new_size) + if self._updating_axes_row_heights: return + self.user_defined_axes_row_heights[logical_index] = new_size - axis_index = self.filters_layout.indexOf(self.dragLabel) // 2 - - # prepare hotSpot, mimeData and pixmap objects - mimeData = QMimeData() - mimeData.setText(self.dragLabel.text()) - mimeData.setData("application/x-axis-index", QByteArray.number(axis_index)) - pixmap = QPixmap(self.dragLabel.size()) - self.dragLabel.render(pixmap) - - # prepare drag object - drag = QDrag(self) - drag.setMimeData(mimeData) - drag.setPixmap(pixmap) - drag.setHotSpot(event.pos() - self.dragStartPosition) - - drag.exec_(Qt.MoveAction | Qt.CopyAction, Qt.CopyAction) + def on_hlabels_column_resized(self, logical_index, old_size, new_size): + # synchronize with linked view + # logger.debug(f"on_hlabels_column_resized {logical_index=} {new_size=}") + self.view_data.setColumnWidth(logical_index, new_size) + if self._updating_hlabels_column_widths: + return + h_offset = self.model_data.h_offset + self.user_defined_hlabels_column_widths[logical_index + h_offset] = new_size - def dragEnterEvent(self, event): - if event.mimeData().hasText(): - if self.filters_layout.geometry().contains(event.pos()): - event.setDropAction(Qt.MoveAction) - event.accept() - else: - event.acceptProposedAction() - else: - event.ignore() + # def on_data_column_resized(self, logical_index, old_size, new_size): + # log_caller() + # logger.debug(f"on_data_column_resized {logical_index=} {new_size=}") - def dragMoveEvent(self, event): - if event.mimeData().hasText() and self.filters_layout.geometry().contains(event.pos()): - child = self.childAt(event.pos()) - if isinstance(child, QLabel) and child.text() != "Filters": - event.setDropAction(Qt.MoveAction) - event.accept() - else: - event.ignore() - else: - event.ignore() + def on_vlabels_row_resized(self, logical_index, old_size, new_size): + # synchronize with linked view + self.view_data.setRowHeight(logical_index, new_size) + if self._updating_vlabels_row_heights: + return + v_offset = self.model_data.v_offset + self.user_defined_vlabels_row_heights[logical_index + v_offset] = new_size - def dropEvent(self, event): - if event.mimeData().hasText(): - if self.filters_layout.geometry().contains(event.pos()): - old_index, success = event.mimeData().data("application/x-axis-index").toInt() - new_index = self.filters_layout.indexOf(self.childAt(event.pos())) // 2 + def gradient_changed(self, index): + gradient = self.gradient_chooser.itemData(index) if index > 0 else None + self.model_data.set_bg_gradient(gradient) - data, bg_value = self.data_adapter.data, self.data_adapter.bg_value - data, bg_value = self.data_adapter.move_axis(data, bg_value, old_index, new_index) - self.set_data(data, bg_value) + def data_changed(self, data_model_changes): + global_changes = self.data_adapter.translate_changes(data_model_changes) + self.dataChanged.emit(global_changes) - event.setDropAction(Qt.MoveAction) - event.accept() - else: - event.acceptProposedAction() - else: - event.ignore() + def _set_models_adapter(self): + self.model_axes.set_adapter(self.data_adapter) + self.model_hlabels.set_adapter(self.data_adapter) + self.model_vlabels.set_adapter(self.data_adapter) + self.model_data.set_adapter(self.data_adapter) - def _update_models(self, reset_model_data, reset_minmax): - # axes names - axes_names = self.data_adapter.get_axes_names(fold_last_axis=True) - self.model_axes.set_data(axes_names) - # horizontal labels - hlabels = self.data_adapter.get_hlabels() - self.model_hlabels.set_data(hlabels) - # vertical labels - vlabels = self.data_adapter.get_vlabels() - self.model_vlabels.set_data(vlabels) - # raw data - # use flag reset=False to avoid calling reset() several times - raw_data = self.data_adapter.get_raw_data() - self.model_data.set_data(raw_data, reset=False) - # bg value - # use flag reset=False to avoid calling reset() several times - bg_value = self.data_adapter.get_bg_value() - self.model_data.set_bg_value(bg_value, reset=False) - # reset min and max values if required - if reset_minmax: - self.model_data.reset_minmax() - # reset the data model if required - if reset_model_data: - self.model_data.reset() - - def set_data(self, data, bg_value=None, frac_digits=None): + def set_data(self, data, attributes=None, frac_digits=None): # get new adapter instance + set data - self.data_adapter = get_adapter(data=data, bg_value=bg_value) - # update filters - self._update_filter() - # update models - # Note: model_data is reset by call of set_format below - self._update_models(reset_model_data=False, reset_minmax=True) - # reset default size - self._reset_default_size() - # update data format - self.set_format(frac_digits=frac_digits, scientific=None) - # update gradient_chooser - self.gradient_chooser.setEnabled(self.model_data.bgcolor_possible) - # update dtype in view_data - self.view_data.set_dtype(self.data_adapter.dtype) - - def _reset_default_size(self): - self.view_axes.set_default_size() - self.view_vlabels.set_default_size() - self.view_hlabels.set_default_size() - self.view_data.set_default_size() - - def _update_filter(self): - filters_layout = self.filters_layout - clear_layout(filters_layout) - axes = self.data_adapter.get_axes_filtered_data() - # size > 0 to avoid arrays with length 0 axes and len(axes) > 0 to avoid scalars (scalar.size == 1) - if self.data_adapter.size > 0 and len(axes) > 0: - filters_layout.addWidget(QLabel(_("Filters"))) - for axis in axes: - filters_layout.addWidget(QLabel(axis.name)) - # FIXME: on very large axes, this is getting too slow. Ideally the combobox should use a model which - # only fetch labels when they are needed to be displayed - if len(axis) < 10000: - filters_layout.addWidget(self.create_filter_combo(axis)) - else: - filters_layout.addWidget(QLabel("too big to be filtered")) - filters_layout.addStretch() + # TODO: add a mechanism that adapters can use to tell whether they support a + # particular *instance* of a data structure. This should probably be a + # class method. + # For example for memoryview, "structured" + # memoryview are not supported and get_adapter currently returns None + data_adapter = get_adapter(data, attributes) + if data_adapter is None: + return + self.data = data + self.set_data_adapter(data_adapter, frac_digits) + + def close(self): + logger.debug("ArrayEditorWidget.close()") + if self.data_adapter is not None: + self._close_adapter(self.data_adapter) + self.back_button_bar.clear() + super().close() + + @staticmethod + def _close_adapter(adapter): + clsname = type(adapter).__name__ + logger.debug(f"closing data adapter ({clsname})") + adapter.close() + + def set_data_adapter(self, data_adapter: AbstractAdapter, frac_digits): + old_adapter = self.data_adapter + if old_adapter is not None: + # We only need to close it if that adapter is not used in any + # "back button" + if not any(adapter is old_adapter + for adapter in self.back_button_bar._back_data_adapters): + self._close_adapter(old_adapter) + self.data_adapter = data_adapter - def set_format(self, frac_digits=None, scientific=None): + # update models + self._set_models_adapter() + + # reset widget to initial state + self.reset_to_defaults() + + # update data format & autosize all cells + # view_data and view_hlabels columns are resized automatically in + # set_frac_digits_or_scientific, so using self.autofit_columns() + # (which resizes columns of the 4 different views) is overkill but we + # still need to resize view_axes and view_vlabels columns. + self.set_frac_digits_or_scientific(frac_digits=frac_digits, + scientific=None) + self._update_axes_column_widths_from_content() + self._update_axes_row_heights_from_content() + self._update_vlabels_row_heights_from_content() + + def reset_to_defaults(self): + logger.debug(f"{self.__class__.__name__}.reset_to_defaults()") + + # reset visible scrollbars + self.vscrollbar.setValue(0) + self.hscrollbar.setValue(0) + + # reset filters + self.filter_bar.reset_to_defaults() + + # reset default sizes and clear selection + self.view_axes.reset_to_defaults() + self.view_vlabels.reset_to_defaults() + self.view_hlabels.reset_to_defaults() + self.view_data.reset_to_defaults() + + # clear user defined & detected column widths & row heights + self.user_defined_axes_column_widths = {} + self.user_defined_axes_row_heights = {} + self.user_defined_hlabels_column_widths = {} + self.user_defined_vlabels_row_heights = {} + self.detected_hlabels_column_widths = {} + self.detected_axes_column_widths = {} + self.detected_vlabels_row_heights = {} + self.detected_axes_row_heights = {} + + def set_frac_digits_or_scientific(self, frac_digits=None, scientific=None): """Set format. Parameters @@ -857,29 +1961,146 @@ def set_format(self, frac_digits=None, scientific=None): frac_digits : int, optional Number of decimals to display. Defaults to None (autodetect). scientific : boolean, optional - Whether or not to display values in scientific format. Defaults to None (autodetect). + Whether or not to display values in scientific format. + Defaults to None (autodetect). + + Currently, it is called from 3 places/cases: + - set_data => frac_digits=None, scientific=None + - user changes scientific => frac_digits=None, scientific=bool + - user changes frac_digits => frac_digits=int, scientific=bool + + ON NEW DATA + compute vmin/vmax on "start" buffer + - can be per buffer, per column, or per row depending on adapter + + if api-provided ndigits is None and scientific is None (default) + autodetect scientific & ndigits + - can be per buffer, per column, or per row + autodetect column width + - can be per buffer, or per column + autodetect row height + - can be per buffer or per row + elif ndigits is a dict and scientific is a dict: + autodetect column widths: + - should be per column, but unsure it is worth blocking + per buffer even though I do not think it makes sense + elif ndigits is an int and scientific is a bool: + autodetect column widths: + - can be per buffer or per column (depending on adapter) + if per column, ndigits is a modifier to autodetected + value + elif ndigits is an int and scientific is None: + autodetect scientific + - can be per buffer, per column, or per row + autodetect column widths: + - can be per buffer or per column (depending on adapter) + if per column, ndigits is a modifier to autodetected + value + elif ndigits is None and scientific is a bool: + autodetect ndigits + - can be per buffer, per column, or per row + autodetect column widths: + - can be per buffer or per column (depending on adapter) + if per column, GUI ndigits is a modifier to autodetected + value + ON V_OFFSET CHANGE: + update vmin/vmax (do not recompute on *just* the current buffer) + - can be per buffer, per column or per row + if LARRAY: + IF still the "autodetected" ndigits, + re-detect ndigits given current window + (I don't think we should touch scientific in this case) + elif dataframe: + re-detect ndigits given current window + (I don't think we should touch scientific in this case) + add/subtract ndigits offset + update columns width if needed + ON H_OFFSET change: + update vmin/vmax + update (invisible) columns width (do not change ndigits, scientific) + ON SCIENTIFIC CHANGE: + do NOT update column widths + autodetect ndigits + - can be per buffer, per column, per row or per cell + ON NDIGITS CHANGE: + update columns width + ON COLWIDTH CHANGE: + if dataframe or per-column array (*): + re-detect ndigits only for the changed column + (I don't think we should touch scientific in this case) + elif homogeneous array: + synchronize exact column width for all columns + doing this implicitly via ndigits will result in "unpleasant" + resizing I think + re-detect scientific format for all columns + + + one option: move format determination in the adapter or model + + when set_frac_digits_or_scientific from the UI: + call corresponding method on adapter passing current column widths + then force-fetch resulting data from the model to compute final + column widths """ + logger.debug(f"ArrayEditorWidget.set_frac_digits_or_scientific(" + f"{frac_digits=}, {scientific=})") assert frac_digits is None or isinstance(frac_digits, int) assert scientific is None or isinstance(scientific, bool) scientific_toggled = scientific is not None and scientific != self.use_scientific - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"ArrayEditorWidget.set_format(frac_digits={frac_digits}, scientific={scientific})") - data_sample = self.data_adapter.get_finite_sample() - is_number_dtype = np.issubdtype(data_sample.dtype, np.number) + data_sample = self.data_adapter.get_sample() + if not isinstance(data_sample, np.ndarray): + # TODO: for non numpy homogeneous data types, this is suboptimal + data_sample = np.asarray(data_sample, dtype=object) + is_number_dtype = (isinstance(data_sample, np.ndarray) and + np.issubdtype(data_sample.dtype, np.number)) cur_colwidth = self._get_current_min_col_width() + if is_number_dtype and data_sample.size: - # TODO: this should come from the adapter or from the data_model (were it is already computed!!!) - # (but modified whenever the data changes) + # TODO: vmin/vmax should come from the adapter (were it is already + # computed and modified whenever the data changes) + # TODO: some (all?) of this should be done in the adapter because + # it knows whether vmin/vmax should be per column or global + # and in the end if format and colwidth should be the same + # for the whole array or per col but I am still unsure of the + # boundary because font_metrics should not be used in the + # adapter. + # * The adapter also knows how expensive it is to compute some + # stuff and whether we can compute vmin/vmax on the full + # array or have to rely on sample + "rolling" vmin/vmax. + # * If vmin/vmax are arrays, we need to know which + # rows/columns (v_offset/h_offset) they correspond to. vmin, vmax = np.min(data_sample), np.max(data_sample) - int_digits = max(num_int_digits(vmin), num_int_digits(vmax)) - has_negative = vmin < 0 + is_finite_data = np.isfinite(vmin) and np.isfinite(vmax) + # logger.debug(f" {data_sample=}") + if is_finite_data: + finite_sample = data_sample + finite_vmin, finite_vmax = vmin, vmax + else: + isfinite = np.isfinite(data_sample) + if isfinite.any(): + finite_sample = data_sample[isfinite] + finite_vmin = np.min(finite_sample) + finite_vmax = np.max(finite_sample) + else: + finite_sample = None + finite_vmin, finite_vmax = 0, 0 + scientific = False + frac_digits = 0 + # logger.debug(f" {finite_sample=}") + logger.debug(f" {finite_vmin=}, {finite_vmax=}") + absmax = max(abs(finite_vmin), abs(finite_vmax)) + int_digits = num_int_digits(absmax) + logger.debug(f" {absmax=} {int_digits=}") + has_negative = finite_vmin < 0 font_metrics = self.font_metrics # choose whether or not to use scientific notation # ================================================ if scientific is None: + # TODO: use numpy ops so that it works for array inputs too + # use scientific format if there are more integer digits than we can display or if we can display # more information that way (scientific format "uses" 4 digits, so we have a net win if we have # >= 4 zeros -- *including the integer one*) @@ -887,41 +2108,67 @@ def set_format(self, frac_digits=None, scientific=None): # 0.00001 can be displayed with 8 chars # 1e-05 # would - absmax = max(abs(vmin), abs(vmax)) + # logabsmax = np.where(absmax > 0, np.log10(absmax), 0) logabsmax = math.log10(absmax) if absmax else 0 # minimum number of zeros before meaningful fractional part + # frac_zeros = np.where(logabsmax < 0, np.ceil(-logabsmax) - 1, 0) frac_zeros = math.ceil(-logabsmax) - 1 if logabsmax < 0 else 0 non_scientific_int_width = font_metrics.get_numbers_width(int_digits, need_sign=has_negative) + # with the current default width and font size, this accepts up + # to 8 digits for positive numbers (7 for negative) + # TODO: change that to accept up to 12 digits for positive + # numbers (11 for negative) so that with the thousand + # separators we can display values up to 999 billions + # without using scientific notation + # scientific = (non_scientific_int_width > cur_colwidth) | (frac_zeros >= 4) scientific = non_scientific_int_width > cur_colwidth or frac_zeros >= 4 - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f" -> detected scientific={scientific}") + logger.debug(f" {logabsmax=} {frac_zeros=} {non_scientific_int_width=}") + logger.info(f" -> detected scientific={scientific}") # determine best number of decimals to display # ============================================ if frac_digits is None: int_part_width = font_metrics.get_numbers_width(int_digits, need_sign=has_negative, scientific=scientific) + # logger.debug(f" {int_digits=} {has_negative=} {scientific=} => {int_part_width=}") # since we are computing the number of frac digits, we always need the dot avail_width_for_frac_part = max(cur_colwidth - int_part_width - font_metrics.dot_width, 0) + # logger.debug(f" {cur_colwidth=} {font_metrics.dot_width=} => {avail_width_for_frac_part=}") max_frac_digits = avail_width_for_frac_part // font_metrics.digit_width - frac_digits = self._data_frac_digits(data_sample, max_frac_digits=max_frac_digits) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f" -> detected frac_digits={frac_digits}") + # logger.debug(f" {font_metrics.digit_width=} => {max_frac_digits=}") + frac_digits = data_frac_digits(finite_sample, max_frac_digits=max_frac_digits) + # logger.info(f" -> detected {frac_digits=}") format_letter = 'e' if scientific else 'f' fmt = '%%.%d%s' % (frac_digits, format_letter) - data_colwidth = font_metrics.get_numbers_width(int_digits, frac_digits, need_sign=has_negative, - scientific=scientific) + data_colwidth = ( + font_metrics.get_numbers_width(int_digits, + frac_digits, + need_sign=has_negative, + scientific=scientific)) + if not is_finite_data: + # We have nans or infs, so we have to make sure we have enough + # room to display "nan" or "inf" + # ideally we should add a finite_sample argument to + # get_numbers_width so that we take the actual "nan" and "inf" + # strings width but I am unsure it is worth it. + # Especially given this whole thing is almost useless (it only + # serves to trigger the data_colwidth > cur_colwidth condition + # so that the column widths are re-computed & updated) + inf_nan_colwidth = ( + font_metrics.get_numbers_width(3, + need_sign=vmin == -math.inf)) + data_colwidth = max(data_colwidth, inf_nan_colwidth) else: frac_digits = 0 scientific = False fmt = '%s' - # TODO: compute actual column width using data - data_colwidth = 60 + data_colwidth = 0 - self.model_data.set_format(fmt, reset=True) + self.data_adapter.set_format(fmt) + self.model_data._get_current_data() + self.model_data.reset() - self.frac_digits = frac_digits self.use_scientific = scientific # avoid triggering frac_digits_changed which would cause a useless redraw @@ -936,26 +2183,76 @@ def set_format(self, frac_digits=None, scientific=None): self.scientific_checkbox.setEnabled(is_number_dtype) self.scientific_checkbox.blockSignals(False) - if not scientific_toggled or data_colwidth > cur_colwidth: - header = self.view_hlabels.horizontalHeader() - - # FIXME: this will set width of the 40 first columns (otherwise it gets very slow, eg. big1d) - # but I am not eager to fix this before merging the buffer branch - num_cols = min(header.count(), 40) - hlabels = self.model_hlabels.get_values(bottom=num_cols) - str_width = FontMetrics(self.model_hlabels).str_width - - MIN_COLWITH = 30 - data_colwidth = max(data_colwidth, MIN_COLWITH) - - MARGIN_WIDTH = 8 # empirically measured - - def get_header_width(i): - return MARGIN_WIDTH + max(str_width(str(label)) for label in hlabels[i]) - - for i in range(num_cols): - colwidth = max(get_header_width(i), data_colwidth) - header.resizeSection(i, colwidth) + frac_digits_changed = not scientific_toggled + # frac digits changed => set new column width + if frac_digits_changed or data_colwidth == 0 or data_colwidth > cur_colwidth: + self._update_hlabels_column_widths_from_content() + + def _update_hlabels_column_widths_from_content(self): + h_offset = self.model_data.h_offset + hlabels_header = self.view_hlabels.horizontalHeader() + + # TODO: I wonder if we should only set the widths for the visible + # columns + some margin as the buffer could become relatively + # large if we let adapters decide their size + user_def_col_width = self.user_defined_hlabels_column_widths + + # This prevents the auto column width changes below from updating + # user_defined_column_widths (via the sectionResized signal). + # This is ugly but I found no other way to avoid that because using + # hlabels_header.blockSignals(True) breaks updating the column width + # in response to ndigits changes (it updates the header widths but not + # the hlabels cells widths). + self._updating_hlabels_column_widths = True + for local_col_idx in range(self.model_data.columnCount()): + global_col_idx = h_offset + local_col_idx + + # We should NOT take the max of the user defined column width and + # the computed column width because sometimes the user wants + # a _smaller_ width than the auto-detected one + if global_col_idx in user_def_col_width: + hlabels_header.resizeSection(local_col_idx, + user_def_col_width[global_col_idx]) + else: + self.resize_hlabels_column_to_contents(local_col_idx, + MIN_COLUMN_WIDTH, MAX_COLUMN_WIDTH) + self._updating_hlabels_column_widths = False + + def _update_vlabels_row_heights_from_content(self): + vlabels_header = self.view_vlabels.verticalHeader() + v_offset = self.model_data.v_offset + self._updating_vlabels_row_heights = True + user_def_row_heights = self.user_defined_vlabels_row_heights + for local_row_idx in range(self.model_vlabels.rowCount()): + global_row_idx = v_offset + local_row_idx + if global_row_idx in user_def_row_heights: + vlabels_header.resizeSection(local_row_idx, + user_def_row_heights[global_row_idx]) + else: + self.resize_vlabels_row_to_contents(local_row_idx) + self._updating_vlabels_row_heights = False + + def _update_axes_column_widths_from_content(self): + self._updating_axes_column_widths = True + user_widths = self.user_defined_axes_column_widths + for local_col_idx in range(self.model_axes.columnCount()): + # Since there is no h_offset for axes, the column width never + # actually changes unless the user explicitly changes it, so just + # preventing the auto-sizing code from running is enough + if local_col_idx not in user_widths: + self.resize_axes_column_to_contents(local_col_idx) + self._updating_axes_column_widths = False + + def _update_axes_row_heights_from_content(self): + self._updating_axes_row_heights = True + user_def_row_heights = self.user_defined_axes_row_heights + for local_row_idx in range(self.model_axes.rowCount()): + # Since there is no v_offset for axes, the row height never + # actually changes unless the user explicitly changes it, so just + # preventing the auto-sizing code from running is enough + if local_row_idx not in user_def_row_heights: + self.resize_axes_row_to_contents(local_row_idx) + self._updating_axes_row_heights = False def _get_current_min_col_width(self): header = self.view_hlabels.horizontalHeader() @@ -964,167 +2261,205 @@ def _get_current_min_col_width(self): else: return 0 - def _data_frac_digits(self, data, max_frac_digits): - if not data.size: - return 0 - threshold = 10 ** -(max_frac_digits + 1) - for frac_digits in range(max_frac_digits): - maxdiff = np.max(np.abs(data - np.round(data, frac_digits))) - if maxdiff < threshold: - return frac_digits - return max_frac_digits - - def autofit_columns(self): - self.view_axes.autofit_columns() - for column in range(self.model_axes.columnCount()): - self.resize_axes_column_to_contents(column) - self.view_hlabels.autofit_columns() - for column in range(self.model_hlabels.columnCount()): - self.resize_hlabels_column_to_contents(column) - - def resize_axes_column_to_contents(self, column): - # must be connected to view_axes.horizontalHeader().sectionHandleDoubleClicked signal - width = max(self.view_axes.horizontalHeader().sectionSize(column), - self.view_vlabels.sizeHintForColumn(column)) - # no need to call resizeSection on view_vlabels (see synchronization lines in init) - self.view_axes.horizontalHeader().resizeSection(column, width) - - def resize_hlabels_column_to_contents(self, column): - # must be connected to view_labels.horizontalHeader().sectionHandleDoubleClicked signal - width = max(self.view_hlabels.horizontalHeader().sectionSize(column), - self.view_data.sizeHintForColumn(column)) - # no need to call resizeSection on view_data (see synchronization lines in init) - self.view_hlabels.horizontalHeader().resizeSection(column, width) - - def resize_axes_row_to_contents(self, row): - # must be connected to view_axes.verticalHeader().sectionHandleDoubleClicked - height = max(self.view_axes.verticalHeader().sectionSize(row), - self.view_hlabels.sizeHintForRow(row)) - # no need to call resizeSection on view_hlabels (see synchronization lines in init) - self.view_axes.verticalHeader().resizeSection(row, height) - - def resize_vlabels_row_to_contents(self, row): - # must be connected to view_labels.verticalHeader().sectionHandleDoubleClicked - height = max(self.view_vlabels.verticalHeader().sectionSize(row), - self.view_data.sizeHintForRow(row)) - # no need to call resizeSection on view_data (see synchronization lines in init) - self.view_vlabels.verticalHeader().resizeSection(row, height) + # must be connected to signal: + # view_axes.horizontalHeader().sectionHandleDoubleClicked + def resize_axes_column_to_contents(self, col_idx): + # clsname = self.__class__.__name__ + # print(f"{clsname}.resize_axes_column_to_contents({col_idx})") + # TODO: + # * maybe reimplement resizeColumnToContents(column) instead? Though + # the doc says it only resize visible columns, so that might not + # work + # * reimplementing sizeHintForColumn on AxesView would be cleaner + # but that would require making it know of the view_vlabels instance + prev_width = self.detected_axes_column_widths.get(col_idx, 0) + width = max(self.view_axes.sizeHintForColumn(col_idx), + self.view_vlabels.sizeHintForColumn(col_idx), + prev_width) + # view_vlabels column width will be synchronized automatically + self.view_axes.horizontalHeader().resizeSection(col_idx, width) + + # set that column's width back to "automatic width" + if col_idx in self.user_defined_axes_column_widths: + del self.user_defined_axes_column_widths[col_idx] + if width > prev_width: + self.detected_axes_column_widths[col_idx] = width + + # must be connected to signal: + # view_hlabels.horizontalHeader().sectionHandleDoubleClicked + def resize_hlabels_column_to_contents(self, local_col_idx, + min_width=None, max_width=None): + global_col_idx = self.model_data.h_offset + local_col_idx + prev_width = self.detected_hlabels_column_widths.get(global_col_idx, 0) + # logger.debug("ArrayEditorWidget.resize_hlabels_column_to_contents(" + # f"{local_col_idx=}, {min_width=}, {max_width=})") + width = max(self.view_hlabels.sizeHintForColumn(local_col_idx), + self.view_data.sizeHintForColumn(local_col_idx), + prev_width) + # logger.debug(f" {global_col_idx=} {prev_width=} => (before clip) " + # f"{width=} ") + if min_width is not None: + width = max(width, min_width) + if max_width is not None: + width = min(width, max_width) + # logger.debug(f" -> (after clip) {width=}") + # view_data column width will be synchronized automatically + self.view_hlabels.horizontalHeader().resizeSection(local_col_idx, width) + + # set that column's width back to "automatic width" + if global_col_idx in self.user_defined_hlabels_column_widths: + del self.user_defined_hlabels_column_widths[global_col_idx] + if width > prev_width: + # logger.debug(f" -> width > prev_width (updating detected)") + self.detected_hlabels_column_widths[global_col_idx] = width + + # must be connected to signal: + # view_axes.verticalHeader().sectionHandleDoubleClicked + def resize_axes_row_to_contents(self, row_idx): + # clsname = self.__class__.__name__ + # print(f"{clsname}.resize_axes_row_to_contents({row})") + prev_height = self.detected_axes_row_heights.get(row_idx, 0) + height = max(self.view_axes.sizeHintForRow(row_idx), + self.view_hlabels.sizeHintForRow(row_idx), + prev_height) + # view_hlabels row height will be synchronized automatically + self.view_axes.verticalHeader().resizeSection(row_idx, height) + # set that row's height back to "automatic height" + if row_idx in self.user_defined_axes_row_heights: + del self.user_defined_axes_row_heights[row_idx] + if height > prev_height: + self.detected_axes_row_heights[row_idx] = height + + # must be connected to signal: + # view_vlabels.verticalHeader().sectionHandleDoubleClicked + def resize_vlabels_row_to_contents(self, local_row_idx): + # clsname = self.__class__.__name__ + # print(f"{clsname}.resize_vlabels_row_to_contents({row})") + global_row_idx = self.model_data.v_offset + local_row_idx + prev_height = self.detected_vlabels_row_heights.get(global_row_idx, 0) + height = max(self.view_vlabels.sizeHintForRow(local_row_idx), + self.view_data.sizeHintForRow(local_row_idx), + prev_height) + # view_data row height will be synchronized automatically + self.view_vlabels.verticalHeader().resizeSection(local_row_idx, height) + # set that row's height back to "automatic height" + if global_row_idx in self.user_defined_vlabels_row_heights: + del self.user_defined_vlabels_row_heights[global_row_idx] + if height > prev_height: + self.detected_vlabels_row_heights[global_row_idx] = height def scientific_changed(self, value): # auto-detect frac_digits - self.set_format(frac_digits=None, scientific=bool(value)) + self.set_frac_digits_or_scientific(frac_digits=None, scientific=bool(value)) def frac_digits_changed(self, value): - self.set_format(value, self.use_scientific) - - def change_filter(self, axis, indices): - self.data_adapter.update_filter(axis, indices) - self._update_models(reset_model_data=True, reset_minmax=False) - - def create_filter_combo(self, axis): - def filter_changed(checked_items): - self.change_filter(axis, checked_items) - combo = FilterComboBox(self) - combo.addItems([str(label) for label in axis.labels]) - combo.checkedItemsChanged.connect(filter_changed) - return combo + # TODO: I should probably drop the use_scientific field and just + # retrieve the checkbox value + self.set_frac_digits_or_scientific(value, self.use_scientific) - def _selection_data(self, headers=True, none_selects_all=True): - """ - Return selected labels as lists and raw data as Numpy ndarray - if headers=True or only the raw data otherwise + def sort_axis_labels(self, axis_idx, ascending): + self.data_adapter.sort_axis_labels(axis_idx, ascending) + self._set_models_adapter() - Parameters - ---------- - headers : bool, optional - Labels are also returned if True. - none_selects_all : bool, optional - If True (default) and selection is empty, returns all data. - - Returns - ------- - raw_data: numpy.ndarray - axes_names: list - vlabels: nested list - hlabels: list - """ - bounds = self.view_data._selection_bounds(none_selects_all=none_selects_all) - if bounds is None: - return None - row_min, row_max, col_min, col_max = bounds - raw_data = self.model_data.get_values(row_min, col_min, row_max, col_max) - if headers: - # FIXME: using data_adapter.ndim here and in the vlabels line below is - # inherently buggy, because this does not take filter into account, - # which should be the case for selection-related stuff which work - # on visible data - if not self.data_adapter.ndim: - return raw_data, None, None, None - axes_names = self.model_axes.get_values() - if len(axes_names): - hlabels = [label[0] - for label in self.model_hlabels.get_values(top=col_min, bottom=col_max)] - else: - hlabels = [] - vlabels = self.model_vlabels.get_values(left=row_min, right=row_max) if self.data_adapter.ndim > 1 else [] - return raw_data, axes_names, vlabels, hlabels - else: - return raw_data + def sort_hlabel(self, row_idx, col_idx, ascending): + self.data_adapter.sort_hlabel(row_idx, col_idx, ascending) + self._set_models_adapter() + # since we will probably display different rows, they can have different + # column widths + self.update_cell_sizes_from_content() def copy(self): """Copy selection as text to clipboard""" - raw_data, axes_names, vlabels, hlabels = self._selection_data() - data = self.data_adapter.selection_to_chain(raw_data, axes_names, vlabels, hlabels) - if data is None: - return - - # np.savetxt make things more complicated, especially on py3 - # XXX: why don't we use repr for everything? - def vrepr(v): - if isinstance(v, float): - return repr(v) - else: - return str(v) - text = '\n'.join('\t'.join(vrepr(v) for v in line) for line in data) + text = self.data_adapter.to_string(*self.view_data.selection_bounds()) clipboard = QApplication.clipboard() clipboard.setText(text) def to_excel(self): """Export selection in Excel""" - raw_data, axes_names, vlabels, hlabels = self._selection_data() try: - self.data_adapter.to_excel(raw_data, axes_names, vlabels, hlabels) + self.data_adapter.to_excel(*self.view_data.selection_bounds()) except ImportError: - QMessageBox.critical(self, "Error", "to_excel() is not available because xlwings is not installed") + msg = "to_excel() is not available because xlwings is not installed" + QMessageBox.critical(self, "Error", msg) def paste(self): - bounds = self.view_data._selection_bounds() - if bounds is None: - return - row_min, row_max, col_min, col_max = bounds + # FIXME: this now returns coordinates in global space while the rest of + # this function assumes local/buffer space coordinates. But this + # whole "set_values" code should be revisited entirely anyway + row_min, row_max, col_min, col_max = self.view_data.selection_bounds() clipboard = QApplication.clipboard() text = str(clipboard.text()) list_data = [line.split('\t') for line in text.splitlines()] - try: - # take the first cell which contains '\' - pos_last = next(i for i, v in enumerate(list_data[0]) if '\\' in v) - except StopIteration: - # if there isn't any, assume 1d array - pos_last = 0 - if pos_last or '\\' in list_data[0][0]: - # ndim > 1 - list_data = [line[pos_last + 1:] for line in list_data[1:]] - elif len(list_data) == 2 and list_data[1][0] == '': - # ndim == 1 - list_data = [list_data[1][1:]] + list_data = self.data_adapter.from_clipboard_data_to_model_data(list_data) new_data = np.array(list_data) if new_data.shape[0] > 1: row_max = row_min + new_data.shape[0] if new_data.shape[1] > 1: col_max = col_min + new_data.shape[1] + # FIXME: the way to change data is extremely convoluted (and slightly wrong): + # * either Widget.paste or ArrayDelegate.setModelData (via model.setData) + # * calls model.set_values on the model + # that does not change any data but computes a + # {global_filtered_2D_coords: (old_value, new_value)} + # dict of changes + # * emits a newChanges signal and then a dataChanged (builtin) + # signal (which only works/makes sense because a signal .emit + # call only returns when all its slots have executed, hence the + # whole chain below has already been executed when that second + # signal is emitted). + # * the newChanges signal is caught by the widget, which + # * asks the adapter to transform the changes from 2d global (but + # potentially filtered) positional keys to ND global positional + # keys, then + # * re-emits a dataChanged signal with a list of those changes, + # * the editor catches that signal and + # * push those changes to the edit_undo_stack which actually + # * applies each change by using + # kernel.shell.run_cell(f"{self.target}.i[{key}] = {new_value}") + # OR + # self.target.i[key] = new_value + # and there, + # * editor.array_widget.model_data.reset() is called which + # * notifies qt the whole thing needs to be refreshed (including + # reprocessing the data via _process_data but does *NOT* fetch + # the actual new data!!!) + # and it actually only appears to work in the simple case of + # editing an unfiltered array because we are using array *views* + # all the way so when we edit the array, the "raw_data" in the + # model is updated directly too and _process_data is indeed + # enough. + # > + # Since we can *NOT* push a command on the edit_undo_stack + # without executing it, we should: + # * create widget.set_values method, call it from paste and the + # ArrayDelegate + # - ask the adapter to create an edit_undo_stack command (which + # will change the real data) + # * create a {NDkey: changes} + # - push command + # - call a new method on the model akin to reset() but which + # *fetches* the data in addition to processing it + # - we will probably need to emit/use signals in there but this + # can come later + # I am still undecided on whether the commands should actually + # update the live object or add changes to a "changes layer", + # which can later be all applied to the real objects. For + # in-memory objects, updating the objects directly seem better + # so that e.g. console commands stay consistent with what we see + # but for on-disk data, writing each change directly to disk + # seems inefficient and surprising. I suppose that decision + # should be done (and implemented) by the adapter but I have no + # idea how to convey the difference to users. It should be + # obvious but unobstrusive and users will need a way to trigger + # a "save". In any case, this can come later as we currently + # do not have any disk-backed adapter anywhere close to + # supporting editing values. + # + # as a side note the (visible) Scrollbar is connected to the + # reset event and updates its range in that + # case which is useless result = self.model_data.set_values(row_min, col_min, row_max, col_max, new_data) - if result is None: return @@ -1136,13 +2471,12 @@ def paste(self): def plot(self): from larray_editor.utils import show_figure from larray_editor.editor import AbstractEditorWindow, MappingEditorWindow - raw_data, axes_names, vlabels, hlabels = self._selection_data() try: - figure = self.data_adapter.plot(raw_data, axes_names, vlabels, hlabels) + figure = self.data_adapter.plot(*self.view_data.selection_bounds()) widget = self while widget is not None and not isinstance(widget, AbstractEditorWindow) and callable(widget.parent): widget = widget.parent() title = widget.current_expr_text if isinstance(widget, MappingEditorWindow) else None - show_figure(self, figure, title) + show_figure(figure, title, parent=self) except ImportError: QMessageBox.critical(self, "Error", "plot() is not available because matplotlib is not installed") diff --git a/larray_editor/combo.py b/larray_editor/combo.py index a0091170..24a47548 100644 --- a/larray_editor/combo.py +++ b/larray_editor/combo.py @@ -1,6 +1,8 @@ from qtpy import QtGui, QtCore, QtWidgets from qtpy.QtCore import QPoint, Qt +from larray_editor.utils import create_action, _ + class StandardItemModelIterator: def __init__(self, model): @@ -62,14 +64,43 @@ def set_checked(self, value): checked = property(get_checked, set_checked) -class FilterMenu(QtWidgets.QMenu): +class CombinedSortFilterMenu(QtWidgets.QMenu): activate = QtCore.Signal(int) - checkedItemsChanged = QtCore.Signal(list) + checked_items_changed = QtCore.Signal(list) + sort_signal = QtCore.Signal(bool) # bool argument is for ascending - def __init__(self, parent=None): + def __init__(self, parent=None, + sortable: bool = False, + sort_direction: str = 'unsorted', + filtrable=False): super().__init__(parent) - self._list_view = QtWidgets.QListView(parent) + self._model, self._list_view = None, None + + if sortable: + self.addAction(create_action(self, _('Sort A-Z'), + triggered=lambda: self.sort_signal.emit(True), + checkable=True, + checked=sort_direction == 'ascending')) + self.addAction(create_action(self, _('Sort Z-A'), + triggered=lambda: self.sort_signal.emit(False), + checkable=True, + checked=sort_direction == 'descending')) + if filtrable: + self.addSeparator() + + if filtrable: + self.setup_list_view() + + self.installEventFilter(self) + self.activate.connect(self.on_activate) + + def setup_list_view(self): + # search_widget = QtWidgets.QLineEdit() + # search_widget.setPlaceholderText('Search') + # self.add_action_widget(search_widget) + + self._list_view = QtWidgets.QListView(self) self._list_view.setFrameStyle(0) model = SequenceStandardItemModel() self._list_view.setModel(model) @@ -80,17 +111,25 @@ def __init__(self, parent=None): except AttributeError: # this is the new name for qt6+ model[0].setUserTristate(True) - - action = QtWidgets.QWidgetAction(self) - action.setDefaultWidget(self._list_view) - self.addAction(action) - self.installEventFilter(self) self._list_view.installEventFilter(self) self._list_view.window().installEventFilter(self) - model.itemChanged.connect(self.on_model_item_changed) self._list_view.pressed.connect(self.on_list_view_pressed) - self.activate.connect(self.on_activate) + + # filters_layout = QtWidgets.QVBoxLayout(parent) + # filters_layout.addWidget(QtWidgets.QLabel("Filters")) + # filters_layout.addWidget(self._list_view) + self.add_action_widget(self._list_view) + + def add_action_widget(self, action_widget): + if isinstance(action_widget, QtWidgets.QLayout): + # you cant add a layout directly in an action, so we have to wrap it in a widget + widget = QtWidgets.QWidget() + widget.setLayout(action_widget) + action_widget = widget + widget_action = QtWidgets.QWidgetAction(self) + widget_action.setDefaultWidget(action_widget) + self.addAction(widget_action) def on_list_view_pressed(self, index): item = self._model.itemFromIndex(index) @@ -123,7 +162,7 @@ def on_model_item_changed(self, item): model[0].checked = 'partial' model.blockSignals(False) checked_indices = [i for i, item in enumerate(model[1:]) if item.checked] - self.checkedItemsChanged.emit(checked_indices) + self.checked_items_changed.emit(checked_indices) # function is called to implement wheel scrolling (select prev/next label) def select_offset(self, offset): @@ -150,18 +189,24 @@ def select_offset(self, offset): for i, item in enumerate(model[1:], start=1): item.checked = i == to_check model.blockSignals(False) - self.checkedItemsChanged.emit([to_check - 1]) + self.checked_items_changed.emit([to_check - 1]) - def addItem(self, text): + def addItem(self, text, checked=True): item = StandardItem(text) # not editable - item.setFlags(Qt.ItemIsUserCheckable | Qt.ItemIsEnabled) - item.checked = True + item.setFlags(QtCore.Qt.ItemIsUserCheckable | QtCore.Qt.ItemIsEnabled) + item.checked = checked self._model.appendRow(item) - def addItems(self, items): - for item in items: - self.addItem(item) + def addItems(self, items, items_checked=None): + if items_checked is None: + for item_label in items: + self.addItem(item_label) + else: + assert 0 <= len(items_checked) <= len(items) + checked_indices_set = set(items_checked) + for idx, item_label in enumerate(items): + self.addItem(item_label, idx in checked_indices_set) def eventFilter(self, obj, event): event_type = event.type() @@ -170,13 +215,14 @@ def eventFilter(self, obj, event): key = event.key() # tab key closes the popup - if obj == self._list_view.window() and key == Qt.Key_Tab: + if obj == self._list_view.window() and key == QtCore.Qt.Key_Tab: self.hide() # return key activates *one* item and closes the popup # first time the key is sent to the menu, afterwards to # list_view - elif obj == self._list_view and key in (Qt.Key_Enter, Qt.Key_Return): + elif (obj == self._list_view and + key in (QtCore.Qt.Key_Enter, QtCore.Qt.Key_Return)): self.activate.emit(self._list_view.currentIndex().row()) self.hide() return True @@ -185,7 +231,7 @@ def eventFilter(self, obj, event): class FilterComboBox(QtWidgets.QToolButton): - checkedItemsChanged = QtCore.Signal(list) + checked_items_changed = QtCore.Signal(list) def __init__(self, parent=None): super().__init__(parent) @@ -195,10 +241,10 @@ def __init__(self, parent=None): # uglier self.setPopupMode(QtWidgets.QToolButton.MenuButtonPopup) - menu = FilterMenu(self) + menu = CombinedSortFilterMenu(self, filtrable=True) self.setMenu(menu) self._menu = menu - menu.checkedItemsChanged.connect(self.on_checked_items_changed) + menu.checked_items_changed.connect(self.on_checked_items_changed) self.installEventFilter(self) def on_checked_items_changed(self, indices_checked): @@ -210,7 +256,7 @@ def on_checked_items_changed(self, indices_checked): self.setText(model[indices_checked[0] + 1].text()) else: self.setText("multi") - self.checkedItemsChanged.emit(indices_checked) + self.checked_items_changed.emit(indices_checked) def addItem(self, text): self._menu.addItem(text) @@ -240,10 +286,19 @@ def eventFilter(self, obj, event): # return key activates *one* item and closes the popup # first time the key is sent to self, afterwards to list_view - elif obj == self and key in (Qt.Key_Enter, Qt.Key_Return): - self._menu.activate.emit(self._list_view.currentIndex().row()) - self._menu.hide() - return True + # elif obj == self and key in (Qt.Key_Enter, Qt.Key_Return): + # print(f'FilterComboBox.eventFilter') + # # this cannot work (there is no _list_view attribute) + # # probably meant as self._menu._list_view BUT + # # this case currently does not seem to happen anyway + # # I am not removing this code entirely because the + # # combo does not seem to get focus which could explain + # # why this is never reached + # current_index = self._list_view.currentIndex().row() + # print(f'current_index={current_index}') + # self._menu.activate.emit(current_index) + # self._menu.hide() + # return True if event_type == QtCore.QEvent.MouseButtonRelease: # clicking anywhere (not just arrow) on the button shows the popup diff --git a/larray_editor/commands.py b/larray_editor/commands.py index 8a004d2b..352b7e9c 100644 --- a/larray_editor/commands.py +++ b/larray_editor/commands.py @@ -11,12 +11,13 @@ from larray_editor.utils import logger -class ArrayValueChange: +class CellValueChange: """ Class representing the change of one value of an array. Parameters ---------- + # FIXME: key is a tuple of indices key: list/tuple of str Key associated with the value old_value: scalar @@ -32,17 +33,17 @@ def __init__(self, key, old_value, new_value): # XXX: we need to handle the case of several changes at once because the method paste() # of ArrayEditorWidget can be used on objects not handling MultiIndex axes (LArray, Numpy). -class EditArrayCommand(QUndoCommand): +class EditObjectCommand(QUndoCommand): """ Class representing the change of one or several value(s) of an array. Parameters ---------- - editor: MappingEditor - Instance of MappingEditor + editor: AbstractEditorWindow + Instance of AbstractEditorWindow target : object - target array to edit. Can be given under any form. - changes: (list of) instance(s) of ArrayValueChange + target object to edit. Can be given under any form. + changes: list of CellValueChange List of changes """ @@ -61,12 +62,18 @@ def __init__(self, editor, target, changes): def undo(self): for change in self.changes: self.apply_change(change.key, change.old_value) - self.editor.arraywidget.model_data.reset() + # FIXME: a full reset is bad, see comment below + self.editor.array_widget.model_data.reset() def redo(self): for change in self.changes: self.apply_change(change.key, change.new_value) - self.editor.arraywidget.model_data.reset() + # FIXME: a full reset is both wasteful, and causes hidden scrollbars + # to jump back to 0 after each cell change, which is very + # annoying. We have an awful workaround for this in + # ArrayDelegate.setModelData but the issue should still be fixed + # properly + self.editor.array_widget.model_data.reset() def get_description(self, target, changes): raise NotImplementedError() @@ -75,7 +82,7 @@ def apply_change(self, key, new_value): raise NotImplementedError() -class EditSessionArrayCommand(EditArrayCommand): +class EditSessionArrayCommand(EditObjectCommand): """ Class representing the change of one or several value(s) of an array. @@ -85,20 +92,21 @@ class EditSessionArrayCommand(EditArrayCommand): Instance of MappingEditor target : str name of array to edit - changes: (list of) instance(s) of ArrayValueChange + changes: (list of) instance(s) of CellValueChange List of changes """ - def get_description(self, target, changes): + def get_description(self, target: str, changes: list[CellValueChange]): if len(changes) == 1: return f"Editing Cell {changes[0].key} of {target}" else: return f"Pasting {len(changes)} Cells in {target}" def apply_change(self, key, new_value): - self.editor.kernel.shell.run_cell(f"{self.target}[{key}] = {new_value}") + # FIXME: we should pass via the adapter to have something generic + self.editor.ipython_kernel.shell.run_cell(f"{self.target}.i[{key}] = {new_value}") -class EditCurrentArrayCommand(EditArrayCommand): +class EditCurrentArrayCommand(EditObjectCommand): """ Class representing the change of one or several value(s) of the current array. @@ -108,7 +116,7 @@ class EditCurrentArrayCommand(EditArrayCommand): Instance of ArrayEditor target : Array array to edit - changes : (list of) instance(s) of ArrayValueChange + changes : (list of) ArrayValueChange List of changes """ def get_description(self, target, changes): @@ -118,4 +126,5 @@ def get_description(self, target, changes): return f"Pasting {len(changes)} Cells" def apply_change(self, key, new_value): - self.target[key] = new_value + # FIXME: we should pass via the adapter to have something generic + self.target.i[key] = new_value diff --git a/larray_editor/comparator.py b/larray_editor/comparator.py index 4e00c1aa..f668f72f 100644 --- a/larray_editor/comparator.py +++ b/larray_editor/comparator.py @@ -1,20 +1,23 @@ import numpy as np import larray as la +import pandas as pd from qtpy.QtCore import Qt from qtpy.QtWidgets import (QWidget, QVBoxLayout, QListWidget, QSplitter, QHBoxLayout, QLabel, QCheckBox, QLineEdit, QComboBox, QMessageBox) -from larray_editor.utils import replace_inf, _ +from larray_editor.utils import _, print_exception, align_arrays from larray_editor.arraywidget import ArrayEditorWidget -from larray_editor.editor import AbstractEditorWindow, DISPLAY_IN_GRID +from larray_editor.editor import AbstractEditorWindow + + +CAN_CONVERT_TO_LARRAY = (la.Array, np.ndarray, pd.DataFrame) class ComparatorWidget(QWidget): """Comparator Widget""" - # FIXME: rtol, atol are unused, and align and fill_value are only partially used - def __init__(self, parent=None, bg_gradient='red-white-blue', rtol=0, atol=0, nans_equal=True, - align='outer', fill_value=np.nan): + def __init__(self, parent=None, bg_gradient='red-white-blue', + nans_equal=True, fill_value=np.nan): QWidget.__init__(self, parent) layout = QVBoxLayout() @@ -30,9 +33,9 @@ def __init__(self, parent=None, bg_gradient='red-white-blue', rtol=0, atol=0, na maxdiff_layout.addStretch() layout.addLayout(maxdiff_layout) - # arraywidget - self.arraywidget = ArrayEditorWidget(self, data=None, readonly=True, bg_gradient=bg_gradient) - layout.addWidget(self.arraywidget) + # array widget + self.array_widget = ArrayEditorWidget(self, data=None, readonly=True, bg_gradient=bg_gradient) + layout.addWidget(self.array_widget) self._combined_array = None self._array0 = None @@ -41,7 +44,6 @@ def __init__(self, parent=None, bg_gradient='red-white-blue', rtol=0, atol=0, na self.stack_axis = None self.nans_equal = nans_equal - self.align_method = align self.fill_value = fill_value # TODO: we might want to use self.align_method, etc instead of using arguments? @@ -64,7 +66,7 @@ def get_comparison_options_layout(self, align, atol, rtol): abs(array1[i] - array2[i]) <= (absolute_tol + relative_tol * abs(array2[i]))""" tolerance_label = QLabel("Tolerance:") tolerance_label.setToolTip(tooltip) - # self.arraywidget.btn_layout.addWidget(tolerance_label) + # self.array_widget.btn_layout.addWidget(tolerance_label) layout.addWidget(tolerance_label) tolerance_combobox = QComboBox() tolerance_combobox.addItems(["absolute", "relative"]) @@ -100,6 +102,9 @@ def get_comparison_options_layout(self, align, atol, rtol): layout.addStretch() return layout + def get_align_method(self): + return self.align_method_combo.currentText() + def _get_atol_rtol(self): try: tol_str = self.tolerance_line_edit.text() @@ -112,7 +117,8 @@ def _get_atol_rtol(self): self.tolerance_line_edit.setText('') tol = 0 QMessageBox.critical(self, "Error", str(e)) - return (tol, 0) if self.tolerance_combobox.currentText() == "absolute" else (0, tol) + is_absolute = self.tolerance_combobox.currentText() == "absolute" + return (tol, 0) if is_absolute else (0, tol) # override keyPressEvent to prevent pressing Enter after changing the tolerance value # in associated QLineEdit to close the parent dialog box @@ -136,19 +142,20 @@ def set_data(self, arrays, stack_axis): self._update_from_arrays() def update_align_method(self, align): - self.align_method = align self._update_from_arrays() def _update_from_arrays(self): # TODO: implement align in stack instead stack_axis = self.stack_axis + align_method = self.get_align_method() try: - aligned_arrays = align_all(self.arrays, - join=self.align_method, - fill_value=self.fill_value) + aligned_arrays = align_arrays(self.arrays, + join=align_method, + fill_value=self.fill_value) self._combined_array = la.stack(aligned_arrays, stack_axis) self._array0 = self._combined_array[stack_axis.i[0]] except Exception as e: + print_exception(e) QMessageBox.critical(self, "Error", str(e)) self._combined_array = la.Array(['']) self._array0 = self._combined_array @@ -160,31 +167,69 @@ def _update_from_combined_array(self): atol, rtol = self._get_atol_rtol() try: - self._diff_below_tolerance = self._combined_array.eq(self._array0, rtol=rtol, atol=atol, nans_equal=self.nans_equal) + # eq does not take atol and rtol into account + eq = self._combined_array.eq(self._array0, + nans_equal=self.nans_equal) + isclose = self._combined_array.eq(self._array0, + rtol=rtol, atol=atol, + nans_equal=self.nans_equal) except TypeError: - self._diff_below_tolerance = self._combined_array == self._array0 + # object arrays + eq = self._combined_array == self._array0 + isclose = eq + self._diff_below_tolerance = isclose try: with np.errstate(divide='ignore', invalid='ignore'): diff = self._combined_array - self._array0 reldiff = diff / self._array0 - + # make reldiff 0 where the values are the same than array0 even for + # special values (0, nan, inf, -inf) + # at this point reldiff can still contain nan and infs + reldiff = la.where(eq, 0, reldiff) + + # 1) compute maxabsreldiff for the label + # this should NOT exclude nans or infs + relmin = reldiff.min(skipna=False) + relmax = reldiff.max(skipna=False) + maxabsreldiff = max(abs(relmin), abs(relmax)) + + # 2) compute bg_value # replace -inf by min(reldiff), +inf by max(reldiff) - finite_reldiff, finite_relmin, finite_relmax = replace_inf(reldiff) - maxabsreldiff = max(abs(finite_relmin), abs(finite_relmax)) - - # We need a separate version for bg and the label, so that when we modify atol/rtol, the background - # color is updated but not the maxreldiff label - # this is necessary for nan, inf and -inf (because inf - inf = nan, not 0) - # this is more precise than divnot0, it only ignore 0 / 0, not x / 0 - reldiff_for_bg = la.where(self._diff_below_tolerance, 0, finite_reldiff) - maxabsreldiff_for_bg = max(abs(np.nanmin(reldiff_for_bg)), abs(np.nanmax(reldiff_for_bg))) + reldiff_for_bg = reldiff.copy() + isneginf = reldiff == -np.inf + isposinf = reldiff == np.inf + isinf = isneginf | isposinf + + # given the way reldiff is constructed, it cannot contain only infs + # (because inf/inf is nan) it can contain only infs and nans though, + # in which case finite_relXXX will be nan, so unless the array + # is empty, finite_relXXX should never be inf + finite_relmin = np.nanmin(reldiff, where=~isinf, initial=np.inf) + finite_relmax = np.nanmax(reldiff, where=~isinf, initial=-np.inf) + # special case when reldiff contains only 0 and infs (to avoid + # coloring the inf cells white in that case) + if finite_relmin == 0 and finite_relmax == 0 and isinf.any(): + finite_relmin = -1 + finite_relmax = 1 + reldiff_for_bg[isneginf] = finite_relmin + reldiff_for_bg[isposinf] = finite_relmax + + # make sure that "acceptable" differences show as white + reldiff_for_bg = la.where(isclose, 0, reldiff_for_bg) + + # We need a separate version for bg and the label, so that when we + # modify atol/rtol, the background color is updated but not the + # maxreldiff label + maxabsreldiff_for_bg = max(abs(np.nanmin(reldiff_for_bg)), + abs(np.nanmax(reldiff_for_bg))) if maxabsreldiff_for_bg: # scale reldiff to range 0-1 with 0.5 for reldiff = 0 self._bg_value = (reldiff_for_bg / maxabsreldiff_for_bg) / 2 + 0.5 # if the only differences are nans on either side - elif not self._diff_below_tolerance.all(): - # use white (0.5) everywhere except where reldiff is nan, so that nans are grey + elif not isclose.all(): + # use white (0.5) everywhere except where reldiff is nan, so + # that nans are grey self._bg_value = reldiff_for_bg + 0.5 else: # do NOT use full_like as we don't want to inherit array dtype @@ -195,7 +240,10 @@ def _update_from_combined_array(self): # do NOT use full_like as we don't want to inherit array dtype self._bg_value = la.full(self._combined_array.axes, 0.5) + # using percents does not look good when the numbers are very small self.maxdiff_label.setText(str(maxabsreldiff)) + color = 'red' if maxabsreldiff != 0.0 else 'black' + self.maxdiff_label.setStyleSheet(f"QLabel {{ color: {color}; }}") self._update_from_bg_value_and_diff_below_tol(self.diff_checkbox.isChecked()) def _update_from_bg_value_and_diff_below_tol(self, diff_only): @@ -211,19 +259,7 @@ def _update_from_bg_value_and_diff_below_tol(self, diff_only): row_filter = (~self._diff_below_tolerance).any(self.stack_axis) array = array[row_filter] bg_value = bg_value[row_filter] - self.arraywidget.set_data(array, bg_value=bg_value) - - -def align_all(arrays, join='outer', fill_value=la.nan): - if len(arrays) > 2: - raise NotImplementedError("not yet implemented") - first_array = arrays[0] - def is_raw(array): - return all(axis.iswildcard and axis.name is None - for axis in array.axes) - if all(is_raw(array) and array.shape == first_array.shape for array in arrays[1:]): - return arrays - return first_array.align(arrays[1], join=join, fill_value=fill_value) + self.array_widget.set_data(array, attributes={'bg_value': bg_value}) class ArrayComparatorWindow(AbstractEditorWindow): @@ -262,7 +298,7 @@ def __init__(self, data, title='', caller_info=None, parent=None, widget = self.centralWidget() arrays = [la.asarray(array) for array in data - if isinstance(array, DISPLAY_IN_GRID)] + if isinstance(array, CAN_CONVERT_TO_LARRAY)] if names is None: names = [f"Array{i}" for i in range(len(arrays))] @@ -270,9 +306,7 @@ def __init__(self, data, title='', caller_info=None, parent=None, widget.setLayout(layout) comparator_widget = ComparatorWidget(self, bg_gradient=bg_gradient, - rtol=rtol, atol=atol, nans_equal=nans_equal, - align=align, fill_value=fill_value) comparison_options_layout = ( comparator_widget.get_comparison_options_layout(align=align, @@ -339,7 +373,7 @@ def __init__(self, data, title='', caller_info=None, parent=None, self.atol = atol self.rtol = rtol - array_names = sorted(set.union(*[set(s.filter(kind=DISPLAY_IN_GRID).names) for s in self.sessions])) + array_names = sorted(set.union(*[set(s.filter(kind=CAN_CONVERT_TO_LARRAY).names) for s in self.sessions])) self.array_names = array_names listwidget = QListWidget(self) listwidget.addItems(array_names) @@ -354,9 +388,7 @@ def __init__(self, data, title='', caller_info=None, parent=None, left_widget.setLayout(left_layout) comparator_widget = ComparatorWidget(self, bg_gradient=bg_gradient, - rtol=rtol, atol=atol, nans_equal=nans_equal, - align=align, fill_value=fill_value) # do not call set_data on the comparator_widget as it will be done by the setCurrentRow below self.comparator_widget = comparator_widget @@ -381,7 +413,7 @@ def __init__(self, data, title='', caller_info=None, parent=None, main_splitter.addWidget(comparator_widget) main_splitter.setSizes([5, 95]) main_splitter.setCollapsible(1, False) - self.widget_state_settings['main_splitter'] = main_splitter + self.widgets_to_save_to_settings['main_splitter'] = main_splitter main_layout.addLayout(comparison_options_layout) main_layout.addWidget(main_splitter) @@ -391,19 +423,25 @@ def __init__(self, data, title='', caller_info=None, parent=None, def update_listwidget_colors(self): atol, rtol = self.comparator_widget._get_atol_rtol() listwidget = self.listwidget - # TODO: this functionality is super useful but can also be super slow when - # the sessions contain large arrays. It would be great if we + align_method = self.comparator_widget.get_align_method() + fill_value = self.comparator_widget.fill_value + nans_equal = self.comparator_widget.nans_equal + # TODO: this functionality is super useful but can also be super slow + # when the sessions contain large arrays. It would be great if we # could do this asynchronously for i, name in enumerate(self.array_names): - align_method = self.comparator_widget.align_method - fill_value = self.comparator_widget.fill_value arrays = self.get_arrays(name) try: - aligned_arrays = align_all(arrays, join=align_method, fill_value=fill_value) + aligned_arrays = align_arrays(arrays, join=align_method, + fill_value=fill_value) first_array = aligned_arrays[0] - all_equal = all(a.equals(first_array, rtol=rtol, atol=atol, nans_equal=True) + all_equal = all(a.equals(first_array, + rtol=rtol, atol=atol, + nans_equal=nans_equal) for a in aligned_arrays[1:]) except Exception: + # print_exception(e) + all_equal = False item = listwidget.item(i) item.setForeground(Qt.black if all_equal else Qt.red) diff --git a/larray_editor/editor.py b/larray_editor/editor.py index b6e59bd0..b59e49d2 100644 --- a/larray_editor/editor.py +++ b/larray_editor/editor.py @@ -1,4 +1,6 @@ +import importlib import io +import logging import os import re import sys @@ -7,10 +9,9 @@ from pathlib import Path from typing import Union - # Python3.8 switched from a Selector to a Proactor based event loop for asyncio but they do not offer the same # features, which breaks Tornado and all projects depending on it, including Jupyter consoles -# refs: https://github.com/larray-project/larray-editor/issues/208 +# ref: https://github.com/larray-project/larray-editor/issues/208 if sys.platform.startswith("win") and sys.version_info >= (3, 8): import asyncio @@ -27,21 +28,20 @@ # explicitly request Qt backend (fixes #278) matplotlib.use('QtAgg') import matplotlib.axes +import matplotlib.figure + import numpy as np import larray as la -from larray_editor.traceback_tools import StackSummary -from larray_editor.utils import (_, create_action, show_figure, ima, commonpath, DEPENDENCIES, - get_versions, get_documentation_url, URLS, RecentlyUsedList) -from larray_editor.arraywidget import ArrayEditorWidget -from larray_editor.commands import EditSessionArrayCommand, EditCurrentArrayCommand - from qtpy.QtCore import Qt, QUrl, QSettings from qtpy.QtGui import QDesktopServices, QKeySequence from qtpy.QtWidgets import (QMainWindow, QWidget, QListWidget, QListWidgetItem, QSplitter, QFileDialog, QPushButton, - QDialogButtonBox, QShortcut, QVBoxLayout, QGridLayout, QLineEdit, - QCheckBox, QComboBox, QMessageBox, QDialog, QInputDialog, QLabel, QGroupBox, QRadioButton) + QDialogButtonBox, QShortcut, + QHBoxLayout, QVBoxLayout, QGridLayout, QLineEdit, + QCheckBox, QComboBox, QMessageBox, QDialog, + QInputDialog, QLabel, QGroupBox, QRadioButton, + QTabWidget) try: from qtpy.QtWidgets import QUndoStack @@ -50,6 +50,23 @@ # unsure qtpy has been fixed yet (see https://github.com/spyder-ide/qtpy/pull/366 for the fix for QUndoCommand) from qtpy.QtGui import QUndoStack +from larray_editor.traceback_tools import StackSummary +from larray_editor.utils import (_, + create_action, + show_figure, + ima, + commonpath, + DEPENDENCIES, + get_versions, + get_documentation_url, + URLS, + RecentlyUsedList, + logger, list_drives) +from larray_editor.arraywidget import ArrayEditorWidget +from larray_editor import arrayadapter +from larray_editor.commands import EditSessionArrayCommand, EditCurrentArrayCommand +from larray_editor.sql import SQLWidget + try: from qtconsole.rich_jupyter_widget import RichJupyterWidget from qtconsole.inprocess import QtInProcessKernelManager @@ -83,16 +100,13 @@ r'([-+*/%&|^><]|//|\*\*|>>|<<)?' r'=\s*[^=].*') HISTORY_VARS_PATTERN = re.compile(r'_i?\d+') -# XXX: add all scalars except strings (from numpy or plain Python)? -# (long) strings are not handled correctly so should NOT be in this list -# tuple, list -DISPLAY_IN_GRID = (la.Array, np.ndarray) +opened_secondary_windows = [] # TODO: remember its size # like MappingEditor via self.set_window_size_and_geometry() class EditorWindow(QWidget): - default_width = 1000 + default_width = 800 default_height = 600 # This is more or less the minimum space required to display a 1D array minimum_width = 300 @@ -104,7 +118,11 @@ def __init__(self, data, title=None, readonly=False): super().__init__(parent=None) layout = QVBoxLayout() self.setLayout(layout) + header_layout = self.setup_header_layout() + if header_layout is not None: + layout.addLayout(header_layout) array_widget = ArrayEditorWidget(self, data=data, readonly=readonly) + self.array_widget = array_widget layout.addWidget(array_widget) icon = ima.icon('larray') @@ -117,6 +135,47 @@ def __init__(self, data, title=None, readonly=False): # TODO: somehow determine better width self.resize(self.default_width, self.default_height) + def setup_header_layout(self): + return None + + def closeEvent(self, event): + logger.debug('EditorWindow.closeEvent()') + if self in opened_secondary_windows: + opened_secondary_windows.remove(self) + super().closeEvent(event) + self.array_widget.close() + + +class FileExplorerWindow(EditorWindow): + name = "File Explorer" + + def create_drive_button_clicked_callback(self, drive): + def callback(): + path = Path(drive) + if not path.exists(): + msg = f"The {drive} drive is currently unavailable !" + QMessageBox.critical(self, "Error", msg) + return + self.array_widget.set_data(path) + return callback + + def setup_header_layout(self): + drives = list_drives() + if not drives: + return None + + layout = QHBoxLayout() + for drive in drives: + if drive.endswith('\\'): + drive = drive[:-1] + button = QPushButton(drive) + button.clicked.connect( + self.create_drive_button_clicked_callback(drive) + ) + layout.addWidget(button) + layout.addStretch() + return layout + class AbstractEditorWindow(QMainWindow): """Abstract Editor Window""" @@ -145,7 +204,7 @@ def __init__(self, title='', readonly=False, caller_info=None, parent=None): self.edit_undo_stack = QUndoStack(self) self.settings_group_name = self.name.lower().replace(' ', '_') - self.widget_state_settings = {} + self.widgets_to_save_to_settings = {} # set icon icon = ima.icon('larray') @@ -298,30 +357,52 @@ def about(self): message += "" QMessageBox.about(self, _("About LArray Editor"), message.format(**kwargs)) - def _update_title(self, title, array, name): + def _update_title(self, title, value, name): if title is None: title = [] - if array is not None: - dtype = array.dtype.name - # current file (if not None) - if isinstance(array, la.Array): - # array info - shape = [f'{display_name} ({len(axis)})' - for display_name, axis in zip(array.axes.display_names, array.axes)] + if value is not None: + # TODO: the type-specific information added to the title should be + # computed by a method on the adapter + # (self.array_widget.data_adapter) + if hasattr(value, 'dtype'): + try: + dtype_str = f' [{value.dtype.name}]' + except Exception: + dtype_str = '' else: - # if it's not an Array, it must be a Numpy ndarray - assert isinstance(array, np.ndarray) - shape = [str(length) for length in array.shape] - # name + shape + dtype - array_info = ' x '.join(shape) + f' [{dtype}]' - if name: - title += [name + ': ' + array_info] + dtype_str = '' + + if hasattr(value, 'shape'): + def format_int(value: int): + if value >= 10_000: + return f'{value:_}' + else: + return str(value) + + if isinstance(value, la.Array): + shape = [f'{display_name} ({format_int(len(axis))})' + for display_name, axis in zip(value.axes.display_names, value.axes)] + else: + try: + shape = [format_int(length) for length in value.shape] + except Exception: + shape = [] + shape_str = ' x '.join(shape) else: - title += [array_info] + shape_str = '' + + # name + shape + dtype + value_info = shape_str + dtype_str + if name and value_info: + title.append(name + ': ' + value_info) + elif name: + title.append(name) + elif value_info: + title.append(value_info) # extra info - title += [self._title] + title.append(self._title) # set title self.setWindowTitle(' - '.join(title)) @@ -336,8 +417,13 @@ def save_widgets_state_and_geometry(self): settings.beginGroup(self.settings_group_name) settings.setValue('geometry', self.saveGeometry()) settings.setValue('state', self.saveState()) - for widget_name, widget in self.widget_state_settings.items(): - settings.setValue(f'state/{widget_name}', widget.saveState()) + for widget_name, widget in self.widgets_to_save_to_settings.items(): + settings.beginGroup(f'widget/{widget_name}') + if hasattr(widget, 'save_to_settings'): + widget.save_to_settings(settings) + elif hasattr(widget, 'saveState'): + settings.setValue('state', widget.saveState()) + settings.endGroup() settings.endGroup() def restore_widgets_state_and_geometry(self): @@ -349,10 +435,15 @@ def restore_widgets_state_and_geometry(self): state = settings.value('state') if state: self.restoreState(state) - for widget_name, widget in self.widget_state_settings.items(): - state = settings.value(f'state/{widget_name}') - if state: - widget.restoreState(state) + for widget_name, widget in self.widgets_to_save_to_settings.items(): + settings.beginGroup(f'widget/{widget_name}') + if hasattr(widget, 'load_from_settings'): + widget.load_from_settings(settings) + elif hasattr(widget, 'restoreState'): + widget_state = settings.value('state') + if widget_state: + widget.restoreState(widget_state) + settings.endGroup() settings.endGroup() return (geometry is not None) or (state is not None) @@ -365,6 +456,27 @@ def update_title(self): raise NotImplementedError() +def void_formatter(obj, p, cycle): + """ + p: PrettyPrinter + has a .text() method to output text. + cycle: bool + Indicates whether the object is part of a reference cycle. + """ + adapter_creator = arrayadapter.get_adapter_creator(obj) + if isinstance(adapter_creator, str): + # the string is an error message => we cannot handle that object + # => use normal formatting + # we can get in this case if we registered a void_formatter for a type + # (such as Sequence) for which we handle some instances of the type + # but not all + p.text(repr(obj)) + else: + # we already display the object in the grid + # => do not print it in the console + return + + class MappingEditorWindow(AbstractEditorWindow): """Session Editor Dialog""" @@ -373,17 +485,31 @@ class MappingEditorWindow(AbstractEditorWindow): file_menu = True help_menu = True - def __init__(self, data, title='', readonly=False, caller_info=None, parent=None, - stack_pos=None, add_larray_functions=False): + def __init__(self, data, title='', readonly=False, caller_info=None, + parent=None, stack_pos=None, add_larray_functions=False, + python_console=True, sql_console=None): AbstractEditorWindow.__init__(self, title=title, readonly=readonly, caller_info=caller_info, parent=parent) + if sql_console is None: + # This was meant to test whether users actually imported polars + # in their script instead of just testing whether polars is present + # in their environment but, in practice, this currently only does + # the later because: larray_editor unconditionally imports larray + # which imports xlwings when available, which imports polars when + # available. + sql_console = 'polars' in sys.modules + logger.debug("polars module is present, enabling SQL console") + elif sql_console: + if importlib.util.find_spec('polars') is None: + raise RuntimeError("SQL console is not available because " + "the 'polars' module is not available") self.current_file = None self.current_array = None self.current_expr_text = None self.expressions = {} - self.kernel = None + self.ipython_kernel = None self._unsaved_modifications = False # to handle recently opened data/script files @@ -407,65 +533,89 @@ def __init__(self, data, title='', readonly=False, caller_info=None, parent=None del_item_shortcut.activated.connect(self.delete_current_item) self.data = la.Session() - self.arraywidget = ArrayEditorWidget(self, readonly=readonly) - self.arraywidget.dataChanged.connect(self.push_changes) - self.arraywidget.model_data.dataChanged.connect(self.update_title) - - if qtconsole_available: - # silence a warning on Python 3.11 (see issue #263) - if "PYDEVD_DISABLE_FILE_VALIDATION" not in os.environ: - os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" - - # Create an in-process kernel - kernel_manager = QtInProcessKernelManager() - kernel_manager.start_kernel(show_banner=False) - kernel = kernel_manager.kernel - - if add_larray_functions: - kernel.shell.run_cell('from larray import *') - kernel.shell.push({ - '__editor__': self - }) - - text_formatter = kernel.shell.display_formatter.formatters['text/plain'] - - def void_formatter(array, *args, **kwargs): - return '' - - for type_ in DISPLAY_IN_GRID: - text_formatter.for_type(type_, void_formatter) - - self.kernel = kernel - - kernel_client = kernel_manager.client() - kernel_client.start_channels() - - ipython_widget = RichJupyterWidget() - ipython_widget.kernel_manager = kernel_manager - ipython_widget.kernel_client = kernel_client - ipython_widget.executed.connect(self.ipython_cell_executed) - ipython_widget._display_banner = False - - self.eval_box = ipython_widget - self.eval_box.setMinimumHeight(20) - - right_panel_widget = QSplitter(Qt.Vertical) - right_panel_widget.addWidget(self.arraywidget) - right_panel_widget.addWidget(self.eval_box) - right_panel_widget.setSizes([90, 10]) - self.widget_state_settings['right_panel_widget'] = right_panel_widget + self.array_widget = ArrayEditorWidget(self, readonly=readonly) + self.array_widget.dataChanged.connect(self.push_changes) + # FIXME: this is currently broken as it fires for each scroll + # we either need to fix model_data.dataChanged (but that might + # be needed for display) or find another way to add a star to + # the window title *only* when the user actually changed + # something + # self.array_widget.model_data.dataChanged.connect(self.update_title) + + if sql_console: + sql_widget = SQLWidget(self) + self.widgets_to_save_to_settings['sql_console'] = sql_widget else: - self.eval_box = QLineEdit() - self.eval_box.returnPressed.connect(self.line_edit_update) + sql_widget = None + self.sql_widget = sql_widget + if python_console: + if qtconsole_available: + # silence a warning on Python 3.11 (see issue #263) + if "PYDEVD_DISABLE_FILE_VALIDATION" not in os.environ: + os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + + # Create an in-process kernel + kernel_manager = QtInProcessKernelManager() + kernel_manager.start_kernel(show_banner=False) + kernel = kernel_manager.kernel + self.ipython_kernel = kernel + + text_formatter = kernel.shell.display_formatter.formatters['text/plain'] + for type_ in arrayadapter.REGISTERED_ADAPTERS: + text_formatter.for_type(type_, void_formatter) + + kernel.shell.push({ + '__editor__': self + }) + + if add_larray_functions: + kernel.shell.run_cell('from larray import *') + self.ipython_cell_executed() + + kernel_client = kernel_manager.client() + kernel_client.start_channels() + + ipython_widget = RichJupyterWidget() + ipython_widget.kernel_manager = kernel_manager + ipython_widget.kernel_client = kernel_client + ipython_widget.executed.connect(self.ipython_cell_executed) + ipython_widget._display_banner = False + + self.eval_box = ipython_widget + self.eval_box.setMinimumHeight(20) + + right_panel_widget = QSplitter(Qt.Vertical) + right_panel_widget.addWidget(self.array_widget) + if sql_console: + tab_widget = QTabWidget(self) + tab_widget.addTab(self.eval_box, 'Python Console') + tab_widget.addTab(sql_widget, 'SQL Console') + right_panel_widget.addWidget(tab_widget) + else: + right_panel_widget.addWidget(self.eval_box) - right_panel_layout = QVBoxLayout() - right_panel_layout.addWidget(self.arraywidget) - right_panel_layout.addWidget(self.eval_box) + right_panel_widget.setSizes([90, 10]) + self.widgets_to_save_to_settings['right_panel_widget'] = right_panel_widget + else: + # cannot easily use a QTextEdit because it has no returnPressed signal + self.eval_box = QLineEdit() + self.eval_box.returnPressed.connect(self.line_edit_update) + + right_panel_layout = QVBoxLayout() + right_panel_layout.addWidget(self.array_widget) + right_panel_layout.addWidget(self.eval_box) + + # you cant add a layout directly in a splitter, so we have to wrap + # it in a widget + right_panel_widget = QWidget() + right_panel_widget.setLayout(right_panel_layout) + elif sql_console: + right_panel_widget = QSplitter(Qt.Vertical) + right_panel_widget.addWidget(self.array_widget) + right_panel_widget.addWidget(sql_widget) - # you cant add a layout directly in a splitter, so we have to wrap - # it in a widget - right_panel_widget = QWidget() - right_panel_widget.setLayout(right_panel_layout) + right_panel_widget.setSizes([90, 10]) + self.widgets_to_save_to_settings['right_panel_widget'] = right_panel_widget main_splitter = QSplitter(Qt.Horizontal) debug = isinstance(data, StackSummary) @@ -497,7 +647,7 @@ def void_formatter(array, *args, **kwargs): main_splitter.addWidget(right_panel_widget) main_splitter.setSizes([180, 620]) main_splitter.setCollapsible(1, False) - self.widget_state_settings['main_splitter'] = main_splitter + self.widgets_to_save_to_settings['main_splitter'] = main_splitter layout.addWidget(main_splitter) @@ -519,15 +669,59 @@ def void_formatter(array, *args, **kwargs): self._push_data(data) self.set_window_size_and_geometry() - self.windows = [] def _push_data(self, data): self.data = data if isinstance(data, la.Session) else la.Session(data) - if qtconsole_available: - self.kernel.shell.push(dict(self.data.items())) + if self.ipython_kernel is not None: + # Avoid displaying objects we handle in IPython console. + + # Sadly, we cannot do this for all objects we support without + # trying to import all the modules we support (which is clearly not + # desirable), because IPython has 3 limitations. + # 1) Its support for "string types" requires + # specifying the exact submodule a type is at (for example: + # pandas.core.frame.DataFrame instead of pandas.DataFrame). + # I do not think this is a maintainable approach for us (that is + # why the registering adapters using "string types" does not + # require that) so we use real/concrete types instead. + + # 2) It only supports *exact* types, not subclasses, so we cannot + # just register a custom formatter for "object" and be done + # with it. + + # 3) We cannot do this "just in time" by doing it in response + # to either ipython_widget executed or executing signals which + # both happen too late (the value is already displayed by the + # time those signals are fired) + + # The combination of the above limitations mean that types + # imported via the console will NOT use the void_formatter :(. + text_formatter = self.ipython_kernel.shell.display_formatter.formatters['text/plain'] + unique_types = {type(v) for v in self.data.values()} + for obj_type in unique_types: + adapter_creator = arrayadapter.get_adapter_creator_for_type(obj_type) + if adapter_creator is None: + # if None, it means we do not handle that type at all + # => do not touch its ipython formatter + continue + + # Otherwise, we know the type is at least partially handled + # (at least some instances are displayed) so we register our + # void formatter and rely on it to fallback to repr() if + # a particular instance of a type is not handled. + try: + current_formatter = text_formatter.for_type(obj_type) + except KeyError: + current_formatter = None + if current_formatter is not void_formatter: + logger.debug(f"applying void_formatter for {obj_type}") + text_formatter.for_type(obj_type, void_formatter) + self.ipython_kernel.shell.push(dict(self.data.items())) var_names = [k for k, v in self.data.items() if self._display_in_varlist(k, v)] self.add_list_items(var_names) self._listwidget.setCurrentRow(0) + if self.sql_widget is not None: + self.sql_widget.update_completer_options(self.data) def on_stack_frame_changed(self): selected = self._stack_frame_widget.selectedItems() @@ -556,8 +750,8 @@ def _reset(self): self.current_array = None self.current_expr_text = None self.edit_undo_stack.clear() - if qtconsole_available: - self.kernel.shell.reset() + if self.ipython_kernel is not None: + self.ipython_kernel.shell.reset() self.ipython_cell_executed() else: self.eval_box.setText('None') @@ -568,36 +762,56 @@ def _setup_file_menu(self, menu_bar): # ============= # # NEW # # ============= # - file_menu.addAction(create_action(self, _('&New'), shortcut="Ctrl+N", triggered=self.new)) + file_menu.addAction(create_action(self, _('&New'), + shortcut="Ctrl+N", + triggered=self.new)) file_menu.addSeparator() # ============= # # DATA # # ============= # file_menu.addSeparator() - file_menu.addAction(create_action(self, _('&Open Data'), shortcut="Ctrl+O", triggered=self.open_data, - statustip=_('Load session from file'))) - file_menu.addAction(create_action(self, _('&Save Data'), shortcut="Ctrl+S", triggered=self.save_data, - statustip=_('Save all arrays as a session in a file'))) - file_menu.addAction(create_action(self, _('Save Data &As'), triggered=self.save_data_as, - statustip=_('Save all arrays as a session in a file'))) + open_tip = _('Load session from file') + file_menu.addAction(create_action(self, _('&Open Data'), + shortcut="Ctrl+O", + triggered=self.open_data, + statustip=open_tip)) + save_tip = _('Save all arrays as a session in a file') + file_menu.addAction(create_action(self, _('&Save Data'), + shortcut="Ctrl+S", + triggered=self.save_data, + statustip=save_tip)) + file_menu.addAction(create_action(self, _('Save Data &As'), + triggered=self.save_data_as, + statustip=save_tip)) recent_files_menu = file_menu.addMenu("Open &Recent Data") for action in self.recent_data_files.actions: recent_files_menu.addAction(action) recent_files_menu.addSeparator() - recent_files_menu.addAction(create_action(self, _('&Clear List'), triggered=self.recent_data_files.clear)) + recent_files_menu.addAction(create_action(self, _('&Clear List'), + triggered=self.recent_data_files.clear)) # ============= # # EXAMPLES # # ============= # file_menu.addSeparator() - file_menu.addAction(create_action(self, _('&Load Example Dataset'), triggered=self.load_example)) + file_menu.addAction(create_action(self, _('&Load Example Dataset'), + triggered=self.load_example)) + # ============= # + # EXPLORER # + # ============= # + file_menu.addSeparator() + file_menu.addAction(create_action(self, _('Open File &Explorer'), + triggered=self.open_explorer)) # ============= # # SCRIPTS # # ============= # if qtconsole_available: file_menu.addSeparator() - file_menu.addAction(create_action(self, _('&Load from Script'), shortcut="Ctrl+Shift+O", - triggered=self.load_script, statustip=_('Load script from file'))) - file_menu.addAction(create_action(self, _('&Save Command History To Script'), shortcut="Ctrl+Shift+S", + file_menu.addAction(create_action(self, _('&Load from Script'), + shortcut="Ctrl+Shift+O", + triggered=self.load_script, + statustip=_('Load script from file'))) + file_menu.addAction(create_action(self, _('&Save Command History To Script'), + shortcut="Ctrl+Shift+S", triggered=self.save_script, statustip=_('Save command history in a file'))) @@ -605,7 +819,9 @@ def _setup_file_menu(self, menu_bar): # QUIT # # ============= # file_menu.addSeparator() - file_menu.addAction(create_action(self, _('&Quit'), shortcut="Ctrl+Q", triggered=self.close)) + file_menu.addAction(create_action(self, _('&Quit'), + shortcut="Ctrl+Q", + triggered=self.close)) def push_changes(self, changes): self.edit_undo_stack.push(EditSessionArrayCommand(self, self.current_expr_text, changes)) @@ -640,18 +856,19 @@ def display_item_in_new_window(self, list_item): assert isinstance(list_item, QListWidgetItem) varname = str(list_item.text()) value = self.data[varname] - self.new_editor_window(value, varname) + self.new_editor_window(value, title=varname) - def new_editor_window(self, data, title: str, readonly: bool=False): - window = EditorWindow(data, title=title, readonly=readonly) + def new_editor_window(self, data, title: str=None, readonly: bool=False, + cls=EditorWindow): + window = cls(data, title=title, readonly=readonly) window.show() - # FIXME: add some mechanism to remove them from the list on close # this is necessary so that the window does not disappear immediately - self.windows.append(window) + opened_secondary_windows.append(window) def select_list_item(self, to_display): changed_items = self._listwidget.findItems(to_display, Qt.MatchExactly) - assert len(changed_items) == 1 + assert len(changed_items) == 1, \ + f"len(changed_items) should be 1 but is {len(changed_items)}:\n{changed_items!r}" prev_selected = self._listwidget.selectedItems() assert len(prev_selected) <= 1 # if the currently selected item (value) need to be refreshed (e.g it was modified) @@ -694,7 +911,11 @@ def update_mapping_and_varlist(self, value): if len(changed_displayable_keys) > 0 or deleted_displayable_keys: self.unsaved_modifications = True - # 4) return variable to display, if any (if there are more than one, + # 4) update sql completer options if needed + if self.sql_widget is not None and (new_displayable_keys or deleted_displayable_keys): + self.sql_widget.update_completer_options(self.data) + + # 5) return variable to display, if any (if there are more than one, # return first) return changed_displayable_keys[0] if changed_displayable_keys else None @@ -702,8 +923,8 @@ def delete_current_item(self): current_item = self._listwidget.currentItem() name = str(current_item.text()) del self.data[name] - if qtconsole_available: - self.kernel.shell.del_var(name) + if self.ipython_kernel is not None: + self.ipython_kernel.shell.del_var(name) self.unsaved_modifications = True self._listwidget.takeItem(self._listwidget.row(current_item)) @@ -726,13 +947,17 @@ def view_expr(self, array, expr_text): self.set_current_array(array, expr_text) def _display_in_varlist(self, k, v): - return self._display_in_grid(v) and not k.startswith('__') + return (self._display_in_grid(v) and not k.startswith('__') and + # This is ugly (and larray specific) but I did not find an + # easy way to exclude that specific variable. I do not think + # it should be in larray top level namespace anyway. + k != 'EXAMPLE_EXCEL_TEMPLATES_DIR') def _display_in_grid(self, v): - return isinstance(v, DISPLAY_IN_GRID) + return not isinstance(arrayadapter.get_adapter_creator(v), str) def ipython_cell_executed(self): - user_ns = self.kernel.shell.user_ns + user_ns = self.ipython_kernel.shell.user_ns ip_keys = {'In', 'Out', '_', '__', '___', '__builtin__', '_dh', '_ih', '_oh', '_sh', '_i', '_ii', '_iii', 'exit', 'get_ipython', 'quit'} # '__builtins__', '__doc__', '__loader__', '__name__', '__package__', '__spec__', @@ -810,7 +1035,7 @@ def ipython_cell_executed(self): if 'inline' not in matplotlib.get_backend(): figure = self._get_figure(cur_output) if figure is not None: - show_figure(self, figure, title=last_input_last_line) + show_figure(figure, title=last_input_last_line, parent=self) def _get_figure(self, cur_output): if isinstance(cur_output, matplotlib.figure.Figure): @@ -869,10 +1094,24 @@ def update_title(self): self._update_title(title, array, name) def set_current_array(self, array, expr_text): - # we should NOT check that "array is not self.current_array" because this method is also called to - # refresh the widget value because of an inplace setitem + if logger.isEnabledFor(logging.DEBUG): + logger.debug("") + clsname = self.__class__.__name__ + msg = f"{clsname}.set_current_array(<...>, {expr_text!r})" + logger.debug(msg) + logger.debug('=' * len(msg)) + + # we should NOT check that "array is not self.current_array" because + # this method is also called to refresh the widget value because of an + # inplace setitem + + if self.sql_widget is not None: + self.sql_widget.update_completer_options(self.data, selected=array) + # FIXME: we should never store the current_array but current_adapter instead self.current_array = array - self.arraywidget.set_data(array) + array_widget = self.array_widget + array_widget.back_button_bar.clear() + array_widget.set_data(array) self.current_expr_text = expr_text self.update_title() @@ -902,11 +1141,15 @@ def _ask_to_save_if_unsaved_modifications(self): return True def closeEvent(self, event): - # as per the example in the Qt doc (https://doc.qt.io/qt-5/qwidget.html#closeEvent), we should *NOT* call - # the closeEvent() method of the superclass in this case because all it does is "event.accept()" - # unconditionally which results in the application being closed regardless of what the user chooses (see #202). + logger.debug('MappingEditorWindow.closeEvent()') + # as per the example in the Qt doc + # (https://doc.qt.io/qt-5/qwidget.html#closeEvent), we should *NOT* call + # the superclass closeEvent() method in this case because all it does is + # "event.accept()" unconditionally which results in the application + # being closed regardless of what the user chooses (see #202). if self._ask_to_save_if_unsaved_modifications(): self.save_widgets_state_and_geometry() + self.array_widget.close() event.accept() else: event.ignore() @@ -918,7 +1161,7 @@ def closeEvent(self, event): def new(self): if self._ask_to_save_if_unsaved_modifications(): self._reset() - self.arraywidget.set_data(la.empty(0)) + self.array_widget.set_data(la.empty(0)) self.set_current_file(None) self.unsaved_modifications = False self.statusBar().showMessage("Viewer has been reset", 4000) @@ -1037,18 +1280,18 @@ def complete_slice(s): if start == '': start = 1 if stop == '': - stop = self.kernel.shell.execution_count + stop = self.ipython_kernel.shell.execution_count if sep == ':': stop += 1 return f'{start}{sep}{stop}' lines = ' '.join(complete_slice(s) for s in lines.split(' ')) else: - lines = f'1-{self.kernel.shell.execution_count}' + lines = f'1-{self.ipython_kernel.shell.execution_count}' with io.StringIO() as tmp_out: with redirect_stdout(tmp_out): - self.kernel.shell.run_line_magic('save', f'{overwrite} "{filepath}" {lines}') + self.ipython_kernel.shell.run_line_magic('save', f'{overwrite} "{filepath}" {lines}') stdout = tmp_out.getvalue() if 'commands were written to file' not in stdout: raise Exception(stdout) @@ -1060,7 +1303,7 @@ def complete_slice(s): # See http://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-save # for more details def save_script(self): - if self.kernel.shell.execution_count == 1: + if self.ipython_kernel.shell.execution_count == 1: QMessageBox.critical(self, "Error", "Cannot save an empty command history") return @@ -1206,8 +1449,31 @@ def open_recent_file(self): QMessageBox.warning(self, "Warning", f"File {filepath} could not be found") def _save_data(self, filepath): + CAN_BE_SAVED = (la.Array, la.Axis, la.Group) + in_var_list = {k: v for k, v in self.data.items() + if self._display_in_varlist(k, v)} + if not in_var_list: + QMessageBox.warning(self, "Warning", "Nothing to save") + return + + to_save = {k: v for k, v in in_var_list.items() + if isinstance(v, CAN_BE_SAVED)} + if not to_save: + msg = ("Nothing can be saved because " + "all the currently loaded variables " + "are of types which are not supported for saving.") + QMessageBox.warning(self, "Warning: unsavable objects", msg) + return + + unsaveable = in_var_list.keys() - to_save.keys() + if unsaveable: + object_names = ', '.join(sorted(unsaveable)) + QMessageBox.warning(self, "Warning: unsavable objects", + "The following variables are of types which " + "are not supported for saving and will be " + f"ignored:\n\n{object_names}") + session = la.Session(to_save) try: - session = la.Session({k: v for k, v in self.data.items() if self._display_in_varlist(k, v)}) session.save(filepath) self.set_current_file(filepath) self.edit_undo_stack.clear() @@ -1249,6 +1515,9 @@ def load_example(self): filepath = AVAILABLE_EXAMPLE_DATA[dataset_name] self._open_file(filepath) + def open_explorer(self): + self.new_editor_window(Path('.'), cls=FileExplorerWindow) + class ArrayEditorWindow(AbstractEditorWindow): """Array Editor Dialog""" @@ -1272,11 +1541,11 @@ def __init__(self, data, title='', readonly=False, caller_info=None, parent=None widget.setLayout(layout) self.data = data - self.arraywidget = ArrayEditorWidget(self, data, readonly, minvalue=minvalue, maxvalue=maxvalue) - self.arraywidget.dataChanged.connect(self.push_changes) - self.arraywidget.model_data.dataChanged.connect(self.update_title) + self.array_widget = ArrayEditorWidget(self, data, readonly, minvalue=minvalue, maxvalue=maxvalue) + self.array_widget.dataChanged.connect(self.push_changes) + self.array_widget.model_data.dataChanged.connect(self.update_title) self.update_title() - layout.addWidget(self.arraywidget) + layout.addWidget(self.array_widget) self.set_window_size_and_geometry() def update_title(self): diff --git a/larray_editor/sql.py b/larray_editor/sql.py new file mode 100644 index 00000000..8a63c475 --- /dev/null +++ b/larray_editor/sql.py @@ -0,0 +1,310 @@ +import re + +from qtpy.QtCore import Qt, QStringListModel +from qtpy.QtGui import QTextCursor +from qtpy.QtWidgets import QTextEdit, QCompleter + +from larray_editor.utils import _, logger + +MAX_SQL_QUERIES = 1000 +SQL_CREATE_TABLE_PATTERN = re.compile(r'CREATE\s+TABLE\s+([\w_]+)\s+', + flags=re.IGNORECASE) +SQL_DROP_TABLE_PATTERN = re.compile(r'DROP\s+TABLE\s+([\w_]+)', + flags=re.IGNORECASE) + + +class SQLWidget(QTextEdit): + SQL_KEYWORDS = [ + "SELECT", "FROM", "WHERE", "INSERT", "UPDATE", "DELETE", + "JOIN", "LEFT", "RIGHT", "INNER", "OUTER", + "GROUP", "BY", "ORDER", "HAVING", + "AS", "ON", "IN", "AND", "OR", "NOT", "NULL", "IS", + "DISTINCT", "LIMIT", "OFFSET", "UNION", "ALL", + "CREATE", "TABLE", "DROP", "ALTER", "ADD", "INDEX", "PRIMARY", + "KEY", "FOREIGN", "VALUES", "SET", + "CASE", "WHEN", "THEN", "ELSE", "END" + ] + SQL_KEYWORDS_SET = set(SQL_KEYWORDS) + + def __init__(self, editor_window): + import polars as pl + + # avoid a circular module dependency by having the import here + from larray_editor.editor import MappingEditorWindow + assert isinstance(editor_window, MappingEditorWindow) + super().__init__() + self.editor_window = editor_window + + msg = _("""Enter an SQL query here and press SHIFT+ENTER to execute it. + +You can use any Polars object in the FROM clause. + +Use the UP/DOWN arrow keys to navigate through queries you typed previously \ +(including during previous sessions). +It will only display past queries which start with the text already typed so \ +far (the part before the cursor) so that one can more easily search for \ +specific queries. + +SQL keywords, names of variables usable as table and column names (once the \ +FROM clause is known) can be autocompleted by using TAB. + +The currently displayed table may be called 'self' (in addition to its real \ +name). +""") + self.setPlaceholderText(msg) + self.setAcceptRichText(False) + font = self.font() + font.setFamily('Calibri') + font.setPointSize(11) + self.setFont(font) + + self.history = [] + self.history_index = 0 + + self.completer = QCompleter([], self) + self.completer.setCaseSensitivity(Qt.CaseSensitivity.CaseInsensitive) + self.completer.setWidget(self) + self.completer.activated.connect(self.insert_completion) + self.sql_context = pl.SQLContext(eager=False) + self.update_completer_options({}) + + def update_completer_options(self, data=None, selected=None): + if data is not None: + data = self._filter_data_for_sql(data) + if selected is not None and self._handled_by_polars_sql(selected): + data['self'] = selected + self.data = data + self.sql_context.register_many(data) + else: + data = self.data + + if 'self' in data: + table_names_to_fetch_columns = ['self'] + else: + table_names_to_fetch_columns = [] + + table_names = [k for k, v in data.items()] + + # extract table names from the current FROM clause + query_text = self.toPlainText() + m = re.search(r'\s+FROM\s+(\S+)', query_text, re.IGNORECASE) + if m: + after_from = m.group(1) + # try any identifier found in the query after the FROM keyword + # there will probably be false positives if a column has the same + # name as another table but that should be rare + from_tables = [word for word in after_from.split() + if word not in self.SQL_KEYWORDS_SET and word in data] + if from_tables: + table_names_to_fetch_columns = from_tables + + # add column names from all the used tables or self, if present + col_names_set = set() + for table_name in table_names_to_fetch_columns: + col_names_set.update(set(data[table_name].collect_schema().names())) + col_names = sorted(col_names_set) + + logger.debug(f"available columns for SQL queries: {col_names}") + logger.debug(f"available tables for SQL queries: {table_names}") + completions = col_names + table_names + self.SQL_KEYWORDS + model = QStringListModel(completions, self.completer) + self.completer.setModel(model) + + def _filter_data_for_sql(self, data): + return {k: v for k, v in data.items() + if self._handled_by_polars_sql(v)} + + def _handled_by_polars_sql(self, obj): + import polars as pl + SUPPORTED_TYPES = (pl.DataFrame, pl.LazyFrame, pl.Series) + # We purposefully do not support pandas and pyarrow objects here, even if + # polars SQL can sort of handle them, because Polars does that by + # converting the object to their Polars counterpart first and that + # can be slow (e.g. >1s for pd_df_big) + + # if 'pandas' in sys.modules: + # import pandas as pd + # SUPPORTED_TYPES += (pd.DataFrame, pd.Series) + # if 'pyarrow' in sys.modules: + # import pyarrow as pa + # SUPPORTED_TYPES += (pa.Table, pa.RecordBatch) + return isinstance(obj, SUPPORTED_TYPES) + + def insert_completion(self, completion): + cursor = self.textCursor() + cursor.select(QTextCursor.WordUnderCursor) + cursor.removeSelectedText() + # Insert a space if the cursor is at the end of the text + at_end = cursor.position() == len(self.toPlainText()) + cursor.insertText(completion + (' ' if at_end else '')) + self.setTextCursor(cursor) + self.update_completer_options() + + def keyPressEvent(self, event): + completer_popup = self.completer.popup() + if completer_popup.isVisible(): + if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter, Qt.Key.Key_Tab): + # Insert the currently highlighted completion + current_index = completer_popup.currentIndex() + if not current_index.isValid(): + # Default to the first item if none is highlighted + current_index = completer_popup.model().index(0, 0) + completion = current_index.data() + self.insert_completion(completion) + completer_popup.hide() + return + elif event.key() == Qt.Key.Key_Escape: + completer_popup.hide() + return + + if (event.key() in (Qt.Key.Key_Enter, Qt.Key.Key_Return) and + event.modifiers() & Qt.KeyboardModifier.ShiftModifier): + query_text = self.toPlainText().strip() + if query_text: + self.append_to_history(query_text) + self.execute_sql(query_text) + return + elif event.key() == Qt.Key.Key_Tab: + prefix = self.get_word_prefix() + self.completer.setCompletionPrefix(prefix) + if self.completer.completionCount() == 1: + completion = self.completer.currentCompletion() + self.insert_completion(completion) + return + else: + self.show_autocomplete_popup() + return + + cursor = self.textCursor() + # for plaintext QTextEdit, blockNumber gives the line number + line_num = cursor.blockNumber() + if event.key() == Qt.Key.Key_Up: + if line_num == 0: + if self.search_and_recall_history(direction=-1): + return + elif event.key() == Qt.Key.Key_Down: + total_lines = self.document().blockCount() + if line_num == total_lines - 1: + if self.history_index < len(self.history) - 1: + if self.search_and_recall_history(direction=1): + return + else: + self.history_index = len(self.history) + self.clear() + return + super().keyPressEvent(event) + self.update_completer_options() + # we need to compute the prefix *after* the keypress event has been + # handled so that the prefix contains the last keystroke + prefix = self.get_word_prefix() + if prefix: + self.completer.setCompletionPrefix(prefix) + num_completion = self.completer.completionCount() + # we must show the popup even if it is already visible, because + # the number of completions might have changed and the popup size + # needs to be updated + if 0 < num_completion <= self.completer.maxVisibleItems(): + self.show_autocomplete_popup() + elif num_completion == 0 and completer_popup.isVisible(): + completer_popup.hide() + + def show_autocomplete_popup(self): + # create a new cursor and move it to the start of the word, so that + # we can position the popup correctly + word_start_cursor = self.textCursor() + word_start_cursor.movePosition(QTextCursor.StartOfWord) + rect = self.cursorRect(word_start_cursor) + completer_popup = self.completer.popup() + popup_scrollbar = completer_popup.verticalScrollBar() + popup_scrollbar_width = popup_scrollbar.sizeHint().width() + 10 + rect.setWidth(completer_popup.sizeHintForColumn(0) + + popup_scrollbar_width) + self.completer.complete(rect) + + def get_word_prefix(self): + text = self.toPlainText() + if not text: + return '' + cursor = self.textCursor() + cursor_pos = cursor.position() + # <= len(text) (instead of <) because cursor can be at the end + assert 0 <= cursor_pos <= len(text), f"{cursor_pos=} {len(text)=}" + word_start = cursor_pos + while (word_start > 0 and + text[word_start - 1].isalnum() or text[word_start - 1] == '_'): + word_start -= 1 + return text[word_start:cursor_pos] + + def search_and_recall_history(self, direction: int): + if not self.history: + return False + cursor = self.textCursor() + query_text = self.toPlainText() + cursor_pos = cursor.position() + prefix = query_text[:cursor_pos] + index = self.history_index + direction + while 0 <= index <= len(self.history) - 1: + if self.history[index].startswith(prefix): + self.history_index = index + self.setPlainText(self.history[self.history_index]) + self.update_completer_options() + cursor.setPosition(cursor_pos) + self.setTextCursor(cursor) + return True + index += direction + # no matching prefix found, do not change history_index + return False + + def append_to_history(self, sql_text): + history = self.history + if not history or history[-1] != sql_text: + history.append(sql_text) + if len(history) > MAX_SQL_QUERIES: + # keep the last N entries + history = history[-MAX_SQL_QUERIES:] + self.history_index = len(history) + + def _fetch_table(self, table_name): + return self.sql_context.execute(f"SELECT * FROM {table_name}", + eager=False) + + def execute_sql(self, sql_text: str): + """Execute SQL query and display result""" + editor_window = self.editor_window + sql_context = self.sql_context + logger.debug(f"Executing SQL query:\n{sql_text}") + result = sql_context.execute(sql_text, eager=False) + + # To determine whether we have added or dropped tables, comparing the + # resulting SQL context to what we had before would be more reliable + # than this regex-based solution but is not currently possible using + # Polars public API + new_table_name = None + m = SQL_CREATE_TABLE_PATTERN.match(sql_text) + if m is not None: + new_table_name = m.group(1) + dropped_table_name = None + m = SQL_DROP_TABLE_PATTERN.match(sql_text) + if m is not None: + dropped_table_name = m.group(1) + if new_table_name or dropped_table_name: + # data might be a Session, make sure we have a dict copy + new_data = dict(editor_window.data.items()) + if new_table_name: + new_data[new_table_name] = self._fetch_table(new_table_name) + logger.debug(f'added table {new_table_name} to session') + if dropped_table_name: + del new_data[dropped_table_name] + logger.debug(f'dropped table {dropped_table_name} from session') + editor_window.update_mapping_and_varlist(new_data) + if new_table_name: + editor_window.select_list_item(new_table_name) + elif not dropped_table_name: + editor_window.array_widget.set_data(result) + + def save_to_settings(self, settings): + settings.setValue('queries', self.history) + + def load_from_settings(self, settings): + self.history = settings.value('queries', [], type=list) + self.history_index = len(self.history) diff --git a/larray_editor/start.py b/larray_editor/start.py index 8a97d4e8..c74e4ac8 100644 --- a/larray_editor/start.py +++ b/larray_editor/start.py @@ -4,19 +4,29 @@ from larray_editor.api import _show_dialog, create_edit_dialog -def call_edit(): - _show_dialog("Viewer", create_edit_dialog, *sys.argv[1:], display_caller_info=False, add_larray_functions=True) +def call_edit(obj): + # we do not use edit() so that we can have display_caller_info=False + _show_dialog("Viewer", create_edit_dialog, obj=obj, + display_caller_info=False, add_larray_functions=True) def main(): + args = sys.argv[1:] + if len(args) > 1: + print(f"Usage: {sys.argv[0]} [file_path]") + sys.exit() + elif len(args) == 1: + obj = args[0] + else: + obj = {} if os.name == 'nt': stderr_path = os.path.join(os.getenv("TEMP"), "stderr-" + os.path.basename(sys.argv[0])) with open(os.devnull, "w") as out, open(stderr_path, "w") as err: sys.stdout = out sys.stderr = err - call_edit() + call_edit(obj) else: - call_edit() + call_edit(obj) if __name__ == '__main__': diff --git a/larray_editor/tests/data/test.xlsx b/larray_editor/tests/data/test.xlsx new file mode 100644 index 00000000..c5488c86 Binary files /dev/null and b/larray_editor/tests/data/test.xlsx differ diff --git a/larray_editor/tests/data/test.zip b/larray_editor/tests/data/test.zip new file mode 100644 index 00000000..690d70a3 Binary files /dev/null and b/larray_editor/tests/data/test.zip differ diff --git a/larray_editor/tests/test_adapter.py b/larray_editor/tests/test_adapter.py index 5c173caa..2e86d561 100644 --- a/larray_editor/tests/test_adapter.py +++ b/larray_editor/tests/test_adapter.py @@ -1,5 +1,4 @@ import pytest - import larray as la diff --git a/larray_editor/tests/test_api_larray.py b/larray_editor/tests/test_api_larray.py index 13b5b442..d5c920ca 100644 --- a/larray_editor/tests/test_api_larray.py +++ b/larray_editor/tests/test_api_larray.py @@ -1,96 +1,117 @@ """Array editor test""" +import importlib +import array import logging -# from pathlib import Path +import sys +import zipfile +from collections import OrderedDict, namedtuple +import sqlite3 +from pathlib import Path import numpy as np +import qtpy import larray as la +import pandas as pd -# from larray_editor.api import edit -from larray_editor.api import view, edit, debug, compare +from larray_editor.api import edit +# from larray_editor.api import view, edit, debug, compare from larray_editor.utils import logger -import qtpy - -print(f"Using {qtpy.API_NAME} as Qt API") - +# Configure logging to output messages to the console +logging.basicConfig( + # Show warnings and above for all loggers + level=logging.WARNING, + format="%(levelname)s:%(name)s:%(message)s", + stream=sys.stdout +) +# Set our own logger to DEBUG logger.setLevel(logging.DEBUG) - -lipro = la.Axis('lipro=P01..P15') -age = la.Axis('age=0..115') -sex = la.Axis('sex=M,F') - -vla = 'A11,A12,A13,A23,A24,A31,A32,A33,A34,A35,A36,A37,A38,A41,A42,A43,A44,A45,A46,A71,A72,A73' -wal = 'A25,A51,A52,A53,A54,A55,A56,A57,A61,A62,A63,A64,A65,A81,A82,A83,A84,A85,A91,A92,A93' -bru = 'A21' -# list of strings -belgium = la.union(vla, wal, bru) - -geo = la.Axis(belgium, 'geo') - -# arr1 = la.ndtest((sex, lipro)) -# edit(arr1) - -# data2 = np.arange(116 * 44 * 2 * 15).reshape(116, 44, 2, 15) \ -# .astype(float) -# data2 = np.random.random(116 * 44 * 2 * 15).reshape(116, 44, 2, 15) \ -# .astype(float) -# data2 = (np.random.randint(10, size=(116, 44, 2, 15)) - 5) / 17 -# data2 = np.random.randint(10, size=(116, 44, 2, 15)) / 100 + 1567 -# data2 = np.random.normal(51000000, 10000000, size=(116, 44, 2, 15)) -arr2 = la.random.normal(axes=(age, geo, sex, lipro)) -# arr2 = la.ndrange([100, 100, 100, 100, 5]) -# arr2 = arr2['F', 'A11', 1] - -# view(arr2[0, 'A11', 'F', 'P01']) -# view(arr1) -# view(arr2[0, 'A11']) -# edit(arr1) -# print(arr2[0, 'A11', :, 'P01']) -# edit(arr2.astype(int), minvalue=-99, maxvalue=55.123456) -# edit(arr2.astype(int), minvalue=-99) -# arr2.i[0, 0, 0, 0] = np.inf -# arr2.i[0, 0, 1, 1] = -np.inf -# arr2 = [0.0000111, 0.0000222] -# arr2 = [0.00001, 0.00002] -# edit(arr2, minvalue=-99, maxvalue=25.123456) -# print(arr2[0, 'A11', :, 'P01']) - -# arr2 = la.random.normal(0, 10, axes="d0=0..4999;d1=0..19") -# edit(arr2) - -# view(['a', 'bb', 5599]) -# view(np.arange(12).reshape(2, 3, 2)) -# view([]) - -data3 = np.random.normal(0, 1, size=(2, 15)) -# FIXME: the new align code makes this fail -# arr3 = la.ndtest((30, sex)) -arr3 = la.ndtest((age, sex)) -# data4 = np.random.normal(0, 1, size=(2, 15)) -# arr4 = la.Array(data4, axes=(sex, lipro)) - -# arr4 = arr3.copy() -# arr4['F'] /= 2 -arr4 = arr3.min(sex) -arr5 = arr3.max(sex) -arr6 = arr3.mean(sex) - -# test isssue #35 -arr7 = la.from_lists([['a', 1, 2, 3], - [ '', 1664780726569649730, -9196963249083393206, -7664327348053294350]]) - - -def make_circle(width=20, radius=9): - x, y = la.Axis(width, 'x'), la.Axis(width, 'y') - center = (width - 1) / 2 - return la.maximum(radius - la.sqrt((x - center) ** 2 + (y - center) ** 2), 0) - - -def make_sphere(width=20, radius=9): - x, y, z = la.Axis(width, 'x'), la.Axis(width, 'y'), la.Axis(width, 'z') - center = (width - 1) / 2 - return la.maximum(radius - la.sqrt((x - center) ** 2 + (y - center) ** 2 + (z - center) ** 2), 0) +logger.info(f"Using {qtpy.API_NAME} as Qt API") + +# array objects +array_double = array.array('d', [1.0, 2.0, 3.14]) +array_signed_int = array.array('l', [1, 2, 3, 4, 5]) +array_signed_int_empty = array.array('l') +# should show as hello alpha and omega +array_unicode = array.array('w', 'hello \u03B1 and \u03C9') + +# list +list_empty = [] +list_int = [2, 5, 7, 3] +list_mixed = ['abc', 1.1, True, 1.0, 42, [1, 2]] +list_seq_mixed = [[1], [2, 3, 4], [5, 6]] +list_seq_regular = [[1, 2], [3, 4], [5, 6]] +list_unicode = ["\N{grinning face}", "\N{winking face}"] +list_mixed_tuples = [ + ("C", 1972), + ("Fortran", 1957), + ("Python", 1991), + ("Go", 2009), +] + +# tuple +tuple_empty = () +tuple_int = (2, 5, 7, 3) +tuple_mixed = ('abc', 1.1, True, 1.0, 42, (1, 2)) +tuple_seq_mixed = ((1,), (2, 3, 4), (5, 6)) +tuple_seq_regular = ((1, 2), (3, 4), (5, 6)) + +# named tuple +PersonNamedTuple = namedtuple('Person', ['name', 'age', 'male', 'height']) +namedtuple1 = PersonNamedTuple("name1", age=42, male=True, height=1.80) +namedtuple2 = PersonNamedTuple("name2", age=41, male=False, height=1.76) + +# set +set_int = {2, 4, 7, 3} +set_int_big = set(range(10 ** 7)) +set_mixed = {2, "hello", 7.0, True} +set_str = {"a", "b", "c", "d"} + +# dict +dict_str_int = {"a": 2, "b": 5, "c": 7, "d": 3} +dict_int_int = {0: 2, 2: 4, 5: 7, 1: 3} +dict_int_str = {0: "a", 2: "b", 5: "c", 1: "d"} +dict_str_mixed = {"a": 2, "b": "hello", "c": 7.0, "d": True} + +# dict views +dictview_keys = dict_str_mixed.keys() +dictview_items = dict_str_mixed.items() +dictview_values = dict_str_mixed.values() + +# OrderedDict +odict_int_int = OrderedDict(dict_int_int) +odict_int_str = OrderedDict(dict_int_str) +odict_str_int = OrderedDict(dict_str_int) +odict_str_mixed = OrderedDict(dict_str_mixed) + +# numpy arrays +np_arr0d = np.full((), 42, dtype=float) +np_arr1d = np.random.normal(0, 1, size=100) +np_arr1d_empty = np.random.normal(0, 1, size=0) +np_arr2d = np.random.normal(0, 1, size=(100, 100)) +np_arr2d_0col = np.random.normal(0, 1, size=(10, 0)) +np_arr2d_0row = np.random.normal(0, 1, size=(0, 10)) +np_arr3d = np.random.normal(0, 1, size=(10, 10, 10)) +np_big1d = np.arange(1000 * 1000 * 500) +np_big3d = np_big1d.reshape((1000, 1000, 500)) +np_dtype = np.dtype([('name', ' 0.35.0 + # and require larray >= 0.35 + # This function is necessary *only* for larray-editor version + # 0.35.ZERO. Because of the incorporation of the larray-editor changelog in + # the larray release, we cannot depend on larray >= 0.35 when releasing + # larray-editor 0.35.0 (which is very silly because we develop both in + # parallel) + def align_arrays(arrays, join='outer', fill_value=np.nan): + if len(arrays) > 2: + raise NotImplementedError("aligning more than two arrays requires " + "larray >= 0.35") + first_array = arrays[0] + + def is_raw(array): + return all(axis.iswildcard and axis.name is None + for axis in array.axes) + + if all(is_raw(array) and array.shape == first_array.shape + for array in arrays[1:]): + return arrays + return first_array.align(arrays[1], join=join, fill_value=fill_value) + +# field is field_name + conversion if any +M_SPECIFIER_PATTERN = re.compile(r'\{(?P[^:}]*):(?P[^m}]*)m\}') +ICON_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'images') + logger = logging.getLogger("editor") @@ -157,14 +189,14 @@ def _get_font(family, size, bold=False, italic=False): return font -def is_float(dtype): +def is_float_dtype(dtype): """Return True if datatype dtype is a float kind""" return ('float' in dtype.name) or dtype.name in ['single', 'double'] -def is_number(dtype): +def is_number_dtype(dtype): """Return True is datatype dtype is a number kind""" - return is_float(dtype) or ('int' in dtype.name) or ('long' in dtype.name) or ('short' in dtype.name) + return is_float_dtype(dtype) or ('int' in dtype.name) or ('long' in dtype.name) or ('short' in dtype.name) # When we drop support for Python3.9, we can use traceback.print_exception @@ -192,17 +224,24 @@ def keybinding(attr): return QKeySequence.keyBindings(ks)[0] -def create_action(parent, text, icon=None, triggered=None, shortcut=None, statustip=None): +def create_action(parent, text, icon=None, triggered=None, shortcut=None, statustip=None, + checkable=False, checked=False): """Create a QAction""" action = QAction(text, parent) if triggered is not None: action.triggered.connect(triggered) if icon is not None: - action.setIcon(icon) + action.setIcon(ima.icon(icon)) if shortcut is not None: action.setShortcut(shortcut) if statustip is not None: action.setStatusTip(statustip) + if checked: + assert checkable + if checkable: + action.setCheckable(True) + if checked: + action.setChecked(True) # action.setShortcutContext(Qt.WidgetShortcut) return action @@ -226,17 +265,17 @@ def get_idx_rect(index_list): class IconManager: _icons = {'larray': 'larray.ico'} - _icon_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'images') def icon(self, ref): if ref in self._icons: - icon_path = os.path.join(self._icon_dir, self._icons[ref]) + icon_path = os.path.join(ICON_DIR, self._icons[ref]) + assert os.path.exists(icon_path) return QIcon(icon_path) else: - # By default, only X11 will support themed icons. In order to use - # themed icons on Mac and Windows, you will have to bundle a compliant - # theme in one of your PySide.QtGui.QIcon.themeSearchPaths() and set the - # appropriate PySide.QtGui.QIcon.themeName() . + # By default, only X11 supports themed icons. In order to use + # themed icons on Mac and Windows, we need to bundle a + # compliant theme in one of the QtGui.QIcon.themeSearchPaths() + # directories and set QtGui.QIcon.themeName() accordingly. return QIcon.fromTheme(ref) @@ -253,8 +292,8 @@ class LinearGradient: Parameters ---------- stop_points: list/tuple, optional - List containing pairs (stop_position, colors_HsvF). - `colors` is a 4 elements list containing `hue`, `saturation`, `value` and `alpha-channel` + List of (stop_position, color) pairs. + `color` is a 4 elements list containing `hue`, `saturation`, `value` and `alpha-channel` """ def __init__(self, stop_points=None, nan_color=None): if stop_points is None: @@ -338,7 +377,8 @@ def __getitem__(self, key): key_isnan = np.isnan(key)[..., np.newaxis] color = color0 + (color1 - color0) * normalized_value[..., np.newaxis] color = np.where(key_isnan, self.nan_color.getHsvF(), color) - return from_hsvf(color[..., 0], color[..., 1], color[..., 2], color[..., 3]) + return from_hsvf(color[..., 0], color[..., 1], color[..., 2], + color[..., 3]) class PlotDialog(QDialog): @@ -355,7 +395,7 @@ def __init__(self, canvas, parent=None): canvas.draw() -def show_figure(parent, figure, title=None): +def show_figure(figure, title=None, parent=None): if (figure.canvas is not None and figure.canvas.manager is not None and figure.canvas.manager.window is not None): figure.canvas.draw() @@ -369,176 +409,6 @@ def show_figure(parent, figure, title=None): window.show() -class Axis: - """ - Represents an Axis. - - Parameters - ---------- - id : str or int - Id of axis. - name : str - Name of the axis. Can be None. - labels : list or tuple or 1D array - List of labels - """ - def __init__(self, id, name, labels): - self.id = id - self.name = name - self.labels = labels - - @property - def id(self): - return self._id - - @id.setter - def id(self, id): - if not isinstance(id, (str, int)): - raise TypeError("id must a string or a integer") - self._id = id - - @property - def name(self): - return self._name - - @name.setter - def name(self, name): - if not isinstance(name, str): - raise TypeError("name must be a string") - self._name = name - - @property - def labels(self): - return self._labels - - @labels.setter - def labels(self, labels): - if not (hasattr(labels, '__len__') and hasattr(labels, '__getitem__')): - raise TypeError("labels must be a list or tuple or any 1D array-like") - self._labels = labels - - def __len__(self): - return len(self.labels) - - def __str__(self): - return f'Axis({self.id}, {self.name}, {self.labels})' - - -class _LazyLabels(object): - def __init__(self, arrays): - self.prod = Product(arrays) - - def __getitem__(self, key): - return ' '.join(self.prod[key]) - - def __len__(self): - return len(self.prod) - - -class _LazyDimLabels(object): - """ - Examples - -------- - >>> p = Product([['a', 'b', 'c'], [1, 2]]) - >>> list(p) - [('a', 1), ('a', 2), ('b', 1), ('b', 2), ('c', 1), ('c', 2)] - >>> l0 = _LazyDimLabels(p, 0) - >>> l1 = _LazyDimLabels(p, 1) - >>> for i in range(len(p)): - ... print(l0[i], l1[i]) - a 1 - a 2 - b 1 - b 2 - c 1 - c 2 - >>> l0[1:4] - ['a', 'b', 'b'] - >>> l1[1:4] - [2, 1, 2] - >>> list(l0) - ['a', 'a', 'b', 'b', 'c', 'c'] - >>> list(l1) - [1, 2, 1, 2, 1, 2] - """ - def __init__(self, prod, i): - self.prod = prod - self.i = i - - def __iter__(self): - return iter(self.prod[i][self.i] for i in range(len(self.prod))) - - def __getitem__(self, key): - key_prod = self.prod[key] - if isinstance(key, slice): - return [p[self.i] for p in key_prod] - else: - return key_prod[self.i] - - def __len__(self): - return len(self.prod) - - -class _LazyRange(object): - def __init__(self, length, offset): - self.length = length - self.offset = offset - - def __getitem__(self, key): - if key >= self.offset: - return key - self.offset - else: - return '' - - def __len__(self): - return self.length + self.offset - - -class _LazyNone(object): - def __init__(self, length): - self.length = length - - def __getitem__(self, key): - return ' ' - - def __len__(self): - return self.length - - -def replace_inf(value): - """Replace -inf/+inf in array with respectively min(array_without_inf)/max(array_without_inf). - - It leaves nans intact. - - Parameters - ---------- - value : np.ndarray or any compatible type - Input array. - - Returns - ------- - (np.ndarray, float, float) - array with infinite values replaced by the min and maximum respectively - minimum finite value - maximum finite value - - Examples - -------- - >>> replace_inf(np.array([-5, np.inf, 0, -np.inf, -4, np.nan, 5])) - (array([ -5., 5., 0., -5., -4., nan, 5.]), -5.0, 5.0) - """ - value = value.copy() - # replace -inf by min(value) - isneginf = value == -np.inf - minvalue = np.nanmin(value[~isneginf]) - value[isneginf] = minvalue - # replace +inf by max(value) - isposinf = value == np.inf - maxvalue = np.nanmax(value[~isposinf]) - value[isposinf] = maxvalue - return value, minvalue, maxvalue - - def scale_to_01range(value, vmin, vmax): """Scale value to 0-1 range based on vmin and vmax. @@ -577,15 +447,13 @@ def scale_to_01range(value, vmin, vmax): array([ 0. , 1. , 0.5, 0. , 0.1, 1. ]) """ if hasattr(value, 'shape') and value.shape: - if np.isnan(vmin) or np.isnan(vmax) or (vmin == vmax): - return np.where(np.isnan(value), np.nan, 0) - else: - assert vmin < vmax, f"vmin ({vmin}) < vmax ({vmax})" - with np.errstate(divide='ignore', invalid='ignore'): - res = (value - vmin) / (vmax - vmin) - res[value == -np.inf] = 0 - res[value == +np.inf] = 1 - return res + with np.errstate(divide='ignore', invalid='ignore'): + res = np.where(np.isnan(vmin) | np.isnan(vmax) | (vmin == vmax), + np.where(np.isnan(value), np.nan, 0), + (value - vmin) / (vmax - vmin)) + res[value == -np.inf] = 0 + res[value == +np.inf] = 1 + return res else: if np.isnan(value): return np.nan @@ -600,7 +468,11 @@ def scale_to_01range(value, vmin, vmax): return (value - vmin) / (vmax - vmin) -is_number_value = np.vectorize(lambda x: isinstance(x, (int, float, np.number))) +def is_number_value(v): + return isinstance(v, (int, float, np.number)) + + +is_number_value_vectorized = np.vectorize(is_number_value, otypes=[bool]) def get_sample_step(data, maxsize): @@ -620,7 +492,7 @@ def get_sample(data, maxsize): Parameters ---------- - data + data : array-like maxsize Returns @@ -647,7 +519,8 @@ def __init__(self, list_name, parent_action=None, triggered=None): if self.settings.value(list_name) is None: self.settings.setValue(list_name, []) if parent_action is not None: - actions = [QAction(parent_action) for _ in range(self.MAX_RECENT_FILES)] + actions = [QAction(parent_action) + for _ in range(self.MAX_RECENT_FILES)] for action in actions: action.setVisible(False) if triggered is not None: @@ -698,7 +571,8 @@ def _update_actions(self): action.setStatusTip(filepath) action.setData(filepath) action.setVisible(True) - # if we have less recent files than actions, hide the remaining actions + # if we have less recent files than actions, hide the remaining + # actions for action in self.actions[len(recent_files):]: action.setVisible(False) @@ -707,7 +581,8 @@ def cached_property(must_invalidate_cache_method): """A decorator to cache class properties.""" def getter_decorator(original_getter): def caching_getter(self): - if must_invalidate_cache_method(self) or not hasattr(self, '_cached_property_values'): + if (must_invalidate_cache_method(self) or + not hasattr(self, '_cached_property_values')): self._cached_property_values = {} try: # cache hit @@ -721,6 +596,37 @@ def caching_getter(self): return getter_decorator +def broadcast_get(seq, row, col): + # allow "broadcasting" (length one sequences are valid) in either direction + if isinstance(seq, (tuple, list, np.ndarray, Product)): + # FIXME: does not handle len(seq) == 0 nicely but I am unsure this should be fixed here + if len(seq) == 0: + print("pouet") + return None + elif len(seq) == 1: + row_data = seq[0] + else: + row_data = seq[row] + + if isinstance(row_data, (tuple, list, np.ndarray, Product)): + # FIXME: does not handle len(row_data) == 0 nicely but I am unsure this should be fixed here + if len(row_data) == 0: + print("yada") + return None + elif len(row_data) == 1: + return row_data[0] + else: + return row_data[col] + # try: + # return row_data[0] if len(row_data) == 1 else row_data[col] + # except IndexError: + # raise IndexError(f"list index {col} if out of range for list of length {len(row_data)}") + else: + return row_data + else: + return seq + + # The following two functions (_allow_interrupt and _allow_interrupt_qt) are # copied from matplotlib code, because they are not part of their public API # and relying on them would require us to pin the matplotlib version, which @@ -827,3 +733,209 @@ def handle_sigint(): PY312 = sys.version_info >= (3, 12) + + +def data_frac_digits(data: np.ndarray, max_frac_digits: int = 99): + """ + Determine the minimum number of fractional digits needed to represent the + data array accurately. + + Parameters + ---------- + data : np.ndarray + Input array of numeric values. + max_frac_digits : int + Maximum number of fractional digits to consider. Must be >= 0. + + Returns + ------- + int + + Examples + -------- + >>> import numpy as np + >>> data_frac_digits(np.array([1, 2, 3])) + 0 + >>> data_frac_digits(np.array([1.0, 2.0, 3.0])) + 0 + >>> data_frac_digits(np.array([1.5, 2.7, 3.1])) + 1 + >>> data_frac_digits(np.array([1.2, 2.751, 3.1])) + 3 + >>> data_frac_digits(np.array([1.1234567])) + 7 + >>> data_frac_digits(np.array([])) + 0 + >>> data_frac_digits(np.array([1.0000001])) + 7 + """ + assert isinstance(data, np.ndarray) + assert isinstance(max_frac_digits, int) and max_frac_digits >= 0 + if not data.size: + return 0 + if np.issubdtype(data.dtype, np.integer): + return 0 + threshold = 10 ** -(max_frac_digits + 1) + for frac_digits in range(max_frac_digits): + maxdiff = np.max(np.abs(data - np.round(data, frac_digits))) + if maxdiff < threshold: + return frac_digits + return max_frac_digits + + +MAX_INT_DIGITS = 308 + + +def num_int_digits(value): + """ + Number of integer digits. Completely ignores the fractional part. + Does not take sign into account. + + Examples + -------- + >>> num_int_digits(1) + 1 + >>> num_int_digits(99) + 2 + >>> num_int_digits(-99.1) + 2 + >>> num_int_digits(np.array([1, 99, -99.1])) + array([1, 2, 2]) + """ + value = abs(value) + log10 = np.where(value > 0, np.log10(value), 0) + res = np.where(np.isinf(log10), MAX_INT_DIGITS, + # maximum(..., 1) because there must be at least one + # integer digit (the 0 in 0.00..X) + np.maximum(np.floor(log10).astype(int) + 1, 1)) + # use normal Python scalar instead of 0D arrays + return res if res.ndim > 0 else res.item() + + +def log_caller(logger=logger, level=logging.DEBUG): + if logger.isEnabledFor(level): + # We start from our caller (f_back). + # The real current frame (this function's code) has zero interest for us + current_frame = inspect.currentframe().f_back + caller_frame = current_frame.f_back + caller_info = inspect.getframeinfo(caller_frame) + caller_module = os.path.basename(caller_info.filename) + logger.debug( + f"{get_func_name(current_frame)}() " + f"called by {get_func_name(caller_frame)}() " + f"from module {caller_module} at line {caller_info.lineno}") + + +def get_func_name(frame): + # assume that if we have 'self' in the frame, it is a method, otherwise + # it is a function. + func_name = frame.f_code.co_name + if 'self' in frame.f_locals: + # We do not use Python 3.11+ frame.f_code.co_qualname because + # it returns the (super) class where the method is + # defined, not the instance class which is usually much more useful + func_name = f"{frame.f_locals['self'].__class__.__name__}.{func_name}" + return func_name + + +def time2str(seconds, precision="auto"): + """Format a duration in seconds as a string using given precision. + + Parameters + ---------- + seconds : float + Duration (in seconds) to format. + precision : str, optional + Precision of the output. Defaults to "auto" (the largest unit minus 2). + See examples below. + + Returns + ------- + str +# FIXME: round values instead of truncating them + Examples + -------- + >>> time2str(3727.2785, precision="ns") + '1 hour 2 minutes 7 seconds 278 ms 500 µs' + >>> # auto: the largest unit is hour, the unit two steps below is seconds => precision = seconds + >>> time2str(3727.2785) + '1 hour 2 minutes 7 seconds' + >>> time2str(3605.2785) + '1 hour 5 seconds' + >>> time2str(3727.2785, precision="hour") + '1 hour' + >>> time2str(3723.1234567890123456789, precision="ns") + '1 hour 2 minutes 3 seconds 123 ms 456 µs 789 ns' + >>> time2str(3723.1234567890123456789) + '1 hour 2 minutes 3 seconds' + """ + # for Python 3.7+, we could use a dict (and rely on dict ordering) + divisors = [ + ('ns', 1000), + ('µs', 1000), + ('ms', 1000), + ('second', 60), + ('minute', 60), + ('hour', 24), + ('day', 365), + ] + precision_map = { + 'day': 6, + 'hour': 5, + 'minute': 4, + 'second': 3, + 'ms': 2, + 'µs': 1, + 'ns': 0, + } + + values = [] + str_parts = [] + ns = int(seconds * 10 ** 9) + value = ns + for cur_precision, (unit, divisor_for_next) in enumerate(divisors): + next_value, cur_value = divmod(value, divisor_for_next) + values.append(cur_value) + if next_value == 0: + break + value = next_value + max_prec = len(values) - 1 + int_precision = max_prec - 2 if precision == 'auto' else precision_map[precision] + for cur_precision, (cur_value, (unit, divisor_for_next)) in enumerate(zip(values, divisors)): + if cur_value > 0 and cur_precision >= int_precision: + str_parts.append(f"{cur_value:d} {unit}{'s' if cur_value > 1 and cur_precision > 2 else ''}") + return ' '.join(str_parts[::-1]) + + +def timed(logger): + def decorator(func): + def new_func(*args, **kwargs): + # testing for this outside new_func to make the decorator return + # the original func if the logger is not enabled does not work + # because the logger is not configured yet when the code to be + # profiled is defined (and the decorator is called) + if logger.isEnabledFor(logging.DEBUG): + start = time.perf_counter() + res = func(*args, **kwargs) + time_taken = time.perf_counter() - start + logger.debug(f"{func.__name__} done in {time2str(time_taken)}") + return res + else: + return func(*args, **kwargs) + return new_func + return decorator + + +def list_drives(): + if PY312: + return os.listdrives() + else: + try: + import win32api + drives_str = win32api.GetLogicalDriveStrings() + return [drivestr for drivestr in drives_str.split('\000') + if drivestr] + except ImportError: + logger.warning("Unable to list drives: on Python < 3.12," + "this needs the 'win32api' module") + return [] \ No newline at end of file diff --git a/setup.py b/setup.py index e0adf99c..6c33a2d7 100644 --- a/setup.py +++ b/setup.py @@ -14,12 +14,24 @@ def readlocal(fname): LONG_DESCRIPTION = readlocal("README.rst") LONG_DESCRIPTION_CONTENT_TYPE = "text/x-rst" SETUP_REQUIRES = [] -# pyqt cannot be installed via pypi. Dependencies (pyqt, qtpy and matplotlib) moved to conda recipe -# requires larray >= 0.32 because of the LArray -> Array rename -# TODO: add qtpy as dependency and mention pyqt or pyside -# when using pyqt, we require at least pyqt >= 4.6 (for API v2) -# jedi >=0.18 to workaround incompatibility between jedi <0.18 and parso >=0.8 (see #220) -INSTALL_REQUIRES = ['larray >=0.32', 'jedi >=0.18'] + +# * jedi >=0.18 to workaround incompatibility between jedi <0.18 and +# parso >=0.8 (see #220) +# * Technically, we should require larray >=0.35 because we need align_arrays +# for compare(), but to make larray-editor releasable, we cannot depend on +# larray X.Y when releasing larray-editor X.Y (see utils.py for more details) +# TODO: require 0.35 for next larray-editor version and drop shim in utils.py +# * Pandas is required directly for a silly reason (to support converting +# pandas dataframes to arrays before comparing them). We could make it an +# optional dependency by lazily importing it but but since it is also +# indirectly required via larray, it does not really matter. +# * we do not actually require PyQt6 but rather either PyQt5, PyQt6 or PySide6 +# but I do not know how to specify this +# * we also have optional dependencies (but I don't know how to specify them): +# - 'xlwings' for the "Copy to Excel" context-menu action +# - 'tables' (PyTables) to load the example datasets from larray +INSTALL_REQUIRES = ['jedi >=0.18', 'larray >=0.32', 'matplotlib', 'numpy', + 'pandas', 'PyQt6', 'qtpy'] TESTS_REQUIRE = ['pytest'] LICENSE = 'GPLv3' @@ -35,11 +47,11 @@ def readlocal(fname): 'Intended Audience :: Developers', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Topic :: Scientific/Engineering', 'Topic :: Software Development :: Libraries', ]