-
Notifications
You must be signed in to change notification settings - Fork 54
Session + Spark example #1022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Session + Spark example #1022
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
b8cd8b5
Add `get` method to CatalogEndpoints
DarthMax ff04d02
Expose arrow_client method in AuraGraphDataScience
DarthMax c9aaf72
Create new example notebook for Sessions + Spark
DarthMax 01c79ec
Improve notebook docs
DarthMax 792309c
Fix documentation in GdsArrowClient
DarthMax ec66ddb
Remove setting java home
DarthMax eaaf560
Improve notebook documentation
DarthMax 9569339
Improve notebook descriptions and address PR comments
DarthMax 8a3a49b
Fix imports
DarthMax File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,374 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "tags": [ | ||
| "aura" | ||
| ] | ||
| }, | ||
| "source": [ | ||
| "# Aura Graph Analytics with Spark" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "colab_type": "text" | ||
| }, | ||
| "source": [ | ||
| "<a target=\"_blank\" href=\"https://colab.research.google.com/github/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb\">\n", | ||
| " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n", | ||
| "</a>" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless-spark.ipynb) in the Neo4j Graph Data Science Client Github repository.\n", | ||
| "\n", | ||
| "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session from within an Apache Spark cluster.\n", | ||
| "\n", | ||
| "We consider a graph of bicycle rentals, which we're using as a simple example to show how to project data from Spark to a GDS Session, run algorithms, and eventually return results back to Spark.\n", | ||
| "In this notebook we will focus on the interaction with Apache Spark, and will not cover all possible actions using GDS sessions. We refer to other Tutorials for additional details." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Prerequisites\n", | ||
| "\n", | ||
| "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "%pip install \"graphdatascience>=1.18a2\" python-dotenv \"pyspark[sql]\"" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from dotenv import load_dotenv\n", | ||
| "\n", | ||
| "# This allows to load required secrets from `.env` file in local directory\n", | ||
| "# This can include Aura API Credentials and Database Credentials.\n", | ||
| "# If file does not exist this is a noop.\n", | ||
| "load_dotenv(\"sessions.env\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Connecting to a Spark Session\n", | ||
| "\n", | ||
| "To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n", | ||
| "Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from pyspark.sql import SparkSession\n", | ||
| "\n", | ||
| "spark = SparkSession.builder.master(\"local[4]\").appName(\"GraphAnalytics\").getOrCreate()\n", | ||
| "\n", | ||
| "# Enable Arrow-based columnar data transfers\n", | ||
| "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Aura API credentials\n", | ||
| "\n", | ||
| "The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication)." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import os\n", | ||
| "\n", | ||
| "from graphdatascience.session import AuraAPICredentials, GdsSessions\n", | ||
| "\n", | ||
| "# you can also use AuraAPICredentials.from_env() to load credentials from environment variables\n", | ||
| "api_credentials = AuraAPICredentials(\n", | ||
| " client_id=os.environ[\"CLIENT_ID\"],\n", | ||
| " client_secret=os.environ[\"CLIENT_SECRET\"],\n", | ||
| " # If your account is a member of several projects, you must also specify the project ID to use\n", | ||
| " project_id=os.environ.get(\"PROJECT_ID\", None),\n", | ||
| ")\n", | ||
| "\n", | ||
| "sessions = GdsSessions(api_credentials=api_credentials)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Creating a new session\n", | ||
| "\n", | ||
| "A new session is created by calling `sessions.get_or_create()` with the following parameters:\n", | ||
| "\n", | ||
| "* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.\n", | ||
| "* The session memory. \n", | ||
| "* The cloud location.\n", | ||
| "* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.\n", | ||
| "\n", | ||
| "See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from datetime import timedelta\n", | ||
| "\n", | ||
| "from graphdatascience.session import CloudLocation, SessionMemory\n", | ||
| "\n", | ||
| "# Create a GDS session!\n", | ||
| "gds = sessions.get_or_create(\n", | ||
| " # we give it a representative name\n", | ||
| " session_name=\"bike_trips\",\n", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i read about bike trips just today in our logs. nice to learn its your workload 👀 |
||
| " memory=SessionMemory.m_2GB,\n", | ||
| " ttl=timedelta(minutes=30),\n", | ||
| " cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Adding a dataset\n", | ||
| "\n", | ||
| "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import io\n", | ||
| "import os\n", | ||
| "import zipfile\n", | ||
| "\n", | ||
| "import requests\n", | ||
| "\n", | ||
| "download_path = \"bike_trips_data\"\n", | ||
| "if not os.path.exists(download_path):\n", | ||
| " url = \"https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips\"\n", | ||
| "\n", | ||
| " response = requests.get(url)\n", | ||
| " response.raise_for_status()\n", | ||
| "\n", | ||
| " # Unzip the content\n", | ||
| " with zipfile.ZipFile(io.BytesIO(response.content)) as z:\n", | ||
| " z.extractall(download_path)\n", | ||
| "\n", | ||
| "df = spark.read.csv(download_path, header=True, inferSchema=True)\n", | ||
| "df.createOrReplaceTempView(\"bike_trips\")\n", | ||
| "df.limit(10).show()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Projecting Graphs\n", | ||
| "\n", | ||
| "Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.\n", | ||
| "\n", | ||
| "We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n", | ||
| "\n", | ||
| "Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server's \"graph import from triplets\" functionality, which requires the following protocol:\n", | ||
| "\n", | ||
| "1. Send an action `v2/graph.project.fromTriplets`\n", | ||
| " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n", | ||
| "2. Send the data in batches to the Arrow server.\n", | ||
| "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.\n", | ||
| "4. Wait for the import process to reach the `DONE` state.\n", | ||
| "\n", | ||
| "The most complicated step here is to run the actual data upload on each spark worker. We will use the `mapInArrow` function to run custom code on each spark worker. Each worker will receive a number of arrow record batches that we can directly send to the GDS session's Arrow server. " | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the `time` module and adding `time.sleep(1)` inside the `while` loop at the end of the cell.\n", | ||
| "\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "<llm-snippet-file>graph-analytics-serverless-spark.ipynb</llm-snippet-file>\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import time\n", | ||
| "\n", | ||
| "import pandas as pd\n", | ||
| "import pyarrow\n", | ||
| "from pyspark.sql import functions\n", | ||
| "\n", | ||
| "graph_name = \"bike_trips\"\n", | ||
| "\n", | ||
| "arrow_client = gds.arrow_client()\n", | ||
| "\n", | ||
| "# 1. Start the import process\n", | ||
| "job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)\n", | ||
| "\n", | ||
| "\n", | ||
| "# Define a function that receives an arrow batch and uploads it to the GDS session\n", | ||
| "def upload_batch(iterator):\n", | ||
| " for batch in iterator:\n", | ||
| " arrow_client.upload_triplets(job_id, [batch])\n", | ||
| " yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({\"batch_rows_imported\": [len(batch)]}))\n", | ||
| "\n", | ||
| "\n", | ||
| "# Select the source target pairs from our source data\n", | ||
| "source_target_pairs = spark.sql(\"\"\"\n", | ||
| " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n", | ||
| " FROM bike_trips\n", | ||
| " \"\"\")\n", | ||
| "\n", | ||
| "# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.\n", | ||
| "uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n", | ||
| "\n", | ||
| "# Aggregate the batch sizes to receive the row count.\n", | ||
| "aggregated_batch_sizes = uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\"))\n", | ||
| "\n", | ||
| "# Show the result. This will trigger the computation and thus run the data upload.\n", | ||
| "aggregated_batch_sizes.show()\n", | ||
| "\n", | ||
| "# 3. Finish the import process\n", | ||
| "arrow_client.triplet_load_done(job_id)\n", | ||
| "\n", | ||
| "# 4. Wait for the import to finish\n", | ||
| "while not arrow_client.job_status(job_id).succeeded():\n", | ||
| " time.sleep(1)\n", | ||
| "\n", | ||
| "G = gds.v2.graph.get(graph_name)\n", | ||
| "G" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Running Algorithms\n", | ||
| "\n", | ||
| "We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "print(\"Running PageRank ...\")\n", | ||
| "pr_result = gds.v2.page_rank.mutate(G, mutate_property=\"pagerank\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Sending the computation result back to Spark\n", | ||
| "\n", | ||
| "Once the computation is done, we might want to further use the result in Spark.\n", | ||
| "We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers.\n", | ||
| "Retrieving the data is a bit more complicated since we need some input DataFrame in order to trigger computations on the Spark workers.\n", | ||
| "We use a data range equal to the size of workers we have in our cluster as our driving table.\n", | ||
| "On the workers we will disregard the input and instead stream the computation data from the GDS Session." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# 1. Start the node property export on the GDS session\n", | ||
| "job_id = arrow_client.get_node_properties(G.name(), [\"pagerank\"])\n", | ||
| "\n", | ||
| "\n", | ||
| "# Define a function that receives data from the GDS Session and turns it into data batches\n", | ||
| "def retrieve_data(ignored):\n", | ||
| " stream_data = arrow_client.stream_job(G.name(), job_id)\n", | ||
| " batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)\n", | ||
| " for b in batches:\n", | ||
| " yield b\n", | ||
| "\n", | ||
| "\n", | ||
| "# Create DataFrame with a single column and one row per worker\n", | ||
| "input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF(\"batch_id\")\n", | ||
| "# 2. Stream the data from the GDS Session into the Spark workers\n", | ||
| "received_batches = input_partitions.mapInArrow(retrieve_data, \"nodeId long, pagerank double\")\n", | ||
| "# Optional: Repartition the data to make sure it is distributed equally\n", | ||
| "result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n", | ||
| "\n", | ||
| "result.toPandas()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Cleanup\n", | ||
| "\n", | ||
| "Now that we have finished our analysis, we can delete the GDS session and stop the Spark session.\n", | ||
| "\n", | ||
| "Deleting the GDS session will release all resources associated with it, and stop incurring costs." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "gds.delete()\n", | ||
| "spark.stop()" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "language_info": { | ||
| "name": "python" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 4 | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how important is this? I wonder if we should make a bigger deal about it in the paragraph, so that it isn't missed