|
40 | 40 | "source": [ |
41 | 41 | "## Prerequisites\n", |
42 | 42 | "\n", |
43 | | - "This notebook requires having an AuraDB instance available and have the Aura Graph Analytics [feature](https://neo4j.com/docs/aura/graph-analytics/#aura-gds-serverless) enabled for your project.\n", |
44 | | - "\n", |
45 | | - "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/" |
| 43 | + "We also need to have the `graphdatascience` Python library installed, version `1.18` or later, as well as `pyspark`." |
46 | 44 | ] |
47 | 45 | }, |
48 | 46 | { |
|
75 | 73 | "### Connecting to a Spark Session\n", |
76 | 74 | "\n", |
77 | 75 | "To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n", |
78 | | - "Working with a remote Spark cluster will work similarly." |
| 76 | + "Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/" |
79 | 77 | ] |
80 | 78 | }, |
81 | 79 | { |
|
164 | 162 | "source": [ |
165 | 163 | "## Adding a dataset\n", |
166 | 164 | "\n", |
167 | | - "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips).", |
168 | | - "The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip." |
| 165 | + "As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip." |
169 | 166 | ] |
170 | 167 | }, |
171 | 168 | { |
|
206 | 203 | "\n", |
207 | 204 | "We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n", |
208 | 205 | "\n", |
209 | | - "Our input data already resembles edge triplets, where each of the rows represents an edge from a source station to a target station. This allows us to use the arrows servers graph import from triplets functionality, which requires the following protocol:\n", |
| 206 | + "Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server's \"graph import from triplets\" functionality, which requires the following protocol:\n", |
210 | 207 | "\n", |
211 | 208 | "1. Send an action `v2/graph.project.fromTriplets`\n", |
212 | 209 | " This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n", |
213 | 210 | "2. Send the data in batches to the Arrow server.\n", |
214 | 211 | "3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.\n", |
215 | 212 | "4. Wait for the import process to reach the `DONE` state.\n", |
216 | 213 | "\n", |
217 | | - "While the overall process is straight forward, we need to somehow tell Spark to" |
| 214 | + "The most complicated step here is to run the actual data upload on each spark worker. We will use the `mapInArrow` function to run custom code on each spark worker. Each worker will receive a number of arrow record batches that we can directly send to the GDS session's Arrow server. " |
| 215 | + ] |
| 216 | + }, |
| 217 | + { |
| 218 | + "metadata": {}, |
| 219 | + "cell_type": "markdown", |
| 220 | + "source": [ |
| 221 | + "The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the `time` module and adding `time.sleep(1)` inside the `while` loop at the end of the cell.\n", |
| 222 | + "\n" |
218 | 223 | ] |
219 | 224 | }, |
220 | 225 | { |
221 | | - "cell_type": "code", |
222 | | - "execution_count": null, |
223 | 226 | "metadata": {}, |
| 227 | + "cell_type": "markdown", |
| 228 | + "source": "<llm-snippet-file>graph-analytics-serverless-spark.ipynb</llm-snippet-file>\n" |
| 229 | + }, |
| 230 | + { |
| 231 | + "metadata": {}, |
| 232 | + "cell_type": "code", |
224 | 233 | "outputs": [], |
| 234 | + "execution_count": null, |
225 | 235 | "source": [ |
| 236 | + "import time\n", |
226 | 237 | "import pandas as pd\n", |
227 | 238 | "import pyarrow\n", |
228 | 239 | "from pyspark.sql import functions\n", |
|
244 | 255 | "\n", |
245 | 256 | "# Select the source target pairs from our source data\n", |
246 | 257 | "source_target_pairs = spark.sql(\"\"\"\n", |
247 | | - " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n", |
248 | | - " FROM bike_trips\n", |
249 | | - "\"\"\")\n", |
| 258 | + " SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n", |
| 259 | + " FROM bike_trips\n", |
| 260 | + " \"\"\")\n", |
250 | 261 | "\n", |
251 | 262 | "# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.\n", |
252 | 263 | "uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n", |
253 | 264 | "\n", |
254 | 265 | "# Aggregate the batch sizes to receive the row count.\n", |
255 | | - "uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\")).show()\n", |
| 266 | + "aggregated_batch_sizes = uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\"))\n", |
| 267 | + "\n", |
| 268 | + "# Show the result. This will trigger the computation and thus run the data upload.\n", |
| 269 | + "aggregated_batch_sizes.show()\n", |
256 | 270 | "\n", |
257 | 271 | "# 3. Finish the import process\n", |
258 | 272 | "arrow_client.triplet_load_done(job_id)\n", |
259 | 273 | "\n", |
260 | 274 | "# 4. Wait for the import to finish\n", |
261 | 275 | "while not arrow_client.job_status(job_id).succeeded():\n", |
262 | | - " pass\n", |
| 276 | + " time.sleep(1)\n", |
263 | 277 | "\n", |
264 | 278 | "G = gds.v2.graph.get(graph_name)\n", |
265 | 279 | "G" |
266 | 280 | ] |
267 | 281 | }, |
268 | 282 | { |
269 | | - "cell_type": "markdown", |
270 | 283 | "metadata": {}, |
| 284 | + "cell_type": "markdown", |
271 | 285 | "source": [ |
272 | 286 | "## Running Algorithms\n", |
273 | 287 | "\n", |
|
322 | 336 | "# Optional: Repartition the data to make sure it is distributed equally\n", |
323 | 337 | "result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n", |
324 | 338 | "\n", |
325 | | - "result.show()" |
| 339 | + "result.toPandas()" |
326 | 340 | ] |
327 | 341 | }, |
328 | 342 | { |
|
0 commit comments