diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb
new file mode 100644
index 000000000..7bc1de62f
--- /dev/null
+++ b/examples/graph-analytics-serverless-spark.ipynb
@@ -0,0 +1,374 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "tags": [
+ "aura"
+ ]
+ },
+ "source": [
+ "# Aura Graph Analytics with Spark"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "
\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless-spark.ipynb) in the Neo4j Graph Data Science Client Github repository.\n",
+ "\n",
+ "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session from within an Apache Spark cluster.\n",
+ "\n",
+ "We consider a graph of bicycle rentals, which we're using as a simple example to show how to project data from Spark to a GDS Session, run algorithms, and eventually return results back to Spark.\n",
+ "In this notebook we will focus on the interaction with Apache Spark, and will not cover all possible actions using GDS sessions. We refer to other Tutorials for additional details."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Prerequisites\n",
+ "\n",
+ "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install \"graphdatascience>=1.18a2\" python-dotenv \"pyspark[sql]\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dotenv import load_dotenv\n",
+ "\n",
+ "# This allows to load required secrets from `.env` file in local directory\n",
+ "# This can include Aura API Credentials and Database Credentials.\n",
+ "# If file does not exist this is a noop.\n",
+ "load_dotenv(\"sessions.env\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Connecting to a Spark Session\n",
+ "\n",
+ "To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n",
+ "Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pyspark.sql import SparkSession\n",
+ "\n",
+ "spark = SparkSession.builder.master(\"local[4]\").appName(\"GraphAnalytics\").getOrCreate()\n",
+ "\n",
+ "# Enable Arrow-based columnar data transfers\n",
+ "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Aura API credentials\n",
+ "\n",
+ "The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "from graphdatascience.session import AuraAPICredentials, GdsSessions\n",
+ "\n",
+ "# you can also use AuraAPICredentials.from_env() to load credentials from environment variables\n",
+ "api_credentials = AuraAPICredentials(\n",
+ " client_id=os.environ[\"CLIENT_ID\"],\n",
+ " client_secret=os.environ[\"CLIENT_SECRET\"],\n",
+ " # If your account is a member of several projects, you must also specify the project ID to use\n",
+ " project_id=os.environ.get(\"PROJECT_ID\", None),\n",
+ ")\n",
+ "\n",
+ "sessions = GdsSessions(api_credentials=api_credentials)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating a new session\n",
+ "\n",
+ "A new session is created by calling `sessions.get_or_create()` with the following parameters:\n",
+ "\n",
+ "* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.\n",
+ "* The session memory. \n",
+ "* The cloud location.\n",
+ "* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.\n",
+ "\n",
+ "See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datetime import timedelta\n",
+ "\n",
+ "from graphdatascience.session import CloudLocation, SessionMemory\n",
+ "\n",
+ "# Create a GDS session!\n",
+ "gds = sessions.get_or_create(\n",
+ " # we give it a representative name\n",
+ " session_name=\"bike_trips\",\n",
+ " memory=SessionMemory.m_2GB,\n",
+ " ttl=timedelta(minutes=30),\n",
+ " cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Adding a dataset\n",
+ "\n",
+ "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import io\n",
+ "import os\n",
+ "import zipfile\n",
+ "\n",
+ "import requests\n",
+ "\n",
+ "download_path = \"bike_trips_data\"\n",
+ "if not os.path.exists(download_path):\n",
+ " url = \"https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips\"\n",
+ "\n",
+ " response = requests.get(url)\n",
+ " response.raise_for_status()\n",
+ "\n",
+ " # Unzip the content\n",
+ " with zipfile.ZipFile(io.BytesIO(response.content)) as z:\n",
+ " z.extractall(download_path)\n",
+ "\n",
+ "df = spark.read.csv(download_path, header=True, inferSchema=True)\n",
+ "df.createOrReplaceTempView(\"bike_trips\")\n",
+ "df.limit(10).show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Projecting Graphs\n",
+ "\n",
+ "Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.\n",
+ "\n",
+ "We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n",
+ "\n",
+ "Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server's \"graph import from triplets\" functionality, which requires the following protocol:\n",
+ "\n",
+ "1. Send an action `v2/graph.project.fromTriplets`\n",
+ " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n",
+ "2. Send the data in batches to the Arrow server.\n",
+ "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.\n",
+ "4. Wait for the import process to reach the `DONE` state.\n",
+ "\n",
+ "The most complicated step here is to run the actual data upload on each spark worker. We will use the `mapInArrow` function to run custom code on each spark worker. Each worker will receive a number of arrow record batches that we can directly send to the GDS session's Arrow server. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the `time` module and adding `time.sleep(1)` inside the `while` loop at the end of the cell.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "graph-analytics-serverless-spark.ipynb\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "\n",
+ "import pandas as pd\n",
+ "import pyarrow\n",
+ "from pyspark.sql import functions\n",
+ "\n",
+ "graph_name = \"bike_trips\"\n",
+ "\n",
+ "arrow_client = gds.arrow_client()\n",
+ "\n",
+ "# 1. Start the import process\n",
+ "job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)\n",
+ "\n",
+ "\n",
+ "# Define a function that receives an arrow batch and uploads it to the GDS session\n",
+ "def upload_batch(iterator):\n",
+ " for batch in iterator:\n",
+ " arrow_client.upload_triplets(job_id, [batch])\n",
+ " yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({\"batch_rows_imported\": [len(batch)]}))\n",
+ "\n",
+ "\n",
+ "# Select the source target pairs from our source data\n",
+ "source_target_pairs = spark.sql(\"\"\"\n",
+ " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n",
+ " FROM bike_trips\n",
+ " \"\"\")\n",
+ "\n",
+ "# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.\n",
+ "uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n",
+ "\n",
+ "# Aggregate the batch sizes to receive the row count.\n",
+ "aggregated_batch_sizes = uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\"))\n",
+ "\n",
+ "# Show the result. This will trigger the computation and thus run the data upload.\n",
+ "aggregated_batch_sizes.show()\n",
+ "\n",
+ "# 3. Finish the import process\n",
+ "arrow_client.triplet_load_done(job_id)\n",
+ "\n",
+ "# 4. Wait for the import to finish\n",
+ "while not arrow_client.job_status(job_id).succeeded():\n",
+ " time.sleep(1)\n",
+ "\n",
+ "G = gds.v2.graph.get(graph_name)\n",
+ "G"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Running Algorithms\n",
+ "\n",
+ "We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Running PageRank ...\")\n",
+ "pr_result = gds.v2.page_rank.mutate(G, mutate_property=\"pagerank\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Sending the computation result back to Spark\n",
+ "\n",
+ "Once the computation is done, we might want to further use the result in Spark.\n",
+ "We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers.\n",
+ "Retrieving the data is a bit more complicated since we need some input DataFrame in order to trigger computations on the Spark workers.\n",
+ "We use a data range equal to the size of workers we have in our cluster as our driving table.\n",
+ "On the workers we will disregard the input and instead stream the computation data from the GDS Session."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 1. Start the node property export on the GDS session\n",
+ "job_id = arrow_client.get_node_properties(G.name(), [\"pagerank\"])\n",
+ "\n",
+ "\n",
+ "# Define a function that receives data from the GDS Session and turns it into data batches\n",
+ "def retrieve_data(ignored):\n",
+ " stream_data = arrow_client.stream_job(G.name(), job_id)\n",
+ " batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)\n",
+ " for b in batches:\n",
+ " yield b\n",
+ "\n",
+ "\n",
+ "# Create DataFrame with a single column and one row per worker\n",
+ "input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF(\"batch_id\")\n",
+ "# 2. Stream the data from the GDS Session into the Spark workers\n",
+ "received_batches = input_partitions.mapInArrow(retrieve_data, \"nodeId long, pagerank double\")\n",
+ "# Optional: Repartition the data to make sure it is distributed equally\n",
+ "result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n",
+ "\n",
+ "result.toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Cleanup\n",
+ "\n",
+ "Now that we have finished our analysis, we can delete the GDS session and stop the Spark session.\n",
+ "\n",
+ "Deleting the GDS session will release all resources associated with it, and stop incurring costs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gds.delete()\n",
+ "spark.stop()"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py
index a2f59511f..73cbfa37c 100644
--- a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py
+++ b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py
@@ -2,13 +2,11 @@
import json
import logging
-import re
from types import TracebackType
from typing import Any, Callable, Iterable, Type
import pandas
import pyarrow
-from neo4j.exceptions import ClientError
from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Table, chunked_array, flight
from pyarrow.types import is_dictionary
from pydantic import BaseModel
@@ -17,6 +15,7 @@
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo
from graphdatascience.arrow_client.v1.data_mapper_utils import deserialize_single
+from ...procedure_surface.arrow.error_handler import handle_flight_error
from ...semantic_version.semantic_version import SemanticVersion
@@ -515,7 +514,7 @@ def upload_batch(p: RecordBatch) -> None:
ack_stream.read()
progress_callback(partition.num_rows)
except Exception as e:
- GdsArrowClient.handle_flight_error(e)
+ handle_flight_error(e)
def _get_data(
self,
@@ -560,7 +559,7 @@ def _fetch_get_result(self, get: flight.FlightStreamReader) -> pandas.DataFrame:
try:
arrow_table = get.read_all()
except Exception as e:
- GdsArrowClient.handle_flight_error(e)
+ handle_flight_error(e)
arrow_table = self._sanitize_arrow_table(arrow_table)
if SemanticVersion.from_string(pandas.__version__) >= SemanticVersion(2, 0, 0):
return arrow_table.to_pandas(types_mapper=pandas.ArrowDtype) # type: ignore
@@ -615,26 +614,6 @@ def _decode_pyarrow_array(array: Array) -> Array:
else:
return array
- @staticmethod
- def handle_flight_error(e: Exception) -> None:
- if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError):
- original_message = e.args[0] if len(e.args) > 0 else e.message
- improved_message = original_message.replace(
- "Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
- )
- improved_message = improved_message.replace(
- "Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
- )
- improved_message = improved_message.replace(
- "Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ",
- "",
- )
- improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message)
-
- raise flight.FlightServerError(improved_message)
- else:
- raise e
-
class NodeLoadDoneResult(BaseModel):
name: str
diff --git a/src/graphdatascience/arrow_client/v2/api_types.py b/src/graphdatascience/arrow_client/v2/api_types.py
index 1b7436001..d55a9791c 100644
--- a/src/graphdatascience/arrow_client/v2/api_types.py
+++ b/src/graphdatascience/arrow_client/v2/api_types.py
@@ -34,13 +34,13 @@ def sub_tasks(self) -> str | None:
return None
def aborted(self) -> bool:
- return self.status == "Aborted"
+ return self.status.lower() == "aborted"
def succeeded(self) -> bool:
- return self.status == "Done"
+ return self.status.lower() == "done"
def running(self) -> bool:
- return self.status == "Running"
+ return self.status.lower() == "running"
class MutateResult(ArrowBaseModel):
diff --git a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py
index 9a6974b92..2085a7e0b 100644
--- a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py
+++ b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py
@@ -93,7 +93,7 @@ def get_nodes(
job_id: str | None = None,
) -> str:
"""
- Start a new export process to stream the nodes with the specified label and filter from the graph.
+ Start a new export process to stream the nodes that match the filter from the graph.
Parameters
----------
diff --git a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py
index 2181b9747..ed3ca8602 100644
--- a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py
+++ b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py
@@ -16,6 +16,22 @@
class CatalogEndpoints(ABC):
+ @abstractmethod
+ def get(self, graph_name: str) -> GraphV2:
+ """Retrieve a handle to a graph from the graph catalog.
+
+ Parameters
+ ----------
+ graph_name
+ The name of the graph.
+
+ Returns
+ -------
+ GraphV2
+ A handle to the graph.
+ """
+ pass
+
@abstractmethod
def construct(
self,
diff --git a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py
index 89131d4ca..5295f28ab 100644
--- a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py
+++ b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py
@@ -56,6 +56,11 @@ def __init__(
protocol_version = ProtocolVersionResolver(query_runner).resolve()
self._project_protocol = ProjectProtocol.select(protocol_version)
+ def get(self, graph_name: str) -> GraphV2:
+ if not self.list(graph_name):
+ raise ValueError(f"A graph with name '{graph_name}' does not exist in the catalog.")
+ return get_graph(graph_name, self._arrow_client)
+
def project(
self,
graph_name: str,
diff --git a/src/graphdatascience/procedure_surface/arrow/error_handler.py b/src/graphdatascience/procedure_surface/arrow/error_handler.py
new file mode 100644
index 000000000..e9c6bd04d
--- /dev/null
+++ b/src/graphdatascience/procedure_surface/arrow/error_handler.py
@@ -0,0 +1,24 @@
+import re
+
+from neo4j.exceptions import ClientError
+from pyarrow import flight
+
+
+def handle_flight_error(e: Exception) -> None:
+ if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError):
+ original_message = e.args[0] if len(e.args) > 0 else e.message
+ improved_message = original_message.replace(
+ "Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
+ )
+ improved_message = improved_message.replace(
+ "Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
+ )
+ improved_message = improved_message.replace(
+ "Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ",
+ "",
+ )
+ improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message)
+
+ raise flight.FlightServerError(improved_message)
+ else:
+ raise e
diff --git a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py
index 0c217192e..856b37b49 100644
--- a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py
+++ b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py
@@ -39,6 +39,11 @@ def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient
self._cypher_runner = cypher_runner
self._arrow_client = arrow_client
+ def get(self, graph_name: str) -> GraphV2:
+ if not self.list(graph_name):
+ raise ValueError(f"A graph with name '{graph_name}' does not exist in the catalog.")
+ return get_graph(graph_name, self._cypher_runner)
+
def construct(
self,
graph_name: str,
diff --git a/src/graphdatascience/query_runner/session_query_runner.py b/src/graphdatascience/query_runner/session_query_runner.py
index 6c693dce1..046889889 100644
--- a/src/graphdatascience/query_runner/session_query_runner.py
+++ b/src/graphdatascience/query_runner/session_query_runner.py
@@ -6,7 +6,7 @@
from pandas import DataFrame
-from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
+from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
from graphdatascience.query_runner.graph_constructor import GraphConstructor
from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger
from graphdatascience.query_runner.query_mode import QueryMode
@@ -14,6 +14,7 @@
from graphdatascience.server_version.server_version import ServerVersion
from ..call_parameters import CallParameters
+from ..procedure_surface.arrow.error_handler import handle_flight_error
from ..session.dbms.protocol_resolver import ProtocolVersionResolver
from .protocol.project_protocols import ProjectProtocol
from .protocol.write_protocols import WriteProtocol
@@ -183,7 +184,7 @@ def run_projection() -> DataFrame:
else:
return run_projection()
except Exception as e:
- GdsArrowClient.handle_flight_error(e)
+ handle_flight_error(e)
raise e # above should already raise
def _remote_write_back(
diff --git a/src/graphdatascience/session/aura_graph_data_science.py b/src/graphdatascience/session/aura_graph_data_science.py
index a94932ccb..47397deda 100644
--- a/src/graphdatascience/session/aura_graph_data_science.py
+++ b/src/graphdatascience/session/aura_graph_data_science.py
@@ -5,7 +5,7 @@
from pandas import DataFrame
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
-from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
+from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
from graphdatascience.call_builder import IndirectCallBuilder
from graphdatascience.endpoints import (
AlphaRemoteEndpoints,
@@ -94,6 +94,7 @@ def create(
v2_endpoints=SessionV2Endpoints(
session_auth_arrow_client, db_bolt_query_runner, show_progress=show_progress
),
+ authenticated_arrow_client=session_auth_arrow_client,
)
else:
standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner)
@@ -102,6 +103,7 @@ def create(
delete_fn=delete_fn,
gds_version=gds_version,
v2_endpoints=SessionV2Endpoints(session_auth_arrow_client, None, show_progress=show_progress),
+ authenticated_arrow_client=session_auth_arrow_client,
)
def __init__(
@@ -110,11 +112,13 @@ def __init__(
delete_fn: Callable[[], bool],
gds_version: ServerVersion,
v2_endpoints: SessionV2Endpoints,
+ authenticated_arrow_client: AuthenticatedArrowClient,
):
self._query_runner = query_runner
self._delete_fn = delete_fn
self._server_version = gds_version
self._v2_endpoints = v2_endpoints
+ self._authenticated_arrow_client = authenticated_arrow_client
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
@@ -177,6 +181,18 @@ def v2(self) -> SessionV2Endpoints:
def __getattr__(self, attr: str) -> IndirectCallBuilder:
return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version)
+ def arrow_client(self) -> GdsArrowClient:
+ """
+ Returns a GdsArrowClient that is authenticated to communicate with the Aura Graph Analytics Session.
+ This client can be used to get direct access to the specific session's Arrow Flight server.
+
+ Returns:
+ A GdsArrowClient
+ -------
+
+ """
+ return GdsArrowClient(self._authenticated_arrow_client)
+
def set_database(self, database: str) -> None:
"""
Set the database which cypher queries are run against.
diff --git a/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py b/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py
index 9ed81668a..3039dc281 100644
--- a/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py
+++ b/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py
@@ -15,6 +15,7 @@ def gds(arrow_client: AuthenticatedArrowClient, db_query_runner: QueryRunner) ->
delete_fn=lambda: True,
gds_version=ServerVersion.from_string("1.2.3"),
v2_endpoints=SessionV2Endpoints(arrow_client, db_query_runner),
+ authenticated_arrow_client=arrow_client,
)
diff --git a/tests/unit/arrow_client/V1/test_gds_arrow_client.py b/tests/unit/arrow_client/V1/test_gds_arrow_client.py
index b2f777b00..ddab3819c 100644
--- a/tests/unit/arrow_client/V1/test_gds_arrow_client.py
+++ b/tests/unit/arrow_client/V1/test_gds_arrow_client.py
@@ -16,6 +16,7 @@
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
+from graphdatascience.procedure_surface.arrow.error_handler import handle_flight_error
from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication
from graphdatascience.query_runner.arrow_info import ArrowInfo
@@ -397,7 +398,7 @@ def test_handle_flight_error() -> None:
FlightServerError,
match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.",
):
- GdsArrowClient.handle_flight_error(
+ handle_flight_error(
FlightServerError(
'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal'
)
@@ -407,7 +408,7 @@ def test_handle_flight_error() -> None:
FlightServerError,
match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"),
):
- GdsArrowClient.handle_flight_error(
+ handle_flight_error(
FlightServerError(
"FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"
)
diff --git a/tests/unit/session/test_aura_graph_data_science.py b/tests/unit/session/test_aura_graph_data_science.py
index 28ba62f31..7183c30ed 100644
--- a/tests/unit/session/test_aura_graph_data_science.py
+++ b/tests/unit/session/test_aura_graph_data_science.py
@@ -14,6 +14,7 @@ def test_remote_projection_configuration(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
g = gds.graph.project(
@@ -50,6 +51,7 @@ def test_remote_projection_defaults(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
g = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -78,6 +80,7 @@ def test_remote_algo_write(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -99,6 +102,7 @@ def test_remote_algo_write_configuration(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -125,6 +129,7 @@ def test_remote_graph_write(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -149,6 +154,7 @@ def test_remote_graph_write_configuration(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -176,6 +182,7 @@ def test_run_cypher_write(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.WRITE, database="bar", retryable=True)
@@ -193,6 +200,7 @@ def test_run_cypher_read(mocker: MockerFixture) -> None:
delete_fn=lambda: True,
gds_version=v,
v2_endpoints=mocker.Mock(),
+ authenticated_arrow_client=mocker.Mock(),
)
gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.READ, retryable=False)