From 4af468dd02491e189c192ffb6dae2f07272a9893 Mon Sep 17 00:00:00 2001 From: Zahari Kassabov Date: Tue, 4 Mar 2025 19:44:24 +0000 Subject: [PATCH] Support forward refs Use get_type_hints instead of the raw annotations to resolve references within dataclasses, namedtuples and typed dicts. --- validobj/tests/test_custom.py | 21 ++++++++++++++++-- validobj/tests/test_forward.py | 39 ++++++++++++++++++++++++++++++++++ validobj/validation.py | 23 +++++++++++++++----- 3 files changed, 76 insertions(+), 7 deletions(-) create mode 100644 validobj/tests/test_forward.py diff --git a/validobj/tests/test_custom.py b/validobj/tests/test_custom.py index 769df2f..d226a19 100644 --- a/validobj/tests/test_custom.py +++ b/validobj/tests/test_custom.py @@ -1,6 +1,6 @@ import pytest import dataclasses -from typing import Any +from typing import Any, TypedDict try: from validobj.custom import Parser, Validator, InputType @@ -13,7 +13,7 @@ @pytest.mark.skipif(not HAVE_CUSTOM, reason="Custom type not found") -def test_custom(): +def test_custom_dataclass(): def my_float(inp: str) -> float: return float(inp) + 1 @@ -30,6 +30,23 @@ class Container: with pytest.raises(ValidationError): parse_input({"value": 5}, Container) +@pytest.mark.skipif(not HAVE_CUSTOM, reason="Custom type not found") +def test_custom_typeddict(): + def my_float(inp: str) -> float: + return float(inp) + 1 + + MyFloat = Parser(my_float) + assert MyFloat.__origin__ is float + assert isinstance(MyFloat.__metadata__[0], InputType) + assert isinstance(MyFloat.__metadata__[1], Validator) + + class Container(TypedDict): + value: MyFloat + + assert parse_input({"value": "5"}, Container) == Container(value=6) + with pytest.raises(ValidationError): + parse_input({"value": 5}, Container) + @pytest.mark.skipif(not HAVE_CUSTOM, reason="Custom type not found") def test_no_annotations(): diff --git a/validobj/tests/test_forward.py b/validobj/tests/test_forward.py new file mode 100644 index 0000000..c839ab5 --- /dev/null +++ b/validobj/tests/test_forward.py @@ -0,0 +1,39 @@ +from typing import NamedTuple, TypedDict, Optional +import dataclasses + +import validobj + + +class C(TypedDict): + a: 'A' + b: 'B' + c: Optional['C'] = None + + +class B(NamedTuple): + a0: 'A' + a1: 'A' + + +@dataclasses.dataclass +class A: + children: list['A'] + + +def test_dataclass(): + assert validobj.parse_input({"children": [{"children": []}]}, A) + + +def test_namedtuple(): + assert validobj.parse_input([{"children": []}, {"children": []}], B) + + +def test_typeddict(): + assert validobj.parse_input( + { + 'a': {'children': [{'children': []}]}, + 'b': [{"children": []}, {"children": []}], + 'c': None, + }, + C, + ) diff --git a/validobj/validation.py b/validobj/validation.py index 859fb70..c6eda55 100644 --- a/validobj/validation.py +++ b/validobj/validation.py @@ -10,7 +10,8 @@ """ -from typing import Set, Union, Any, Optional, TypeVar, Type, Literal + +from typing import Set, Union, Any, Optional, TypeVar, Type, Literal, get_type_hints import sys try: @@ -209,10 +210,18 @@ def _parse_dataclass(value, spec): header=f"Cannot process value into {_typename(spec)!r} because " f"fields do not match.", ) + + # Use this to resolve forward references. + annotations = get_type_hints(spec, include_extras=True) + res = {} field_dict = { # Look inside InitVar - f.name: f.type if not isinstance(f.type, dataclasses.InitVar) else f.type.type + f.name: ( + annotations[f.name] + if not isinstance(f.type, dataclasses.InitVar) + else annotations[f.name].type + ) for f in fields } for k, v in value.items(): @@ -239,10 +248,12 @@ def _parse_typed_dict(value, spec): header=f"Cannot process value into {_typename(spec)!r} because " f"fields do not match.", ) + # Resolve forward references. + annotations = get_type_hints(spec, include_extras=True) res = {} for k, v in value.items(): try: - res[k] = parse_input(v, spec.__annotations__[k]) + res[k] = parse_input(v, annotations[k]) except ValidationError as e: raise WrongFieldError( f"Cannot process field {k!r} of value into the " @@ -281,10 +292,12 @@ def _parse_namedtuple(value, spec): res = {} + annotations = get_type_hints(spec, include_extras=True) + for i, (k, v) in enumerate(field_inputs.items()): - if k in spec.__annotations__: + if k in annotations: try: - res[k] = parse_input(v, spec.__annotations__[k]) + res[k] = parse_input(v, annotations[k]) except ValidationError as e: raise WrongListItemError( f"Cannot process list item {i+1} into the field {k!r} of {_typename(spec)!r}",