diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b0ac581b..295c55db9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ ### Fixed - Fixed copying of `TypeMap` and `TypeConfigurator`. Previously, the same global `TypeConfigurator` instance was used in all copies of a `TypeMap`. @rly [#1302](https://github.com/hdmf-dev/hdmf/pull/1302) - Fixed `get_data_shape` to use `Data.data.shape` instead of `Data.shape`, which may be overridden by subclasses. @rly [#1311](https://github.com/hdmf-dev/hdmf/pull/1311) +- Added a check when setting or adding data to a `DynamicTableRegion` or setting the `table` attribute of a `DynamicTableRegion` + that the data values are in bounds of the linked table. This can be turned off for + `DynamicTableRegion.__init__` using the keyword argument `validate_data=False`. @rly [#1168](https://github.com/hdmf-dev/hdmf/pull/1168) ### Added - Added a check for a compound datatype that is not defined in the schema or spec. This is currently not supported. @mavaylon1 [#1276](https://github.com/hdmf-dev/hdmf/pull/1276) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index f4b2cd011..62ab7f535 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -1404,17 +1404,45 @@ class DynamicTableRegion(VectorData): 'table', ) + MAX_ROWS_TO_VALIDATE_INIT = int(1e3) + @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorData'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors'}, {'name': 'description', 'type': str, 'doc': 'a description of what this region represents'}, {'name': 'table', 'type': DynamicTable, 'doc': 'the DynamicTable this region applies to', 'default': None}, + {'name': 'validate_data', 'type': bool, + 'doc': 'whether to validate the data is in bounds of the linked table', 'default': True}, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): - t = popargs('table', kwargs) + table, validate_data = popargs('table', 'validate_data', kwargs) + data = getargs('data', kwargs) + self._validate_data = validate_data + if self._validate_data: + self._validate_index_in_range(data, table) + super().__init__(**kwargs) - self.table = t + if table is not None: # set the table attribute using fields to avoid another validation in the setter + self.fields['table'] = table + + def _validate_index_in_range(self, data, table): + """If the length of data is small, and if data contains an index that is out of bounds, then raise an error. + If the object is being constructed from a file, raise a warning instead to ensure invalid data can still be + read. + """ + if table and len(data) <= self.MAX_ROWS_TO_VALIDATE_INIT: + if isinstance(data, (list, tuple)): + data_arr = np.array(data) + else: + data_arr = data[:] + violators = np.where((data_arr >= len(table)) | (data_arr < 0))[0] + if violators.size > 0: + error_msg = ( + f"DynamicTableRegion values {data_arr[violators]} are out of bounds for " + f"{type(table)} '{table.name}'." + ) + self._error_on_new_warn_on_construct(error_msg, error_cls=IndexError) @property def table(self): @@ -1422,24 +1450,27 @@ def table(self): return self.fields.get('table') @table.setter - def table(self, val): + def table(self, table): """ - Set the table this DynamicTableRegion should be pointing to + Set the table this DynamicTableRegion should be pointing to. + + If the length of the data is small, this will validate all data elements in this DynamicTableRegion to + ensure they are within bounds. - :param val: The DynamicTable this DynamicTableRegion should be pointing to + :param table: The DynamicTable this DynamicTableRegion should be pointing to :raises: AttributeError if table is already in fields :raises: IndexError if the current indices are out of bounds for the new table given by val """ - if val is None: + if table is None: return if 'table' in self.fields: msg = "can't set attribute 'table' -- already set" raise AttributeError(msg) - dat = self.data - if isinstance(dat, DataIO): - dat = dat.data - self.fields['table'] = val + + self.fields['table'] = table + if self._validate_data: + self._validate_index_in_range(self.data, table) def __getitem__(self, arg): return self.get(arg) @@ -1583,6 +1614,12 @@ def _validate_on_set_parent(self): warn(msg, stacklevel=2) return super()._validate_on_set_parent() + def _validate_new_data_element(self, arg): + """Validate that the new index is within bounds of the table. Raises an IndexError if not.""" + if self.table and (arg >= len(self.table) or arg < 0): + raise IndexError(f"DynamicTableRegion index {arg} is out of bounds for " + f"{type(self.table)} '{self.table.name}'.") + def _uint_precision(elements): """ Calculate the uint precision needed to encode a set of elements """ diff --git a/src/hdmf/container.py b/src/hdmf/container.py index dd30d208b..33100e9a5 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -573,6 +573,19 @@ def _validate_on_set_parent(self): """ pass + def _error_on_new_warn_on_construct(self, error_msg: str, error_cls: type = ValueError): + """Raise a ValueError when a check is violated on instance creation. + To ensure backwards compatibility, this method throws a warning + instead of raising an error when reading from a file, ensuring that + files with invalid data can be read. If error_msg is set to None + the function will simply return without further action. + """ + if error_msg is None: + return + if not self._in_construct_mode: + raise error_cls(error_msg) + warn(error_msg) + class Container(AbstractContainer): """A container that can contain other containers and has special functionality for printing.""" diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 16136a8da..6199c679d 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -1,5 +1,4 @@ import numpy as np -import os import shutil import tempfile from warnings import warn @@ -659,9 +658,6 @@ class TestGetClassSeparateNamespace(TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - if os.path.exists(self.test_dir): # start clean - self.tearDown() - os.mkdir(self.test_dir) self.bar_spec = GroupSpec( doc='A test group specification with a data type', @@ -843,9 +839,6 @@ class TestGetClassObjectReferences(TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - if os.path.exists(self.test_dir): # start clean - self.tearDown() - os.mkdir(self.test_dir) self.type_map = TypeMap() def tearDown(self): diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index 15a0c9e91..1c439248a 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -1285,6 +1285,34 @@ def test_no_df_nested(self): with self.assertRaisesWith(ValueError, msg): dynamic_table_region.get(0, df=False, index=False) + def test_init_out_of_bounds(self): + table = self.with_columns_and_data() + with self.assertRaises(IndexError): + DynamicTableRegion(name='dtr', data=[0, 1, 2, 2, 5], description='desc', table=table) + + def test_init_out_of_bounds_long(self): + table = self.with_columns_and_data() + data = np.ones(DynamicTableRegion.MAX_ROWS_TO_VALIDATE_INIT+1, dtype=int)*5 + dtr = DynamicTableRegion(name='dtr', data=data, description='desc', table=table) + assert dtr.data is data # no exception raised + + def test_init_out_of_bounds_no_validate(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1, 5], description='desc', table=table, validate_data=False) + self.assertEqual(dtr.data, [0, 1, 5]) # no exception raised + + def test_add_row_out_of_bounds(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + with self.assertRaises(IndexError): + dtr.add_row(5) + + def test_set_table_out_of_bounds(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1, 5], description='desc') + with self.assertRaises(IndexError): + dtr.table = table + class DynamicTableRegionRoundTrip(H5RoundTripMixin, TestCase): diff --git a/tests/unit/test_io_hdf5_streaming.py b/tests/unit/test_io_hdf5_streaming.py index 1a487b939..c03110a76 100644 --- a/tests/unit/test_io_hdf5_streaming.py +++ b/tests/unit/test_io_hdf5_streaming.py @@ -80,9 +80,9 @@ def setUp(self): self.manager = BuildManager(type_map) def tearDown(self): - if os.path.exists(self.ns_filename): + if hasattr(self, 'ns_filename') and os.path.exists(self.ns_filename): os.remove(self.ns_filename) - if os.path.exists(self.ext_filename): + if hasattr(self, 'ext_filename') and os.path.exists(self.ext_filename): os.remove(self.ext_filename) def test_basic_read(self):