Skip to content

Commit 73e7eaf

Browse files
authored
Merge pull request #4 from bogdandm/attrs
Attrs
2 parents 49b5f26 + e5fb8c9 commit 73e7eaf

File tree

11 files changed

+330
-32
lines changed

11 files changed

+330
-32
lines changed

rest_client_gen/dynamic_typing/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Iterable, List, Tuple, Union
22

3-
ImportPathList = List[Tuple[str, Union[Iterable[str], str]]]
3+
ImportPathList = List[Tuple[str, Union[Iterable[str], str, None]]]
44

55

66
class BaseType:

rest_client_gen/dynamic_typing/typing.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,31 @@ def compile_imports(imports: ImportPathList) -> str:
2828
"""
2929
Merge list of imports path and convert them into list code (string)
3030
"""
31-
imports_map: Dict[str, Set[str]] = OrderedDict()
31+
class_imports_map: Dict[str, Set[str]] = OrderedDict()
32+
package_imports_set: Set[str] = set()
3233
for module, classes in filter(None, imports):
33-
classes_set = imports_map.get(module, set())
34-
if isinstance(classes, str):
35-
classes_set.add(classes)
34+
if classes is None:
35+
package_imports_set.add(module)
3636
else:
37-
classes_set.update(classes)
38-
imports_map[module] = classes_set
37+
classes_set = class_imports_map.get(module, set())
38+
if isinstance(classes, str):
39+
classes_set.add(classes)
40+
else:
41+
classes_set.update(classes)
42+
class_imports_map[module] = classes_set
3943

4044
# Sort imports by package name and sort class names of each import
41-
imports_map = OrderedDict(sorted(
42-
((module, sorted(classes)) for module, classes in imports_map.items()),
45+
class_imports_map = OrderedDict(sorted(
46+
((module, sorted(classes)) for module, classes in class_imports_map.items()),
4347
key=operator.itemgetter(0)
4448
))
4549

46-
return "\n".join(f"from {module} import {', '.join(classes)}" for module, classes in imports_map.items())
50+
class_imports = "\n".join(
51+
f"from {module} import {', '.join(classes)}"
52+
for module, classes in class_imports_map.items()
53+
)
54+
package_imports = "\n".join(
55+
f"import {module}"
56+
for module in sorted(package_imports_set)
57+
)
58+
return "\n".join(filter(None, (package_imports, class_imports)))

rest_client_gen/generator.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
from enum import Enum
33
from typing import Any, Callable, List, Optional, Union
44

5-
import inflection
65
from unidecode import unidecode
76

8-
from rest_client_gen.dynamic_typing import ComplexType, SingleType
9-
from .dynamic_typing import (DList, DOptional, DUnion, MetaData, ModelPtr, NoneType, StringSerializable,
10-
StringSerializableRegistry, Unknown, registry)
7+
from .dynamic_typing import (ComplexType, DList, DOptional, DUnion, MetaData, ModelPtr, NoneType, SingleType,
8+
StringSerializable, StringSerializableRegistry, Unknown, registry)
119

1210

1311
class Hierarchy(Enum):
@@ -32,10 +30,6 @@ def __str__(self):
3230
class MetadataGenerator:
3331
CONVERTER_TYPE = Optional[Callable[[str], Any]]
3432

35-
# TODO: sep_style: SepStyle = SepStyle.Underscore
36-
# TODO: hierarchy: Hierarchy = Hierarchy.Nested
37-
# TODO: fpolicy: OptionalFieldsPolicy = OptionalFieldsPolicy.Optional
38-
3933
def __init__(self, str_types_registry: StringSerializableRegistry = None):
4034
self.str_types_registry = str_types_registry if str_types_registry is not None else registry
4135

@@ -57,8 +51,7 @@ def _convert(self, data: dict):
5751
# ! _detect_type function can crash at some complex data sets if value is unicode with some characters (maybe German)
5852
# Crash does not produce any useful logs and can occur any time after bad string was processed
5953
# It can be reproduced on real_apis tests (openlibrary API)
60-
fields[inflection.underscore(key)] = self._detect_type(value if not isinstance(value, str)
61-
else unidecode(value))
54+
fields[key] = self._detect_type(value if not isinstance(value, str) else unidecode(value))
6255
return fields
6356

6457
def _detect_type(self, value, convert_dict=True) -> MetaData:

rest_client_gen/models/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Dict, Generic, Iterable, List, Set, Tuple, TypeVar
22

3-
from rest_client_gen.dynamic_typing import DOptional
4-
from ..dynamic_typing import ModelMeta, ModelPtr
3+
from ..dynamic_typing import DOptional, ModelMeta, ModelPtr
54

65
Index = str
76
T = TypeVar('T')

rest_client_gen/models/attr.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from inspect import isclass
2+
from typing import Iterable, List, Tuple
3+
4+
from .base import GenericModelCodeGenerator, template
5+
from ..dynamic_typing import DList, DOptional, ImportPathList, MetaData, ModelMeta, StringSerializable
6+
7+
METADATA_FIELD_NAME = "RCG_ORIGINAL_FIELD"
8+
KWAGRS_TEMPLATE = "{% for key, value in kwargs.items() %}" \
9+
"{{ key }}={{ value }}" \
10+
"{% if not loop.last %}, {% endif %}" \
11+
"{% endfor %}"
12+
13+
DEFAULT_ORDER = (
14+
("default", "converter", "factory"),
15+
"*",
16+
("metadata",)
17+
)
18+
19+
20+
def sort_kwargs(kwargs: dict, ordering: Iterable[Iterable[str]] = DEFAULT_ORDER) -> dict:
21+
sorted_dict_1 = {}
22+
sorted_dict_2 = {}
23+
current = sorted_dict_1
24+
for group in ordering:
25+
if isinstance(group, str):
26+
if group != "*":
27+
raise ValueError(f"Unknown kwarg group: {group}")
28+
current = sorted_dict_2
29+
else:
30+
for item in group:
31+
if item in kwargs:
32+
value = kwargs.pop(item)
33+
current[item] = value
34+
sorted_dict = {**sorted_dict_1, **kwargs, **sorted_dict_2}
35+
return sorted_dict
36+
37+
38+
class AttrsModelCodeGenerator(GenericModelCodeGenerator):
39+
ATTRS = template("attr.s"
40+
"{% if kwargs %}"
41+
f"({KWAGRS_TEMPLATE})"
42+
"{% endif %}")
43+
ATTRIB = template(f"attr.ib({KWAGRS_TEMPLATE})")
44+
45+
def __init__(self, model: ModelMeta, no_meta=False, attrs_kwargs: dict = None, **kwargs):
46+
"""
47+
:param model: ModelMeta instance
48+
:param no_meta: Disable generation of metadata as attrib argument
49+
:param attrs_kwargs: kwargs for @attr.s() decorators
50+
:param kwargs:
51+
"""
52+
super().__init__(model, **kwargs)
53+
self.no_meta = no_meta
54+
self.attrs_kwargs = attrs_kwargs or {}
55+
56+
def generate(self, nested_classes: List[str] = None) -> Tuple[ImportPathList, str]:
57+
"""
58+
:param nested_classes: list of strings that contains classes code
59+
:return: list of import data, class code
60+
"""
61+
imports, code = super().generate(nested_classes)
62+
imports.append(('attr', None))
63+
return imports, code
64+
65+
@property
66+
def decorators(self) -> List[str]:
67+
"""
68+
:return: List of decorators code (without @)
69+
"""
70+
return [self.ATTRS.render(kwargs=self.attrs_kwargs)]
71+
72+
def field_data(self, name: str, meta: MetaData, optional: bool) -> Tuple[ImportPathList, dict]:
73+
"""
74+
Form field data for template
75+
76+
:param name: Field name
77+
:param meta: Field metadata
78+
:param optional: Is field optional
79+
:return: imports, field data
80+
"""
81+
imports, data = super().field_data(name, meta, optional)
82+
body_kwargs = {}
83+
if optional:
84+
meta: DOptional
85+
if isinstance(meta.type, DList):
86+
body_kwargs["factory"] = "list"
87+
else:
88+
body_kwargs["default"] = "None"
89+
if isclass(meta.type) and issubclass(meta.type, StringSerializable):
90+
body_kwargs["converter"] = f"optional({meta.type.__name__})"
91+
imports.append(("attr.converter", "optional"))
92+
elif isclass(meta) and issubclass(meta, StringSerializable):
93+
body_kwargs["converter"] = meta.__name__
94+
95+
if not self.no_meta:
96+
body_kwargs["metadata"] = {METADATA_FIELD_NAME: name}
97+
data["body"] = self.ATTRIB.render(kwargs=sort_kwargs(body_kwargs))
98+
return imports, data

rest_client_gen/models/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from typing import List, Tuple, Type
22

3+
import inflection
34
from jinja2 import Template
45

5-
from rest_client_gen.dynamic_typing import AbsoluteModelRef, compile_imports
6-
from rest_client_gen.models import INDENT, ModelsStructureType, OBJECTS_DELIMITER
7-
from . import indent, sort_fields
8-
from ..dynamic_typing import ImportPathList, MetaData, ModelMeta, metadata_to_typing
6+
from . import INDENT, ModelsStructureType, OBJECTS_DELIMITER, indent, sort_fields
7+
from ..dynamic_typing import AbsoluteModelRef, ImportPathList, MetaData, ModelMeta, compile_imports, metadata_to_typing
98

109

1110
def template(pattern: str, indent: str = INDENT) -> Template:
@@ -82,7 +81,7 @@ def field_data(self, name: str, meta: MetaData, optional: bool) -> Tuple[ImportP
8281
"""
8382
imports, typing = metadata_to_typing(meta)
8483
data = {
85-
"name": name,
84+
"name": inflection.underscore(name),
8685
"type": typing
8786
}
8887
return imports, data
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from typing import Dict, List
2+
3+
import pytest
4+
5+
from rest_client_gen.dynamic_typing import (DList, DOptional, FloatString, IntString, ModelMeta, compile_imports)
6+
from rest_client_gen.models import sort_fields
7+
from rest_client_gen.models.attr import AttrsModelCodeGenerator, METADATA_FIELD_NAME, sort_kwargs
8+
from rest_client_gen.models.base import generate_code
9+
from test.test_code_generation.test_models_code_generator import model_factory, trim
10+
11+
12+
def test_attrib_kwargs_sort():
13+
sorted_kwargs = sort_kwargs(dict(
14+
y=2,
15+
metadata='b',
16+
converter='a',
17+
default=None,
18+
x=1,
19+
))
20+
expected = ['default', 'converter', 'y', 'x', 'metadata']
21+
for k1, k2 in zip(sorted_kwargs.keys(), expected):
22+
assert k1 == k2
23+
try:
24+
sort_kwargs({}, ['wrong_char'])
25+
except ValueError as e:
26+
assert e.args[0].endswith('wrong_char')
27+
else:
28+
assert 0, "XPass"
29+
30+
31+
32+
def field_meta(original_name):
33+
return f"metadata={{'{METADATA_FIELD_NAME}': '{original_name}'}}"
34+
35+
36+
# Data structure:
37+
# pytest.param id -> {
38+
# "model" -> (model_name, model_metadata),
39+
# test_name -> expected, ...
40+
# }
41+
test_data = {
42+
"base": {
43+
"model": ("Test", {
44+
"foo": int,
45+
"bar": int,
46+
"baz": float
47+
}),
48+
"fields_data": {
49+
"foo": {
50+
"name": "foo",
51+
"type": "int",
52+
"body": f"attr.ib({field_meta('foo')})"
53+
},
54+
"bar": {
55+
"name": "bar",
56+
"type": "int",
57+
"body": f"attr.ib({field_meta('bar')})"
58+
},
59+
"baz": {
60+
"name": "baz",
61+
"type": "float",
62+
"body": f"attr.ib({field_meta('baz')})"
63+
}
64+
},
65+
"fields": {
66+
"imports": "",
67+
"fields": [
68+
f"foo: int = attr.ib({field_meta('foo')})",
69+
f"bar: int = attr.ib({field_meta('bar')})",
70+
f"baz: float = attr.ib({field_meta('baz')})",
71+
]
72+
},
73+
"generated": trim(f"""
74+
import attr
75+
76+
77+
@attr.s
78+
class Test:
79+
foo: int = attr.ib({field_meta('foo')})
80+
bar: int = attr.ib({field_meta('bar')})
81+
baz: float = attr.ib({field_meta('baz')})
82+
""")
83+
},
84+
"complex": {
85+
"model": ("Test", {
86+
"foo": int,
87+
"baz": DOptional(DList(DList(str))),
88+
"bar": DOptional(IntString),
89+
"qwerty": FloatString,
90+
"asdfg": DOptional(int)
91+
}),
92+
"fields_data": {
93+
"foo": {
94+
"name": "foo",
95+
"type": "int",
96+
"body": f"attr.ib({field_meta('foo')})"
97+
},
98+
"baz": {
99+
"name": "baz",
100+
"type": "Optional[List[List[str]]]",
101+
"body": f"attr.ib(factory=list, {field_meta('baz')})"
102+
},
103+
"bar": {
104+
"name": "bar",
105+
"type": "Optional[IntString]",
106+
"body": f"attr.ib(default=None, converter=optional(IntString), {field_meta('bar')})"
107+
},
108+
"qwerty": {
109+
"name": "qwerty",
110+
"type": "FloatString",
111+
"body": f"attr.ib(converter=FloatString, {field_meta('qwerty')})"
112+
},
113+
"asdfg": {
114+
"name": "asdfg",
115+
"type": "Optional[int]",
116+
"body": f"attr.ib(default=None, {field_meta('asdfg')})"
117+
}
118+
},
119+
"generated": trim(f"""
120+
import attr
121+
from attr.converter import optional
122+
from rest_client_gen.dynamic_typing.string_serializable import FloatString, IntString
123+
from typing import List, Optional
124+
125+
126+
@attr.s
127+
class Test:
128+
foo: int = attr.ib({field_meta('foo')})
129+
qwerty: FloatString = attr.ib(converter=FloatString, {field_meta('qwerty')})
130+
baz: Optional[List[List[str]]] = attr.ib(factory=list, {field_meta('baz')})
131+
bar: Optional[IntString] = attr.ib(default=None, converter=optional(IntString), {field_meta('bar')})
132+
asdfg: Optional[int] = attr.ib(default=None, {field_meta('asdfg')})
133+
""")
134+
}
135+
}
136+
137+
test_data_unzip = {
138+
test: [
139+
pytest.param(
140+
model_factory(*data["model"]),
141+
data[test],
142+
id=id
143+
)
144+
for id, data in test_data.items()
145+
if test in data
146+
]
147+
for test in ("fields_data", "fields", "generated")
148+
}
149+
150+
151+
@pytest.mark.parametrize("value,expected", test_data_unzip["fields_data"])
152+
def test_fields_data_attr(value: ModelMeta, expected: Dict[str, dict]):
153+
gen = AttrsModelCodeGenerator(value)
154+
required, optional = sort_fields(value)
155+
for is_optional, fields in enumerate((required, optional)):
156+
for field in fields:
157+
field_imports, data = gen.field_data(field, value.type[field], bool(is_optional))
158+
assert data == expected[field]
159+
160+
161+
@pytest.mark.parametrize("value,expected", test_data_unzip["fields"])
162+
def test_fields_attr(value: ModelMeta, expected: dict):
163+
expected_imports: str = expected["imports"]
164+
expected_fields: List[str] = expected["fields"]
165+
gen = AttrsModelCodeGenerator(value)
166+
imports, fields = gen.fields
167+
imports = compile_imports(imports)
168+
assert imports == expected_imports
169+
assert fields == expected_fields
170+
171+
172+
@pytest.mark.parametrize("value,expected", test_data_unzip["generated"])
173+
def test_generated_attr(value: ModelMeta, expected: str):
174+
generated = generate_code(([{"model": value, "nested": []}], {}), AttrsModelCodeGenerator)
175+
assert generated.rstrip() == expected, generated

0 commit comments

Comments
 (0)