From b8cd8b51b4796f9c7292bd65ce841f80b2788f7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 28 Nov 2025 13:54:22 +0100 Subject: [PATCH 1/9] Add `get` method to CatalogEndpoints --- .../api/catalog/catalog_endpoints.py | 16 ++++++++++++++++ .../arrow/catalog/catalog_arrow_endpoints.py | 5 +++++ .../cypher/catalog_cypher_endpoints.py | 5 +++++ 3 files changed, 26 insertions(+) diff --git a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py index 2181b9747..ed3ca8602 100644 --- a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py +++ b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py @@ -16,6 +16,22 @@ class CatalogEndpoints(ABC): + @abstractmethod + def get(self, graph_name: str) -> GraphV2: + """Retrieve a handle to a graph from the graph catalog. + + Parameters + ---------- + graph_name + The name of the graph. + + Returns + ------- + GraphV2 + A handle to the graph. + """ + pass + @abstractmethod def construct( self, diff --git a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py index 89131d4ca..5295f28ab 100644 --- a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py +++ b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py @@ -56,6 +56,11 @@ def __init__( protocol_version = ProtocolVersionResolver(query_runner).resolve() self._project_protocol = ProjectProtocol.select(protocol_version) + def get(self, graph_name: str) -> GraphV2: + if not self.list(graph_name): + raise ValueError(f"A graph with name '{graph_name}' does not exist in the catalog.") + return get_graph(graph_name, self._arrow_client) + def project( self, graph_name: str, diff --git a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py index 0c217192e..cdb051b20 100644 --- a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py +++ b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py @@ -39,6 +39,11 @@ def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient self._cypher_runner = cypher_runner self._arrow_client = arrow_client + def get(self, graph_name: str) -> GraphV2: + if not self.list(graph_name): + raise ValueError(f"A graph with name '{graph_name}' does not exist in the catalog.") + return get_graph(graph_name, self.cypher_runner) + def construct( self, graph_name: str, From ff04d028acc3d1fa0070b555ac726e91934822d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 28 Nov 2025 13:54:51 +0100 Subject: [PATCH 2/9] Expose arrow_client method in AuraGraphDataScience --- .../arrow_client/v1/gds_arrow_client.py | 27 +++---------------- .../arrow_client/v2/api_types.py | 4 +-- .../query_runner/session_query_runner.py | 5 ++-- .../session/aura_graph_data_science.py | 18 ++++++++++++- .../session/test_walking_skeleton.py | 1 + .../arrow_client/V1/test_gds_arrow_client.py | 5 ++-- .../session/test_aura_graph_data_science.py | 8 ++++++ 7 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py index a2f59511f..860c23005 100644 --- a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py +++ b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py @@ -2,13 +2,11 @@ import json import logging -import re from types import TracebackType from typing import Any, Callable, Iterable, Type import pandas import pyarrow -from neo4j.exceptions import ClientError from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Table, chunked_array, flight from pyarrow.types import is_dictionary from pydantic import BaseModel @@ -18,6 +16,7 @@ from graphdatascience.arrow_client.v1.data_mapper_utils import deserialize_single from ...semantic_version.semantic_version import SemanticVersion +from ..error_handler import handle_flight_error class GdsArrowClient: @@ -515,7 +514,7 @@ def upload_batch(p: RecordBatch) -> None: ack_stream.read() progress_callback(partition.num_rows) except Exception as e: - GdsArrowClient.handle_flight_error(e) + handle_flight_error(e) def _get_data( self, @@ -560,7 +559,7 @@ def _fetch_get_result(self, get: flight.FlightStreamReader) -> pandas.DataFrame: try: arrow_table = get.read_all() except Exception as e: - GdsArrowClient.handle_flight_error(e) + handle_flight_error(e) arrow_table = self._sanitize_arrow_table(arrow_table) if SemanticVersion.from_string(pandas.__version__) >= SemanticVersion(2, 0, 0): return arrow_table.to_pandas(types_mapper=pandas.ArrowDtype) # type: ignore @@ -615,26 +614,6 @@ def _decode_pyarrow_array(array: Array) -> Array: else: return array - @staticmethod - def handle_flight_error(e: Exception) -> None: - if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError): - original_message = e.args[0] if len(e.args) > 0 else e.message - improved_message = original_message.replace( - "Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", "" - ) - improved_message = improved_message.replace( - "Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", "" - ) - improved_message = improved_message.replace( - "Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ", - "", - ) - improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message) - - raise flight.FlightServerError(improved_message) - else: - raise e - class NodeLoadDoneResult(BaseModel): name: str diff --git a/src/graphdatascience/arrow_client/v2/api_types.py b/src/graphdatascience/arrow_client/v2/api_types.py index 1b7436001..a0b736811 100644 --- a/src/graphdatascience/arrow_client/v2/api_types.py +++ b/src/graphdatascience/arrow_client/v2/api_types.py @@ -37,10 +37,10 @@ def aborted(self) -> bool: return self.status == "Aborted" def succeeded(self) -> bool: - return self.status == "Done" + return self.status.lower() == "done" def running(self) -> bool: - return self.status == "Running" + return self.status.lower() == "running" class MutateResult(ArrowBaseModel): diff --git a/src/graphdatascience/query_runner/session_query_runner.py b/src/graphdatascience/query_runner/session_query_runner.py index 6c693dce1..1b4876042 100644 --- a/src/graphdatascience/query_runner/session_query_runner.py +++ b/src/graphdatascience/query_runner/session_query_runner.py @@ -6,13 +6,14 @@ from pandas import DataFrame -from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient +from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.graph_constructor import GraphConstructor from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger from graphdatascience.query_runner.query_mode import QueryMode from graphdatascience.query_runner.termination_flag import TerminationFlag from graphdatascience.server_version.server_version import ServerVersion +from ..arrow_client.error_handler import handle_flight_error from ..call_parameters import CallParameters from ..session.dbms.protocol_resolver import ProtocolVersionResolver from .protocol.project_protocols import ProjectProtocol @@ -183,7 +184,7 @@ def run_projection() -> DataFrame: else: return run_projection() except Exception as e: - GdsArrowClient.handle_flight_error(e) + handle_flight_error(e) raise e # above should already raise def _remote_write_back( diff --git a/src/graphdatascience/session/aura_graph_data_science.py b/src/graphdatascience/session/aura_graph_data_science.py index a94932ccb..54881eb46 100644 --- a/src/graphdatascience/session/aura_graph_data_science.py +++ b/src/graphdatascience/session/aura_graph_data_science.py @@ -5,7 +5,7 @@ from pandas import DataFrame from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient -from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient +from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient from graphdatascience.call_builder import IndirectCallBuilder from graphdatascience.endpoints import ( AlphaRemoteEndpoints, @@ -94,6 +94,7 @@ def create( v2_endpoints=SessionV2Endpoints( session_auth_arrow_client, db_bolt_query_runner, show_progress=show_progress ), + authenticated_arrow_client=session_auth_arrow_client, ) else: standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner) @@ -102,6 +103,7 @@ def create( delete_fn=delete_fn, gds_version=gds_version, v2_endpoints=SessionV2Endpoints(session_auth_arrow_client, None, show_progress=show_progress), + authenticated_arrow_client=session_auth_arrow_client, ) def __init__( @@ -110,11 +112,13 @@ def __init__( delete_fn: Callable[[], bool], gds_version: ServerVersion, v2_endpoints: SessionV2Endpoints, + authenticated_arrow_client: AuthenticatedArrowClient, ): self._query_runner = query_runner self._delete_fn = delete_fn self._server_version = gds_version self._v2_endpoints = v2_endpoints + self._authenticated_arrow_client = authenticated_arrow_client super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) @@ -177,6 +181,18 @@ def v2(self) -> SessionV2Endpoints: def __getattr__(self, attr: str) -> IndirectCallBuilder: return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version) + def arrow_client(self) -> GdsArrowClient: + """ + Returns a GdsArrowClient that is authenticated to communicate with the Aura Graph Analytics Session. + This client can be used to get direct access to the sessions Arrow Flight server. + + Returns: + A GdsArrowClient + ------- + + """ + return GdsArrowClient(self._authenticated_arrow_client) + def set_database(self, database: str) -> None: """ Set the database which cypher queries are run against. diff --git a/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py b/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py index 9ed81668a..3039dc281 100644 --- a/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py +++ b/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py @@ -15,6 +15,7 @@ def gds(arrow_client: AuthenticatedArrowClient, db_query_runner: QueryRunner) -> delete_fn=lambda: True, gds_version=ServerVersion.from_string("1.2.3"), v2_endpoints=SessionV2Endpoints(arrow_client, db_query_runner), + authenticated_arrow_client=arrow_client, ) diff --git a/tests/unit/arrow_client/V1/test_gds_arrow_client.py b/tests/unit/arrow_client/V1/test_gds_arrow_client.py index b2f777b00..b8f90ee8a 100644 --- a/tests/unit/arrow_client/V1/test_gds_arrow_client.py +++ b/tests/unit/arrow_client/V1/test_gds_arrow_client.py @@ -15,6 +15,7 @@ ) from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.error_handler import handle_flight_error from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.query_runner.arrow_info import ArrowInfo @@ -397,7 +398,7 @@ def test_handle_flight_error() -> None: FlightServerError, match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.", ): - GdsArrowClient.handle_flight_error( + handle_flight_error( FlightServerError( 'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal' ) @@ -407,7 +408,7 @@ def test_handle_flight_error() -> None: FlightServerError, match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"), ): - GdsArrowClient.handle_flight_error( + handle_flight_error( FlightServerError( "FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]" ) diff --git a/tests/unit/session/test_aura_graph_data_science.py b/tests/unit/session/test_aura_graph_data_science.py index 28ba62f31..7183c30ed 100644 --- a/tests/unit/session/test_aura_graph_data_science.py +++ b/tests/unit/session/test_aura_graph_data_science.py @@ -14,6 +14,7 @@ def test_remote_projection_configuration(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) g = gds.graph.project( @@ -50,6 +51,7 @@ def test_remote_projection_defaults(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) g = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -78,6 +80,7 @@ def test_remote_algo_write(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -99,6 +102,7 @@ def test_remote_algo_write_configuration(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -125,6 +129,7 @@ def test_remote_graph_write(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -149,6 +154,7 @@ def test_remote_graph_write_configuration(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -176,6 +182,7 @@ def test_run_cypher_write(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.WRITE, database="bar", retryable=True) @@ -193,6 +200,7 @@ def test_run_cypher_read(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.READ, retryable=False) From c9aaf72f54961c55e2268a8a9368890699b14c63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 28 Nov 2025 13:55:19 +0100 Subject: [PATCH 3/9] Create new example notebook for Sessions + Spark --- .../graph-analytics-serverless-spark.ipynb | 362 ++++++++++++++++++ 1 file changed, 362 insertions(+) create mode 100644 examples/graph-analytics-serverless-spark.ipynb diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb new file mode 100644 index 000000000..278820002 --- /dev/null +++ b/examples/graph-analytics-serverless-spark.ipynb @@ -0,0 +1,362 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "aura" + ] + }, + "source": [ + "# Aura Graph Analytics with Spark" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb) in the Neo4j Graph Data Science Client Github repository.\n", + "\n", + "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session.\n", + "\n", + "We consider a graph of people and fruits, which we're using as a simple example to show how to connect your AuraDB instance to a GDS Session, run algorithms, and eventually write back your analytical results to the AuraDB database. \n", + "We will cover all management operations: creation, listing, and deletion.\n", + "\n", + "If you are using self managed DB, follow [this example](../graph-analytics-serverless-self-managed)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "This notebook requires having an AuraDB instance available and have the Aura Graph Analytics [feature](https://neo4j.com/docs/aura/graph-analytics/#aura-gds-serverless) enabled for your project.\n", + "\n", + "You also need to have the `graphdatascience` Python library installed, version `1.15` or later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"graphdatascience>=1.18a2\" python-dotenv \"pyspark[sql]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "\n", + "# This allows to load required secrets from `.env` file in local directory\n", + "# This can include Aura API Credentials and Database Credentials.\n", + "# If file does not exist this is a noop.\n", + "load_dotenv(\"sessions.env\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Connecting to a Spark Session\n", + "\n", + "To interact with the Spark Cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n", + "Working with a remote Spark cluster will work similarly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from pyspark.sql import SparkSession\n", + "\n", + "os.environ[\"JAVA_HOME\"] = \"/home/max/.sdkman/candidates/java/current\"\n", + "\n", + "spark = SparkSession.builder.master(\"local[4]\").appName(\"GraphAnalytics\").getOrCreate()\n", + "\n", + "# Enable Arrow-based columnar data transfers\n", + "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Aura API credentials\n", + "\n", + "The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from graphdatascience.session import AuraAPICredentials, GdsSessions\n", + "\n", + "# you can also use AuraAPICredentials.from_env() to load credentials from environment variables\n", + "api_credentials = AuraAPICredentials(\n", + " client_id=os.environ[\"CLIENT_ID\"],\n", + " client_secret=os.environ[\"CLIENT_SECRET\"],\n", + " # If your account is a member of several project, you must also specify the project ID to use\n", + " project_id=os.environ.get(\"PROJECT_ID\", None),\n", + ")\n", + "\n", + "sessions = GdsSessions(api_credentials=api_credentials)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a new session\n", + "\n", + "A new session is created by calling `sessions.get_or_create()` with the following parameters:\n", + "\n", + "* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.\n", + "* The session memory. \n", + "* The cloud location.\n", + "* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.\n", + "\n", + "See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import timedelta\n", + "\n", + "from graphdatascience.session import CloudLocation, SessionMemory\n", + "\n", + "# Create a GDS session!\n", + "gds = sessions.get_or_create(\n", + " # we give it a representative name\n", + " session_name=\"people_and_fruits\",\n", + " memory=SessionMemory.m_2GB,\n", + " ttl=timedelta(minutes=30),\n", + " cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding a dataset\n", + "\n", + "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import os\n", + "import zipfile\n", + "\n", + "import requests\n", + "\n", + "download_path = \"bike_trips_data\"\n", + "if not os.path.exists(download_path):\n", + " url = \"https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips\"\n", + "\n", + " response = requests.get(url)\n", + " response.raise_for_status()\n", + "\n", + " # Unzip the content\n", + " with zipfile.ZipFile(io.BytesIO(response.content)) as z:\n", + " z.extractall(download_path)\n", + "\n", + "df = spark.read.csv(download_path, header=True, inferSchema=True)\n", + "df.createOrReplaceTempView(\"bike_trips\")\n", + "df.limit(10).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Projecting Graphs\n", + "\n", + "Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.\n", + "\n", + "We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n", + "\n", + "Our input data already resembles edge triplets, where each of the rows represents an edge from a source station to a target station. This allows us to use the arrows servers graph import from triplets functionality, which requires the following protocol:\n", + "\n", + "1. Send an action `v2/graph.project.fromTriplets`\n", + " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n", + "2. Send the data in batches to the arrow server.\n", + "3. Send another action called `v2/graph.project.fromTriples.done` to tell the import process that no more data will be send. This will trigger the final graph creation inside the session.\n", + "4. Wait for the import process to reach the `DONE` state.\n", + "\n", + "While the overall process is straight forward, we need to somehow tell Spark to" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import pyarrow\n", + "from pyspark.sql import functions\n", + "\n", + "graph_name = \"bike_trips\"\n", + "\n", + "arrow_client = gds.arrow_client()\n", + "\n", + "# 1. Start the import process\n", + "job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)\n", + "\n", + "\n", + "# Define a function that receives an arrow batch and uploads it to the session\n", + "def upload_batch(iterator):\n", + " for batch in iterator:\n", + " arrow_client.upload_triplets(job_id, [batch])\n", + " yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({\"batch_rows_imported\": [len(batch)]}))\n", + "\n", + "\n", + "# Select the source target pairs from our source data\n", + "source_target_pairs = spark.sql(\"\"\"\n", + " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n", + " FROM bike_trips\n", + "\"\"\")\n", + "\n", + "# 2. Use the `mapInArrow` function to upload the data to the sessions. Returns a dataframe with a single column with the batch sizes.\n", + "uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n", + "\n", + "# Aggregate the batch sizes to receive the row count.\n", + "uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\")).show()\n", + "\n", + "# 3. Finish the import process\n", + "arrow_client.triplet_load_done(job_id)\n", + "\n", + "# 4. Wait for the import to finish\n", + "while not arrow_client.job_status(job_id).succeeded():\n", + " pass\n", + "\n", + "G = gds.v2.graph.get(graph_name)\n", + "G" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running Algorithms\n", + "\n", + "We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Running PageRank ...\")\n", + "pr_result = gds.v2.page_rank.mutate(G, mutate_property=\"pagerank\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sending the computation result back to Spark\n", + "\n", + "Once the computation is done. We might want to further use the result in Spark.\n", + "We can do this in a similar to the projection, by streaming batches of data into each of the Spark workers.\n", + "Retrieving the data is a bit more complicated since we need some input data frame in order to trigger computations on the Spark workers.\n", + "We use a data range equal to the size of workers we have in our cluster as our driving table.\n", + "On the workers we will disregard the input and instead stream the computation data from the GDS Session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. Start the node property export on the session\n", + "job_id = arrow_client.get_node_properties(G.name(), [\"pagerank\"])\n", + "\n", + "\n", + "# Define a function that receives data from the GDS Session and turns it into data batches\n", + "def retrieve_data(ignored):\n", + " stream_data = arrow_client.stream_job(G.name(), job_id)\n", + " batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)\n", + " for b in batches:\n", + " yield b\n", + "\n", + "\n", + "# Create DataFrame with a single column and one row per worker\n", + "input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF(\"batch_id\")\n", + "# 2. Stream the data from the GDS Session into the Spark workers\n", + "received_batches = input_partitions.mapInArrow(retrieve_data, \"nodeId long, pagerank double\")\n", + "# Optional: Repartition the data to make sure it is distributed equally\n", + "result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n", + "\n", + "result.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup\n", + "\n", + "Now that we have finished our analysis, we can delete the session and stop the spark connection.\n", + "\n", + "Deleting the session will release all resources associated with it, and stop incurring costs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.delete()\n", + "spark.stop()" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 01c79ec96e308b3d2d6eea58ba75dd43ebe5716a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 28 Nov 2025 16:40:35 +0100 Subject: [PATCH 4/9] Improve notebook docs --- examples/graph-analytics-serverless-spark.ipynb | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb index 278820002..44c87e8f4 100644 --- a/examples/graph-analytics-serverless-spark.ipynb +++ b/examples/graph-analytics-serverless-spark.ipynb @@ -30,10 +30,8 @@ "\n", "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session.\n", "\n", - "We consider a graph of people and fruits, which we're using as a simple example to show how to connect your AuraDB instance to a GDS Session, run algorithms, and eventually write back your analytical results to the AuraDB database. \n", - "We will cover all management operations: creation, listing, and deletion.\n", - "\n", - "If you are using self managed DB, follow [this example](../graph-analytics-serverless-self-managed)." + "We consider a graph of bicycle rentals, which we're using as a simple example to show how project data from Spark to a GDS Session, run algorithms, and eventually retrieving the results back to Spark.\n", + "We will cover all management operations: creation, listing, and deletion." ] }, { @@ -157,7 +155,7 @@ "# Create a GDS session!\n", "gds = sessions.get_or_create(\n", " # we give it a representative name\n", - " session_name=\"people_and_fruits\",\n", + " session_name=\"bike_trips\",\n", " memory=SessionMemory.m_2GB,\n", " ttl=timedelta(minutes=30),\n", " cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n", @@ -216,7 +214,7 @@ "1. Send an action `v2/graph.project.fromTriplets`\n", " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n", "2. Send the data in batches to the arrow server.\n", - "3. Send another action called `v2/graph.project.fromTriples.done` to tell the import process that no more data will be send. This will trigger the final graph creation inside the session.\n", + "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be send. This will trigger the final graph creation inside the session.\n", "4. Wait for the import process to reach the `DONE` state.\n", "\n", "While the overall process is straight forward, we need to somehow tell Spark to" @@ -295,8 +293,8 @@ "source": [ "## Sending the computation result back to Spark\n", "\n", - "Once the computation is done. We might want to further use the result in Spark.\n", - "We can do this in a similar to the projection, by streaming batches of data into each of the Spark workers.\n", + "Once the computation is done, we might want to further use the result in Spark.\n", + "We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers.\n", "Retrieving the data is a bit more complicated since we need some input data frame in order to trigger computations on the Spark workers.\n", "We use a data range equal to the size of workers we have in our cluster as our driving table.\n", "On the workers we will disregard the input and instead stream the computation data from the GDS Session." From 792309c2bafb8d1430c4140dfe51114b1f1585ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 28 Nov 2025 16:41:40 +0100 Subject: [PATCH 5/9] Fix documentation in GdsArrowClient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre --- src/graphdatascience/arrow_client/v2/gds_arrow_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py index 9a6974b92..2085a7e0b 100644 --- a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py +++ b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py @@ -93,7 +93,7 @@ def get_nodes( job_id: str | None = None, ) -> str: """ - Start a new export process to stream the nodes with the specified label and filter from the graph. + Start a new export process to stream the nodes that match the filter from the graph. Parameters ---------- From ec66ddb16bd9cca93e58aefa3d4648d903d47c06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 8 Dec 2025 17:40:59 +0100 Subject: [PATCH 6/9] Remove setting java home --- examples/graph-analytics-serverless-spark.ipynb | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb index 44c87e8f4..462630685 100644 --- a/examples/graph-analytics-serverless-spark.ipynb +++ b/examples/graph-analytics-serverless-spark.ipynb @@ -42,7 +42,7 @@ "\n", "This notebook requires having an AuraDB instance available and have the Aura Graph Analytics [feature](https://neo4j.com/docs/aura/graph-analytics/#aura-gds-serverless) enabled for your project.\n", "\n", - "You also need to have the `graphdatascience` Python library installed, version `1.15` or later." + "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/" ] }, { @@ -84,12 +84,8 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", "from pyspark.sql import SparkSession\n", "\n", - "os.environ[\"JAVA_HOME\"] = \"/home/max/.sdkman/candidates/java/current\"\n", - "\n", "spark = SparkSession.builder.master(\"local[4]\").appName(\"GraphAnalytics\").getOrCreate()\n", "\n", "# Enable Arrow-based columnar data transfers\n", From eaaf560a4e269af8523293413197252f904623d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 10 Dec 2025 10:57:52 +0100 Subject: [PATCH 7/9] Improve notebook documentation Co-authored-by: Mats Rydberg --- .../graph-analytics-serverless-spark.ipynb | 31 ++++++++++--------- .../session/aura_graph_data_science.py | 2 +- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb index 462630685..432ff71ed 100644 --- a/examples/graph-analytics-serverless-spark.ipynb +++ b/examples/graph-analytics-serverless-spark.ipynb @@ -26,12 +26,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb) in the Neo4j Graph Data Science Client Github repository.\n", + "This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless-spark.ipynb) in the Neo4j Graph Data Science Client Github repository.\n", "\n", - "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session.\n", + "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session from within an Apache Spark cluster.\n", "\n", - "We consider a graph of bicycle rentals, which we're using as a simple example to show how project data from Spark to a GDS Session, run algorithms, and eventually retrieving the results back to Spark.\n", - "We will cover all management operations: creation, listing, and deletion." + "We consider a graph of bicycle rentals, which we're using as a simple example to show how to project data from Spark to a GDS Session, run algorithms, and eventually return results back to Spark.\n", + "In this notebook we will focus on the interaction with Apache Spark, and will not cover all possible actions using GDS sessions. We refer to other Tutorials for additional details." ] }, { @@ -74,7 +74,7 @@ "source": [ "### Connecting to a Spark Session\n", "\n", - "To interact with the Spark Cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n", + "To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n", "Working with a remote Spark cluster will work similarly." ] }, @@ -115,7 +115,7 @@ "api_credentials = AuraAPICredentials(\n", " client_id=os.environ[\"CLIENT_ID\"],\n", " client_secret=os.environ[\"CLIENT_SECRET\"],\n", - " # If your account is a member of several project, you must also specify the project ID to use\n", + " # If your account is a member of several projects, you must also specify the project ID to use\n", " project_id=os.environ.get(\"PROJECT_ID\", None),\n", ")\n", "\n", @@ -164,7 +164,8 @@ "source": [ "## Adding a dataset\n", "\n", - "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips)." + "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips).", + "The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip." ] }, { @@ -209,8 +210,8 @@ "\n", "1. Send an action `v2/graph.project.fromTriplets`\n", " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n", - "2. Send the data in batches to the arrow server.\n", - "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be send. This will trigger the final graph creation inside the session.\n", + "2. Send the data in batches to the Arrow server.\n", + "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.\n", "4. Wait for the import process to reach the `DONE` state.\n", "\n", "While the overall process is straight forward, we need to somehow tell Spark to" @@ -234,7 +235,7 @@ "job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)\n", "\n", "\n", - "# Define a function that receives an arrow batch and uploads it to the session\n", + "# Define a function that receives an arrow batch and uploads it to the GDS session\n", "def upload_batch(iterator):\n", " for batch in iterator:\n", " arrow_client.upload_triplets(job_id, [batch])\n", @@ -247,7 +248,7 @@ " FROM bike_trips\n", "\"\"\")\n", "\n", - "# 2. Use the `mapInArrow` function to upload the data to the sessions. Returns a dataframe with a single column with the batch sizes.\n", + "# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.\n", "uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n", "\n", "# Aggregate the batch sizes to receive the row count.\n", @@ -291,7 +292,7 @@ "\n", "Once the computation is done, we might want to further use the result in Spark.\n", "We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers.\n", - "Retrieving the data is a bit more complicated since we need some input data frame in order to trigger computations on the Spark workers.\n", + "Retrieving the data is a bit more complicated since we need some input DataFrame in order to trigger computations on the Spark workers.\n", "We use a data range equal to the size of workers we have in our cluster as our driving table.\n", "On the workers we will disregard the input and instead stream the computation data from the GDS Session." ] @@ -302,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "# 1. Start the node property export on the session\n", + "# 1. Start the node property export on the GDS session\n", "job_id = arrow_client.get_node_properties(G.name(), [\"pagerank\"])\n", "\n", "\n", @@ -330,9 +331,9 @@ "source": [ "## Cleanup\n", "\n", - "Now that we have finished our analysis, we can delete the session and stop the spark connection.\n", + "Now that we have finished our analysis, we can delete the GDS session and stop the Spark session.\n", "\n", - "Deleting the session will release all resources associated with it, and stop incurring costs." + "Deleting the GDS session will release all resources associated with it, and stop incurring costs." ] }, { diff --git a/src/graphdatascience/session/aura_graph_data_science.py b/src/graphdatascience/session/aura_graph_data_science.py index 54881eb46..47397deda 100644 --- a/src/graphdatascience/session/aura_graph_data_science.py +++ b/src/graphdatascience/session/aura_graph_data_science.py @@ -184,7 +184,7 @@ def __getattr__(self, attr: str) -> IndirectCallBuilder: def arrow_client(self) -> GdsArrowClient: """ Returns a GdsArrowClient that is authenticated to communicate with the Aura Graph Analytics Session. - This client can be used to get direct access to the sessions Arrow Flight server. + This client can be used to get direct access to the specific session's Arrow Flight server. Returns: A GdsArrowClient From 9569339834de0a353e7c38d3db2f32ccafc7fc15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 10 Dec 2025 11:38:19 +0100 Subject: [PATCH 8/9] Improve notebook descriptions and address PR comments --- .../graph-analytics-serverless-spark.ipynb | 48 ++++++++++++------- .../arrow_client/v2/api_types.py | 2 +- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb index 432ff71ed..2846f154a 100644 --- a/examples/graph-analytics-serverless-spark.ipynb +++ b/examples/graph-analytics-serverless-spark.ipynb @@ -40,9 +40,7 @@ "source": [ "## Prerequisites\n", "\n", - "This notebook requires having an AuraDB instance available and have the Aura Graph Analytics [feature](https://neo4j.com/docs/aura/graph-analytics/#aura-gds-serverless) enabled for your project.\n", - "\n", - "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/" + "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`." ] }, { @@ -75,7 +73,7 @@ "### Connecting to a Spark Session\n", "\n", "To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n", - "Working with a remote Spark cluster will work similarly." + "Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/" ] }, { @@ -164,8 +162,7 @@ "source": [ "## Adding a dataset\n", "\n", - "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips).", - "The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip." + "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip." ] }, { @@ -206,7 +203,7 @@ "\n", "We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n", "\n", - "Our input data already resembles edge triplets, where each of the rows represents an edge from a source station to a target station. This allows us to use the arrows servers graph import from triplets functionality, which requires the following protocol:\n", + "Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server's \"graph import from triplets\" functionality, which requires the following protocol:\n", "\n", "1. Send an action `v2/graph.project.fromTriplets`\n", " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n", @@ -214,15 +211,29 @@ "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.\n", "4. Wait for the import process to reach the `DONE` state.\n", "\n", - "While the overall process is straight forward, we need to somehow tell Spark to" + "The most complicated step here is to run the actual data upload on each spark worker. We will use the `mapInArrow` function to run custom code on each spark worker. Each worker will receive a number of arrow record batches that we can directly send to the GDS session's Arrow server. " + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the `time` module and adding `time.sleep(1)` inside the `while` loop at the end of the cell.\n", + "\n" ] }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, + "cell_type": "markdown", + "source": "graph-analytics-serverless-spark.ipynb\n" + }, + { + "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ + "import time\n", "import pandas as pd\n", "import pyarrow\n", "from pyspark.sql import functions\n", @@ -244,30 +255,33 @@ "\n", "# Select the source target pairs from our source data\n", "source_target_pairs = spark.sql(\"\"\"\n", - " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n", - " FROM bike_trips\n", - "\"\"\")\n", + " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n", + " FROM bike_trips\n", + " \"\"\")\n", "\n", "# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.\n", "uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n", "\n", "# Aggregate the batch sizes to receive the row count.\n", - "uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\")).show()\n", + "aggregated_batch_sizes = uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\"))\n", + "\n", + "# Show the result. This will trigger the computation and thus run the data upload.\n", + "aggregated_batch_sizes.show()\n", "\n", "# 3. Finish the import process\n", "arrow_client.triplet_load_done(job_id)\n", "\n", "# 4. Wait for the import to finish\n", "while not arrow_client.job_status(job_id).succeeded():\n", - " pass\n", + " time.sleep(1)\n", "\n", "G = gds.v2.graph.get(graph_name)\n", "G" ] }, { - "cell_type": "markdown", "metadata": {}, + "cell_type": "markdown", "source": [ "## Running Algorithms\n", "\n", @@ -322,7 +336,7 @@ "# Optional: Repartition the data to make sure it is distributed equally\n", "result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n", "\n", - "result.show()" + "result.toPandas()" ] }, { diff --git a/src/graphdatascience/arrow_client/v2/api_types.py b/src/graphdatascience/arrow_client/v2/api_types.py index a0b736811..d55a9791c 100644 --- a/src/graphdatascience/arrow_client/v2/api_types.py +++ b/src/graphdatascience/arrow_client/v2/api_types.py @@ -34,7 +34,7 @@ def sub_tasks(self) -> str | None: return None def aborted(self) -> bool: - return self.status == "Aborted" + return self.status.lower() == "aborted" def succeeded(self) -> bool: return self.status.lower() == "done" From 8a3a49b6ff26da7bccd9d04411975b843a4ef40c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Wed, 10 Dec 2025 12:32:04 +0100 Subject: [PATCH 9/9] Fix imports --- .../graph-analytics-serverless-spark.ipynb | 15 +++++++----- .../arrow_client/v1/gds_arrow_client.py | 2 +- .../procedure_surface/arrow/error_handler.py | 24 +++++++++++++++++++ .../cypher/catalog_cypher_endpoints.py | 2 +- .../query_runner/session_query_runner.py | 2 +- .../arrow_client/V1/test_gds_arrow_client.py | 2 +- 6 files changed, 37 insertions(+), 10 deletions(-) create mode 100644 src/graphdatascience/procedure_surface/arrow/error_handler.py diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb index 2846f154a..7bc1de62f 100644 --- a/examples/graph-analytics-serverless-spark.ipynb +++ b/examples/graph-analytics-serverless-spark.ipynb @@ -215,25 +215,28 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the `time` module and adding `time.sleep(1)` inside the `while` loop at the end of the cell.\n", "\n" ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "graph-analytics-serverless-spark.ipynb\n" + "metadata": {}, + "source": [ + "graph-analytics-serverless-spark.ipynb\n" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "import time\n", + "\n", "import pandas as pd\n", "import pyarrow\n", "from pyspark.sql import functions\n", @@ -280,8 +283,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Running Algorithms\n", "\n", diff --git a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py index 860c23005..73cbfa37c 100644 --- a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py +++ b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py @@ -15,8 +15,8 @@ from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo from graphdatascience.arrow_client.v1.data_mapper_utils import deserialize_single +from ...procedure_surface.arrow.error_handler import handle_flight_error from ...semantic_version.semantic_version import SemanticVersion -from ..error_handler import handle_flight_error class GdsArrowClient: diff --git a/src/graphdatascience/procedure_surface/arrow/error_handler.py b/src/graphdatascience/procedure_surface/arrow/error_handler.py new file mode 100644 index 000000000..e9c6bd04d --- /dev/null +++ b/src/graphdatascience/procedure_surface/arrow/error_handler.py @@ -0,0 +1,24 @@ +import re + +from neo4j.exceptions import ClientError +from pyarrow import flight + + +def handle_flight_error(e: Exception) -> None: + if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError): + original_message = e.args[0] if len(e.args) > 0 else e.message + improved_message = original_message.replace( + "Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", "" + ) + improved_message = improved_message.replace( + "Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", "" + ) + improved_message = improved_message.replace( + "Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ", + "", + ) + improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message) + + raise flight.FlightServerError(improved_message) + else: + raise e diff --git a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py index cdb051b20..856b37b49 100644 --- a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py +++ b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py @@ -42,7 +42,7 @@ def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient def get(self, graph_name: str) -> GraphV2: if not self.list(graph_name): raise ValueError(f"A graph with name '{graph_name}' does not exist in the catalog.") - return get_graph(graph_name, self.cypher_runner) + return get_graph(graph_name, self._cypher_runner) def construct( self, diff --git a/src/graphdatascience/query_runner/session_query_runner.py b/src/graphdatascience/query_runner/session_query_runner.py index 1b4876042..046889889 100644 --- a/src/graphdatascience/query_runner/session_query_runner.py +++ b/src/graphdatascience/query_runner/session_query_runner.py @@ -13,8 +13,8 @@ from graphdatascience.query_runner.termination_flag import TerminationFlag from graphdatascience.server_version.server_version import ServerVersion -from ..arrow_client.error_handler import handle_flight_error from ..call_parameters import CallParameters +from ..procedure_surface.arrow.error_handler import handle_flight_error from ..session.dbms.protocol_resolver import ProtocolVersionResolver from .protocol.project_protocols import ProjectProtocol from .protocol.write_protocols import WriteProtocol diff --git a/tests/unit/arrow_client/V1/test_gds_arrow_client.py b/tests/unit/arrow_client/V1/test_gds_arrow_client.py index b8f90ee8a..ddab3819c 100644 --- a/tests/unit/arrow_client/V1/test_gds_arrow_client.py +++ b/tests/unit/arrow_client/V1/test_gds_arrow_client.py @@ -15,8 +15,8 @@ ) from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient -from graphdatascience.arrow_client.error_handler import handle_flight_error from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient +from graphdatascience.procedure_surface.arrow.error_handler import handle_flight_error from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.query_runner.arrow_info import ArrowInfo