Skip to content

Commit 9dfc8c7

Browse files
committed
- fix typo
- autocommit and autorollback in run_in_new_ctx
1 parent d334c36 commit 9dfc8c7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+351
-203
lines changed

.python-version

Lines changed: 0 additions & 1 deletion
This file was deleted.

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ lint:
55
flake8 .
66

77
test:
8-
pytest --cov context_async_sqlalchemy exmaples/fastapi_example/tests --cov-report=term-missing
8+
pytest --cov context_async_sqlalchemy examples/fastapi_example/tests --cov-report=term-missing
99

1010
uv:
1111
uv sync

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ async def some_func() -> None:
3636

3737
The repository includes an example integration with FastAPI,
3838
which describes numerous workflows.
39-
[FastAPI example](https://github.com/krylosov-aa/context-async-sqlalchemy/tree/main/exmaples/fastapi_example/routes)
39+
[FastAPI example](https://github.com/krylosov-aa/context-async-sqlalchemy/tree/main/examples/fastapi_example/routes)
4040

4141

4242
It also includes two types of test setups you can use in your projects.
4343
The library currently has 90% test coverage. The tests are in the
4444
examples, as we want to test not in the abstract but in the context of a real
4545
asynchronous web application.
4646

47-
[FastAPI tests example](https://github.com/krylosov-aa/context-async-sqlalchemy/tree/main/exmaples/fastapi_example/tests)
47+
[FastAPI tests example](https://github.com/krylosov-aa/context-async-sqlalchemy/tree/main/examples/fastapi_example/tests)
4848

4949
### The most basic example
5050

@@ -278,7 +278,7 @@ achieved through fast transaction rollback.
278278
You can see the capabilities in the examples:
279279

280280
[Here are tests with a common transaction between the
281-
application and the tests.](https://github.com/krylosov-aa/context-async-sqlalchemy/blob/main/exmaples/fastapi_example/tests/transactional/__init__.py)
281+
application and the tests.](https://github.com/krylosov-aa/context-async-sqlalchemy/blob/main/examples/fastapi_example/tests/transactional/__init__.py)
282282

283283

284-
[And here's an example with different transactions.](https://github.com/krylosov-aa/context-async-sqlalchemy/blob/main/exmaples/fastapi_example/tests/non_transactional/__init__.py)
284+
[And here's an example with different transactions.](https://github.com/krylosov-aa/context-async-sqlalchemy/blob/main/examples/fastapi_example/tests/non_transactional/__init__.py)

context_async_sqlalchemy/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
get_db_session_from_context,
66
put_db_session_to_context,
77
pop_db_session_from_context,
8-
run_in_new_ctx,
98
)
109
from .connect import DBConnect
1110
from .session import (
@@ -23,9 +22,15 @@
2322
rollback_all_sessions,
2423
close_all_sessions,
2524
)
25+
from .run_in_new_context import run_in_new_ctx
26+
from .starlette_utils import (
27+
add_starlette_http_db_session_middleware,
28+
starlette_http_db_session_middleware,
29+
)
30+
2631
from .fastapi_utils import (
27-
fastapi_db_session_middleware,
28-
add_fastapi_db_session_middleware,
32+
fastapi_http_db_session_middleware,
33+
add_fastapi_http_db_session_middleware,
2934
)
3035

3136
__all__ = [
@@ -48,6 +53,8 @@
4853
"commit_all_sessions",
4954
"rollback_all_sessions",
5055
"close_all_sessions",
51-
"fastapi_db_session_middleware",
52-
"add_fastapi_db_session_middleware",
56+
"add_starlette_http_db_session_middleware",
57+
"starlette_http_db_session_middleware",
58+
"fastapi_http_db_session_middleware",
59+
"add_fastapi_http_db_session_middleware",
5360
]

context_async_sqlalchemy/connect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ async def change_host(self, host: str) -> None:
6666

6767
async def create_session(self) -> AsyncSession:
6868
"""Creates a new session"""
69-
maker = await self.get_session_maker()
69+
maker = await self.session_maker()
7070
return maker()
7171

72-
async def get_session_maker(self) -> async_sessionmaker[AsyncSession]:
72+
async def session_maker(self) -> async_sessionmaker[AsyncSession]:
7373
"""Gets the session maker"""
7474
if self._before_create_session_handler:
7575
await self._before_create_session_handler(self)

context_async_sqlalchemy/context.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
from contextvars import ContextVar, copy_context, Token
2-
from typing import Any, Awaitable, Callable, Generator, TypeVar
1+
from contextvars import ContextVar, Token
2+
from typing import Any, Generator
33

44
from sqlalchemy.ext.asyncio import AsyncSession
55

66
from .connect import DBConnect
77

88

9-
def init_db_session_ctx() -> Token[dict[str, AsyncSession] | None]:
9+
def init_db_session_ctx(
10+
force: bool = False,
11+
) -> Token[dict[str, AsyncSession] | None]:
1012
"""
1113
Initiates a context for storing sessions
1214
"""
13-
if is_context_initiated():
15+
if not force and is_context_initiated():
1416
raise Exception("Context already initiated")
1517

1618
return _init_db_session_ctx()
@@ -73,30 +75,6 @@ def sessions_stream() -> Generator[AsyncSession, Any, None]:
7375
yield session
7476

7577

76-
AsyncCallableResult = TypeVar("AsyncCallableResult")
77-
AsyncCallable = Callable[..., Awaitable[AsyncCallableResult]]
78-
79-
80-
async def run_in_new_ctx(
81-
callable_func: AsyncCallable[AsyncCallableResult],
82-
*args: Any,
83-
**kwargs: Any,
84-
) -> AsyncCallableResult:
85-
"""
86-
Runs a function in a new context with new sessions that have their
87-
own connection.
88-
The intended use is to run multiple database queries concurrently.
89-
90-
example of use:
91-
await asyncio.gather(
92-
run_in_new_ctx(your_function_with_db_session, ...),
93-
run_in_new_ctx(your_function_with_db_session, ...),
94-
)
95-
"""
96-
new_ctx = copy_context()
97-
return await new_ctx.run(_new_ctx_wrapper, callable_func, *args, **kwargs)
98-
99-
10078
_db_session_ctx: ContextVar[dict[str, AsyncSession] | None] = ContextVar(
10179
"db_session_ctx", default=None
10280
)
@@ -112,15 +90,3 @@ def _get_initiated_context() -> dict[str, AsyncSession]:
11290
def _init_db_session_ctx() -> Token[dict[str, AsyncSession] | None]:
11391
session_ctx: dict[str, AsyncSession] | None = {}
11492
return _db_session_ctx.set(session_ctx)
115-
116-
117-
async def _new_ctx_wrapper(
118-
callable_func: AsyncCallable[AsyncCallableResult],
119-
*args: Any,
120-
**kwargs: Any,
121-
) -> AsyncCallableResult:
122-
token = _init_db_session_ctx()
123-
try:
124-
return await callable_func(*args, **kwargs)
125-
finally:
126-
await reset_db_session_ctx(token)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from .middleware import (
2-
fastapi_db_session_middleware,
3-
add_fastapi_db_session_middleware,
2+
fastapi_http_db_session_middleware,
3+
add_fastapi_http_db_session_middleware,
44
)
55

66
__all__ = [
7-
"fastapi_db_session_middleware",
8-
"add_fastapi_db_session_middleware",
7+
"fastapi_http_db_session_middleware",
8+
"add_fastapi_http_db_session_middleware",
99
]
Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,22 @@
1-
from fastapi import FastAPI, Request
1+
from fastapi import FastAPI
22
from starlette.middleware.base import ( # type: ignore[attr-defined]
3+
Request,
34
Response,
45
RequestResponseEndpoint,
5-
BaseHTTPMiddleware,
66
)
77

88
from context_async_sqlalchemy import (
9-
init_db_session_ctx,
10-
is_context_initiated,
11-
reset_db_session_ctx,
12-
auto_commit_by_status_code,
13-
rollback_all_sessions,
9+
add_starlette_http_db_session_middleware,
10+
starlette_http_db_session_middleware,
1411
)
1512

1613

17-
def add_fastapi_db_session_middleware(app: FastAPI) -> None:
14+
def add_fastapi_http_db_session_middleware(app: FastAPI) -> None:
1815
"""Adds middleware to the application"""
19-
app.add_middleware(
20-
BaseHTTPMiddleware, dispatch=fastapi_db_session_middleware
21-
)
16+
add_starlette_http_db_session_middleware(app)
2217

2318

24-
async def fastapi_db_session_middleware(
19+
async def fastapi_http_db_session_middleware(
2520
request: Request, call_next: RequestResponseEndpoint
2621
) -> Response:
2722
"""
@@ -33,22 +28,4 @@ async def fastapi_db_session_middleware(
3328
3429
But you can commit or rollback manually in the handler.
3530
"""
36-
# Tests have different session management rules
37-
# so if the context variable is already set, we do nothing
38-
if is_context_initiated():
39-
return await call_next(request)
40-
41-
# We set the context here, meaning all child coroutines will receive the
42-
# same context. And even if a child coroutine requests the
43-
# session first, the dictionary itself is shared, and this coroutine will
44-
# add the session to dictionary = shared context.
45-
token = init_db_session_ctx()
46-
try:
47-
response = await call_next(request)
48-
await auto_commit_by_status_code(response.status_code)
49-
return response
50-
except Exception:
51-
await rollback_all_sessions()
52-
raise
53-
finally:
54-
await reset_db_session_ctx(token)
31+
return await starlette_http_db_session_middleware(request, call_next)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from contextvars import copy_context
2+
from typing import Any, Awaitable, Callable, TypeVar
3+
4+
from .context import init_db_session_ctx, reset_db_session_ctx
5+
from .auto_commit import commit_all_sessions, rollback_all_sessions
6+
7+
8+
AsyncCallableResult = TypeVar("AsyncCallableResult")
9+
AsyncCallable = Callable[..., Awaitable[AsyncCallableResult]]
10+
11+
12+
async def run_in_new_ctx(
13+
callable_func: AsyncCallable[AsyncCallableResult],
14+
*args: Any,
15+
**kwargs: Any,
16+
) -> AsyncCallableResult:
17+
"""
18+
Runs a function in a new context with new sessions that have their
19+
own connection.
20+
21+
It will commit the transaction automatically if callable_func does not
22+
raise exceptions. Otherwise, the transaction will be rolled back.
23+
24+
The intended use is to run multiple database queries concurrently.
25+
26+
example of use:
27+
await asyncio.gather(
28+
run_in_new_ctx(
29+
your_function_with_db_session, some_arg, some_kwarg=123,
30+
),
31+
run_in_new_ctx(your_function_with_db_session, ...),
32+
)
33+
"""
34+
new_ctx = copy_context()
35+
return await new_ctx.run(_new_ctx_wrapper, callable_func, *args, **kwargs)
36+
37+
38+
async def _new_ctx_wrapper(
39+
callable_func: AsyncCallable[AsyncCallableResult],
40+
*args: Any,
41+
**kwargs: Any,
42+
) -> AsyncCallableResult:
43+
token = init_db_session_ctx(force=True)
44+
try:
45+
result = await callable_func(*args, **kwargs)
46+
await commit_all_sessions()
47+
return result
48+
except Exception:
49+
await rollback_all_sessions()
50+
raise
51+
finally:
52+
await reset_db_session_ctx(token)

context_async_sqlalchemy/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def new_non_ctx_session(
9999
async with new_non_ctx_session(connect) as session:
100100
await session.execute(...)
101101
"""
102-
session_maker = await connect.get_session_maker()
102+
session_maker = await connect.session_maker()
103103
async with session_maker() as session:
104104
yield session
105105

0 commit comments

Comments
 (0)