66from inspect import getattr_static , isclass , iscoroutinefunction , isfunction , signature
77from functools import partial
88from warnings import warn
9- from typing import get_type_hints , Any , Callable , Dict , NamedTuple , Optional
9+ from types import TracebackType
10+ from typing import (
11+ cast ,
12+ get_type_hints ,
13+ Any ,
14+ Callable ,
15+ ContextManager ,
16+ Dict ,
17+ NamedTuple ,
18+ Optional ,
19+ Type ,
20+ )
1021
1122from .spy_calls import SpyCall
1223from .warnings import IncorrectCallWarning
@@ -37,7 +48,7 @@ def _get_type_hints(obj: Any) -> Dict[str, Any]:
3748 return {}
3849
3950
40- class BaseSpy :
51+ class BaseSpy ( ContextManager [ Any ]) :
4152 """Spy object base class.
4253
4354 - Pretends to be another class, if another class is given as a spec
@@ -84,25 +95,35 @@ def __class__(self) -> Any:
8495
8596 return type (self )
8697
87- def _call (self , * args : Any , ** kwargs : Any ) -> Any :
88- spy_id = id (self )
89- spy_name = (
90- self ._name
91- if self ._name
92- else f"{ type (self ).__module__ } .{ type (self ).__qualname__ } "
93- )
98+ def __enter__ (self ) -> Any :
99+ """Allow a spy to be used as a context manager."""
100+ enter_spy = self ._get_or_create_child_spy ("__enter__" )
101+ return enter_spy ()
94102
95- if hasattr (self , "__signature__" ):
96- try :
97- bound_args = self .__signature__ .bind (* args , ** kwargs )
98- except TypeError as e :
99- # stacklevel: 3 ensures warning is linked to call location
100- warn (IncorrectCallWarning (e ), stacklevel = 3 )
101- else :
102- args = bound_args .args
103- kwargs = bound_args .kwargs
104-
105- return self ._handle_call (SpyCall (spy_id , spy_name , args , kwargs ))
103+ def __exit__ (
104+ self ,
105+ exc_type : Optional [Type [BaseException ]],
106+ exc_value : Optional [BaseException ],
107+ traceback : Optional [TracebackType ],
108+ ) -> Optional [bool ]:
109+ """Allow a spy to be used as a context manager."""
110+ exit_spy = self ._get_or_create_child_spy ("__exit__" )
111+ return cast (Optional [bool ], exit_spy (exc_type , exc_value , traceback ))
112+
113+ async def __aenter__ (self ) -> Any :
114+ """Allow a spy to be used as an async context manager."""
115+ enter_spy = self ._get_or_create_child_spy ("__aenter__" )
116+ return await enter_spy ()
117+
118+ async def __aexit__ (
119+ self ,
120+ exc_type : Optional [Type [BaseException ]],
121+ exc_value : Optional [BaseException ],
122+ traceback : Optional [TracebackType ],
123+ ) -> Optional [bool ]:
124+ """Allow a spy to be used as a context manager."""
125+ exit_spy = self ._get_or_create_child_spy ("__aexit__" )
126+ return cast (Optional [bool ], await exit_spy (exc_type , exc_value , traceback ))
106127
107128 def __repr__ (self ) -> str :
108129 """Get a helpful string representation of the spy."""
@@ -118,14 +139,15 @@ def __repr__(self) -> str:
118139 return "<Decoy mock>"
119140
120141 def __getattr__ (self , name : str ) -> Any :
121- """Get a property of the spy.
122-
123- Lazily constructs child spies, basing them on type hints if available.
124- """
142+ """Get a property of the spy, always returning a child spy."""
125143 # do not attempt to mock magic methods
126144 if name .startswith ("__" ) and name .endswith ("__" ):
127145 return super ().__getattribute__ (name )
128146
147+ return self ._get_or_create_child_spy (name )
148+
149+ def _get_or_create_child_spy (self , name : str ) -> Any :
150+ """Lazily construct a child spy, basing it on type hints if available."""
129151 # return previously constructed (and cached) child spies
130152 if name in self ._spy_children :
131153 return self ._spy_children [name ]
@@ -167,6 +189,26 @@ def __getattr__(self, name: str) -> Any:
167189
168190 return spy
169191
192+ def _call (self , * args : Any , ** kwargs : Any ) -> Any :
193+ spy_id = id (self )
194+ spy_name = (
195+ self ._name
196+ if self ._name
197+ else f"{ type (self ).__module__ } .{ type (self ).__qualname__ } "
198+ )
199+
200+ if hasattr (self , "__signature__" ):
201+ try :
202+ bound_args = self .__signature__ .bind (* args , ** kwargs )
203+ except TypeError as e :
204+ # stacklevel: 3 ensures warning is linked to call location
205+ warn (IncorrectCallWarning (e ), stacklevel = 3 )
206+ else :
207+ args = bound_args .args
208+ kwargs = bound_args .kwargs
209+
210+ return self ._handle_call (SpyCall (spy_id , spy_name , args , kwargs ))
211+
170212
171213class Spy (BaseSpy ):
172214 """An object that records all calls made to itself and its children."""
0 commit comments