44import logging
55import platform
66from dataclasses import dataclass
7- from typing import Any , Iterator
7+ from types import TracebackType
8+ from typing import Any , Iterator , Type
89
910import certifi
11+ from pyarrow import Schema , flight
1012from pyarrow import __version__ as arrow_version
11- from pyarrow import flight , Schema
1213from pyarrow ._flight import (
1314 Action ,
1415 ActionType ,
1920 Result ,
2021 Ticket ,
2122)
22- from tenacity import retry , retry_any , retry_if_exception_type , stop_after_attempt , stop_after_delay , wait_exponential
2323
2424from graphdatascience .arrow_client .arrow_authentication import ArrowAuthentication
2525from graphdatascience .arrow_client .arrow_info import ArrowInfo
26- from graphdatascience .retry_utils .retry_config import RetryConfig
26+ from graphdatascience .retry_utils .retry_config import ExponentialWaitConfig , RetryConfigV2 , StopConfig
2727
28- from ..retry_utils .retry_utils import before_log
2928from ..version import __version__
3029from .middleware .auth_middleware import AuthFactory , AuthMiddleware
3130from .middleware .user_agent_middleware import UserAgentFactory
@@ -39,7 +38,7 @@ def create(
3938 encrypted : bool = False ,
4039 arrow_client_options : dict [str , Any ] | None = None ,
4140 connection_string_override : str | None = None ,
42- retry_config : RetryConfig | None = None ,
41+ retry_config : RetryConfigV2 | None = None ,
4342 advertised_listen_address : tuple [str , int ] | None = None ,
4443 ) -> AuthenticatedArrowClient :
4544 connection_string : str
@@ -51,14 +50,14 @@ def create(
5150 host , port = connection_string .split (":" )
5251
5352 if retry_config is None :
54- retry_config = RetryConfig (
55- retry = retry_any (
56- retry_if_exception_type ( FlightTimedOutError ) ,
57- retry_if_exception_type ( FlightUnavailableError ) ,
58- retry_if_exception_type ( FlightInternalError ) ,
59- ) ,
60- stop = ( stop_after_delay ( 10 ) | stop_after_attempt ( 5 ) ),
61- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
53+ retry_config = RetryConfigV2 (
54+ retryable_exceptions = [
55+ FlightTimedOutError ,
56+ FlightUnavailableError ,
57+ FlightInternalError ,
58+ ] ,
59+ stop = StopConfig ( after_delay = 10 , after_attempt = 5 ),
60+ wait = ExponentialWaitConfig (multiplier = 1 , min = 1 , max = 10 ),
6261 )
6362
6463 return AuthenticatedArrowClient (
@@ -74,7 +73,7 @@ def create(
7473 def __init__ (
7574 self ,
7675 host : str ,
77- retry_config : RetryConfig ,
76+ retry_config : RetryConfigV2 ,
7877 port : int = 8491 ,
7978 auth : ArrowAuthentication | None = None ,
8079 encrypted : bool = False ,
@@ -117,7 +116,7 @@ def __init__(
117116 self ._auth_middleware = AuthMiddleware (auth )
118117 self .advertised_listen_address = advertised_listen_address
119118
120- self ._flight_client = self ._instantiate_flight_client ()
119+ self ._flight_client : flight . FlightClient = self ._instantiate_flight_client ()
121120
122121 def connection_info (self ) -> ConnectionInfo :
123122 """
@@ -155,13 +154,7 @@ def request_token(self) -> str | None:
155154 a token from the server and returns it.
156155 """
157156
158- @retry (
159- reraise = True ,
160- before = before_log ("Request token" , self ._logger , logging .DEBUG ),
161- retry = self ._retry_config .retry ,
162- stop = self ._retry_config .stop ,
163- wait = self ._retry_config .wait ,
164- )
157+ @self ._retry_config .decorator (operation_name = "Request token" , logger = self ._logger )
165158 def auth_with_retry () -> None :
166159 client = self ._flight_client
167160 if self ._auth :
@@ -183,13 +176,7 @@ def do_action(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[
183176 return self ._flight_client .do_action (Action (endpoint , payload_bytes )) # type: ignore
184177
185178 def do_action_with_retry (self , endpoint : str , payload : bytes | dict [str , Any ]) -> list [Result ]:
186- @retry (
187- reraise = True ,
188- before = before_log ("Send action" , self ._logger , logging .DEBUG ),
189- retry = self ._retry_config .retry ,
190- stop = self ._retry_config .stop ,
191- wait = self ._retry_config .wait ,
192- )
179+ @self ._retry_config .decorator (operation_name = "Send action" , logger = self ._logger )
193180 def run_with_retry () -> list [Result ]:
194181 # the Flight response error code is only checked on iterator consumption
195182 # we eagerly collect iterator here to trigger retry in case of an error
@@ -203,18 +190,27 @@ def list_actions(self) -> set[ActionType]:
203190 def do_put_with_retry (
204191 self , descriptor : flight .FlightDescriptor , schema : Schema
205192 ) -> tuple [flight .FlightStreamWriter , flight .FlightMetadataReader ]:
206- @retry (
207- reraise = True ,
208- before = before_log ("Do put" , self ._logger , logging .DEBUG ),
209- retry = self ._retry_config .retry ,
210- stop = self ._retry_config .stop ,
211- wait = self ._retry_config .wait ,
212- )
213- def run_with_retry () -> list [Result ]:
214- return self ._client ().do_put (descriptor , schema )
193+ @self ._retry_config .decorator (operation_name = "Do put" , logger = self ._logger )
194+ def run_with_retry () -> tuple [flight .FlightStreamWriter , flight .FlightMetadataReader ]:
195+ return self ._flight_client .do_put (descriptor , schema ) # type: ignore
215196
216197 return run_with_retry ()
217198
199+ def __enter__ (self ) -> AuthenticatedArrowClient :
200+ return self
201+
202+ def __exit__ (
203+ self ,
204+ exception_type : Type [BaseException ] | None ,
205+ exception_value : BaseException | None ,
206+ traceback : TracebackType | None ,
207+ ) -> None :
208+ self .close ()
209+
210+ def close (self ) -> None :
211+ if self ._flight_client :
212+ self ._flight_client .close ()
213+
218214 def _instantiate_flight_client (self ) -> flight .FlightClient :
219215 location = (
220216 flight .Location .for_grpc_tls (self ._host , self ._port )
0 commit comments