Skip to content

Commit 714f034

Browse files
committed
Expose v2 endpoints on plugin
using the cypher-api here
1 parent e96d0f3 commit 714f034

18 files changed

+1020
-130
lines changed

graphdatascience/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .pipeline.lp_training_pipeline import LPTrainingPipeline
1111
from .pipeline.nc_training_pipeline import NCTrainingPipeline
1212
from .pipeline.nr_training_pipeline import NRTrainingPipeline
13+
from .plugin_v2_endpoints import PluginV2Endpoints
1314
from .query_runner.query_runner import QueryRunner
1415
from .server_version.server_version import ServerVersion
1516
from .session.gds_sessions import GdsSessions
@@ -33,4 +34,5 @@
3334
"NRModel",
3435
"GraphSageModel",
3536
"SimpleRelEmbeddingModel",
37+
"PluginV2Endpoints",
3638
]

graphdatascience/graph_data_science.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from neo4j import Driver
99
from pandas import DataFrame
1010

11+
from graphdatascience.plugin_v2_endpoints import PluginV2Endpoints
1112
from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication
1213
from graphdatascience.query_runner.query_mode import QueryMode
1314
from graphdatascience.topological_lp.topological_lp_runner import TopologicalLPRunner
@@ -87,15 +88,17 @@ def __init__(
8788
if aura_ds:
8889
GraphDataScience._validate_endpoint(endpoint)
8990

91+
neo4j_query_runner: Neo4jQueryRunner | None = None
9092
if isinstance(endpoint, QueryRunner):
9193
self._query_runner = endpoint
9294
else:
9395
db_auth = None
9496
if auth:
9597
db_auth = neo4j.basic_auth(*auth)
96-
self._query_runner = Neo4jQueryRunner.create_for_db(
98+
neo4j_query_runner = Neo4jQueryRunner.create_for_db(
9799
endpoint, db_auth, aura_ds, database, bookmarks, show_progress
98100
)
101+
self._query_runner = neo4j_query_runner
99102

100103
self._server_version = self._query_runner.server_version()
101104

@@ -130,6 +133,14 @@ def __init__(
130133
connection_string_override=None if arrow is True else arrow,
131134
)
132135

136+
arrow_client = (
137+
None if not isinstance(self._query_runner, ArrowQueryRunner) else self._query_runner._gds_arrow_client
138+
)
139+
if neo4j_query_runner:
140+
self._v2_endpoints: PluginV2Endpoints | None = PluginV2Endpoints(neo4j_query_runner, arrow_client)
141+
else:
142+
self._v2_endpoints = None
143+
133144
self._query_runner.set_show_progress(show_progress)
134145
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
135146

@@ -265,6 +276,17 @@ def driver_config(self) -> dict[str, Any]:
265276
"""
266277
return self._query_runner.driver_config()
267278

279+
@property
280+
def v2(self) -> PluginV2Endpoints:
281+
"""
282+
Return preview v2 endpoints. These endpoints may change without warning.
283+
These endpoints are a preview of the API for the next major version of this library.
284+
"""
285+
if not self._v2_endpoints:
286+
raise RuntimeError("v2 endpoints are not available.")
287+
288+
return self._v2_endpoints
289+
268290
@classmethod
269291
def from_neo4j_driver(
270292
cls: Type[GraphDataScience],

graphdatascience/plugin_v2_endpoints.py

Lines changed: 436 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from graphdatascience.procedure_surface.cypher.catalog_cypher_endpoints import (
2+
CatalogCypherEndpoints,
3+
GraphProjectResult,
4+
GraphWithProjectResult,
5+
)
6+
7+
__all__ = ["CatalogCypherEndpoints", "GraphWithProjectResult", "GraphProjectResult"]

graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from graphdatascience.procedure_surface.api.catalog.graph_info import GraphInfo, GraphInfoWithDegrees
1717
from graphdatascience.procedure_surface.api.catalog.graph_sampling_endpoints import GraphSamplingEndpoints
1818
from graphdatascience.procedure_surface.cypher.catalog.graph_backend_cypher import get_graph
19+
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
20+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1921

2022
from ...call_parameters import CallParameters
21-
from ...query_runner.query_runner import QueryRunner
2223
from ..api.base_result import BaseResult
2324
from ..utils.config_converter import ConfigConverter
2425
from .catalog.graph_sampling_cypher_endpoints import GraphSamplingCypherEndpoints
@@ -28,14 +29,15 @@
2829

2930

3031
class CatalogCypherEndpoints(CatalogEndpoints):
31-
def __init__(self, query_runner: QueryRunner):
32-
self._query_runner = query_runner
32+
def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient | None = None):
33+
self.cypher_runner = cypher_runner
34+
self._arrow_client = arrow_client
3335

3436
def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
3537
graph_name = G if isinstance(G, str) else G.name() if G is not None else None
3638
params = CallParameters(graphName=graph_name) if graph_name else CallParameters()
3739

38-
result = self._query_runner.call_procedure(endpoint="gds.graph.list", params=params)
40+
result = self.cypher_runner.call_procedure(endpoint="gds.graph.list", params=params)
3941
return [GraphInfoWithDegrees(**row.to_dict()) for _, row in result.iterrows()]
4042

4143
def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
@@ -47,7 +49,7 @@ def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | No
4749
else CallParameters(graphName=graph_name)
4850
)
4951

50-
result = self._query_runner.call_procedure(endpoint="gds.graph.drop", params=params)
52+
result = self.cypher_runner.call_procedure(endpoint="gds.graph.drop", params=params)
5153
if len(result) > 0:
5254
return GraphInfo(**result.iloc[0].to_dict())
5355
else:
@@ -64,6 +66,7 @@ def project(
6466
job_id: str | None = None,
6567
sudo: bool = False,
6668
username: str | None = None,
69+
log_progress: bool = True,
6770
) -> GraphWithProjectResult:
6871
config = ConfigConverter.convert_to_gds_config(
6972
nodeProperties=node_properties,
@@ -82,9 +85,11 @@ def project(
8285
)
8386
params.ensure_job_id_in_config()
8487

85-
result = self._query_runner.call_procedure(endpoint="gds.graph.project", params=params).squeeze()
88+
result = self.cypher_runner.call_procedure(
89+
endpoint="gds.graph.project", params=params, logging=log_progress
90+
).squeeze()
8691
project_result = GraphProjectResult(**result.to_dict())
87-
return GraphWithProjectResult(get_graph(project_result.graph_name, self._query_runner), project_result)
92+
return GraphWithProjectResult(get_graph(project_result.graph_name, self.cypher_runner), project_result)
8893

8994
def filter(
9095
self,
@@ -94,6 +99,7 @@ def filter(
9499
relationship_filter: str,
95100
concurrency: int | None = None,
96101
job_id: str | None = None,
102+
log_progress: bool = True,
97103
) -> GraphWithFilterResult:
98104
config = ConfigConverter.convert_to_gds_config(
99105
concurrency=concurrency,
@@ -109,8 +115,10 @@ def filter(
109115
)
110116
params.ensure_job_id_in_config()
111117

112-
result = self._query_runner.call_procedure(endpoint="gds.graph.filter", params=params).squeeze()
113-
return GraphWithFilterResult(get_graph(graph_name, self._query_runner), GraphFilterResult(**result.to_dict()))
118+
result = self.cypher_runner.call_procedure(
119+
endpoint="gds.graph.filter", params=params, logging=log_progress
120+
).squeeze()
121+
return GraphWithFilterResult(get_graph(graph_name, self.cypher_runner), GraphFilterResult(**result.to_dict()))
114122

115123
def generate(
116124
self,
@@ -151,26 +159,28 @@ def generate(
151159

152160
params.ensure_job_id_in_config()
153161

154-
result = self._query_runner.call_procedure(endpoint="gds.graph.generate", params=params).squeeze()
162+
result = self.cypher_runner.call_procedure(
163+
endpoint="gds.graph.generate", params=params, logging=log_progress
164+
).squeeze()
155165
return GraphWithGenerationStats(
156-
get_graph(graph_name, self._query_runner), GraphGenerationStats(**result.to_dict())
166+
get_graph(graph_name, self.cypher_runner), GraphGenerationStats(**result.to_dict())
157167
)
158168

159169
@property
160170
def sample(self) -> GraphSamplingEndpoints:
161-
return GraphSamplingCypherEndpoints(self._query_runner)
171+
return GraphSamplingCypherEndpoints(self.cypher_runner)
162172

163173
@property
164174
def node_labels(self) -> NodeLabelCypherEndpoints:
165-
return NodeLabelCypherEndpoints(self._query_runner)
175+
return NodeLabelCypherEndpoints(self.cypher_runner)
166176

167177
@property
168178
def node_properties(self) -> NodePropertiesCypherEndpoints:
169-
return NodePropertiesCypherEndpoints(self._query_runner)
179+
return NodePropertiesCypherEndpoints(self.cypher_runner, self._arrow_client)
170180

171181
@property
172182
def relationships(self) -> RelationshipCypherEndpoints:
173-
return RelationshipCypherEndpoints(self._query_runner)
183+
return RelationshipCypherEndpoints(self.cypher_runner, self._arrow_client)
174184

175185

176186
class GraphProjectResult(BaseResult):

graphdatascience/procedure_surface/utils/result_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from pandas import DataFrame
44

5-
from graphdatascience import QueryRunner
65
from graphdatascience.query_runner.query_mode import QueryMode
6+
from graphdatascience.query_runner.query_runner import QueryRunner
77

88

99
def transpose_property_columns(result: DataFrame, list_node_labels: bool) -> DataFrame:

graphdatascience/session/session_v2_endpoints.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,16 @@ def graph(self) -> CatalogArrowEndpoints:
152152

153153
@property
154154
def config(self) -> ConfigArrowEndpoints:
155+
"""
156+
Return configuration-related endpoints.
157+
"""
155158
return ConfigArrowEndpoints(self._arrow_client)
156159

157160
@property
158161
def system(self) -> SystemArrowEndpoints:
162+
"""
163+
Return system-related endpoints.
164+
"""
159165
return SystemArrowEndpoints(self._arrow_client)
160166

161167
## Algorithms

graphdatascience/tests/integrationV2/procedure_surface/conftest.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
11
import logging
22
import os
3+
import subprocess
34
from dataclasses import dataclass
45
from datetime import datetime
56
from pathlib import Path
67
from typing import Generator
78

9+
import dotenv
810
import pytest
911
from dateutil.relativedelta import relativedelta
1012
from testcontainers.core.container import DockerContainer
1113
from testcontainers.core.network import Network
1214
from testcontainers.core.waiting_utils import wait_for_logs
15+
from testcontainers.neo4j import Neo4jContainer
1316

1417
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
1518
from graphdatascience.arrow_client.arrow_info import ArrowInfo
1619
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
1720
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1821
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1922
from graphdatascience.tests.integrationV2.conftest import inside_ci
23+
from graphdatascience.tests.integrationV2.procedure_surface.gds_api_spec import (
24+
EndpointWithModesSpec,
25+
resolve_spec_from_file,
26+
)
2027

2128
LOGGER = logging.getLogger(__name__)
2229

@@ -49,6 +56,49 @@ def network() -> Generator[Network, None, None]:
4956
yield network
5057

5158

59+
@pytest.fixture(scope="session")
60+
def gds_api_spec(tmp_path_factory: pytest.TempPathFactory) -> Generator[list[EndpointWithModesSpec], None, None]:
61+
provided_spec_file = os.environ.get("GDS_API_SPEC_FILE")
62+
63+
spec_file: Path | None = None
64+
if provided_spec_file:
65+
spec_file = Path(provided_spec_file)
66+
67+
if spec_file and not spec_file.exists():
68+
raise FileNotFoundError(f"GDS_API_SPEC_FILE is set to '{spec_file}', but the file does not exist.")
69+
70+
if not spec_file:
71+
spec_dir = tmp_path_factory.mktemp("gds_api_spec")
72+
spec_file = spec_dir / "gds-api-spec.json"
73+
74+
# allow for caching
75+
if not spec_file.exists():
76+
download_gds_api_spec(spec_file)
77+
78+
# Adjust the path to pull from graph-analytics
79+
yield resolve_spec_from_file(spec_file)
80+
81+
82+
def download_gds_api_spec(destination: Path) -> None:
83+
import requests
84+
85+
url = "https://raw.githubusercontent.com/neo-technology/graph-analytics/refs/heads/master/tools/gds-api-spec/gds-api-spec.json"
86+
gh_token = os.environ.get("GITHUB_TOKEN")
87+
if not gh_token:
88+
try:
89+
result = subprocess.run(["gh", "auth", "token"], capture_output=True, text=True, check=True)
90+
gh_token = result.stdout.strip()
91+
except (subprocess.CalledProcessError, FileNotFoundError) as e:
92+
raise ValueError("Failed to get GitHub token. Set GITHUB_TOKEN or authenticate with gh CLI.") from e
93+
94+
headers = {"Authorization": f"Token {gh_token}"}
95+
response = requests.get(url, headers=headers)
96+
response.raise_for_status()
97+
98+
with open(destination, "wb") as f:
99+
f.write(response.content)
100+
101+
52102
def latest_neo4j_version() -> str:
53103
today = datetime.now()
54104
previous_month = today - relativedelta(months=1)
@@ -156,6 +206,60 @@ def start_database(logs_dir: Path, network: Network) -> Generator[DbmsConnection
156206
f.write(stdout.decode("utf-8"))
157207

158208

209+
def start_gds_plugin_database(
210+
logs_dir: Path, tmp_path_factory: pytest.TempPathFactory
211+
) -> Generator[Neo4jContainer, None, None]:
212+
neo4j_image = os.getenv("NEO4J_DATABASE_IMAGE", "neo4j:enterprise")
213+
214+
dotenv.load_dotenv("graphdatascience/tests/test.env", override=True)
215+
GDS_LICENSE_KEY = os.getenv("GDS_LICENSE_KEY")
216+
217+
db_logs_dir = logs_dir / "cypher_surface" / "db_logs"
218+
db_logs_dir.mkdir(parents=True)
219+
db_logs_dir.chmod(0o777)
220+
221+
neo4j_container = (
222+
Neo4jContainer(
223+
image=neo4j_image,
224+
)
225+
.with_env("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes")
226+
.with_env("NEO4J_PLUGINS", '["graph-data-science"]')
227+
.with_env("NEO4J_gds_arrow_enabled", "true")
228+
.with_env("NEO4J_gds_arrow_listen__address", "0.0.0.0:8491")
229+
.with_exposed_ports(8491)
230+
.with_volume_mapping(db_logs_dir, "/logs", mode="rw")
231+
)
232+
233+
if GDS_LICENSE_KEY is not None:
234+
license_dir = tmp_path_factory.mktemp("gds_license")
235+
license_dir.chmod(0o755)
236+
license_file = os.path.join(license_dir, "license_key")
237+
with open(license_file, "w") as f:
238+
f.write(GDS_LICENSE_KEY)
239+
240+
neo4j_container.with_volume_mapping(
241+
license_dir,
242+
"/licenses",
243+
)
244+
neo4j_container.with_env("NEO4J_gds_enterprise_license__file", "/licenses/license_key")
245+
246+
with neo4j_container as neo4j_db:
247+
try:
248+
wait_for_logs(neo4j_db, "Started.")
249+
yield neo4j_db
250+
finally
251+
stdout, stderr = neo4j_db.get_logs()
252+
if stderr:
253+
print(f"Error logs from Neo4j container:\n{stderr}")
254+
255+
if inside_ci():
256+
print(f"Neo4j container logs:\n{stdout}")
257+
258+
out_file = db_logs_dir / "stdout.log"
259+
with open(out_file, "w") as f:
260+
f.write(stdout.decode("utf-8"))
261+
262+
159263
def create_db_query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[Neo4jQueryRunner, None, None]:
160264
query_runner = Neo4jQueryRunner.create_for_db(
161265
f"bolt://{neo4j_connection.uri}",

graphdatascience/tests/integrationV2/procedure_surface/cypher/catalog/test_relationship_cypher_endpoints.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from graphdatascience.procedure_surface.cypher.catalog.relationship_cypher_endpoints import (
1010
RelationshipCypherEndpoints,
1111
)
12+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1213
from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph
1314

1415

@@ -45,7 +46,7 @@ def sample_graph(query_runner: QueryRunner) -> Generator[GraphV2, None, None]:
4546

4647
@pytest.fixture
4748
def relationship_endpoints(
48-
query_runner: QueryRunner,
49+
query_runner: Neo4jQueryRunner,
4950
) -> Generator[RelationshipCypherEndpoints, None, None]:
5051
yield RelationshipCypherEndpoints(query_runner)
5152

0 commit comments

Comments
 (0)