diff --git a/examples/graph-analytics-serverless-spark.ipynb b/examples/graph-analytics-serverless-spark.ipynb new file mode 100644 index 000000000..7bc1de62f --- /dev/null +++ b/examples/graph-analytics-serverless-spark.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "aura" + ] + }, + "source": [ + "# Aura Graph Analytics with Spark" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless-spark.ipynb) in the Neo4j Graph Data Science Client Github repository.\n", + "\n", + "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session from within an Apache Spark cluster.\n", + "\n", + "We consider a graph of bicycle rentals, which we're using as a simple example to show how to project data from Spark to a GDS Session, run algorithms, and eventually return results back to Spark.\n", + "In this notebook we will focus on the interaction with Apache Spark, and will not cover all possible actions using GDS sessions. We refer to other Tutorials for additional details." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"graphdatascience>=1.18a2\" python-dotenv \"pyspark[sql]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "\n", + "# This allows to load required secrets from `.env` file in local directory\n", + "# This can include Aura API Credentials and Database Credentials.\n", + "# If file does not exist this is a noop.\n", + "load_dotenv(\"sessions.env\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Connecting to a Spark Session\n", + "\n", + "To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n", + "Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import SparkSession\n", + "\n", + "spark = SparkSession.builder.master(\"local[4]\").appName(\"GraphAnalytics\").getOrCreate()\n", + "\n", + "# Enable Arrow-based columnar data transfers\n", + "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Aura API credentials\n", + "\n", + "The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from graphdatascience.session import AuraAPICredentials, GdsSessions\n", + "\n", + "# you can also use AuraAPICredentials.from_env() to load credentials from environment variables\n", + "api_credentials = AuraAPICredentials(\n", + " client_id=os.environ[\"CLIENT_ID\"],\n", + " client_secret=os.environ[\"CLIENT_SECRET\"],\n", + " # If your account is a member of several projects, you must also specify the project ID to use\n", + " project_id=os.environ.get(\"PROJECT_ID\", None),\n", + ")\n", + "\n", + "sessions = GdsSessions(api_credentials=api_credentials)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a new session\n", + "\n", + "A new session is created by calling `sessions.get_or_create()` with the following parameters:\n", + "\n", + "* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.\n", + "* The session memory. \n", + "* The cloud location.\n", + "* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.\n", + "\n", + "See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import timedelta\n", + "\n", + "from graphdatascience.session import CloudLocation, SessionMemory\n", + "\n", + "# Create a GDS session!\n", + "gds = sessions.get_or_create(\n", + " # we give it a representative name\n", + " session_name=\"bike_trips\",\n", + " memory=SessionMemory.m_2GB,\n", + " ttl=timedelta(minutes=30),\n", + " cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding a dataset\n", + "\n", + "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import os\n", + "import zipfile\n", + "\n", + "import requests\n", + "\n", + "download_path = \"bike_trips_data\"\n", + "if not os.path.exists(download_path):\n", + " url = \"https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips\"\n", + "\n", + " response = requests.get(url)\n", + " response.raise_for_status()\n", + "\n", + " # Unzip the content\n", + " with zipfile.ZipFile(io.BytesIO(response.content)) as z:\n", + " z.extractall(download_path)\n", + "\n", + "df = spark.read.csv(download_path, header=True, inferSchema=True)\n", + "df.createOrReplaceTempView(\"bike_trips\")\n", + "df.limit(10).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Projecting Graphs\n", + "\n", + "Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.\n", + "\n", + "We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n", + "\n", + "Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server's \"graph import from triplets\" functionality, which requires the following protocol:\n", + "\n", + "1. Send an action `v2/graph.project.fromTriplets`\n", + " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n", + "2. Send the data in batches to the Arrow server.\n", + "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.\n", + "4. Wait for the import process to reach the `DONE` state.\n", + "\n", + "The most complicated step here is to run the actual data upload on each spark worker. We will use the `mapInArrow` function to run custom code on each spark worker. Each worker will receive a number of arrow record batches that we can directly send to the GDS session's Arrow server. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the `time` module and adding `time.sleep(1)` inside the `while` loop at the end of the cell.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "graph-analytics-serverless-spark.ipynb\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "import pandas as pd\n", + "import pyarrow\n", + "from pyspark.sql import functions\n", + "\n", + "graph_name = \"bike_trips\"\n", + "\n", + "arrow_client = gds.arrow_client()\n", + "\n", + "# 1. Start the import process\n", + "job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)\n", + "\n", + "\n", + "# Define a function that receives an arrow batch and uploads it to the GDS session\n", + "def upload_batch(iterator):\n", + " for batch in iterator:\n", + " arrow_client.upload_triplets(job_id, [batch])\n", + " yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({\"batch_rows_imported\": [len(batch)]}))\n", + "\n", + "\n", + "# Select the source target pairs from our source data\n", + "source_target_pairs = spark.sql(\"\"\"\n", + " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n", + " FROM bike_trips\n", + " \"\"\")\n", + "\n", + "# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.\n", + "uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n", + "\n", + "# Aggregate the batch sizes to receive the row count.\n", + "aggregated_batch_sizes = uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\"))\n", + "\n", + "# Show the result. This will trigger the computation and thus run the data upload.\n", + "aggregated_batch_sizes.show()\n", + "\n", + "# 3. Finish the import process\n", + "arrow_client.triplet_load_done(job_id)\n", + "\n", + "# 4. Wait for the import to finish\n", + "while not arrow_client.job_status(job_id).succeeded():\n", + " time.sleep(1)\n", + "\n", + "G = gds.v2.graph.get(graph_name)\n", + "G" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running Algorithms\n", + "\n", + "We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Running PageRank ...\")\n", + "pr_result = gds.v2.page_rank.mutate(G, mutate_property=\"pagerank\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sending the computation result back to Spark\n", + "\n", + "Once the computation is done, we might want to further use the result in Spark.\n", + "We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers.\n", + "Retrieving the data is a bit more complicated since we need some input DataFrame in order to trigger computations on the Spark workers.\n", + "We use a data range equal to the size of workers we have in our cluster as our driving table.\n", + "On the workers we will disregard the input and instead stream the computation data from the GDS Session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. Start the node property export on the GDS session\n", + "job_id = arrow_client.get_node_properties(G.name(), [\"pagerank\"])\n", + "\n", + "\n", + "# Define a function that receives data from the GDS Session and turns it into data batches\n", + "def retrieve_data(ignored):\n", + " stream_data = arrow_client.stream_job(G.name(), job_id)\n", + " batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)\n", + " for b in batches:\n", + " yield b\n", + "\n", + "\n", + "# Create DataFrame with a single column and one row per worker\n", + "input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF(\"batch_id\")\n", + "# 2. Stream the data from the GDS Session into the Spark workers\n", + "received_batches = input_partitions.mapInArrow(retrieve_data, \"nodeId long, pagerank double\")\n", + "# Optional: Repartition the data to make sure it is distributed equally\n", + "result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n", + "\n", + "result.toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup\n", + "\n", + "Now that we have finished our analysis, we can delete the GDS session and stop the Spark session.\n", + "\n", + "Deleting the GDS session will release all resources associated with it, and stop incurring costs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.delete()\n", + "spark.stop()" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py index a2f59511f..73cbfa37c 100644 --- a/src/graphdatascience/arrow_client/v1/gds_arrow_client.py +++ b/src/graphdatascience/arrow_client/v1/gds_arrow_client.py @@ -2,13 +2,11 @@ import json import logging -import re from types import TracebackType from typing import Any, Callable, Iterable, Type import pandas import pyarrow -from neo4j.exceptions import ClientError from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Table, chunked_array, flight from pyarrow.types import is_dictionary from pydantic import BaseModel @@ -17,6 +15,7 @@ from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo from graphdatascience.arrow_client.v1.data_mapper_utils import deserialize_single +from ...procedure_surface.arrow.error_handler import handle_flight_error from ...semantic_version.semantic_version import SemanticVersion @@ -515,7 +514,7 @@ def upload_batch(p: RecordBatch) -> None: ack_stream.read() progress_callback(partition.num_rows) except Exception as e: - GdsArrowClient.handle_flight_error(e) + handle_flight_error(e) def _get_data( self, @@ -560,7 +559,7 @@ def _fetch_get_result(self, get: flight.FlightStreamReader) -> pandas.DataFrame: try: arrow_table = get.read_all() except Exception as e: - GdsArrowClient.handle_flight_error(e) + handle_flight_error(e) arrow_table = self._sanitize_arrow_table(arrow_table) if SemanticVersion.from_string(pandas.__version__) >= SemanticVersion(2, 0, 0): return arrow_table.to_pandas(types_mapper=pandas.ArrowDtype) # type: ignore @@ -615,26 +614,6 @@ def _decode_pyarrow_array(array: Array) -> Array: else: return array - @staticmethod - def handle_flight_error(e: Exception) -> None: - if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError): - original_message = e.args[0] if len(e.args) > 0 else e.message - improved_message = original_message.replace( - "Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", "" - ) - improved_message = improved_message.replace( - "Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", "" - ) - improved_message = improved_message.replace( - "Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ", - "", - ) - improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message) - - raise flight.FlightServerError(improved_message) - else: - raise e - class NodeLoadDoneResult(BaseModel): name: str diff --git a/src/graphdatascience/arrow_client/v2/api_types.py b/src/graphdatascience/arrow_client/v2/api_types.py index 1b7436001..d55a9791c 100644 --- a/src/graphdatascience/arrow_client/v2/api_types.py +++ b/src/graphdatascience/arrow_client/v2/api_types.py @@ -34,13 +34,13 @@ def sub_tasks(self) -> str | None: return None def aborted(self) -> bool: - return self.status == "Aborted" + return self.status.lower() == "aborted" def succeeded(self) -> bool: - return self.status == "Done" + return self.status.lower() == "done" def running(self) -> bool: - return self.status == "Running" + return self.status.lower() == "running" class MutateResult(ArrowBaseModel): diff --git a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py index 9a6974b92..2085a7e0b 100644 --- a/src/graphdatascience/arrow_client/v2/gds_arrow_client.py +++ b/src/graphdatascience/arrow_client/v2/gds_arrow_client.py @@ -93,7 +93,7 @@ def get_nodes( job_id: str | None = None, ) -> str: """ - Start a new export process to stream the nodes with the specified label and filter from the graph. + Start a new export process to stream the nodes that match the filter from the graph. Parameters ---------- diff --git a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py index 2181b9747..ed3ca8602 100644 --- a/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py +++ b/src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py @@ -16,6 +16,22 @@ class CatalogEndpoints(ABC): + @abstractmethod + def get(self, graph_name: str) -> GraphV2: + """Retrieve a handle to a graph from the graph catalog. + + Parameters + ---------- + graph_name + The name of the graph. + + Returns + ------- + GraphV2 + A handle to the graph. + """ + pass + @abstractmethod def construct( self, diff --git a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py index 89131d4ca..5295f28ab 100644 --- a/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py +++ b/src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py @@ -56,6 +56,11 @@ def __init__( protocol_version = ProtocolVersionResolver(query_runner).resolve() self._project_protocol = ProjectProtocol.select(protocol_version) + def get(self, graph_name: str) -> GraphV2: + if not self.list(graph_name): + raise ValueError(f"A graph with name '{graph_name}' does not exist in the catalog.") + return get_graph(graph_name, self._arrow_client) + def project( self, graph_name: str, diff --git a/src/graphdatascience/procedure_surface/arrow/error_handler.py b/src/graphdatascience/procedure_surface/arrow/error_handler.py new file mode 100644 index 000000000..e9c6bd04d --- /dev/null +++ b/src/graphdatascience/procedure_surface/arrow/error_handler.py @@ -0,0 +1,24 @@ +import re + +from neo4j.exceptions import ClientError +from pyarrow import flight + + +def handle_flight_error(e: Exception) -> None: + if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError): + original_message = e.args[0] if len(e.args) > 0 else e.message + improved_message = original_message.replace( + "Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", "" + ) + improved_message = improved_message.replace( + "Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", "" + ) + improved_message = improved_message.replace( + "Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ", + "", + ) + improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message) + + raise flight.FlightServerError(improved_message) + else: + raise e diff --git a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py index 0c217192e..856b37b49 100644 --- a/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py +++ b/src/graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py @@ -39,6 +39,11 @@ def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient self._cypher_runner = cypher_runner self._arrow_client = arrow_client + def get(self, graph_name: str) -> GraphV2: + if not self.list(graph_name): + raise ValueError(f"A graph with name '{graph_name}' does not exist in the catalog.") + return get_graph(graph_name, self._cypher_runner) + def construct( self, graph_name: str, diff --git a/src/graphdatascience/query_runner/session_query_runner.py b/src/graphdatascience/query_runner/session_query_runner.py index 6c693dce1..046889889 100644 --- a/src/graphdatascience/query_runner/session_query_runner.py +++ b/src/graphdatascience/query_runner/session_query_runner.py @@ -6,7 +6,7 @@ from pandas import DataFrame -from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient +from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.graph_constructor import GraphConstructor from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger from graphdatascience.query_runner.query_mode import QueryMode @@ -14,6 +14,7 @@ from graphdatascience.server_version.server_version import ServerVersion from ..call_parameters import CallParameters +from ..procedure_surface.arrow.error_handler import handle_flight_error from ..session.dbms.protocol_resolver import ProtocolVersionResolver from .protocol.project_protocols import ProjectProtocol from .protocol.write_protocols import WriteProtocol @@ -183,7 +184,7 @@ def run_projection() -> DataFrame: else: return run_projection() except Exception as e: - GdsArrowClient.handle_flight_error(e) + handle_flight_error(e) raise e # above should already raise def _remote_write_back( diff --git a/src/graphdatascience/session/aura_graph_data_science.py b/src/graphdatascience/session/aura_graph_data_science.py index a94932ccb..47397deda 100644 --- a/src/graphdatascience/session/aura_graph_data_science.py +++ b/src/graphdatascience/session/aura_graph_data_science.py @@ -5,7 +5,7 @@ from pandas import DataFrame from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient -from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient +from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient from graphdatascience.call_builder import IndirectCallBuilder from graphdatascience.endpoints import ( AlphaRemoteEndpoints, @@ -94,6 +94,7 @@ def create( v2_endpoints=SessionV2Endpoints( session_auth_arrow_client, db_bolt_query_runner, show_progress=show_progress ), + authenticated_arrow_client=session_auth_arrow_client, ) else: standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner) @@ -102,6 +103,7 @@ def create( delete_fn=delete_fn, gds_version=gds_version, v2_endpoints=SessionV2Endpoints(session_auth_arrow_client, None, show_progress=show_progress), + authenticated_arrow_client=session_auth_arrow_client, ) def __init__( @@ -110,11 +112,13 @@ def __init__( delete_fn: Callable[[], bool], gds_version: ServerVersion, v2_endpoints: SessionV2Endpoints, + authenticated_arrow_client: AuthenticatedArrowClient, ): self._query_runner = query_runner self._delete_fn = delete_fn self._server_version = gds_version self._v2_endpoints = v2_endpoints + self._authenticated_arrow_client = authenticated_arrow_client super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) @@ -177,6 +181,18 @@ def v2(self) -> SessionV2Endpoints: def __getattr__(self, attr: str) -> IndirectCallBuilder: return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version) + def arrow_client(self) -> GdsArrowClient: + """ + Returns a GdsArrowClient that is authenticated to communicate with the Aura Graph Analytics Session. + This client can be used to get direct access to the specific session's Arrow Flight server. + + Returns: + A GdsArrowClient + ------- + + """ + return GdsArrowClient(self._authenticated_arrow_client) + def set_database(self, database: str) -> None: """ Set the database which cypher queries are run against. diff --git a/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py b/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py index 9ed81668a..3039dc281 100644 --- a/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py +++ b/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py @@ -15,6 +15,7 @@ def gds(arrow_client: AuthenticatedArrowClient, db_query_runner: QueryRunner) -> delete_fn=lambda: True, gds_version=ServerVersion.from_string("1.2.3"), v2_endpoints=SessionV2Endpoints(arrow_client, db_query_runner), + authenticated_arrow_client=arrow_client, ) diff --git a/tests/unit/arrow_client/V1/test_gds_arrow_client.py b/tests/unit/arrow_client/V1/test_gds_arrow_client.py index b2f777b00..ddab3819c 100644 --- a/tests/unit/arrow_client/V1/test_gds_arrow_client.py +++ b/tests/unit/arrow_client/V1/test_gds_arrow_client.py @@ -16,6 +16,7 @@ from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient +from graphdatascience.procedure_surface.arrow.error_handler import handle_flight_error from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.query_runner.arrow_info import ArrowInfo @@ -397,7 +398,7 @@ def test_handle_flight_error() -> None: FlightServerError, match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.", ): - GdsArrowClient.handle_flight_error( + handle_flight_error( FlightServerError( 'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal' ) @@ -407,7 +408,7 @@ def test_handle_flight_error() -> None: FlightServerError, match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"), ): - GdsArrowClient.handle_flight_error( + handle_flight_error( FlightServerError( "FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]" ) diff --git a/tests/unit/session/test_aura_graph_data_science.py b/tests/unit/session/test_aura_graph_data_science.py index 28ba62f31..7183c30ed 100644 --- a/tests/unit/session/test_aura_graph_data_science.py +++ b/tests/unit/session/test_aura_graph_data_science.py @@ -14,6 +14,7 @@ def test_remote_projection_configuration(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) g = gds.graph.project( @@ -50,6 +51,7 @@ def test_remote_projection_defaults(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) g = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -78,6 +80,7 @@ def test_remote_algo_write(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -99,6 +102,7 @@ def test_remote_algo_write_configuration(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -125,6 +129,7 @@ def test_remote_graph_write(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -149,6 +154,7 @@ def test_remote_graph_write_configuration(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)") @@ -176,6 +182,7 @@ def test_run_cypher_write(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.WRITE, database="bar", retryable=True) @@ -193,6 +200,7 @@ def test_run_cypher_read(mocker: MockerFixture) -> None: delete_fn=lambda: True, gds_version=v, v2_endpoints=mocker.Mock(), + authenticated_arrow_client=mocker.Mock(), ) gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.READ, retryable=False)