Skip to content

Commit 111d619

Browse files
authored
fix(litestar): correctly handle sync context manager
Ensure the Litestar SQLSpecPlugin connection provider unwraps sync context managers before injecting connections.
1 parent 24b0a1d commit 111d619

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

sqlspec/extensions/litestar/handlers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import inspect
33
from collections.abc import AsyncGenerator, Callable
4-
from contextlib import AbstractAsyncContextManager
4+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
55
from typing import TYPE_CHECKING, Any, cast
66

77
from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
@@ -13,7 +13,7 @@
1313
get_sqlspec_scope_state,
1414
set_sqlspec_scope_state,
1515
)
16-
from sqlspec.utils.sync_tools import ensure_async_
16+
from sqlspec.utils.sync_tools import ensure_async_, with_ensure_async_
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import Awaitable, Coroutine
@@ -239,8 +239,14 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[
239239
raise ImproperConfigurationError(msg)
240240

241241
connection_cm: Any = config.provide_connection(db_pool)
242+
context_manager: AbstractAsyncContextManager[ConnectionT] | None = None
242243

243-
if not isinstance(connection_cm, AbstractAsyncContextManager):
244+
if isinstance(connection_cm, AbstractAsyncContextManager):
245+
context_manager = connection_cm
246+
elif isinstance(connection_cm, AbstractContextManager):
247+
context_manager = with_ensure_async_(connection_cm)
248+
249+
if context_manager is None:
244250
conn_instance: ConnectionT
245251
if inspect.isawaitable(connection_cm):
246252
conn_instance = await cast("Awaitable[ConnectionT]", connection_cm)
@@ -250,12 +256,12 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[
250256
yield conn_instance
251257
return
252258

253-
entered_connection = await connection_cm.__aenter__()
259+
entered_connection = await context_manager.__aenter__()
254260
try:
255261
set_sqlspec_scope_state(scope, connection_key, entered_connection)
256262
yield entered_connection
257263
finally:
258-
await connection_cm.__aexit__(None, None, None)
264+
await context_manager.__aexit__(None, None, None)
259265
delete_sqlspec_scope_state(scope, connection_key)
260266

261267
return provide_connection

tests/unit/test_extensions/test_litestar/test_handlers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from litestar.constants import HTTP_RESPONSE_START
88

99
from sqlspec.adapters.aiosqlite.config import AiosqliteConfig
10+
from sqlspec.adapters.sqlite.config import SqliteConfig
1011
from sqlspec.exceptions import ImproperConfigurationError
1112
from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state
1213
from sqlspec.extensions.litestar.handlers import (
@@ -245,6 +246,29 @@ async def test_async_connection_provider_raises_when_pool_missing() -> None:
245246
assert pool_key in str(exc_info.value)
246247

247248

249+
async def test_sync_connection_provider_supports_context_manager() -> None:
250+
"""Test sync connection provider wraps sync context managers."""
251+
config = SqliteConfig(pool_config={"database": ":memory:"})
252+
pool_key = "test_pool"
253+
connection_key = "test_connection"
254+
255+
provider = connection_provider_maker(config, pool_key, connection_key)
256+
257+
pool = config.create_pool()
258+
state = MagicMock()
259+
state.get.return_value = pool
260+
scope = cast("Scope", {})
261+
262+
try:
263+
async for connection in provider(state, scope):
264+
assert connection is not None
265+
assert get_sqlspec_scope_state(scope, connection_key) is connection
266+
finally:
267+
pool.close()
268+
269+
assert get_sqlspec_scope_state(scope, connection_key) is None
270+
271+
248272
async def test_async_session_provider_creates_session() -> None:
249273
"""Test async session provider creates driver session."""
250274
config = AiosqliteConfig(pool_config={"database": ":memory:"})

0 commit comments

Comments
 (0)