11from __future__ import annotations
22
3- import base64
43import json
54import logging
65import re
7- import time
8- import warnings
9- from dataclasses import dataclass
106from types import TracebackType
117from typing import Any , Callable , Iterable , Type
128
139import pandas
1410import pyarrow
1511from neo4j .exceptions import ClientError
16- from pyarrow import Array , ChunkedArray , DictionaryArray , RecordBatch , Schema , Table , chunked_array , flight
17- from pyarrow import __version__ as arrow_version
18- from pyarrow .flight import (
19- ClientMiddleware ,
20- ClientMiddlewareFactory ,
21- FlightDescriptor ,
22- FlightInternalError ,
23- FlightMetadataReader ,
24- FlightStreamWriter ,
25- FlightTimedOutError ,
26- FlightUnavailableError ,
27- )
12+ from pyarrow import Array , ChunkedArray , DictionaryArray , RecordBatch , Table , chunked_array , flight
2813from pyarrow .types import is_dictionary
2914from pydantic import BaseModel
30- from tenacity import (
31- retry ,
32- retry_any ,
33- retry_if_exception_type ,
34- stop_after_attempt ,
35- stop_after_delay ,
36- wait_exponential ,
37- )
3815
3916from graphdatascience .arrow_client .arrow_endpoint_version import ArrowEndpointVersion
4017from graphdatascience .arrow_client .authenticated_flight_client import AuthenticatedArrowClient
41- from graphdatascience .query_runner .arrow_authentication import ArrowAuthentication , UsernamePasswordAuthentication
42- from graphdatascience .retry_utils .retry_config import RetryConfig
43- from graphdatascience .retry_utils .retry_utils import before_log
18+ from graphdatascience .arrow_client .v1 .data_mapper_utils import deserialize_single
4419
4520from ...semantic_version .semantic_version import SemanticVersion
46- from ...version import __version__
47- from ..arrow_info import ArrowInfo
4821
4922
5023class GdsArrowClient :
5124 def __init__ (
5225 self ,
5326 flight_client : AuthenticatedArrowClient ,
27+ auto_close : bool = True ,
5428 ):
5529 """Creates a new GdsArrowClient instance."""
5630 self ._flight_client = flight_client
31+ self ._auto_close = auto_close
5732 self ._logger = logging .getLogger ("gds_arrow_client" )
5833
5934 def get_node_properties (
@@ -102,7 +77,7 @@ def get_node_properties(
10277 if node_labels :
10378 config ["node_labels" ] = node_labels
10479
105- result = self ._get_data (database , graph_name , proc , concurrency , config )
80+ result = self ._get_data (graph_name , database , proc , concurrency , config )
10681 if list_node_labels :
10782 result .rename (columns = {"labels" : "nodeLabels" }, inplace = True )
10883
@@ -126,7 +101,7 @@ def get_node_labels(self, graph_name: str, database: str, concurrency: int | Non
126101 DataFrame
127102 The requested nodes as a DataFrame
128103 """
129- return self ._get_data (database , graph_name , "gds.graph.nodeLabels.stream" , concurrency , {})
104+ return self ._get_data (graph_name , database , "gds.graph.nodeLabels.stream" , concurrency , {})
130105
131106 def get_relationships (
132107 self ,
@@ -154,11 +129,13 @@ def get_relationships(
154129 DataFrame
155130 The requested relationships as a DataFrame
156131 """
157- return self ._flight_client . _get_data ( database ,
132+ return self ._get_data (
158133 graph_name ,
134+ database ,
159135 "gds.graph.relationships.stream" ,
160136 concurrency ,
161- {"relationship_types" : relationship_types })
137+ {"relationship_types" : relationship_types },
138+ )
162139
163140 def get_relationship_properties (
164141 self ,
@@ -200,8 +177,7 @@ def get_relationship_properties(
200177 if relationship_types :
201178 config ["relationship_types" ] = relationship_types
202179
203- return self ._get_data (database , graph_name , proc , concurrency , config )
204-
180+ return self ._get_data (graph_name , database , proc , concurrency , config )
205181
206182 def create_graph (
207183 self ,
@@ -353,7 +329,7 @@ def node_load_done(self, graph_name: str) -> NodeLoadDoneResult:
353329 NodeLoadDoneResult
354330 A result object containing the name of the import process and the number of nodes loaded
355331 """
356- return NodeLoadDoneResult . from_json ( self ._send_action ("NODE_LOAD_DONE" , {"name" : graph_name }))
332+ return NodeLoadDoneResult ( ** self ._send_action ("NODE_LOAD_DONE" , {"name" : graph_name }))
357333
358334 def relationship_load_done (self , graph_name : str ) -> RelationshipLoadDoneResult :
359335 """
@@ -371,7 +347,7 @@ def relationship_load_done(self, graph_name: str) -> RelationshipLoadDoneResult:
371347 RelationshipLoadDoneResult
372348 A result object containing the name of the import process and the number of relationships loaded
373349 """
374- return RelationshipLoadDoneResult . from_json ( self ._send_action ("RELATIONSHIP_LOAD_DONE" , {"name" : graph_name }))
350+ return RelationshipLoadDoneResult ( ** self ._send_action ("RELATIONSHIP_LOAD_DONE" , {"name" : graph_name }))
375351
376352 def triplet_load_done (self , graph_name : str ) -> TripletLoadDoneResult :
377353 """
@@ -389,7 +365,7 @@ def triplet_load_done(self, graph_name: str) -> TripletLoadDoneResult:
389365 TripletLoadDoneResult
390366 A result object containing the name of the import process and the number of nodes and relationships loaded
391367 """
392- return TripletLoadDoneResult . from_json ( self ._send_action ("TRIPLET_LOAD_DONE" , {"name" : graph_name }))
368+ return TripletLoadDoneResult ( ** self ._send_action ("TRIPLET_LOAD_DONE" , {"name" : graph_name }))
393369
394370 def abort (self , graph_name : str ) -> None :
395371 """
@@ -472,8 +448,9 @@ def upload_triplets(
472448 self ._upload_data (graph_name , "triplet" , triplet_data , batch_size , progress_callback )
473449
474450 def _send_action (self , action_type : str , meta_data : dict [str , Any ]) -> dict [str , Any ]:
475- action_type = f"{ ArrowEndpointVersion .V1 .prefix ()} /{ action_type } "
476- return self ._flight_client .do_action_with_retry (action_type , meta_data )
451+ action_type = f"{ ArrowEndpointVersion .V1 .prefix ()} { action_type } "
452+ raw_result = self ._flight_client .do_action_with_retry (action_type , meta_data )
453+ return deserialize_single (raw_result )
477454
478455 def _upload_data (
479456 self ,
@@ -500,13 +477,7 @@ def _upload_data(
500477
501478 put_stream , ack_stream = self ._flight_client .do_put_with_retry (upload_descriptor , batches [0 ].schema )
502479
503- @retry (
504- reraise = True ,
505- before = before_log ("Upload batch" , self ._logger , logging .DEBUG ),
506- retry = self ._retry_config .retry ,
507- stop = self ._retry_config .stop ,
508- wait = self ._retry_config .wait ,
509- )
480+ @self ._flight_client ._retry_config .decorator (operation_name = "Upload batch" , logger = self ._logger )
510481 def upload_batch (p : RecordBatch ) -> None :
511482 put_stream .write_batch (p )
512483
@@ -582,9 +553,8 @@ def __exit__(
582553 self .close ()
583554
584555 def close (self ) -> None :
585- self ._client .close ()
586-
587- def _parse
556+ if self ._auto_close :
557+ self ._flight_client .close ()
588558
589559 @staticmethod
590560 def _sanitize_arrow_table (arrow_table : Table ) -> Table :
0 commit comments