11import contextlib
22import inspect
33from collections .abc import AsyncGenerator , Callable
4- from contextlib import AbstractAsyncContextManager
4+ from contextlib import AbstractAsyncContextManager , AbstractContextManager
55from typing import TYPE_CHECKING , Any , cast
66
77from litestar .constants import HTTP_DISCONNECT , HTTP_RESPONSE_START , WEBSOCKET_CLOSE , WEBSOCKET_DISCONNECT
1313 get_sqlspec_scope_state ,
1414 set_sqlspec_scope_state ,
1515)
16- from sqlspec .utils .sync_tools import ensure_async_
16+ from sqlspec .utils .sync_tools import ensure_async_ , with_ensure_async_
1717
1818if TYPE_CHECKING :
1919 from collections .abc import Awaitable , Coroutine
@@ -239,8 +239,14 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[
239239 raise ImproperConfigurationError (msg )
240240
241241 connection_cm : Any = config .provide_connection (db_pool )
242+ context_manager : AbstractAsyncContextManager [ConnectionT ] | None = None
242243
243- if not isinstance (connection_cm , AbstractAsyncContextManager ):
244+ if isinstance (connection_cm , AbstractAsyncContextManager ):
245+ context_manager = connection_cm
246+ elif isinstance (connection_cm , AbstractContextManager ):
247+ context_manager = with_ensure_async_ (connection_cm )
248+
249+ if context_manager is None :
244250 conn_instance : ConnectionT
245251 if inspect .isawaitable (connection_cm ):
246252 conn_instance = await cast ("Awaitable[ConnectionT]" , connection_cm )
@@ -250,12 +256,12 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[
250256 yield conn_instance
251257 return
252258
253- entered_connection = await connection_cm .__aenter__ ()
259+ entered_connection = await context_manager .__aenter__ ()
254260 try :
255261 set_sqlspec_scope_state (scope , connection_key , entered_connection )
256262 yield entered_connection
257263 finally :
258- await connection_cm .__aexit__ (None , None , None )
264+ await context_manager .__aexit__ (None , None , None )
259265 delete_sqlspec_scope_state (scope , connection_key )
260266
261267 return provide_connection
0 commit comments