Skip to content

Commit 3a693e9

Browse files
DarthMaxFlorentinD
authored andcommitted
Address PR comments
1 parent 1bb35e3 commit 3a693e9

File tree

2 files changed

+43
-42
lines changed

2 files changed

+43
-42
lines changed

graphdatascience/arrow_client/v2/gds_arrow_client.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def get_relationships(
173173

174174
return JobClient.run_job(self._flight_client, endpoint, config)
175175

176-
def stream(self, graph_name: str, job_id: str) -> pandas.DataFrame:
176+
def stream_job(self, graph_name: str, job_id: str) -> pandas.DataFrame:
177177
"""
178178
Streams the results of a previously started job.
179179
@@ -391,44 +391,7 @@ def upload_triplets(
391391
"""
392392
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback)
393393

394-
def _upload_data(
395-
self,
396-
endpoint: str,
397-
job_id: str,
398-
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
399-
batch_size: int = 10000,
400-
progress_callback: Callable[[int], None] = lambda x: None,
401-
) -> None:
402-
match data:
403-
case pyarrow.Table():
404-
batches = data.to_batches(batch_size)
405-
case pandas.DataFrame():
406-
batches = pyarrow.Table.from_pandas(data).to_batches(batch_size)
407-
case _:
408-
batches = data
409-
410-
flight_descriptor = {
411-
"name": endpoint,
412-
"version": ArrowEndpointVersion.V2.version(),
413-
"body": {
414-
"jobId": job_id,
415-
},
416-
}
417-
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
418-
419-
put_stream, ack_stream = self._flight_client.do_put_with_retry(upload_descriptor, batches[0].schema)
420-
421-
@self._flight_client._retry_config.decorator(operation_name="Upload batch", logger=self._logger)
422-
def upload_batch(p: RecordBatch) -> None:
423-
put_stream.write_batch(p)
424-
425-
with put_stream:
426-
for partition in batches:
427-
upload_batch(partition)
428-
ack_stream.read()
429-
progress_callback(partition.num_rows)
430-
431-
def abort(self, job_id: str) -> None:
394+
def abort_job(self, job_id: str) -> None:
432395
"""
433396
Aborts the specified process
434397
@@ -494,6 +457,43 @@ def request_token(self) -> str | None:
494457

495458
return self._flight_client.request_token()
496459

460+
def _upload_data(
461+
self,
462+
endpoint: str,
463+
job_id: str,
464+
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
465+
batch_size: int = 10000,
466+
progress_callback: Callable[[int], None] = lambda x: None,
467+
) -> None:
468+
match data:
469+
case pyarrow.Table():
470+
batches = data.to_batches(batch_size)
471+
case pandas.DataFrame():
472+
batches = pyarrow.Table.from_pandas(data).to_batches(batch_size)
473+
case _:
474+
batches = data
475+
476+
flight_descriptor = {
477+
"name": endpoint,
478+
"version": ArrowEndpointVersion.V2.version(),
479+
"body": {
480+
"jobId": job_id,
481+
},
482+
}
483+
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
484+
485+
put_stream, ack_stream = self._flight_client.do_put_with_retry(upload_descriptor, batches[0].schema)
486+
487+
@self._flight_client._retry_config.decorator(operation_name="Upload batch", logger=self._logger)
488+
def upload_batch(p: RecordBatch) -> None:
489+
put_stream.write_batch(p)
490+
491+
with put_stream:
492+
for partition in batches:
493+
upload_batch(partition)
494+
ack_stream.read()
495+
progress_callback(partition.num_rows)
496+
497497
def __enter__(self) -> GdsArrowClient:
498498
return self
499499

graphdatascience/tests/integrationV2/arrow_client/v2/test_gds_arrow_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,28 @@ def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, N
5050

5151
def test_stream_node_label(gds_arrow_client: GdsArrowClient, sample_graph: GraphV2) -> None:
5252
job_id = gds_arrow_client.get_nodes(sample_graph.name(), node_filter="n.prop1 > 1")
53-
result = gds_arrow_client.stream(sample_graph.name(), job_id)
53+
result = gds_arrow_client.stream_job(sample_graph.name(), job_id)
5454

5555
assert ["nodeId"] == list(result.columns)
5656
assert len(result) == 2
5757

5858

5959
def test_stream_node_properties(gds_arrow_client: GdsArrowClient, sample_graph: GraphV2) -> None:
6060
job_id = gds_arrow_client.get_node_properties(sample_graph.name(), node_properties=["prop1", "prop2"])
61-
result = gds_arrow_client.stream(sample_graph.name(), job_id)
61+
result = gds_arrow_client.stream_job(sample_graph.name(), job_id)
6262

6363
assert len(result) == 3
6464
assert "nodeId" in result.columns
6565
assert "prop1" in result.columns
6666
assert "prop2" in result.columns
67+
assert {"nodeId", "prop1", "prop2"} == set(result.columns)
6768
assert set(result["prop1"].tolist()) == {1, 2, 3}
6869
assert set(result["prop2"].tolist()) == {42.0, 43.0, 44.0}
6970

7071

7172
def test_stream_relationship_properties(gds_arrow_client: GdsArrowClient, sample_graph: GraphV2) -> None:
7273
job_id = gds_arrow_client.get_relationships(sample_graph.name(), ["REL"], relationship_properties=["relX", "relY"])
73-
result = gds_arrow_client.stream(sample_graph.name(), job_id)
74+
result = gds_arrow_client.stream_job(sample_graph.name(), job_id)
7475

7576
assert len(result) == 2
7677
assert "sourceNodeId" in result.columns

0 commit comments

Comments
 (0)