Skip to content

Commit b549f86

Browse files
committed
Support serializing authenticated flight client
1 parent 41b1051 commit b549f86

File tree

3 files changed

+191
-55
lines changed

3 files changed

+191
-55
lines changed

graphdatascience/arrow_client/authenticated_flight_client.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import logging
55
import platform
66
from dataclasses import dataclass
7-
from typing import Any, Iterator
7+
from types import TracebackType
8+
from typing import Any, Iterator, Type
89

910
import certifi
11+
from pyarrow import Schema, flight
1012
from pyarrow import __version__ as arrow_version
11-
from pyarrow import flight, Schema
1213
from pyarrow._flight import (
1314
Action,
1415
ActionType,
@@ -19,13 +20,11 @@
1920
Result,
2021
Ticket,
2122
)
22-
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
2323

2424
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
2525
from graphdatascience.arrow_client.arrow_info import ArrowInfo
26-
from graphdatascience.retry_utils.retry_config import RetryConfig
26+
from graphdatascience.retry_utils.retry_config import ExponentialWaitConfig, RetryConfigV2, StopConfig
2727

28-
from ..retry_utils.retry_utils import before_log
2928
from ..version import __version__
3029
from .middleware.auth_middleware import AuthFactory, AuthMiddleware
3130
from .middleware.user_agent_middleware import UserAgentFactory
@@ -39,7 +38,7 @@ def create(
3938
encrypted: bool = False,
4039
arrow_client_options: dict[str, Any] | None = None,
4140
connection_string_override: str | None = None,
42-
retry_config: RetryConfig | None = None,
41+
retry_config: RetryConfigV2 | None = None,
4342
advertised_listen_address: tuple[str, int] | None = None,
4443
) -> AuthenticatedArrowClient:
4544
connection_string: str
@@ -51,14 +50,14 @@ def create(
5150
host, port = connection_string.split(":")
5251

5352
if retry_config is None:
54-
retry_config = RetryConfig(
55-
retry=retry_any(
56-
retry_if_exception_type(FlightTimedOutError),
57-
retry_if_exception_type(FlightUnavailableError),
58-
retry_if_exception_type(FlightInternalError),
59-
),
60-
stop=(stop_after_delay(10) | stop_after_attempt(5)),
61-
wait=wait_exponential(multiplier=1, min=1, max=10),
53+
retry_config = RetryConfigV2(
54+
retryable_exceptions=[
55+
FlightTimedOutError,
56+
FlightUnavailableError,
57+
FlightInternalError,
58+
],
59+
stop=StopConfig(after_delay=10, after_attempt=5),
60+
wait=ExponentialWaitConfig(multiplier=1, min=1, max=10),
6261
)
6362

6463
return AuthenticatedArrowClient(
@@ -74,7 +73,7 @@ def create(
7473
def __init__(
7574
self,
7675
host: str,
77-
retry_config: RetryConfig,
76+
retry_config: RetryConfigV2,
7877
port: int = 8491,
7978
auth: ArrowAuthentication | None = None,
8079
encrypted: bool = False,
@@ -117,7 +116,7 @@ def __init__(
117116
self._auth_middleware = AuthMiddleware(auth)
118117
self.advertised_listen_address = advertised_listen_address
119118

120-
self._flight_client = self._instantiate_flight_client()
119+
self._flight_client: flight.FlightClient = self._instantiate_flight_client()
121120

122121
def connection_info(self) -> ConnectionInfo:
123122
"""
@@ -155,13 +154,7 @@ def request_token(self) -> str | None:
155154
a token from the server and returns it.
156155
"""
157156

158-
@retry(
159-
reraise=True,
160-
before=before_log("Request token", self._logger, logging.DEBUG),
161-
retry=self._retry_config.retry,
162-
stop=self._retry_config.stop,
163-
wait=self._retry_config.wait,
164-
)
157+
@self._retry_config.decorator(operation_name="Request token", logger=self._logger)
165158
def auth_with_retry() -> None:
166159
client = self._flight_client
167160
if self._auth:
@@ -183,13 +176,7 @@ def do_action(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[
183176
return self._flight_client.do_action(Action(endpoint, payload_bytes)) # type: ignore
184177

185178
def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) -> list[Result]:
186-
@retry(
187-
reraise=True,
188-
before=before_log("Send action", self._logger, logging.DEBUG),
189-
retry=self._retry_config.retry,
190-
stop=self._retry_config.stop,
191-
wait=self._retry_config.wait,
192-
)
179+
@self._retry_config.decorator(operation_name="Send action", logger=self._logger)
193180
def run_with_retry() -> list[Result]:
194181
# the Flight response error code is only checked on iterator consumption
195182
# we eagerly collect iterator here to trigger retry in case of an error
@@ -203,18 +190,27 @@ def list_actions(self) -> set[ActionType]:
203190
def do_put_with_retry(
204191
self, descriptor: flight.FlightDescriptor, schema: Schema
205192
) -> tuple[flight.FlightStreamWriter, flight.FlightMetadataReader]:
206-
@retry(
207-
reraise=True,
208-
before=before_log("Do put", self._logger, logging.DEBUG),
209-
retry=self._retry_config.retry,
210-
stop=self._retry_config.stop,
211-
wait=self._retry_config.wait,
212-
)
213-
def run_with_retry() -> list[Result]:
214-
return self._client().do_put(descriptor, schema)
193+
@self._retry_config.decorator(operation_name="Do put", logger=self._logger)
194+
def run_with_retry() -> tuple[flight.FlightStreamWriter, flight.FlightMetadataReader]:
195+
return self._flight_client.do_put(descriptor, schema) # type: ignore
215196

216197
return run_with_retry()
217198

199+
def __enter__(self) -> AuthenticatedArrowClient:
200+
return self
201+
202+
def __exit__(
203+
self,
204+
exception_type: Type[BaseException] | None,
205+
exception_value: BaseException | None,
206+
traceback: TracebackType | None,
207+
) -> None:
208+
self.close()
209+
210+
def close(self) -> None:
211+
if self._flight_client:
212+
self._flight_client.close()
213+
218214
def _instantiate_flight_client(self) -> flight.FlightClient:
219215
location = (
220216
flight.Location.for_grpc_tls(self._host, self._port)
Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,142 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import sys
15
from dataclasses import dataclass
6+
from typing import Any, Callable
7+
8+
from pydantic import BaseModel
9+
from tenacity import WrappedFn, retry
10+
from tenacity.retry import retry_always, retry_any, retry_base, retry_if_exception_type
11+
from tenacity.stop import stop_after_attempt, stop_after_delay, stop_any, stop_base
12+
from tenacity.wait import wait_base, wait_exponential
213

3-
from tenacity.retry import retry_base
4-
from tenacity.stop import stop_base
5-
from tenacity.wait import wait_base
14+
from graphdatascience.retry_utils.retry_utils import before_log
615

716

817
@dataclass(frozen=True, repr=True)
918
class RetryConfig:
1019
stop: stop_base
1120
wait: wait_base
1221
retry: retry_base
22+
23+
24+
class RetryConfigV2(BaseModel):
25+
"""Retry configuration which can be serialized/deserialized."""
26+
27+
retryable_exceptions: list[type[BaseException]] | None = None
28+
stop_config: StopConfig | None = None
29+
wait_config: ExponentialWaitConfig | None = None
30+
31+
def decorator(
32+
self,
33+
logger: logging.Logger | None = None,
34+
log_level: int = logging.DEBUG,
35+
operation_name: str = "Operation",
36+
) -> Callable[[WrappedFn], WrappedFn]:
37+
"""
38+
Create a tenacity retry decorator configured with this retry config.
39+
40+
Parameters
41+
----------
42+
logger
43+
Logger to use for retry logging. If None, no logging is performed.
44+
log_level
45+
Logging level to use (e.g., logging.DEBUG, logging.INFO).
46+
operation_name
47+
Name of the operation being retried, used in log messages.
48+
49+
Returns
50+
-------
51+
Callable
52+
A decorator that can be applied to functions to add retry behavior.
53+
54+
Examples
55+
--------
56+
>>> config = RetryConfigV2(
57+
... retryable_exceptions=[ValueError, KeyError],
58+
... stop_config=StopConfig(after_attempt=3),
59+
... wait_config=ExponentialWaitConfig(multiplier=1, min=1, max=10)
60+
... )
61+
>>> @config.create_retry_decorator()
62+
... def my_function():
63+
... # function logic here
64+
... pass
65+
"""
66+
kwargs: dict[str, Any] = {}
67+
68+
# Add retry condition
69+
retry_condition = self.tenacity_retry()
70+
if retry_condition is not None:
71+
kwargs["retry"] = retry_condition
72+
73+
# Add stop condition
74+
stop_condition = self.tenacity_stop()
75+
if stop_condition is not None:
76+
kwargs["stop"] = stop_condition
77+
78+
# Add wait strategy
79+
wait_strategy = self.tenacity_wait()
80+
if wait_strategy is not None:
81+
kwargs["wait"] = wait_strategy
82+
83+
# Add logging if logger is provided
84+
if logger is not None:
85+
kwargs["before"] = before_log(operation_name, logger, log_level)
86+
87+
return retry(**kwargs, reraise=True)
88+
89+
def tenacity_wait(self) -> wait_base | None:
90+
if self.wait_config is None:
91+
return None
92+
93+
return self.wait_config.to_tenacity()
94+
95+
def tenacity_stop(self) -> stop_base | None:
96+
if self.stop_config is None:
97+
return None
98+
99+
return self.stop_config.to_tenacity()
100+
101+
def tenacity_retry(self) -> retry_base | None:
102+
if self.retryable_exceptions is None:
103+
return retry_always
104+
105+
retries = [retry_if_exception_type(exc) for exc in self.retryable_exceptions]
106+
if len(retries) == 1:
107+
return retries[0]
108+
109+
return retry_any(*retries)
110+
111+
112+
class StopConfig(BaseModel):
113+
after_attempt: int | None = None
114+
after_delay: int | None = None
115+
116+
def to_tenacity(self) -> stop_base | None:
117+
stops: list[stop_base] = []
118+
if self.after_attempt is not None:
119+
stops.append(stop_after_attempt(self.after_attempt))
120+
if self.after_delay is not None:
121+
stops.append(stop_after_delay(self.after_delay))
122+
123+
if not stops:
124+
return None
125+
if len(stops) == 1:
126+
return stops[0]
127+
128+
return stop_any(*stops)
129+
130+
131+
class ExponentialWaitConfig(BaseModel):
132+
multiplier: float = 1
133+
min: int = 0
134+
max: float | int = sys.maxsize / 2
135+
exp_base: int = 2
136+
137+
def to_tenacity(self) -> wait_base | None:
138+
return wait_exponential(
139+
multiplier=self.multiplier,
140+
min=self.min,
141+
max=self.max,
142+
)
Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# graphdatascience/tests/test_authenticated_flight_client.py
22
import pytest
33
from pyarrow._flight import FlightInternalError, FlightTimedOutError, FlightUnavailableError
4-
from tenacity import retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
54

65
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
76
from graphdatascience.arrow_client.arrow_info import ArrowInfo
87
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo
9-
from graphdatascience.retry_utils.retry_config import RetryConfig
8+
from graphdatascience.retry_utils.retry_config import ExponentialWaitConfig, RetryConfigV2, StopConfig
109

1110

1211
@pytest.fixture
@@ -15,15 +14,15 @@ def arrow_info() -> ArrowInfo:
1514

1615

1716
@pytest.fixture
18-
def retry_config() -> RetryConfig:
19-
return RetryConfig(
20-
retry=retry_any(
21-
retry_if_exception_type(FlightTimedOutError),
22-
retry_if_exception_type(FlightUnavailableError),
23-
retry_if_exception_type(FlightInternalError),
24-
),
25-
stop=(stop_after_delay(10) | stop_after_attempt(5)),
26-
wait=wait_exponential(multiplier=1, min=1, max=10),
17+
def retry_config() -> RetryConfigV2:
18+
return RetryConfigV2(
19+
retryable_exceptions=[
20+
FlightTimedOutError,
21+
FlightUnavailableError,
22+
FlightInternalError,
23+
],
24+
stop_config=StopConfig(after_delay=10, after_attempt=5),
25+
wait_config=ExponentialWaitConfig(multiplier=1, min=1, max=10),
2726
)
2827

2928

@@ -37,7 +36,7 @@ def auth_pair(self) -> tuple[str, str]:
3736

3837

3938
def test_create_authenticated_arrow_client(
40-
arrow_info: ArrowInfo, retry_config: RetryConfig, mock_auth: ArrowAuthentication
39+
arrow_info: ArrowInfo, retry_config: RetryConfigV2, mock_auth: ArrowAuthentication
4140
) -> None:
4241
client = AuthenticatedArrowClient.create(
4342
arrow_info=arrow_info, auth=mock_auth, encrypted=True, retry_config=retry_config
@@ -46,7 +45,18 @@ def test_create_authenticated_arrow_client(
4645
assert client.connection_info() == ConnectionInfo("localhost", 8491, encrypted=True)
4746

4847

49-
def test_connection_info(arrow_info: ArrowInfo, retry_config: RetryConfig) -> None:
48+
def test_connection_info(arrow_info: ArrowInfo, retry_config: RetryConfigV2) -> None:
5049
client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config)
5150
connection_info = client.connection_info()
5251
assert connection_info == ConnectionInfo("localhost", 8491, encrypted=False)
52+
53+
54+
def test_pickle_roundtrip(arrow_info: ArrowInfo, retry_config: RetryConfigV2) -> None:
55+
client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config)
56+
import pickle
57+
58+
pickled_client = pickle.dumps(client)
59+
unpickled_client = pickle.loads(pickled_client)
60+
assert isinstance(unpickled_client, AuthenticatedArrowClient)
61+
assert unpickled_client.connection_info() == client.connection_info()
62+
assert unpickled_client._retry_config == client._retry_config

0 commit comments

Comments
 (0)