Skip to content

Commit 3dfae11

Browse files
authored
Merge pull request #1022 from DarthMax/sessions_expose_arrow_client
Session + Spark example
2 parents e8845eb + 8a3a49b commit 3dfae11

File tree

13 files changed

+463
-33
lines changed

13 files changed

+463
-33
lines changed
Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"tags": [
7+
"aura"
8+
]
9+
},
10+
"source": [
11+
"# Aura Graph Analytics with Spark"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {
17+
"colab_type": "text"
18+
},
19+
"source": [
20+
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb\">\n",
21+
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
22+
"</a>"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"metadata": {},
28+
"source": [
29+
"This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless-spark.ipynb) in the Neo4j Graph Data Science Client Github repository.\n",
30+
"\n",
31+
"The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session from within an Apache Spark cluster.\n",
32+
"\n",
33+
"We consider a graph of bicycle rentals, which we're using as a simple example to show how to project data from Spark to a GDS Session, run algorithms, and eventually return results back to Spark.\n",
34+
"In this notebook we will focus on the interaction with Apache Spark, and will not cover all possible actions using GDS sessions. We refer to other Tutorials for additional details."
35+
]
36+
},
37+
{
38+
"cell_type": "markdown",
39+
"metadata": {},
40+
"source": [
41+
"## Prerequisites\n",
42+
"\n",
43+
"We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`."
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"%pip install \"graphdatascience>=1.18a2\" python-dotenv \"pyspark[sql]\""
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": null,
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"from dotenv import load_dotenv\n",
62+
"\n",
63+
"# This allows to load required secrets from `.env` file in local directory\n",
64+
"# This can include Aura API Credentials and Database Credentials.\n",
65+
"# If file does not exist this is a noop.\n",
66+
"load_dotenv(\"sessions.env\")"
67+
]
68+
},
69+
{
70+
"cell_type": "markdown",
71+
"metadata": {},
72+
"source": [
73+
"### Connecting to a Spark Session\n",
74+
"\n",
75+
"To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n",
76+
"Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"from pyspark.sql import SparkSession\n",
86+
"\n",
87+
"spark = SparkSession.builder.master(\"local[4]\").appName(\"GraphAnalytics\").getOrCreate()\n",
88+
"\n",
89+
"# Enable Arrow-based columnar data transfers\n",
90+
"spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")"
91+
]
92+
},
93+
{
94+
"cell_type": "markdown",
95+
"metadata": {},
96+
"source": [
97+
"## Aura API credentials\n",
98+
"\n",
99+
"The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication)."
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"import os\n",
109+
"\n",
110+
"from graphdatascience.session import AuraAPICredentials, GdsSessions\n",
111+
"\n",
112+
"# you can also use AuraAPICredentials.from_env() to load credentials from environment variables\n",
113+
"api_credentials = AuraAPICredentials(\n",
114+
" client_id=os.environ[\"CLIENT_ID\"],\n",
115+
" client_secret=os.environ[\"CLIENT_SECRET\"],\n",
116+
" # If your account is a member of several projects, you must also specify the project ID to use\n",
117+
" project_id=os.environ.get(\"PROJECT_ID\", None),\n",
118+
")\n",
119+
"\n",
120+
"sessions = GdsSessions(api_credentials=api_credentials)"
121+
]
122+
},
123+
{
124+
"cell_type": "markdown",
125+
"metadata": {},
126+
"source": [
127+
"## Creating a new session\n",
128+
"\n",
129+
"A new session is created by calling `sessions.get_or_create()` with the following parameters:\n",
130+
"\n",
131+
"* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.\n",
132+
"* The session memory. \n",
133+
"* The cloud location.\n",
134+
"* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.\n",
135+
"\n",
136+
"See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters."
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"from datetime import timedelta\n",
146+
"\n",
147+
"from graphdatascience.session import CloudLocation, SessionMemory\n",
148+
"\n",
149+
"# Create a GDS session!\n",
150+
"gds = sessions.get_or_create(\n",
151+
" # we give it a representative name\n",
152+
" session_name=\"bike_trips\",\n",
153+
" memory=SessionMemory.m_2GB,\n",
154+
" ttl=timedelta(minutes=30),\n",
155+
" cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n",
156+
")"
157+
]
158+
},
159+
{
160+
"cell_type": "markdown",
161+
"metadata": {},
162+
"source": [
163+
"## Adding a dataset\n",
164+
"\n",
165+
"As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip."
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"metadata": {},
172+
"outputs": [],
173+
"source": [
174+
"import io\n",
175+
"import os\n",
176+
"import zipfile\n",
177+
"\n",
178+
"import requests\n",
179+
"\n",
180+
"download_path = \"bike_trips_data\"\n",
181+
"if not os.path.exists(download_path):\n",
182+
" url = \"https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips\"\n",
183+
"\n",
184+
" response = requests.get(url)\n",
185+
" response.raise_for_status()\n",
186+
"\n",
187+
" # Unzip the content\n",
188+
" with zipfile.ZipFile(io.BytesIO(response.content)) as z:\n",
189+
" z.extractall(download_path)\n",
190+
"\n",
191+
"df = spark.read.csv(download_path, header=True, inferSchema=True)\n",
192+
"df.createOrReplaceTempView(\"bike_trips\")\n",
193+
"df.limit(10).show()"
194+
]
195+
},
196+
{
197+
"cell_type": "markdown",
198+
"metadata": {},
199+
"source": [
200+
"## Projecting Graphs\n",
201+
"\n",
202+
"Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.\n",
203+
"\n",
204+
"We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n",
205+
"\n",
206+
"Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server's \"graph import from triplets\" functionality, which requires the following protocol:\n",
207+
"\n",
208+
"1. Send an action `v2/graph.project.fromTriplets`\n",
209+
" This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n",
210+
"2. Send the data in batches to the Arrow server.\n",
211+
"3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.\n",
212+
"4. Wait for the import process to reach the `DONE` state.\n",
213+
"\n",
214+
"The most complicated step here is to run the actual data upload on each spark worker. We will use the `mapInArrow` function to run custom code on each spark worker. Each worker will receive a number of arrow record batches that we can directly send to the GDS session's Arrow server. "
215+
]
216+
},
217+
{
218+
"cell_type": "markdown",
219+
"metadata": {},
220+
"source": [
221+
"The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the `time` module and adding `time.sleep(1)` inside the `while` loop at the end of the cell.\n",
222+
"\n"
223+
]
224+
},
225+
{
226+
"cell_type": "markdown",
227+
"metadata": {},
228+
"source": [
229+
"<llm-snippet-file>graph-analytics-serverless-spark.ipynb</llm-snippet-file>\n"
230+
]
231+
},
232+
{
233+
"cell_type": "code",
234+
"execution_count": null,
235+
"metadata": {},
236+
"outputs": [],
237+
"source": [
238+
"import time\n",
239+
"\n",
240+
"import pandas as pd\n",
241+
"import pyarrow\n",
242+
"from pyspark.sql import functions\n",
243+
"\n",
244+
"graph_name = \"bike_trips\"\n",
245+
"\n",
246+
"arrow_client = gds.arrow_client()\n",
247+
"\n",
248+
"# 1. Start the import process\n",
249+
"job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)\n",
250+
"\n",
251+
"\n",
252+
"# Define a function that receives an arrow batch and uploads it to the GDS session\n",
253+
"def upload_batch(iterator):\n",
254+
" for batch in iterator:\n",
255+
" arrow_client.upload_triplets(job_id, [batch])\n",
256+
" yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({\"batch_rows_imported\": [len(batch)]}))\n",
257+
"\n",
258+
"\n",
259+
"# Select the source target pairs from our source data\n",
260+
"source_target_pairs = spark.sql(\"\"\"\n",
261+
" SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n",
262+
" FROM bike_trips\n",
263+
" \"\"\")\n",
264+
"\n",
265+
"# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.\n",
266+
"uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n",
267+
"\n",
268+
"# Aggregate the batch sizes to receive the row count.\n",
269+
"aggregated_batch_sizes = uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\"))\n",
270+
"\n",
271+
"# Show the result. This will trigger the computation and thus run the data upload.\n",
272+
"aggregated_batch_sizes.show()\n",
273+
"\n",
274+
"# 3. Finish the import process\n",
275+
"arrow_client.triplet_load_done(job_id)\n",
276+
"\n",
277+
"# 4. Wait for the import to finish\n",
278+
"while not arrow_client.job_status(job_id).succeeded():\n",
279+
" time.sleep(1)\n",
280+
"\n",
281+
"G = gds.v2.graph.get(graph_name)\n",
282+
"G"
283+
]
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"metadata": {},
288+
"source": [
289+
"## Running Algorithms\n",
290+
"\n",
291+
"We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples."
292+
]
293+
},
294+
{
295+
"cell_type": "code",
296+
"execution_count": null,
297+
"metadata": {},
298+
"outputs": [],
299+
"source": [
300+
"print(\"Running PageRank ...\")\n",
301+
"pr_result = gds.v2.page_rank.mutate(G, mutate_property=\"pagerank\")"
302+
]
303+
},
304+
{
305+
"cell_type": "markdown",
306+
"metadata": {},
307+
"source": [
308+
"## Sending the computation result back to Spark\n",
309+
"\n",
310+
"Once the computation is done, we might want to further use the result in Spark.\n",
311+
"We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers.\n",
312+
"Retrieving the data is a bit more complicated since we need some input DataFrame in order to trigger computations on the Spark workers.\n",
313+
"We use a data range equal to the size of workers we have in our cluster as our driving table.\n",
314+
"On the workers we will disregard the input and instead stream the computation data from the GDS Session."
315+
]
316+
},
317+
{
318+
"cell_type": "code",
319+
"execution_count": null,
320+
"metadata": {},
321+
"outputs": [],
322+
"source": [
323+
"# 1. Start the node property export on the GDS session\n",
324+
"job_id = arrow_client.get_node_properties(G.name(), [\"pagerank\"])\n",
325+
"\n",
326+
"\n",
327+
"# Define a function that receives data from the GDS Session and turns it into data batches\n",
328+
"def retrieve_data(ignored):\n",
329+
" stream_data = arrow_client.stream_job(G.name(), job_id)\n",
330+
" batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)\n",
331+
" for b in batches:\n",
332+
" yield b\n",
333+
"\n",
334+
"\n",
335+
"# Create DataFrame with a single column and one row per worker\n",
336+
"input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF(\"batch_id\")\n",
337+
"# 2. Stream the data from the GDS Session into the Spark workers\n",
338+
"received_batches = input_partitions.mapInArrow(retrieve_data, \"nodeId long, pagerank double\")\n",
339+
"# Optional: Repartition the data to make sure it is distributed equally\n",
340+
"result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n",
341+
"\n",
342+
"result.toPandas()"
343+
]
344+
},
345+
{
346+
"cell_type": "markdown",
347+
"metadata": {},
348+
"source": [
349+
"## Cleanup\n",
350+
"\n",
351+
"Now that we have finished our analysis, we can delete the GDS session and stop the Spark session.\n",
352+
"\n",
353+
"Deleting the GDS session will release all resources associated with it, and stop incurring costs."
354+
]
355+
},
356+
{
357+
"cell_type": "code",
358+
"execution_count": null,
359+
"metadata": {},
360+
"outputs": [],
361+
"source": [
362+
"gds.delete()\n",
363+
"spark.stop()"
364+
]
365+
}
366+
],
367+
"metadata": {
368+
"language_info": {
369+
"name": "python"
370+
}
371+
},
372+
"nbformat": 4,
373+
"nbformat_minor": 4
374+
}

0 commit comments

Comments
 (0)