Skip to content

Commit 754387a

Browse files
committed
Implement GdsArrowClient based on AuthArrowClient
1 parent b549f86 commit 754387a

File tree

4 files changed

+466
-51
lines changed

4 files changed

+466
-51
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import json
2+
from typing import Any
3+
4+
from pyarrow._flight import Result
5+
6+
7+
def deserialize_single(input_stream: list[Result]) -> dict[str, Any]:
8+
rows = deserialize(input_stream)
9+
if len(rows) != 1:
10+
raise ValueError(f"Expected exactly one result, got {len(rows)}")
11+
12+
return rows[0]
13+
14+
15+
def deserialize(input_stream: list[Result]) -> list[dict[str, Any]]:
16+
def deserialize_row(row: Result): # type:ignore
17+
return json.loads(row.body.to_pybytes().decode())
18+
19+
return [deserialize_row(row) for row in input_stream]

graphdatascience/arrow_client/v1/gds_arrow_client.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,34 @@
11
from __future__ import annotations
22

3-
import base64
43
import json
54
import logging
65
import re
7-
import time
8-
import warnings
9-
from dataclasses import dataclass
106
from types import TracebackType
117
from typing import Any, Callable, Iterable, Type
128

139
import pandas
1410
import pyarrow
1511
from neo4j.exceptions import ClientError
16-
from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Schema, Table, chunked_array, flight
17-
from pyarrow import __version__ as arrow_version
18-
from pyarrow.flight import (
19-
ClientMiddleware,
20-
ClientMiddlewareFactory,
21-
FlightDescriptor,
22-
FlightInternalError,
23-
FlightMetadataReader,
24-
FlightStreamWriter,
25-
FlightTimedOutError,
26-
FlightUnavailableError,
27-
)
12+
from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Table, chunked_array, flight
2813
from pyarrow.types import is_dictionary
2914
from pydantic import BaseModel
30-
from tenacity import (
31-
retry,
32-
retry_any,
33-
retry_if_exception_type,
34-
stop_after_attempt,
35-
stop_after_delay,
36-
wait_exponential,
37-
)
3815

3916
from graphdatascience.arrow_client.arrow_endpoint_version import ArrowEndpointVersion
4017
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
41-
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication
42-
from graphdatascience.retry_utils.retry_config import RetryConfig
43-
from graphdatascience.retry_utils.retry_utils import before_log
18+
from graphdatascience.arrow_client.v1.data_mapper_utils import deserialize_single
4419

4520
from ...semantic_version.semantic_version import SemanticVersion
46-
from ...version import __version__
47-
from ..arrow_info import ArrowInfo
4821

4922

5023
class GdsArrowClient:
5124
def __init__(
5225
self,
5326
flight_client: AuthenticatedArrowClient,
27+
auto_close: bool = True,
5428
):
5529
"""Creates a new GdsArrowClient instance."""
5630
self._flight_client = flight_client
31+
self._auto_close = auto_close
5732
self._logger = logging.getLogger("gds_arrow_client")
5833

5934
def get_node_properties(
@@ -102,7 +77,7 @@ def get_node_properties(
10277
if node_labels:
10378
config["node_labels"] = node_labels
10479

105-
result = self._get_data(database, graph_name, proc, concurrency, config)
80+
result = self._get_data(graph_name, database, proc, concurrency, config)
10681
if list_node_labels:
10782
result.rename(columns={"labels": "nodeLabels"}, inplace=True)
10883

@@ -126,7 +101,7 @@ def get_node_labels(self, graph_name: str, database: str, concurrency: int | Non
126101
DataFrame
127102
The requested nodes as a DataFrame
128103
"""
129-
return self._get_data(database, graph_name, "gds.graph.nodeLabels.stream", concurrency, {})
104+
return self._get_data(graph_name, database, "gds.graph.nodeLabels.stream", concurrency, {})
130105

131106
def get_relationships(
132107
self,
@@ -154,11 +129,13 @@ def get_relationships(
154129
DataFrame
155130
The requested relationships as a DataFrame
156131
"""
157-
return self._flight_client._get_data( database,
132+
return self._get_data(
158133
graph_name,
134+
database,
159135
"gds.graph.relationships.stream",
160136
concurrency,
161-
{"relationship_types": relationship_types})
137+
{"relationship_types": relationship_types},
138+
)
162139

163140
def get_relationship_properties(
164141
self,
@@ -200,8 +177,7 @@ def get_relationship_properties(
200177
if relationship_types:
201178
config["relationship_types"] = relationship_types
202179

203-
return self._get_data(database, graph_name, proc, concurrency, config)
204-
180+
return self._get_data(graph_name, database, proc, concurrency, config)
205181

206182
def create_graph(
207183
self,
@@ -353,7 +329,7 @@ def node_load_done(self, graph_name: str) -> NodeLoadDoneResult:
353329
NodeLoadDoneResult
354330
A result object containing the name of the import process and the number of nodes loaded
355331
"""
356-
return NodeLoadDoneResult.from_json(self._send_action("NODE_LOAD_DONE", {"name": graph_name}))
332+
return NodeLoadDoneResult(**self._send_action("NODE_LOAD_DONE", {"name": graph_name}))
357333

358334
def relationship_load_done(self, graph_name: str) -> RelationshipLoadDoneResult:
359335
"""
@@ -371,7 +347,7 @@ def relationship_load_done(self, graph_name: str) -> RelationshipLoadDoneResult:
371347
RelationshipLoadDoneResult
372348
A result object containing the name of the import process and the number of relationships loaded
373349
"""
374-
return RelationshipLoadDoneResult.from_json(self._send_action("RELATIONSHIP_LOAD_DONE", {"name": graph_name}))
350+
return RelationshipLoadDoneResult(**self._send_action("RELATIONSHIP_LOAD_DONE", {"name": graph_name}))
375351

376352
def triplet_load_done(self, graph_name: str) -> TripletLoadDoneResult:
377353
"""
@@ -389,7 +365,7 @@ def triplet_load_done(self, graph_name: str) -> TripletLoadDoneResult:
389365
TripletLoadDoneResult
390366
A result object containing the name of the import process and the number of nodes and relationships loaded
391367
"""
392-
return TripletLoadDoneResult.from_json(self._send_action("TRIPLET_LOAD_DONE", {"name": graph_name}))
368+
return TripletLoadDoneResult(**self._send_action("TRIPLET_LOAD_DONE", {"name": graph_name}))
393369

394370
def abort(self, graph_name: str) -> None:
395371
"""
@@ -472,8 +448,9 @@ def upload_triplets(
472448
self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback)
473449

474450
def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]:
475-
action_type = f"{ArrowEndpointVersion.V1.prefix()}/{action_type}"
476-
return self._flight_client.do_action_with_retry(action_type, meta_data)
451+
action_type = f"{ArrowEndpointVersion.V1.prefix()}{action_type}"
452+
raw_result = self._flight_client.do_action_with_retry(action_type, meta_data)
453+
return deserialize_single(raw_result)
477454

478455
def _upload_data(
479456
self,
@@ -500,13 +477,7 @@ def _upload_data(
500477

501478
put_stream, ack_stream = self._flight_client.do_put_with_retry(upload_descriptor, batches[0].schema)
502479

503-
@retry(
504-
reraise=True,
505-
before=before_log("Upload batch", self._logger, logging.DEBUG),
506-
retry=self._retry_config.retry,
507-
stop=self._retry_config.stop,
508-
wait=self._retry_config.wait,
509-
)
480+
@self._flight_client._retry_config.decorator(operation_name="Upload batch", logger=self._logger)
510481
def upload_batch(p: RecordBatch) -> None:
511482
put_stream.write_batch(p)
512483

@@ -582,9 +553,8 @@ def __exit__(
582553
self.close()
583554

584555
def close(self) -> None:
585-
self._client.close()
586-
587-
def _parse
556+
if self._auto_close:
557+
self._flight_client.close()
588558

589559
@staticmethod
590560
def _sanitize_arrow_table(arrow_table: Table) -> Table:

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@
4141
from graphdatascience.retry_utils.retry_config import RetryConfig
4242
from graphdatascience.retry_utils.retry_utils import before_log
4343

44+
from ..arrow_client.arrow_endpoint_version import ArrowEndpointVersion
4445
from ..semantic_version.semantic_version import SemanticVersion
4546
from ..version import __version__
46-
from ..arrow_client.arrow_endpoint_version import ArrowEndpointVersion
4747
from .arrow_info import ArrowInfo
4848

4949

0 commit comments

Comments
 (0)