Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ test:
-Wdefault:"Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working":DeprecationWarning:: \
-Wdefault:"set_output_charset() is deprecated":DeprecationWarning:: \
-Wdefault:"parameter codeset is deprecated":DeprecationWarning:: \
# Remove cgi warning when dropping support for Django<=4.1.
-Wdefault:"'cgi' is deprecated and slated for removal in Python 3.13":DeprecationWarning:: \
-m unittest
# Remove cgi warning when dropping support for Django<=4.1.

# DOC: Test the examples
example-test:
Expand Down
235 changes: 235 additions & 0 deletions factory/fireo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Auto factory for Fireo models."""
from typing import Type

from faker import config
from fireo import fields
from fireo.models import Model

import factory
from factory.base import (
FactoryMetaClass,
FactoryOptions,
OptionDefault,
resolve_attribute,
)
from factory.declarations import BaseDeclaration


class FireoFactory(factory.Factory):
"""Factory for FireO models."""

class Meta:
abstract = True

@classmethod
def _build(cls, model_class, **kwargs):
instance = model_class()
for key, value in kwargs.items():
setattr(instance, key, value)

return instance

@classmethod
def _create(cls, model_class, **kwargs):
instance = model_class()
for key, value in kwargs.items():
setattr(instance, key, value)

if instance.collection_name:
instance.save()

return instance


class FireoAutoFactoryOptions(FactoryOptions):
def _build_default_options(self):
return super()._build_default_options() + [
OptionDefault('extra_mapping', {}, inherit=True),
OptionDefault('sub_factories', {}, inherit=True),
OptionDefault('nested_factory_cls', None, inherit=True),
]


class FireoAutoFactoryMetaClass(FactoryMetaClass):
def __new__(mcs, name, bases, attrs):
meta_cls = attrs.get('Meta')
model = getattr(meta_cls, 'model', None)
if model is not None:
base_meta = resolve_attribute('_meta', bases)
extra_mapping = getattr(meta_cls, 'extra_mapping', getattr(base_meta, 'extra_mapping', {}))
sub_factories = getattr(meta_cls, 'sub_factories', getattr(base_meta, 'sub_factories', {}))
nested_factory_cls = mcs._get_nested_factory_cls(base_meta, bases, meta_cls)

gen_attrs = FireoAutoFactoryMaker(
sub_factories,
extra_mapping,
nested_factory_cls, # type: ignore
).generate_factory_fields(
model,
attrs,
)
attrs.update(gen_attrs)

cls = super().__new__(mcs, name, bases, attrs)
return cls

@staticmethod
def _get_nested_factory_cls(base_meta, bases, meta_cls):
nested_factory_cls = getattr(meta_cls, 'nested_factory_cls', getattr(base_meta, 'nested_factory_cls', None))
if not nested_factory_cls:
if len(bases) != 1:
raise ValueError('You must specify nested_factory_cls in Meta if you have multiple bases')

nested_factory_cls = bases[0]

if not nested_factory_cls._meta.abstract:
raise ValueError('nested_factory_cls must be abstract')

if not issubclass(nested_factory_cls, FireoAutoFactory):
raise ValueError('nested_factory_cls must be a FireoAutoFactory')

return nested_factory_cls


class FireoAutoFactory(FireoFactory, metaclass=FireoAutoFactoryMetaClass):
"""Auto factory for Fireo models.

This factory will generate fields for all fields in the model.
You can override any field by defining it in the class.

You can also specify Meta.extra_mapping to map fields to other factories.
You can also specify Meta.sub_factories to map fields to other SubFactories.
Note: Meta.sub_factories is populated automatically by the factory for not specified fields.
You can also specify Meta.nested_factory_cls to use as a base for nested factories.

Example:
>>> class Comment(Model):
... text = fields.TextField()
>>>
>>> class User(Model):
... name = fields.TextField()
... email = MyCustomEmailField()
... comments = fields.ListField(Comment)
>>>
>>> class UserFactory(FireoAutoFactory):
... class Meta:
... model = User
... extra_mapping = {
... MyCustomEmailField: factory.Faker('email'),
... }
>>>
>>> model = UserFactory.create()
>>> assert model.name
>>> assert model.email # works for custom fields too
>>> assert model.comments[0].text # works for nested models too
"""
_options_class = FireoAutoFactoryOptions

class Meta:
abstract = True


def _raise_not_implemented(*_, **__):
raise NotImplementedError()


def nullable(field, factory_field):
if not field.raw_attributes.get('required'):
factory_field = MaybeNone(factory_field)

return factory_field


class FireoAutoFactoryMaker:
mapping = {
fields.IDField: lambda maker, field: nullable(field, factory.Faker('pystr', min_chars=20, max_chars=20)),
fields.BooleanField: lambda maker, field: nullable(field, factory.Faker('pybool')),
fields.DateTime: lambda maker, field: nullable(field, factory.Faker('date_time')),
fields.NumberField: lambda maker, field: nullable(field, (
factory.Faker('pyint')
if field.raw_attributes.get('int_only') else
factory.Faker('pyfloat')
)),
fields.TextField: lambda maker, field: nullable(field, factory.Faker('pystr')),
fields.MapField: lambda maker, field: nullable(field, factory.Faker('pydict', value_types=[str])),
fields.ReferenceField: _raise_not_implemented,
fields.GeoPoint: _raise_not_implemented,
fields.Field: _raise_not_implemented,
fields.ListField: lambda maker, field: nullable(field, (
factory.Faker('pylist', value_types=[str])
if field.raw_attributes.get('nested_field') is None else
factory.List([maker.get_field_factory(field.raw_attributes['nested_field'])])
)),
fields.NestedModelField: lambda maker, field: nullable(
field, factory.SubFactory(maker.get_model_factory(field.nested_model))
),
}

def __init__(
self,
sub_factories: dict[Model, Type[FireoFactory]] | None = None,
extra_mapping: None = None,
base_cls: Type[FireoAutoFactory] = FireoAutoFactory, # type: ignore
):
self.sub_factories = sub_factories or {}
self.mapping = {
**self.mapping,
**(extra_mapping or {}),
}
self.base_cls = base_cls

def get_field_factory(self, field: fields.Field) -> BaseDeclaration:
field_type = type(field)
if field_type in self.mapping:
return self.mapping[field_type](self, field) # type: ignore
else:
raise NotImplementedError(f'Field type {field_type} is not implemented')

def get_model_factory(self, model: Model):
if model not in self.sub_factories:
the_model = model

class AutoFactory(self.base_cls): # type: ignore
class Meta:
model = the_model

AutoFactory.__name__ = f'{model.__name__}Factory'
self.sub_factories[model] = AutoFactory

return self.sub_factories[model]

def generate_factory_fields(self, model: Model, attrs: dict) -> dict[str, BaseDeclaration]:
generated_fields = {}
for field_name, field in model._meta.field_list.items():
if field_name in attrs:
continue

if not isinstance(field, fields.Field):
continue

generated_fields[field_name] = self.get_field_factory(field)

return generated_fields


class MaybeNone(factory.Maybe):
"""Factory for optional fields."""

def __init__(self, field):
super().__init__(factory.Faker("boolean", locale=config.DEFAULT_LOCALE), field, None)

def evaluate_pre(self, instance, step, overrides):
choice = self.decider.evaluate(instance=instance, step=step, extra={
'locale': config.DEFAULT_LOCALE,
})
target = self.yes if choice else self.no

if isinstance(target, BaseDeclaration):
return target.evaluate_pre(
instance=instance,
step=step,
overrides=overrides,
)
else:
# Flat value (can't be POST_INSTANTIATION, checked in __init__)
return target
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dev =
SQLAlchemy
sqlalchemy_utils
mongoengine
fireo >= 2.0.0
wheel>=0.32.0
tox
zest.releaser[recommended]
Expand Down
95 changes: 95 additions & 0 deletions tests/test_fireo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import unittest
from typing import List, Optional, Union

from fireo.typedmodels import TypedModel

import factory
from factory.fireo import FireoAutoFactory, MaybeNone


class Deep1Model(TypedModel):
nested_int: int


class RootModel(TypedModel):
int_: int
float_: float
str_: str
bool_: bool
list_: list
dict_: dict
int_or_float: Union[int, float]
optional_int_or_float: Union[int, float, None]
optional_int: Optional[int]
list_of_int: List[int]
list_of_int_or_none: List[Optional[int]]
none_or_list_of_int: Optional[List[int]]

nested: Deep1Model
list_of_nested: List[Deep1Model]


class RootModelFactory(FireoAutoFactory):
class Meta:
model = RootModel


class FireoAutoFactoryTestCase(unittest.TestCase):
def test_generate_fields_by_auto_factory(self):
assert RootModelFactory.int_.provider == 'pyint'
assert RootModelFactory.float_.provider == 'pyfloat'
assert RootModelFactory.str_.provider == 'pystr'
assert RootModelFactory.bool_.provider == 'pybool'
assert RootModelFactory.list_.provider == 'pylist'
assert RootModelFactory.dict_.provider == 'pydict'
assert RootModelFactory.int_or_float.provider == 'pyfloat'

assert RootModelFactory.optional_int_or_float.__class__ is MaybeNone
assert RootModelFactory.optional_int_or_float.yes.provider == 'pyfloat'

assert RootModelFactory.optional_int.__class__ is MaybeNone
assert RootModelFactory.optional_int.yes.provider == 'pyint'

assert RootModelFactory.list_of_int.__class__ is factory.List
assert RootModelFactory.list_of_int._defaults['0'].provider == 'pyint'

assert RootModelFactory.list_of_int_or_none.__class__ == factory.List
assert RootModelFactory.list_of_int_or_none._defaults['0'].__class__ is MaybeNone
assert RootModelFactory.list_of_int_or_none._defaults['0'].yes.provider == 'pyint'

assert RootModelFactory.none_or_list_of_int.__class__ is MaybeNone
assert RootModelFactory.none_or_list_of_int.yes.__class__ is factory.List
assert RootModelFactory.none_or_list_of_int.yes._defaults['0'].provider == 'pyint'

assert RootModelFactory.nested.__class__ is factory.SubFactory
deep1_model_factory = RootModelFactory.nested.factory_wrapper.factory
assert deep1_model_factory.__name__ == 'Deep1ModelFactory'
assert RootModelFactory.nested.factory_wrapper.factory._meta.model is Deep1Model
assert deep1_model_factory.nested_int.provider == 'pyint'

assert RootModelFactory.list_of_nested.__class__ is factory.List
assert RootModelFactory.list_of_nested._defaults['0'].__class__ is factory.SubFactory
assert RootModelFactory.list_of_nested._defaults['0'].factory_wrapper.factory is deep1_model_factory

def test_generate_model_by_auto_factory(self):
model = RootModelFactory.build()

assert isinstance(model, RootModel)
assert isinstance(model.int_, int)
assert isinstance(model.float_, float)
assert isinstance(model.str_, str)
assert isinstance(model.bool_, bool)
assert isinstance(model.list_, list)
assert isinstance(model.dict_, dict)
assert isinstance(model.int_or_float, (int, float))
assert isinstance(model.optional_int_or_float, (int, float, type(None)))
assert isinstance(model.optional_int, (int, type(None)))
assert isinstance(model.list_of_int, list)
assert all(isinstance(i, int) for i in model.list_of_int)
assert isinstance(model.list_of_int_or_none, list)
assert all(isinstance(i, (int, type(None))) for i in model.list_of_int_or_none)
assert isinstance(model.nested, Deep1Model)
assert isinstance(model.nested.nested_int, int)
assert isinstance(model.list_of_nested, list)
assert all(isinstance(i, Deep1Model) for i in model.list_of_nested)
assert all(isinstance(i.nested_int, int) for i in model.list_of_nested)