diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 91ca9fc6..d6d6227e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,8 @@ Added - Signature methods now when given ``sub_configs=True``, list of paths types can now receive a file containing a list of paths (`#816 `__). +- Public interface for enabling/disabling support of type subclasses (`#817 + `__). Fixed ^^^^^ diff --git a/DOCUMENTATION.rst b/DOCUMENTATION.rst index acf22626..f834f33c 100644 --- a/DOCUMENTATION.rst +++ b/DOCUMENTATION.rst @@ -501,15 +501,15 @@ Some notes about this support are: :py:meth:`.ArgumentParser.instantiate_classes` can be used to instantiate all classes in a config object. For more details see :ref:`sub-classes`. -- ``Protocol`` types are also supported the same as sub-classes. The protocols +- ``Protocol`` types are also supported the same as subclasses. The protocols are not required to be ``runtime_checkable``. But the accepted classes must match exactly the signature of the protocol's public methods. -- ``dataclasses`` are supported even when nested. Final classes, attrs' - ``define`` decorator, and pydantic's ``dataclass`` decorator and ``BaseModel`` - classes are supported and behave like standard dataclasses. For more details - see :ref:`dataclass-like`. If a dataclass is mixed inheriting from a normal - class, it is considered a subclass type instead of a dataclass. +- ``dataclasses`` are supported even when nested and by default don't accept + subclasses. Final classes, attrs' ``define``, pydantic's ``dataclass`` and + pydantic's ``BaseModel`` classes are supported and behave like standard + dataclasses. For more details see :ref:`subclasses-disabled`. If a dataclass + is mixed inheriting from a normal class, by default it will accept subclasses. - User-defined ``Generic`` types are supported. For more details see :ref:`generic-types`. @@ -953,55 +953,6 @@ be achieved as follows: Namespace(dict={'key1': 'val1', 'key2': 'val2'}) -.. _dataclass-like: - -Dataclass-like classes ----------------------- - -In contrast to subclasses, which requires the user to provide a ``class_path``, -in some cases it is not expected to have subclasses. In this case the init args -are given directly in a dictionary without specifying a ``class_path``. This is -the behavior for standard ``dataclasses``, ``final`` classes, attrs' ``define`` -decorator, and pydantic's ``dataclass`` decorator and ``BaseModel`` classes. - -As an example, take a class that is decorated with :func:`.final`, meaning that -it shouldn't be subclassed. The code below would accept the corresponding YAML -structure. - -.. testsetup:: final_classes - - cwd = os.getcwd() - tmpdir = tempfile.mkdtemp(prefix="_jsonargparse_doctest_") - os.chdir(tmpdir) - with open("config.yaml", "w") as f: - f.write("data:\n number: 8\n accepted: true\n") - -.. testcleanup:: final_classes - - os.chdir(cwd) - shutil.rmtree(tmpdir) - -.. testcode:: final_classes - - from jsonargparse.typing import final - - - @final - class FinalClass: - def __init__(self, number: int = 0, accepted: bool = False): - ... - - - parser = ArgumentParser() - parser.add_argument("--data", type=FinalClass) - cfg = parser.parse_path("config.yaml") - -.. code-block:: yaml - - data: - number: 8 - accepted: true - .. _generic-types: Generic types @@ -1132,9 +1083,9 @@ requires to give both a serializer and a deserializer as seen below. .. note:: The registering of types is only intended for simple types. By default any - class used as a type hint is considered a sub-class (see :ref:`sub-classes`) + class used as a type hint is considered a subclass (see :ref:`sub-classes`) which might be good for many use cases. If a class is registered with - :func:`.register_type` then the sub-class option is no longer available. + :func:`.register_type` then the subclass option is no longer available. .. _custom-types: @@ -1953,8 +1904,8 @@ In Python, dependency injection is achieved by: .. _sub-classes: -Class type and sub-classes --------------------------- +Class type and subclasses +------------------------- When a class is used as a type hint, jsonargparse expects in config files a dictionary with a ``class_path`` entry indicating the dot notation expression to @@ -1963,7 +1914,7 @@ instantiate it. When parsing, it will be checked that the class can be imported, that it is a subclass of the given type and that ``init_args`` values correspond to valid arguments to instantiate it. After parsing, the config object will include the ``class_path`` and ``init_args`` entries. To get a config object -with all nested sub-classes instantiated, the +with all nested subclasses instantiated, the :py:meth:`.ArgumentParser.instantiate_classes` method is used. Additional to using a class as type hint in signatures, for low level @@ -2314,6 +2265,123 @@ to :ref:`instance-factories`. module from where the respective object can be imported. +.. _subclasses-disabled: + +Class types with subclasses disabled +------------------------------------ + +In certain situations, it is preferable to use a class as a type hint with no +intention to receive subclasses. From a parser perspective, this means that +providing a subclass is not permitted, and when serializing, the instantiation +arguments are stored directly, without including ``class_path`` and +``init_args``. The standard Python approach for this scenario is to decorate +classes with :func:`.final`, which explicitly indicates that subclassing is not +intended. A parsing example would be: + +.. testcode:: final_classes + + from jsonargparse.typing import final + + + @final + class FinalClass: + def __init__(self, number: int = 0, accepted: bool = False): + ... + + + parser = ArgumentParser() + parser.add_argument("--data", type=FinalClass) + cfg = parser.parse_args(["--data.number=8", "--data.accepted=true"]) + +for which a dump would give as output: + +.. doctest:: final_classes + + >>> print(parser.dump(cfg)) # doctest: +NORMALIZE_WHITESPACE + data: + number: 8 + accepted: true + +In some cases, subclasses are not intended, but the :func:`.final` decorator is +not applied. For example, having ``class_path`` for a simple ``x, y`` +coordinates dataclass would be unnecessarily cumbersome. For this reason, +``jsonargparse`` early on, implemented the same behavior for pure (not mixed +with normal classes) ``dataclasses``, attrs' ``define``, pydantic's +``dataclass``, and pydantic's ``BaseModel`` classes. However, since these +classes technically support subclassing, subclass support can be enabled as +described below. Subclass support has been kept disabled for these types by +default to avoid introducing breaking changes. + + +.. _enable-disable-subclasses: + +Enable/disable subclasses +------------------------- + +The :func:`.set_parsing_settings` function provides the ``subclasses_disabled`` +and ``subclasses_enabled`` parameters, which, as their names suggest, control +which class types support subclasses. The ``subclasses_disabled`` parameter +accepts a list of class types and functions. When a type is provided, that type +and its descendants will have subclass support disabled. Functions in the list +should accept a type and return ``True`` if subclasses should be disabled for +that type. + +The ``subclasses_enabled`` parameter accepts a list of class types and function +names. When a type is provided, both the type and its descendants will have +subclass support enabled. Types specified in ``subclasses_enabled`` take +precedence over those in ``subclasses_disabled``. If a function name is given to +``subclasses_enabled``, it must correspond to a function previously registered +in ``subclasses_disabled``; in this case, the effect is to unregister it. By +default, the following disable functions are registered: ``is_pure_dataclass``, +``is_pydantic_model``, ``is_attrs_class``, and ``is_final_class``. + +Some examples. Since ``subclasses_enabled`` takes precedence, it is possible to +keep subclass support disabled for dataclasses, but enable enable it for a +specific dataclass as follows: + +.. testsetup:: enable_disable_subclasses + + selectors = _common.subclasses_disabled_selectors + _common.subclasses_disabled_selectors = selectors.copy() + + @dataclass + class DataClassBaseType: + pass + +.. testcleanup:: enable_disable_subclasses + + _common.subclasses_disabled_selectors = selectors + +.. testcode:: enable_disable_subclasses + + from jsonargparse import set_parsing_settings + + set_parsing_settings(subclasses_enabled=[DataClassBaseType]) + +To enable subclass support for all pydantic models, the following can be done: + +.. testcode:: enable_disable_subclasses + + set_parsing_settings(subclasses_enabled=["is_pydantic_model"]) + +To enable subclass support for all dataclasses, but have it disabled for a +specific dataclass, the following can be done: + +.. testcode:: enable_disable_subclasses + + set_parsing_settings( + subclasses_enabled=["is_pure_dataclass"], + subclasses_disabled=[DataClassBaseType], + ) + +.. note:: + + Enabling subclass support for types is currently experimental. While the + interface and behavior is expected to be stable, fundamental issues may + arise that require changes to the design, which could result in breaking + changes in future releases. + + .. _argument-linking: Argument linking diff --git a/jsonargparse/_actions.py b/jsonargparse/_actions.py index 24b02d9c..a675b524 100644 --- a/jsonargparse/_actions.py +++ b/jsonargparse/_actions.py @@ -9,7 +9,7 @@ from contextvars import ContextVar from typing import Any, Optional, Union -from ._common import Action, NonParsingAction, is_not_subclass_type, is_subclass, parser_context +from ._common import Action, NonParsingAction, is_subclass, is_subclasses_disabled, parser_context from ._loaders_dumpers import get_loader_exceptions, load_value from ._namespace import Namespace, NSKeyError, split_key, split_key_root from ._optionals import _get_config_read_mode, ruamel_support @@ -365,13 +365,13 @@ def update_init_kwargs(self, kwargs): self._typehint = kwargs.pop("_typehint") self._help_types = self.get_help_types(self._typehint) assert self._help_types and all(isinstance(b, type) for b in self._help_types) - self._not_subclass = len(self._help_types) == 1 and is_not_subclass_type(self._help_types[0]) + self._single_class = len(self._help_types) == 1 and is_subclasses_disabled(self._help_types[0]) self._basename = iter_to_set_str(t.__name__ for t in self._help_types) if len(self._help_types) == 1: - kwargs["nargs"] = 0 if self._not_subclass else "?" + kwargs["nargs"] = 0 if self._single_class else "?" - if self._not_subclass: + if self._single_class: msg = "" else: kwargs["metavar"] = "CLASS_PATH_OR_NAME" diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py index 885c87bb..0db60ec3 100644 --- a/jsonargparse/_common.py +++ b/jsonargparse/_common.py @@ -110,6 +110,8 @@ def set_parsing_settings( parse_optionals_as_positionals: Optional[bool] = None, stubs_resolver_allow_py_files: Optional[bool] = None, omegaconf_absolute_to_relative_paths: Optional[bool] = None, + subclasses_disabled: Optional[list[Union[type, Callable[[type], bool]]]] = None, + subclasses_enabled: Optional[list[Union[type, str]]] = None, ) -> None: """ Modify settings that affect the parsing behavior. @@ -138,6 +140,16 @@ def set_parsing_settings( with ``omegaconf+`` parser mode, absolute interpolation paths are converted to relative. This is only intended for backward compatibility with ``omegaconf`` parser mode. + subclasses_disabled: List of types or functions, so that when parsing + only the exact type hints (not their subclasses) are accepted. + Descendants of the configured types are also disabled. Functions + should return ``True`` for types to disable. + subclasses_enabled: List of types or disable function names, so that + subclasses are accepted. Types given here have precedence over those + in ``subclasses_disabled``. Giving a function name removes the + corresponding function from ``subclasses_disabled``. By default, the + following disable functions are registered: ``is_pure_dataclass``, + ``is_pydantic_model``, ``is_attrs_class`` and ``is_final_class``. """ # validate_defaults if isinstance(validate_defaults, bool): @@ -171,6 +183,12 @@ def set_parsing_settings( raise ValueError( f"omegaconf_absolute_to_relative_paths must be a boolean, but got {omegaconf_absolute_to_relative_paths}." ) + # subclass behavior + if subclasses_disabled or subclasses_enabled: + subclass_type_behavior( + subclasses_disabled=subclasses_disabled, + subclasses_enabled=subclasses_enabled, + ) def get_parsing_setting(name: str): @@ -283,20 +301,55 @@ def is_pure_dataclass(cls) -> bool: return all(dataclasses.is_dataclass(c) for c in classes) -not_subclass_type_selectors: dict[str, Callable[[type], Union[bool, int]]] = { - "final": is_final_class, - "dataclass": is_pure_dataclass, - "pydantic": is_pydantic_model, - "attrs": is_attrs_class, +subclasses_enabled_types: set[type] = set() +subclasses_disabled_types: set[type] = set() +subclasses_disabled_selectors: dict[str, Callable[[type], Union[bool, int]]] = { + "is_pure_dataclass": is_pure_dataclass, + "is_pydantic_model": is_pydantic_model, + "is_attrs_class": is_attrs_class, + "is_final_class": is_final_class, } -def is_not_subclass_type(cls) -> bool: +def is_subclasses_disabled(cls) -> bool: if is_generic_class(cls): - return is_not_subclass_type(cls.__origin__) + return is_subclasses_disabled(cls.__origin__) if not inspect.isclass(cls): return False - return any(validator(cls) for validator in not_subclass_type_selectors.values()) + subclass_disabled = any(selector(cls) for selector in subclasses_disabled_selectors.values()) + if not subclass_disabled: + subclass_disabled = any(issubclass(cls, disable_type) for disable_type in subclasses_disabled_types) + if subclass_disabled: + subclass_disabled = not any(issubclass(cls, enable_type) for enable_type in subclasses_enabled_types) + return subclass_disabled + + +def subclass_type_behavior( + subclasses_disabled: Optional[list[Union[type, Callable[[type], bool]]]] = None, + subclasses_enabled: Optional[list[Union[type, str]]] = None, +) -> None: + """Configures whether class types accept or not subclasses.""" + for enable_item in subclasses_enabled or []: + if isinstance(enable_item, str): + if enable_item not in subclasses_disabled_selectors: + raise ValueError(f"There is no function '{enable_item}' registered in subclasses_disabled") + subclasses_disabled_selectors.pop(enable_item) + elif inspect.isclass(enable_item): + subclasses_enabled_types.add(enable_item) + else: + raise ValueError( + f"Expected 'subclasses_enabled' list items to be types or strings, but got {enable_item!r}" + ) + + for disable_item in subclasses_disabled or []: + if inspect.isclass(disable_item): + subclasses_disabled_types.add(disable_item) + elif inspect.isfunction(disable_item): + subclasses_disabled_selectors[disable_item.__name__] = disable_item + else: + raise ValueError( + f"Expected 'subclasses_disabled' list items to be types or functions, but got {disable_item!r}" + ) def default_class_instantiator(class_type: type[ClassType], *args, **kwargs) -> ClassType: diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 75b771ee..339d83ef 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -38,7 +38,7 @@ class_instantiators, debug_mode_active, get_optionals_as_positionals_actions, - is_not_subclass_type, + is_subclasses_disabled, lenient_check, parser_context, supports_optionals_as_positionals, @@ -126,7 +126,7 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs): return ActionParser._move_parser_actions(parser, args, kwargs) ActionConfigFile._ensure_single_config_argument(self, kwargs["action"]) if "type" in kwargs: - if is_not_subclass_type(kwargs["type"]): + if is_subclasses_disabled(kwargs["type"]): nested_key = args[0].lstrip("-") self.add_class_arguments(kwargs.pop("type"), nested_key, **kwargs) return _find_action(parser, nested_key) diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py index f7361374..c0aac79c 100644 --- a/jsonargparse/_signatures.py +++ b/jsonargparse/_signatures.py @@ -13,8 +13,8 @@ get_generic_origin, get_unaliased_type, is_final_class, - is_not_subclass_type, is_subclass, + is_subclasses_disabled, ) from ._namespace import Namespace from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model @@ -85,7 +85,7 @@ def add_class_arguments( or (isinstance(default, LazyInitBaseClass) and isinstance(default, unaliased_class_type)) or ( not is_final_class(default.__class__) - and is_not_subclass_type(default.__class__) + and is_subclasses_disabled(default.__class__) and isinstance(default, unaliased_class_type) ) ): @@ -386,7 +386,7 @@ def _add_signature_parameter( elif not as_positional or is_non_positional: kwargs["required"] = True is_subclass_typehint = False - is_not_subclass_typehint = is_not_subclass_type(annotation) + subclasses_disabled = is_subclasses_disabled(annotation) dest = (nested_key + "." if nested_key else "") + name args = [dest if is_required and as_positional and not is_non_positional else "--" + dest] if param.origin: @@ -401,11 +401,7 @@ def _add_signature_parameter( f"Conditional arguments [origins: {group_name}]", name=group_name, ) - if ( - annotation in {str, int, float, bool} - or is_subclass(annotation, (str, int, float)) - or is_not_subclass_typehint - ): + if annotation in {str, int, float, bool} or is_subclass(annotation, (str, int, float)) or subclasses_disabled: kwargs["type"] = annotation register_pydantic_type(annotation) elif annotation != inspect_empty: @@ -440,7 +436,7 @@ def _add_signature_parameter( "sub_configs": sub_configs, "instantiate": instantiate, } - if is_not_subclass_typehint: + if subclasses_disabled: kwargs.update(sub_add_kwargs) with ActionTypeHint.allow_default_instance_context(): action = container.add_argument(*args, **kwargs) @@ -612,6 +608,6 @@ def convert_to_dict(value) -> dict: attr[num] = convert_to_dict(item) init_args[name] = attr - if is_not_subclass_type(value_type): + if is_subclasses_disabled(value_type): return init_args return {"class_path": get_import_path(value_type), "init_args": init_args} diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 039a27b3..0e39776e 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -54,8 +54,8 @@ get_unaliased_type, is_generic_class, is_instance, - is_not_subclass_type, is_subclass, + is_subclasses_disabled, lenient_check, nested_links, parent_parser, @@ -312,7 +312,7 @@ def is_supported_typehint(typehint, full=False): or get_typehint_origin(typehint) in root_types or get_registered_type(typehint) is not None or is_subclass(typehint, Enum) - or is_not_subclass_type(typehint) + or is_subclasses_disabled(typehint) or ActionTypeHint.is_subclass_typehint(typehint) ) if full and supported: @@ -1085,15 +1085,15 @@ def adapt_typehints( return val_class # importable instance if is_protocol(val_class): raise_unexpected_value(f"Expected an instantiatable class, but {val['class_path']} is a protocol") - not_subclass = False + subclass = True if not is_subclass_or_implements_protocol(val_class, typehint): - not_subclass = True + subclass = False if not inspect.isclass(val_class) and callable(val_class): from ._postponed_annotations import get_return_type return_type = get_return_type(val_class, logger) if is_subclass_or_implements_protocol(return_type, typehint): - not_subclass = False + subclass = True elif prev_implicit_defaults: inner_parser = ActionTypeHint.get_class_parser(typehint, sub_add_kwargs) prev_val.init_args = inner_parser.get_defaults() @@ -1101,7 +1101,7 @@ def adapt_typehints( inner_parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs) for key in inner_parser.get_defaults().keys(): prev_val.init_args.pop(key, None) - if not_subclass: + if not subclass: msg = "implement protocol" if is_protocol(typehint) else "correspond to a subclass of" raise_unexpected_value(f"Import path {val['class_path']} does not {msg} {typehint.__name__}") val["class_path"] = class_path @@ -1264,7 +1264,7 @@ def is_single_class_type(typehint, typehint_origin, closed_class): ): return False if not closed_class: - return not is_not_subclass_type(typehint) + return not is_subclasses_disabled(typehint) return True @@ -1536,7 +1536,7 @@ def adapt_class_type( val = load_value(val, simple_types=True) value["dict_kwargs"][key] = val - if is_not_subclass_type(typehint) and value.class_path == get_import_path(typehint): + if is_subclasses_disabled(typehint) and value.class_path == get_import_path(typehint): value = Namespace({**value.get("init_args", {}), **value.get("dict_kwargs", {})}) return value diff --git a/jsonargparse_tests/conftest.py b/jsonargparse_tests/conftest.py index aad70abe..9c78cfd7 100644 --- a/jsonargparse_tests/conftest.py +++ b/jsonargparse_tests/conftest.py @@ -146,6 +146,14 @@ def example_parser() -> ArgumentParser: return parser +@pytest.fixture +def subclass_behavior(monkeypatch) -> Iterator[None]: + monkeypatch.setattr("jsonargparse._common.subclasses_enabled_types", set()) + monkeypatch.setattr("jsonargparse._common.subclasses_disabled_types", set()) + with patch.dict("jsonargparse._common.subclasses_disabled_selectors"): + yield + + @pytest.fixture def tmp_cwd(tmpdir) -> Iterator[Path]: with tmpdir.as_cwd(): diff --git a/jsonargparse_tests/test_dataclasses.py b/jsonargparse_tests/test_dataclasses.py index 3823ac00..785e9e56 100644 --- a/jsonargparse_tests/test_dataclasses.py +++ b/jsonargparse_tests/test_dataclasses.py @@ -14,6 +14,7 @@ Namespace, set_parsing_settings, ) +from jsonargparse._common import subclasses_disabled_selectors from jsonargparse._namespace import NSKeyError from jsonargparse._optionals import ( docstring_parser_support, @@ -31,6 +32,13 @@ annotated = typing_extensions_import("Annotated") + +@pytest.fixture +def enable_subclasses(subclass_behavior): + set_parsing_settings(subclasses_enabled=["is_pure_dataclass"]) + yield + + BetweenThreeAndNine = restricted_number_type("BetweenThreeAndNine", float, [(">=", 3), ("<=", 9)]) ListPositiveInt = List[PositiveInt] @@ -752,7 +760,7 @@ class DataSub(DataMain): p2: str = "-" -def test_dataclass_not_subclass(parser): +def test_dataclass_subclasses_disabled(parser): parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) help_str = get_parser_help(parser) @@ -763,20 +771,13 @@ def test_dataclass_not_subclass(parser): parser.parse_args([f"--data={json.dumps(config)}"]) -def test_add_subclass_dataclass_not_subclass(parser): +def test_add_subclass_dataclass_subclasses_disabled(parser): with pytest.raises(ValueError, match="Expected .* a subclass type or a tuple of subclass types"): parser.add_subclass_arguments(DataMain, "data") -@pytest.fixture -def subclass_behavior(): - with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors: - not_subclass_type_selectors.pop("dataclass") - yield - - @pytest.mark.parametrize("default", [None, DataMain()]) -def test_add_subclass_dataclass_as_subclass(parser, default, subclass_behavior): +def test_add_subclass_dataclass_subclasses_enabled(parser, default, enable_subclasses): parser.add_subclass_arguments(DataMain, "data", default=default) config = {"class_path": f"{__name__}.DataMain", "init_args": {"p1": 2}} @@ -796,7 +797,7 @@ def test_add_subclass_dataclass_as_subclass(parser, default, subclass_behavior): assert dump == {"class_path": f"{__name__}.DataSub", "init_args": {"p1": 1, "p2": "y"}} -def test_add_argument_dataclass_as_subclass(parser, subtests, subclass_behavior): +def test_add_argument_dataclass_subclasses_enabled(parser, subtests, enable_subclasses): parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) with subtests.test("help"): @@ -849,12 +850,52 @@ def test_add_argument_dataclass_as_subclass(parser, subtests, subclass_behavior) assert dataclasses.asdict(init.data) == {"p1": 2, "p2": "-"} +def test_add_argument_dataclass_single_type_subclasses_enabled(parser, subclass_behavior): + set_parsing_settings(subclasses_enabled=[DataMain]) + assert "is_pure_dataclass" in subclasses_disabled_selectors + + parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) + + config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}} + cfg = parser.parse_args([f"--data={json.dumps(config)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataSub) + assert dataclasses.asdict(init.data) == {"p1": 2, "p2": "y"} + dump = json_or_yaml_load(parser.dump(cfg))["data"] + assert dump == {"class_path": f"{__name__}.DataSub", "init_args": {"p1": 2, "p2": "y"}} + + +def test_add_argument_dataclass_single_type_subclasses_disabled(parser, enable_subclasses): + set_parsing_settings(subclasses_disabled=[DataMain]) + assert "is_pure_dataclass" not in subclasses_disabled_selectors + + parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) + + config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}} + with pytest.raises(ArgumentError, match="Group 'data' does not accept option 'init_args.p2'"): + parser.parse_args([f"--data={json.dumps(config)}"]) + + +def test_add_argument_dataclass_subclasses_disabled_function(parser, enable_subclasses): + def is_data_main(obj): + return obj is DataMain + + set_parsing_settings(subclasses_disabled=[is_data_main]) + assert "is_pure_dataclass" not in subclasses_disabled_selectors + + parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) + + config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}} + with pytest.raises(ArgumentError, match="Group 'data' does not accept option 'init_args.p2'"): + parser.parse_args([f"--data={json.dumps(config)}"]) + + class ParentData: def __init__(self, data: DataMain = DataMain(p1=2)): self.data = data -def test_dataclass_nested_not_subclass(parser): +def test_dataclass_nested_subclasses_disabled(parser): parser.add_argument("--parent", type=ParentData) help_str = get_parse_args_stdout(parser, ["--parent.help"]) @@ -873,7 +914,7 @@ def test_dataclass_nested_not_subclass(parser): parser.parse_args([f"--parent={json.dumps(config)}"]) -def test_dataclass_nested_as_subclass(parser, subclass_behavior): +def test_dataclass_nested_subclasses_enabled(parser, enable_subclasses): parser.add_argument("--parent", type=ParentData) help_str = get_parse_args_stdout(parser, ["--parent.help"]) @@ -938,7 +979,7 @@ class Person(Pet): ) -def test_convert_to_dict_not_subclass(): +def test_convert_to_dict_subclasses_disabled(): person_dict = convert_to_dict(person) assert person_dict == { "name": "jt", @@ -953,7 +994,7 @@ def test_convert_to_dict_not_subclass(): } -def test_convert_to_dict_subclass(subclass_behavior): +def test_convert_to_dict_subclasses_enabled(enable_subclasses): person_dict = convert_to_dict(person) assert person_dict == { "class_path": f"{__name__}.Person", diff --git a/jsonargparse_tests/test_parsing_settings.py b/jsonargparse_tests/test_parsing_settings.py index c1c824ed..d4473d2c 100644 --- a/jsonargparse_tests/test_parsing_settings.py +++ b/jsonargparse_tests/test_parsing_settings.py @@ -203,3 +203,30 @@ def test_set_stubs_resolver_allow_py_files_failure(): def test_set_omegaconf_absolute_to_relative_paths_failure(): with pytest.raises(ValueError, match="omegaconf_absolute_to_relative_paths must be a boolean"): set_parsing_settings(omegaconf_absolute_to_relative_paths="invalid") + + +# enable/disable-subclasses + + +def test_default_subclass_disable_functions(subclass_behavior): + from jsonargparse._common import subclasses_disabled_selectors + + set_parsing_settings( + subclasses_enabled=["is_pure_dataclass", "is_pydantic_model", "is_attrs_class", "is_final_class"] + ) + assert not subclasses_disabled_selectors + + +def test_unknown_subclass_disable_function(): + with pytest.raises(ValueError, match="no function 'unknown_selector'"): + set_parsing_settings(subclasses_enabled=["unknown_selector"]) + + +def test_invalid_item_type_subclass_enable(): + with pytest.raises(ValueError, match="Expected 'subclasses_enabled' list items to be types or strings"): + set_parsing_settings(subclasses_enabled=[123]) + + +def test_invalid_item_type_subclass_disable(): + with pytest.raises(ValueError, match="Expected 'subclasses_disabled' list items to be types or functions"): + set_parsing_settings(subclasses_disabled=[123]) diff --git a/jsonargparse_tests/test_pydantic.py b/jsonargparse_tests/test_pydantic.py index 132e15f8..68c7b9a8 100644 --- a/jsonargparse_tests/test_pydantic.py +++ b/jsonargparse_tests/test_pydantic.py @@ -4,11 +4,10 @@ import json from copy import deepcopy from typing import Dict, List, Literal, Optional, Union -from unittest.mock import patch import pytest -from jsonargparse import ArgumentError, ArgumentParser, Namespace +from jsonargparse import ArgumentError, ArgumentParser, Namespace, set_parsing_settings from jsonargparse._optionals import ( docstring_parser_support, pydantic_support, @@ -40,10 +39,9 @@ def missing_pydantic(): @pytest.fixture -def subclass_behavior(): - with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors: - not_subclass_type_selectors.pop("pydantic") - yield +def enable_subclasses(subclass_behavior): + set_parsing_settings(subclasses_enabled=["is_pydantic_model"]) + yield @skip_if_pydantic_v1_on_v2 @@ -396,7 +394,7 @@ class Person(Pet): } -def test_model_argument_as_subclass(parser, subtests, subclass_behavior): +def test_model_argument_subclasses_enabled(parser, subtests, enable_subclasses): parser.add_argument("--person", type=Person, default=person) with subtests.test("help"): @@ -425,11 +423,11 @@ def test_model_argument_as_subclass(parser, subtests, subclass_behavior): assert dump == expected -def test_convert_to_dict_not_subclass(): +def test_convert_to_dict_closed_to_subclasses(): converted = convert_to_dict(person) assert converted == person_expected_dict -def test_convert_to_dict_subclass(subclass_behavior): +def test_convert_to_dict_subclasses_enabled(enable_subclasses): converted = convert_to_dict(person) assert converted == person_expected_subclass_dict diff --git a/sphinx/conf.py b/sphinx/conf.py index 93d5050d..1e5d9f17 100644 --- a/sphinx/conf.py +++ b/sphinx/conf.py @@ -85,6 +85,7 @@ def check_output(self, want, got, optionflags): from typing import Callable, Iterable, List, Protocol import jsonargparse_tests from jsonargparse import * +from jsonargparse import _common from jsonargparse.typing import * from jsonargparse._util import unresolvable_import_paths