Skip to content

Commit 1bb35e3

Browse files
DarthMaxFlorentinD
authored andcommitted
write integration tests for v2 GdsArrowClient
1 parent 8efd83c commit 1bb35e3

File tree

6 files changed

+302
-107
lines changed

6 files changed

+302
-107
lines changed

graphdatascience/arrow_client/v2/gds_arrow_client.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,11 @@ def get_node_properties(
8282

8383
return JobClient.run_job(self._flight_client, "v2/graph.nodeProperties.stream", config)
8484

85-
def get_node_labels(
85+
def get_nodes(
8686
self,
8787
graph_name: str,
88-
node_label: str,
8988
*,
90-
node_filter: str,
89+
node_filter: str | None = None,
9190
log_progress: bool = True,
9291
concurrency: int | None = None,
9392
job_id: str | None = None,
@@ -99,8 +98,6 @@ def get_node_labels(
9998
----------
10099
graph_name
101100
The name of the graph
102-
node_label
103-
The node label to stream back.
104101
node_filter
105102
A Cypher predicate for filtering nodes in the input graph.
106103
log_progress
@@ -117,14 +114,14 @@ def get_node_labels(
117114
"""
118115
config = ConfigConverter.convert_to_gds_config(
119116
graph_name=graph_name,
120-
node_label=node_label,
117+
node_label="__IGNORED__",
121118
node_filter=node_filter,
122119
concurrency=concurrency,
123120
log_progress=log_progress,
124121
job_id=job_id,
125122
)
126123

127-
return JobClient.run_job(self._flight_client, "v2/graph.nodeLabel.stream", config)
124+
return JobClient.run_job_and_wait(self._flight_client, "v2/graph.nodeLabel.stream", config, log_progress)
128125

129126
def get_relationships(
130127
self,
@@ -325,15 +322,15 @@ def triplet_load_done(self, job_id: str) -> None:
325322
"""
326323
self._flight_client.do_action_with_retry("v2/graph.project.fromTriplets.done", {"jobId": job_id})
327324

328-
def upload_data(
325+
def upload_nodes(
329326
self,
330327
job_id: str,
331328
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
332329
batch_size: int = 10000,
333330
progress_callback: Callable[[int], None] = lambda x: None,
334331
) -> None:
335332
"""
336-
Uploads data to the server for a given job.
333+
Uploads node data to the server for a given job.
337334
338335
Parameters
339336
----------
@@ -346,7 +343,62 @@ def upload_data(
346343
progress_callback
347344
A callback function that is called with the number of rows uploaded after each batch
348345
"""
346+
self._upload_data("graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback)
349347

348+
def upload_relationships(
349+
self,
350+
job_id: str,
351+
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
352+
batch_size: int = 10000,
353+
progress_callback: Callable[[int], None] = lambda x: None,
354+
) -> None:
355+
"""
356+
Uploads relationship data to the server for a given job.
357+
358+
Parameters
359+
----------
360+
job_id
361+
The job id of the import process
362+
data
363+
The data to upload
364+
batch_size
365+
The number of rows per batch
366+
progress_callback
367+
A callback function that is called with the number of rows uploaded after each batch
368+
"""
369+
self._upload_data("graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback)
370+
371+
def upload_triplets(
372+
self,
373+
job_id: str,
374+
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
375+
batch_size: int = 10000,
376+
progress_callback: Callable[[int], None] = lambda x: None,
377+
) -> None:
378+
"""
379+
Uploads triplet data to the server for a given job.
380+
381+
Parameters
382+
----------
383+
job_id
384+
The job id of the import process
385+
data
386+
The data to upload
387+
batch_size
388+
The number of rows per batch
389+
progress_callback
390+
A callback function that is called with the number of rows uploaded after each batch
391+
"""
392+
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback)
393+
394+
def _upload_data(
395+
self,
396+
endpoint: str,
397+
job_id: str,
398+
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
399+
batch_size: int = 10000,
400+
progress_callback: Callable[[int], None] = lambda x: None,
401+
) -> None:
350402
match data:
351403
case pyarrow.Table():
352404
batches = data.to_batches(batch_size)
@@ -356,9 +408,11 @@ def upload_data(
356408
batches = data
357409

358410
flight_descriptor = {
359-
"name": job_id,
411+
"name": endpoint,
360412
"version": ArrowEndpointVersion.V2.version(),
361-
"body": {},
413+
"body": {
414+
"jobId": job_id,
415+
},
362416
}
363417
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
364418

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from pathlib import Path
2+
from typing import Generator
3+
4+
import numpy as np
5+
import pandas as pd
6+
import pytest
7+
from testcontainers.core.network import Network
8+
9+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
10+
from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
11+
from graphdatascience.procedure_surface.api.catalog import GraphV2
12+
from graphdatascience.procedure_surface.arrow.catalog import CatalogArrowEndpoints
13+
from graphdatascience.tests.integrationV2.conftest import GdsSessionConnectionInfo, create_arrow_client, start_session
14+
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph
15+
16+
17+
@pytest.fixture(scope="package")
18+
def session_connection(
19+
network: Network, password_dir: Path, logs_dir: Path
20+
) -> Generator[GdsSessionConnectionInfo, None, None]:
21+
yield from start_session(logs_dir, network, password_dir)
22+
23+
24+
@pytest.fixture(scope="package")
25+
def arrow_client(session_connection: GdsSessionConnectionInfo) -> AuthenticatedArrowClient:
26+
return create_arrow_client(session_connection)
27+
28+
29+
@pytest.fixture(scope="package")
30+
def gds_arrow_client(arrow_client: AuthenticatedArrowClient) -> GdsArrowClient:
31+
return GdsArrowClient(arrow_client)
32+
33+
34+
@pytest.fixture
35+
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]:
36+
gdl = """
37+
CREATE
38+
(a: Node:Foo {prop1: 1, prop2: 42.0}),
39+
(b: Node {prop1: 2, prop2: 43.0}),
40+
(c: Node:Foo {prop1: 3, prop2: 44.0}),
41+
42+
(a)-[:REL {relX: 1, relY: 42}]->(b),
43+
(b)-[:REL {relX: 2, relY: 43}]->(c),
44+
(c)-[:REL2 {relX: 1, relY: 2}]->(a),
45+
"""
46+
47+
with create_graph(arrow_client, "g", gdl) as G:
48+
yield G
49+
50+
51+
def test_stream_node_label(gds_arrow_client: GdsArrowClient, sample_graph: GraphV2) -> None:
52+
job_id = gds_arrow_client.get_nodes(sample_graph.name(), node_filter="n.prop1 > 1")
53+
result = gds_arrow_client.stream(sample_graph.name(), job_id)
54+
55+
assert ["nodeId"] == list(result.columns)
56+
assert len(result) == 2
57+
58+
59+
def test_stream_node_properties(gds_arrow_client: GdsArrowClient, sample_graph: GraphV2) -> None:
60+
job_id = gds_arrow_client.get_node_properties(sample_graph.name(), node_properties=["prop1", "prop2"])
61+
result = gds_arrow_client.stream(sample_graph.name(), job_id)
62+
63+
assert len(result) == 3
64+
assert "nodeId" in result.columns
65+
assert "prop1" in result.columns
66+
assert "prop2" in result.columns
67+
assert set(result["prop1"].tolist()) == {1, 2, 3}
68+
assert set(result["prop2"].tolist()) == {42.0, 43.0, 44.0}
69+
70+
71+
def test_stream_relationship_properties(gds_arrow_client: GdsArrowClient, sample_graph: GraphV2) -> None:
72+
job_id = gds_arrow_client.get_relationships(sample_graph.name(), ["REL"], relationship_properties=["relX", "relY"])
73+
result = gds_arrow_client.stream(sample_graph.name(), job_id)
74+
75+
assert len(result) == 2
76+
assert "sourceNodeId" in result.columns
77+
assert "targetNodeId" in result.columns
78+
assert "relationshipType" in result.columns
79+
assert "REL" in result["relationshipType"].tolist()
80+
assert "relX" in result.columns
81+
assert "relY" in result.columns
82+
assert list(result["relX"].tolist()) == [1.0, 2.0]
83+
assert list(result["relY"].tolist()) == [42.0, 43.0]
84+
85+
86+
def test_project_from_triplets(arrow_client: AuthenticatedArrowClient, gds_arrow_client: GdsArrowClient) -> None:
87+
df = pd.DataFrame(
88+
{"sourceNode": np.array([1, 2, 3], dtype=np.int64), "targetNode": np.array([4, 5, 6], dtype=np.int64)}
89+
)
90+
91+
job_id = gds_arrow_client.create_graph_from_triplets("triplets")
92+
gds_arrow_client.upload_triplets(job_id, df)
93+
gds_arrow_client.triplet_load_done(job_id)
94+
95+
while gds_arrow_client.job_status(job_id).status != "Done":
96+
pass
97+
98+
listing = CatalogArrowEndpoints(arrow_client).list("triplets")[0]
99+
assert listing.node_count == 6
100+
assert listing.relationship_count == 3
101+
assert listing.graph_name == "triplets"
102+
103+
104+
def test_project_from_tables(arrow_client: AuthenticatedArrowClient, gds_arrow_client: GdsArrowClient) -> None:
105+
nodes = pd.DataFrame(
106+
{
107+
"nodeId": np.array([1, 2, 3, 4, 5, 6], dtype=np.int64),
108+
}
109+
)
110+
111+
rels = pd.DataFrame(
112+
{
113+
"sourceNodeId": np.array([1, 2, 3], dtype=np.int64),
114+
"targetNodeId": np.array([4, 5, 6], dtype=np.int64),
115+
}
116+
)
117+
118+
job_id = gds_arrow_client.create_graph("table")
119+
gds_arrow_client.upload_nodes(job_id, nodes)
120+
gds_arrow_client.node_load_done(job_id)
121+
122+
while gds_arrow_client.job_status(job_id).status != "RELATIONSHIP_LOADING":
123+
pass
124+
125+
gds_arrow_client.upload_relationships(job_id, rels)
126+
gds_arrow_client.relationship_load_done(job_id)
127+
128+
while gds_arrow_client.job_status(job_id).status != "Done":
129+
pass
130+
131+
listing = CatalogArrowEndpoints(arrow_client).list("table")[0]
132+
assert listing.node_count == 6
133+
assert listing.relationship_count == 3
134+
assert listing.graph_name == "table"

graphdatascience/tests/integrationV2/conftest.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1+
import logging
12
import os
3+
from dataclasses import dataclass
24
from pathlib import Path
35
from typing import Any, Generator
46

57
import pytest
8+
from testcontainers.core.container import DockerContainer
9+
from testcontainers.core.network import Network
10+
from testcontainers.core.waiting_utils import wait_for_logs
11+
12+
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
13+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
14+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
15+
16+
LOGGER = logging.getLogger(__name__)
617

718

819
def pytest_collection_modifyitems(config: Any, items: Any) -> None:
@@ -31,3 +42,89 @@ def logs_dir(tmp_path_factory: pytest.TempPathFactory) -> Generator[Path, None,
3142

3243
def inside_ci() -> bool:
3344
return os.environ.get("BUILD_NUMBER") is not None
45+
46+
47+
@dataclass
48+
class GdsSessionConnectionInfo:
49+
host: str
50+
arrow_port: int
51+
bolt_port: int
52+
53+
54+
@pytest.fixture(scope="package")
55+
def password_dir(tmp_path_factory: pytest.TempPathFactory) -> Generator[Path, None, None]:
56+
"""Create a temporary file and return its path."""
57+
tmp_dir = tmp_path_factory.mktemp("passwords")
58+
temp_file_path = os.path.join(tmp_dir, "password")
59+
60+
with open(temp_file_path, "w") as f:
61+
f.write("password")
62+
63+
yield tmp_dir
64+
65+
# Clean up the file
66+
os.unlink(temp_file_path)
67+
68+
69+
@pytest.fixture(scope="package")
70+
def network() -> Generator[Network, None, None]:
71+
with Network() as network:
72+
yield network
73+
74+
75+
def start_session(
76+
logs_dir: Path, network: Network, password_dir: Path
77+
) -> Generator[GdsSessionConnectionInfo, None, None]:
78+
if (session_uri := os.environ.get("GDS_SESSION_URI")) is not None:
79+
uri_parts = session_uri.split(":")
80+
yield GdsSessionConnectionInfo(host=uri_parts[0], arrow_port=8491, bolt_port=int(uri_parts[1]))
81+
return
82+
83+
session_image = os.getenv(
84+
"GDS_SESSION_IMAGE", "europe-west1-docker.pkg.dev/gds-aura-artefacts/gds/gds-session:latest"
85+
)
86+
LOGGER.info(f"Using session image: {session_image}")
87+
session_container = (
88+
DockerContainer(
89+
image=session_image,
90+
)
91+
.with_env("ALLOW_LIST", "DEFAULT")
92+
.with_env("DNS_NAME", "gds-session")
93+
.with_env("PAGE_CACHE_SIZE", "100M")
94+
.with_env("MODEL_STORAGE_BASE_LOCATION", "/models")
95+
.with_exposed_ports(8491)
96+
.with_volume_mapping(password_dir, "/passwords")
97+
)
98+
if not inside_ci():
99+
session_container = session_container.with_network(network).with_network_aliases("gds-session")
100+
with session_container as session_container:
101+
try:
102+
wait_for_logs(session_container, "Running GDS tasks: 0", timeout=20)
103+
yield GdsSessionConnectionInfo(
104+
host=session_container.get_container_host_ip(),
105+
arrow_port=session_container.get_exposed_port(8491),
106+
bolt_port=-1, # not used in tests
107+
)
108+
finally:
109+
stdout, stderr = session_container.get_logs()
110+
111+
if stderr:
112+
print(f"Error logs from session container:\n{stderr}")
113+
114+
if inside_ci():
115+
print(f"Session container logs:\n{stdout}")
116+
117+
out_file = logs_dir / "session_container.log"
118+
with open(out_file, "w") as f:
119+
f.write(stdout.decode("utf-8"))
120+
121+
122+
def create_arrow_client(session_uri: GdsSessionConnectionInfo) -> AuthenticatedArrowClient:
123+
"""Create an authenticated Arrow client connected to the session container."""
124+
125+
return AuthenticatedArrowClient.create(
126+
arrow_info=ArrowInfo(f"{session_uri.host}:{session_uri.arrow_port}", True, True, ["v1", "v2"]),
127+
auth=UsernamePasswordAuthentication("neo4j", "password"),
128+
encrypted=False,
129+
advertised_listen_address=("gds-session", 8491),
130+
)

graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from graphdatascience import QueryRunner
1010
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
1111
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
12-
from graphdatascience.tests.integrationV2.conftest import inside_ci
13-
from graphdatascience.tests.integrationV2.procedure_surface.conftest import (
12+
from graphdatascience.tests.integrationV2.conftest import (
1413
GdsSessionConnectionInfo,
1514
create_arrow_client,
15+
inside_ci,
16+
start_session,
17+
)
18+
from graphdatascience.tests.integrationV2.procedure_surface.conftest import (
1619
create_db_query_runner,
1720
start_database,
18-
start_session,
1921
)
2022

2123
LOGGER = logging.getLogger(__name__)

0 commit comments

Comments
 (0)