diff --git a/agave/blueprints/rest_api.py b/agave/blueprints/rest_api.py index ef07eb21..0c1838ed 100644 --- a/agave/blueprints/rest_api.py +++ b/agave/blueprints/rest_api.py @@ -1,11 +1,11 @@ -from typing import Optional, Type +from typing import Type from urllib.parse import urlencode from chalice import Blueprint, NotFoundError, Response from cuenca_validations.types import QueryParams -from mongoengine import DoesNotExist, Q from pydantic import BaseModel, ValidationError +from ..exc import DoesNotExist from .decorators import copy_attributes @@ -112,9 +112,10 @@ def update(id: str): params = self.current_request.json_body or dict() try: data = cls.update_validator(**params) - model = cls.model.objects.get(id=id) except ValidationError as e: return Response(e.json(), status_code=400) + try: + model = cls.model.retrieve(id=id) except DoesNotExist: raise NotFoundError('Not valid id') else: @@ -140,14 +141,15 @@ def retrieve(id: str): # at the moment, there are no resources with a custom # retrieve method return cls.retrieve(id) # pragma: no cover + + user_id = None + if self.user_id_filter_required(): + user_id = self.current_user_id try: - id_query = Q(id=id) - if self.user_id_filter_required(): - id_query = id_query & Q(user_id=self.current_user_id) - data = cls.model.objects.get(id_query) + data = cls.model.retrieve(id, user_id=user_id) except DoesNotExist: raise NotFoundError('Not valid id') - return data.to_dict() + return data.dict() @self.get(path) @copy_attributes(cls) @@ -183,37 +185,34 @@ def query(): return _count(filters) return _all(query_params, filters) - def _count(filters: Q): - count = cls.model.objects.filter(filters).count() + def _count(filters): + count = cls.model.count(filters) return dict(count=count) - def _all(query: QueryParams, filters: Q): + def _all(query: QueryParams, filters): if query.limit: limit = min(query.limit, query.page_size) query.limit = max(0, query.limit - limit) # type: ignore else: limit = query.page_size - items = ( - cls.model.objects.order_by("-created_at") - .filter(filters) - .limit(limit) - ) - item_dicts = [i.to_dict() for i in items] - has_more: Optional[bool] = None - if wants_more := query.limit is None or query.limit > 0: - # only perform this query if it's necessary - has_more = items.limit(limit + 1).count() > limit + wants_more = query.limit is None or query.limit > 0 + items, has_more = cls.model.all( + filters, limit=limit, wants_more=wants_more + ) - next_page_uri: Optional[str] = None + next_page_uri = None if wants_more and has_more: - query.created_before = item_dicts[-1]['created_at'] + query.created_before = items[-1].created_at.isoformat() path = self.current_request.context['resourcePath'] params = query.dict() if self.user_id_filter_required(): params.pop('user_id') next_page_uri = f'{path}?{urlencode(params)}' - return dict(items=item_dicts, next_page_uri=next_page_uri) + return dict( + items=[i.dict() for i in items], # type: ignore + next_page_uri=next_page_uri, + ) return cls diff --git a/agave/exc.py b/agave/exc.py new file mode 100644 index 00000000..5cdf3691 --- /dev/null +++ b/agave/exc.py @@ -0,0 +1,2 @@ +class DoesNotExist(Exception): + '''object does not exist''' diff --git a/agave/filters.py b/agave/filters.py index b44978ed..803f59c8 100644 --- a/agave/filters.py +++ b/agave/filters.py @@ -1,14 +1,10 @@ +from typing import Any, Dict + from cuenca_validations.types import QueryParams -from mongoengine import Q -def generic_query(query: QueryParams) -> Q: - filters = Q() - if query.created_before: - filters &= Q(created_at__lt=query.created_before) - if query.created_after: - filters &= Q(created_at__gt=query.created_after) - exclude_fields = { +def exclude_fields(query: QueryParams) -> Dict[str, Any]: + excluded_fields = { 'created_before', 'created_after', 'active', @@ -16,7 +12,7 @@ def generic_query(query: QueryParams) -> Q: 'page_size', 'key', } - fields = query.dict(exclude=exclude_fields) + fields = query.dict(exclude=excluded_fields) if 'count' in fields: del fields['count'] - return filters & Q(**fields) + return fields diff --git a/agave/models/__init__.py b/agave/models/__init__.py index 5f4f48a7..93a3e24e 100644 --- a/agave/models/__init__.py +++ b/agave/models/__init__.py @@ -1,3 +1,23 @@ -__all__ = ['BaseModel'] +__all__ = [] -from .base import BaseModel +try: + import mongoengine # noqa +except ImportError: # pragma: no cover + ... +else: + from .mongo import MongoModel # noqa + from .mongo.filters import generic_mongo_query # noqa + + __all__.extend(['MongoModel', 'generic_mongo_query']) + + +try: + import rom # noqa +except ImportError: # pragma: no cover + ... +else: + + from .redis import RedisModel # noqa + from .redis.filters import generic_redis_query # noqa + + __all__.extend(['RedisModel', 'generic_redis_query']) diff --git a/agave/models/base.py b/agave/models/base.py index 4a370ba3..98d58801 100644 --- a/agave/models/base.py +++ b/agave/models/base.py @@ -1,6 +1,6 @@ -from typing import ClassVar, Dict +from typing import Callable, ClassVar -from ..lib.mongoengine.model_helpers import mongo_to_dict +from cuenca_validations.typing import DictStrAny class BaseModel: @@ -10,14 +10,14 @@ class BaseModel: def __init__(self, *args, **values): return super().__init__(*args, **values) - def to_dict(self) -> Dict: + def _dict(self, dict_func: Callable) -> DictStrAny: private_fields = [f for f in dir(self) if f.startswith('_')] excluded = self._excluded + private_fields - mongo_dict: dict = mongo_to_dict(self, excluded) + model_dict = dict_func(self, excluded) for field in self._hidden: - mongo_dict[field] = '********' - return mongo_dict + model_dict[field] = '********' + return model_dict def __repr__(self) -> str: - return str(self.to_dict()) # pragma: no cover + return str(self.dict()) # type: ignore # pragma: no cover diff --git a/agave/models/mongo/__init__.py b/agave/models/mongo/__init__.py new file mode 100644 index 00000000..866d39ce --- /dev/null +++ b/agave/models/mongo/__init__.py @@ -0,0 +1,3 @@ +__all__ = ['MongoModel'] + +from .mongo_model import MongoModel diff --git a/agave/models/mongo/filters.py b/agave/models/mongo/filters.py new file mode 100644 index 00000000..418f5354 --- /dev/null +++ b/agave/models/mongo/filters.py @@ -0,0 +1,14 @@ +from cuenca_validations.types import QueryParams +from mongoengine import Q + +from ...filters import exclude_fields + + +def generic_mongo_query(query: QueryParams) -> Q: + filters = Q() + if query.created_before: + filters &= Q(created_at__lt=query.created_before) + if query.created_after: + filters &= Q(created_at__gt=query.created_after) + fields = exclude_fields(query) + return filters & Q(**fields) diff --git a/agave/models/mongo/mongo_model.py b/agave/models/mongo/mongo_model.py new file mode 100644 index 00000000..7c5122c3 --- /dev/null +++ b/agave/models/mongo/mongo_model.py @@ -0,0 +1,47 @@ +from typing import List, Optional, Tuple + +import mongoengine as mongo +from cuenca_validations.typing import DictStrAny +from mongoengine import Document, Q + +from agave import exc +from agave.lib.mongoengine.model_helpers import mongo_to_dict +from agave.models.base import BaseModel + + +class MongoModel(BaseModel, Document): + meta = {'allow_inheritance': True} + + def dict(self) -> DictStrAny: + return self._dict(mongo_to_dict) + + @classmethod + def retrieve( + cls, id: str, *, user_id: Optional[str] = None + ) -> 'MongoModel': + query = Q(id=id) + if user_id: + query = query & Q(user_id=user_id) + try: + obj = cls.objects.get(query) + except mongo.DoesNotExist: + raise exc.DoesNotExist + return obj + + @classmethod + def count(cls, filters: Q) -> int: + return cls.objects.filter(filters).count() + + @classmethod + def all( + cls, filters: Q, *, limit: int, wants_more: bool + ) -> Tuple[List['MongoModel'], bool]: + items = ( + cls.objects.order_by("-created_at").filter(filters).limit(limit) + ) + + has_more = False + if wants_more: + has_more = items.limit(limit + 1).count() > limit + + return list(items), has_more diff --git a/agave/models/redis/__init__.py b/agave/models/redis/__init__.py new file mode 100644 index 00000000..b5317760 --- /dev/null +++ b/agave/models/redis/__init__.py @@ -0,0 +1,3 @@ +__all__ = ['RedisModel', 'String'] + +from .redis_model import RedisModel, String diff --git a/agave/models/redis/filters.py b/agave/models/redis/filters.py new file mode 100644 index 00000000..75dc4370 --- /dev/null +++ b/agave/models/redis/filters.py @@ -0,0 +1,31 @@ +import datetime as dt + +from cuenca_validations.types import QueryParams +from cuenca_validations.typing import DictStrAny + +from ...filters import exclude_fields + + +def generic_redis_query(query: QueryParams, **kwargs) -> DictStrAny: + filters = dict() + if query.created_before or query.created_after: + # Restamos o sumamos un microsegundo porque la comparación + # aquí es inclusiva + created_at_lt = ( + query.created_before.replace(tzinfo=None) + + dt.timedelta(microseconds=-1) + if query.created_before + else None + ) + created_at_gt = ( + query.created_after.replace(tzinfo=None) + + dt.timedelta(microseconds=1) + if query.created_after + else None + ) + filters['created_at'] = (created_at_gt, created_at_lt) + fields = exclude_fields(query) + fields = {**fields, **kwargs} + if not filters: + filters = fields + return filters diff --git a/agave/models/redis/redis_model.py b/agave/models/redis/redis_model.py new file mode 100644 index 00000000..6cfd0f7c --- /dev/null +++ b/agave/models/redis/redis_model.py @@ -0,0 +1,73 @@ +from typing import Dict, List, Optional, Tuple + +from cuenca_validations.typing import DictStrAny +from cuenca_validations.validators import sanitize_item +from rom import Column, Model, PrimaryKey + +from agave.exc import DoesNotExist +from agave.models.base import BaseModel + +EXCLUDED = ['o_id'] + + +class String(Column): + """ + No utilizo la clase String de rom porque todo lo maneja en bytes + codificado en latin-1. + """ + + _allowed = str + + def _to_redis(self, value): + return value.encode('utf-8') + + def _from_redis(self, value): + return value.decode('utf-8') + + +def redis_to_dict(obj, exclude_fields: List[str]) -> DictStrAny: + excluded = EXCLUDED + exclude_fields + response = { + key: sanitize_item(value) + for key, value in obj._data.items() + if key not in excluded + } + return response + + +class RedisModel(BaseModel, Model): + meta = {'allow_inheritance': True} + o_id = PrimaryKey() # Para que podamos usar `id` en los modelos + + def dict(self) -> DictStrAny: + return self._dict(redis_to_dict) + + @classmethod + def retrieve( + cls, id: str, *, user_id: Optional[str] = None + ) -> 'RedisModel': + params = dict(id=id) + if user_id: + params['user_id'] = user_id + obj = cls.query.filter(**params).first() + if not obj: + raise DoesNotExist + return obj + + @classmethod + def count(cls, filters: Dict) -> int: + return cls.query.filter(**filters).count() + + @classmethod + def all( + cls, filters: Dict, *, limit: int, wants_more: bool + ) -> Tuple[List['RedisModel'], bool]: + items = ( + cls.query.filter(**filters).order_by('-created_at').limit(0, limit) + ) + + has_more = False + if wants_more: + has_more = items.limit(0, limit + 1).count() > limit + + return list(items), has_more diff --git a/agave/version.py b/agave/version.py index eead3198..2458914b 100644 --- a/agave/version.py +++ b/agave/version.py @@ -1 +1 @@ -__version__ = '0.0.5' +__version__ = '0.0.6.dev0' diff --git a/examples/chalicelib/models/__init__.py b/examples/chalicelib/models/__init__.py index a3a1c293..e69de29b 100644 --- a/examples/chalicelib/models/__init__.py +++ b/examples/chalicelib/models/__init__.py @@ -1,4 +0,0 @@ -__all__ = ['Account', 'Transaction'] - -from .accounts import Account -from .transactions import Transaction diff --git a/examples/chalicelib/models/accounts.py b/examples/chalicelib/models/accounts.py deleted file mode 100644 index fd6617c2..00000000 --- a/examples/chalicelib/models/accounts.py +++ /dev/null @@ -1,12 +0,0 @@ -from mongoengine import DateTimeField, Document, StringField - -from agave.models import BaseModel -from agave.models.helpers import uuid_field - - -class Account(BaseModel, Document): - id = StringField(primary_key=True, default=uuid_field('AC')) - name = StringField(required=True) - user_id = StringField(required=True) - created_at = DateTimeField() - deactivated_at = DateTimeField() diff --git a/examples/chalicelib/models/mongo_models.py b/examples/chalicelib/models/mongo_models.py new file mode 100644 index 00000000..b726027a --- /dev/null +++ b/examples/chalicelib/models/mongo_models.py @@ -0,0 +1,21 @@ +from mongoengine import DateTimeField, FloatField, StringField + +from agave.models.mongo import MongoModel +from agave.models.helpers import uuid_field + + +class Account(MongoModel): + id = StringField(primary_key=True, default=uuid_field('AC')) + name = StringField(required=True) + user_id = StringField(required=True) + created_at = DateTimeField() + deactivated_at = DateTimeField() + secret_field = StringField() + + _hidden = ['secret_field'] + + +class Transaction(MongoModel): + id = StringField(primary_key=True, default=uuid_field('TR')) + user_id = StringField(required=True) + amount = FloatField(required=True) diff --git a/examples/chalicelib/models/redis_models.py b/examples/chalicelib/models/redis_models.py new file mode 100644 index 00000000..7019d03b --- /dev/null +++ b/examples/chalicelib/models/redis_models.py @@ -0,0 +1,26 @@ +import datetime as dt + +from rom import DateTime, util + +from agave.models.redis import RedisModel, String +from agave.models.helpers import uuid_field + +DEFAULT_MISSING_DATE = dt.datetime.utcfromtimestamp(0) + + +class Account(RedisModel): + id = String( + default=uuid_field('US'), + required=True, + unique=True, + index=True, + keygen=util.IDENTITY, + ) + name = String(required=True, index=True, keygen=util.IDENTITY) + user_id = String(required=True, index=True, keygen=util.IDENTITY) + created_at = DateTime(default=dt.datetime.utcnow, index=True) + deactivated_at = DateTime(default=DEFAULT_MISSING_DATE, index=True) + secret_field = String(index=True, keygen=util.IDENTITY) + + _hidden = ['secret_field'] + _excluded = ['deactivated_at'] diff --git a/examples/chalicelib/models/transactions.py b/examples/chalicelib/models/transactions.py index dcde86d4..e69de29b 100644 --- a/examples/chalicelib/models/transactions.py +++ b/examples/chalicelib/models/transactions.py @@ -1,10 +0,0 @@ -from mongoengine import Document, FloatField, StringField - -from agave.models import BaseModel -from agave.models.helpers import uuid_field - - -class Transaction(BaseModel, Document): - id = StringField(primary_key=True, default=uuid_field('TR')) - user_id = StringField(required=True) - amount = FloatField(required=True) diff --git a/examples/chalicelib/resources/accounts.py b/examples/chalicelib/resources/accounts.py index 37c4b2f0..972769ae 100644 --- a/examples/chalicelib/resources/accounts.py +++ b/examples/chalicelib/resources/accounts.py @@ -1,13 +1,12 @@ import datetime as dt - from chalice import NotFoundError, Response from mongoengine import DoesNotExist -from agave.filters import generic_query - -from ..models import Account as AccountModel +from agave.models.mongo.filters import generic_mongo_query +from ..models.mongo_models import Account as AccountModel from ..validators import AccountQuery, AccountRequest, AccountUpdateRequest from .base import app +from agave.exc import DoesNotExist @app.resource('/accounts') @@ -15,7 +14,7 @@ class Account: model = AccountModel query_validator = AccountQuery update_validator = AccountUpdateRequest - get_query_filter = generic_query + get_query_filter = generic_mongo_query @staticmethod @app.validate(AccountRequest) @@ -25,7 +24,7 @@ def create(request: AccountRequest) -> Response: user_id=app.current_user_id, ) account.save() - return Response(account.to_dict(), status_code=201) + return Response(account.dict(), status_code=201) @staticmethod def update( @@ -33,15 +32,18 @@ def update( ) -> Response: account.name = request.name account.save() - return Response(account.to_dict(), status_code=200) + return Response(account.dict(), status_code=200) @staticmethod def delete(id: str) -> Response: + account = None try: - account = AccountModel.objects.get(id=id) + account = AccountModel.retrieve(id=id) # type: ignore except DoesNotExist: raise NotFoundError('Not valid id') - + except Exception: + if not account: + raise NotFoundError('Not valid id') account.deactivated_at = dt.datetime.utcnow().replace(microsecond=0) account.save() - return Response(account.to_dict(), status_code=200) + return Response(account.dict(), status_code=200) diff --git a/examples/chalicelib/resources/transactions.py b/examples/chalicelib/resources/transactions.py index f940510f..0db4fe24 100644 --- a/examples/chalicelib/resources/transactions.py +++ b/examples/chalicelib/resources/transactions.py @@ -1,6 +1,5 @@ -from agave.filters import generic_query - -from ..models.transactions import Transaction as TransactionModel +from agave.models.mongo.filters import generic_mongo_query +from ..models.mongo_models import Transaction as TransactionModel from ..validators import TransactionQuery from .base import app @@ -9,4 +8,4 @@ class Transaction: model = TransactionModel query_validator = TransactionQuery - get_query_filter = generic_query + get_query_filter = generic_mongo_query diff --git a/requirements-test.txt b/requirements-test.txt index c0523696..1df8f17e 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -8,3 +8,4 @@ mypy==0.790 pytest-chalice==0.0.* mongomock==3.21.* mock==4.0.3 +redislite==5.0.* diff --git a/requirements.txt b/requirements.txt index 9bdbff20..9f21026b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ chalice==1.21.6 mongoengine==0.22.1 cuenca-validations==0.6.8 dnspython==2.0.0 +rom==1.0.0 diff --git a/setup.py b/setup.py index 9fb90bf2..df933924 100644 --- a/setup.py +++ b/setup.py @@ -26,10 +26,13 @@ 'chalice>=1.16.0,<1.21.7', 'cuenca-validations>=0.4,<0.7', 'blinker>=1.4,<1.5', - 'mongoengine>=0.20.0,<0.23.0', 'dnspython>=2.0.0,<2.1.0', 'dataclasses>=0.6;python_version<"3.7"', ], + extras_require={ + 'mongoengine': 'mongoengine>=0.20.0,<0.23.0', + 'rom': 'rom>=1.0.0,<1.1.0', + }, classifiers=[ 'Programming Language :: Python :: 3.8', 'License :: OSI Approved :: MIT License', diff --git a/tests/blueprint/test_blueprint.py b/tests/blueprint/test_blueprint.py index 4324cf7e..c15fa7e8 100644 --- a/tests/blueprint/test_blueprint.py +++ b/tests/blueprint/test_blueprint.py @@ -4,7 +4,7 @@ from chalice.test import Client from mock import MagicMock, patch -from examples.chalicelib.models import Account +from examples.chalicelib.models.mongo_models import Account USER_ID_FILTER_REQUIRED = ( 'examples.chalicelib.blueprints.authed.' @@ -15,9 +15,9 @@ def test_create_resource(client: Client) -> None: data = dict(name='Doroteo Arango') resp = client.http.post('/accounts', json=data) - model = Account.objects.get(id=resp.json_body['id']) + model = Account.retrieve(id=resp.json_body['id']) # type: ignore assert resp.status_code == 201 - assert model.to_dict() == resp.json_body + assert model.dict() == resp.json_body model.delete() @@ -30,7 +30,7 @@ def test_create_resource_bad_request(client: Client) -> None: def test_retrieve_resource(client: Client, account: Account) -> None: resp = client.http.get(f'/accounts/{account.id}') assert resp.status_code == 200 - assert resp.json_body == account.to_dict() + assert resp.json_body == account.dict() @patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) @@ -68,6 +68,7 @@ def test_update_resource(client: Client, account: Account) -> None: f'/accounts/{account.id}', json=dict(name='Maria Felix'), ) + account.reload() assert resp.json_body['name'] == 'Maria Felix' assert account.name == 'Maria Felix' @@ -116,21 +117,17 @@ def test_query_all_resource(client: Client) -> None: @pytest.mark.usefixtures('accounts') @patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) -def test_query_user_id_filter_required(client: Client) -> None: +def test_query_user_id_filter_required(client: Client, user_id: str) -> None: query_params = dict(page_size=2) resp = client.http.get(f'/accounts?{urlencode(query_params)}') assert resp.status_code == 200 assert len(resp.json_body['items']) == 2 - assert all( - item['user_id'] == 'US123456789' for item in resp.json_body['items'] - ) + assert all(item['user_id'] == user_id for item in resp.json_body['items']) resp = client.http.get(resp.json_body['next_page_uri']) assert resp.status_code == 200 assert len(resp.json_body['items']) == 1 - assert all( - item['user_id'] == 'US123456789' for item in resp.json_body['items'] - ) + assert all(item['user_id'] == user_id for item in resp.json_body['items']) def test_query_resource_with_invalid_params(client: Client) -> None: diff --git a/tests/blueprint/test_filters.py b/tests/blueprint/test_filters.py index b9a58849..14100a6e 100644 --- a/tests/blueprint/test_filters.py +++ b/tests/blueprint/test_filters.py @@ -2,18 +2,26 @@ from cuenca_validations.types import QueryParams -from agave.filters import generic_query +from agave.models.mongo.filters import generic_mongo_query +from agave.models.redis.filters import generic_redis_query def test_generic_query_before(): params = QueryParams(created_before=dt.datetime.utcnow().isoformat()) - query = generic_query(params) + query = generic_mongo_query(params) assert "created_at__lt" in repr(query) assert "user" not in repr(query) def test_generic_query_after(): params = QueryParams(created_after=dt.datetime.utcnow().isoformat()) - query = generic_query(params) + query = generic_mongo_query(params) assert "created_at__gt" in repr(query) assert "user" not in repr(query) + + +def test_generic_query_redis(): + params = QueryParams(created_before=dt.datetime.utcnow().isoformat()) + query = generic_redis_query(params) + assert "created_at" in repr(query) + assert "user" not in repr(query) diff --git a/tests/conftest.py b/tests/conftest.py index 9081cf9b..56730e1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,13 +2,50 @@ from typing import Generator, List import pytest +import rom +from _pytest.monkeypatch import MonkeyPatch from chalice.test import Client +from redislite import Redis -from examples.chalicelib.models import Account +from examples.chalicelib.models.mongo_models import Account from .helpers import accept_json +@pytest.fixture +def user_id() -> str: + return 'US123456789' + + +@pytest.fixture +def another_user_id() -> str: + return 'US987654321' + + +@pytest.fixture(scope='session') +def monkeypatchsession(request): + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(autouse=True) +def setup_redis(monkeypatchsession) -> Generator[None, None, None]: + # Usa un fake redis para no utilizar un servidor de Redis + redis_connection = Redis('/tmp/redis.db') + monkeypatchsession.setattr( + rom.util, 'get_connection', lambda: redis_connection + ) + yield + + +@pytest.fixture(autouse=True) +def flush_redis() -> Generator[None, None, None]: + yield + redis_connection = Redis('/tmp/redis.db') + redis_connection.flushall() + + @pytest.fixture() def client() -> Generator[Client, None, None]: from examples import app @@ -28,7 +65,9 @@ def client() -> Generator[Client, None, None]: @pytest.fixture -def accounts() -> Generator[List[Account], None, None]: +def accounts( + user_id: str, another_user_id: str +) -> Generator[List[Account], None, None]: user_id = 'US123456789' accs = [ Account( @@ -48,7 +87,7 @@ def accounts() -> Generator[List[Account], None, None]: ), Account( name='Remedios Varo', - user_id='US987654321', + user_id=another_user_id, created_at=dt.datetime(2020, 4, 1), ), ] diff --git a/tests/models/conftest.py b/tests/models/conftest.py new file mode 100644 index 00000000..ad225024 --- /dev/null +++ b/tests/models/conftest.py @@ -0,0 +1,59 @@ +import datetime as dt +from typing import Generator, List, Type, Union + +import pytest + +from agave.models import mongo, redis + +DbModel = Union[mongo.MongoModel, redis.RedisModel] + + +@pytest.fixture +def db_model(request) -> Type[DbModel]: + return request.param + + +@pytest.fixture +def accounts( + db_model: Type[DbModel], user_id: str, another_user_id: str +) -> Generator[List[DbModel], None, None]: + accs = [ + db_model( + name='Frida Kahlo', + user_id=user_id, + created_at=dt.datetime(2020, 1, 1), + ), + db_model( + name='Sor Juana Inés', + user_id=user_id, + created_at=dt.datetime(2020, 2, 1), + ), + db_model( + name='Leona Vicario', + user_id=user_id, + created_at=dt.datetime(2020, 3, 1), + ), + db_model( + name='Remedios Varo', + user_id=another_user_id, + created_at=dt.datetime(2020, 4, 1), + ), + ] + + for acc in accs: + acc.save() + yield accs + for acc in accs: + acc.delete() + + +@pytest.fixture +def account(accounts: List[DbModel]) -> Generator[DbModel, None, None]: + yield accounts[0] + + +@pytest.fixture +def another_account( + accounts: List[DbModel], +) -> Generator[DbModel, None, None]: + yield accounts[-1] diff --git a/tests/models/test_base.py b/tests/models/test_base.py deleted file mode 100644 index d86f6815..00000000 --- a/tests/models/test_base.py +++ /dev/null @@ -1,17 +0,0 @@ -from mongoengine import Document, StringField - -from agave.models import BaseModel - - -class TestModel(BaseModel, Document): - id = StringField() - secret_field = StringField() - __test__ = False - _hidden = ['secret_field'] - - -def test_hide_field(): - model = TestModel(id='12345', secret_field='secret') - model_dict = model.to_dict() - assert model_dict['secret_field'] == '********' - assert model_dict['id'] == '12345' diff --git a/tests/models/test_models.py b/tests/models/test_models.py new file mode 100644 index 00000000..b20dd643 --- /dev/null +++ b/tests/models/test_models.py @@ -0,0 +1,134 @@ +import datetime as dt +from typing import Callable, List, Type, Union + +import pytest + +from agave.exc import DoesNotExist +from agave.models import mongo, redis +from examples.chalicelib.models import mongo_models, redis_models +from examples.chalicelib.validators import AccountQuery + +DbModel = Union[mongo.MongoModel, redis.RedisModel] +models = [mongo_models.Account, redis_models.Account] +generic_query_funcs = [ + mongo.filters.generic_mongo_query, + redis.filters.generic_redis_query, +] + + +@pytest.mark.parametrize('db_model', models, indirect=['db_model']) +def test_retrieve(db_model: Type[DbModel], account: DbModel) -> None: + obj = db_model.retrieve(account.id) + assert obj.id == account.id + + +@pytest.mark.parametrize('db_model', models, indirect=['db_model']) +def test_retrieve_not_found(db_model: Type[DbModel]) -> None: + with pytest.raises(DoesNotExist): + db_model.retrieve('unknown-id') + + +@pytest.mark.parametrize('db_model', models, indirect=['db_model']) +def test_retrieve_with_user_id_filter( + db_model: Type[DbModel], account: DbModel, user_id: str +) -> None: + obj = db_model.retrieve(account.id, user_id=user_id) + assert obj.id == account.id + assert obj.user_id == user_id + + +@pytest.mark.parametrize('db_model', models, indirect=['db_model']) +def test_retrieve_not_found_with_user_id_filter( + db_model: Type[DbModel], account: DbModel, another_user_id +) -> None: + with pytest.raises(DoesNotExist): + db_model.retrieve(account.id, user_id=another_user_id) + + +@pytest.mark.parametrize( + 'db_model,generic_query_func', + zip(models, generic_query_funcs), + indirect=['db_model'], +) +def test_query_count( + db_model: Type[DbModel], + generic_query_func: Callable, + accounts: List[DbModel], + user_id: str, +) -> None: + query_params = AccountQuery(count=1, name='Frida Kahlo') + assert db_model.count(generic_query_func(query_params)) == 1 + + query_params = AccountQuery(count=1) + assert db_model.count(generic_query_func(query_params)) == len(accounts) + + query_params = AccountQuery(count=1, user_id=user_id) + assert db_model.count(generic_query_func(query_params)) == len( + [acc for acc in accounts if acc.user_id == user_id] + ) + + +@pytest.mark.parametrize( + 'db_model,generic_query_func', + zip(models, generic_query_funcs), + indirect=['db_model'], +) +@pytest.mark.usefixtures('accounts') +def test_query_all_with_limit( + db_model: Type[DbModel], generic_query_func: Callable +) -> None: + limit = 2 + query_params = AccountQuery(limit=limit) + items, has_more = db_model.all( + generic_query_func(query_params), limit=limit, wants_more=False + ) + assert not has_more + assert len(items) == limit + + +@pytest.mark.parametrize( + 'db_model,generic_query_func', + zip(models, generic_query_funcs), + indirect=['db_model'], +) +@pytest.mark.usefixtures('accounts') +def test_query_all_resource( + db_model: Type[DbModel], generic_query_func: Callable +) -> None: + limit = 3 + query_params = AccountQuery(page_size=limit) + items, has_more = db_model.all( + generic_query_func(query_params), limit=limit, wants_more=True + ) + assert has_more + assert len(items) == limit + + query_params = AccountQuery( + page_size=limit, created_before=items[-1].created_at + ) + items, has_more = db_model.all( + generic_query_func(query_params), limit=limit, wants_more=True + ) + assert not has_more + assert len(items) == 1 + + +@pytest.mark.parametrize('db_model', models, indirect=['db_model']) +def test_to_dict(db_model: Type[DbModel]) -> None: + now = dt.datetime.utcnow() + expected = dict( + id='12345', + name='frida', + user_id='w72638', + secret_field='********', + ) + model = db_model( + id='12345', + name='frida', + user_id='w72638', + secret_field='secret', + created_at=now, + ) + model.save() + model_dict = model.dict() + assert all(model_dict[key] == val for key, val in expected.items())