Skip to content

Commit d789559

Browse files
committed
Check db env through cypher and allow aura_instance_id to be filled from env
1 parent de481b9 commit d789559

File tree

8 files changed

+79
-35
lines changed

8 files changed

+79
-35
lines changed

doc/modules/ROOT/pages/graph-analytics-serverless.adoc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,9 @@ sessions.get_or_create(
9898
| Name | Type | Optional | Default | Description
9999
| session_name | str | no | - | Name of the session. Must be unique within the project.
100100
| memory | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/session_memory[SessionMemory] | no | - | Amount of memory available to the session.
101-
| db_connection | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/dbms_connection_info[DbmsConnectionInfo] | yes | None | Bolt server URL, username, and password to a Neo4j DBMS. Required for the Attached and Self-managed types. Alternatively to username and password, you can provide a `neo4j.Auth` https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods[object].
101+
| db_connection | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/dbms_connection_info[DbmsConnectionInfo] | yes | None | Bolt server URI, username, and password to a Neo4j DBMS. Required for the Attached and Self-managed types. Alternatively to username and password, you can provide a `neo4j.Auth` https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods[object].
102102
| ttl | datetime.timedelta | yes | 1h | Time-to-live for the session.
103103
| cloud_location | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/cloud_location[CloudLocation] | yes | None | Aura-supported cloud provider and region where the GDS Session will run. Required for the Self-managed and Standalone types.
104-
| aura_instance_id | str | yes | None | Aura instance ID for the session. Required for the Attached type, if the id could not be derived from the DB connection.
105104
| timeout | int | yes | None | Seconds to wait for the session to enter Ready state. If the time is exceeded, an error will be returned.
106105
| neo4j_driver_options | dict[str, any] | yes | None | Additional options passed to the Neo4j driver to the Neo4j DBMS. Only relevant if `db_connection` is specified.
107106
| arrow_client_options | dict[str, any] | yes | None | Additional options passed to the Arrow Flight Client used to connect to the Session.
@@ -317,7 +316,13 @@ from graphdatascience.session import SessionMemory, DbmsConnectionInfo, GdsSessi
317316
318317
sessions = GdsSessions(api_credentials=AuraAPICredentials(os.environ["CLIENT_ID"], os.environ["CLIENT_SECRET"]))
319318
320-
db_connection = DbmsConnectionInfo(os.environ["DB_URI"], os.environ["DB_USER"], os.environ["DB_PASSWORD"])
319+
# you can also use DbmsConnectionInfo.from_env() to load credentials from environment variables
320+
db_connection = DbmsConnectionInfo(
321+
uri=os.environ["NEO4J_URI"],
322+
username=os.environ["NEO4J_USERNAME"],
323+
password=os.environ["NEO4J_PASSWORD"],
324+
aura_instance_id=os.environ["AURA_INSTANCEID"]
325+
)
321326
gds = sessions.get_or_create(
322327
session_name="my-new-session",
323328
memory=SessionMemory.m_8GB,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
2+
3+
4+
class DbEnvironmentResolver:
5+
@staticmethod
6+
def hosted_in_aura(db_runner: Neo4jQueryRunner) -> bool:
7+
return (
8+
db_runner.run_retryable_cypher("""
9+
CALL dbms.components() YIELD name, versions
10+
WHERE name = "Neo4j Kernel"
11+
UNWIND versions as v
12+
WITH name, v
13+
WHERE v ENDS WITH "aura"
14+
RETURN count(*) <> 0
15+
""").squeeze()
16+
is True
17+
)

graphdatascience/session/dbms_connection_info.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import os
44
from dataclasses import dataclass
5-
from urllib.parse import urlparse
65

76
from neo4j import Auth, basic_auth
87

@@ -20,6 +19,7 @@ class DbmsConnectionInfo:
2019
database: str | None = None
2120
# Optional: typed authentication, used instead of username/password. Supports for example a token. See https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods
2221
auth: Auth | None = None
22+
aura_instance_id: str | None = None
2323

2424
def __post_init__(self) -> None:
2525
# Validate auth fields
@@ -39,14 +39,6 @@ def get_auth(self) -> Auth | None:
3939
auth = basic_auth(self.username, self.password)
4040
return auth
4141

42-
def hosted_in_aura(self) -> bool:
43-
"""
44-
Returns:
45-
True if the DBMS is hosted in Aura, False otherwise.
46-
"""
47-
host = urlparse(self.uri).hostname
48-
return host is not None and host.endswith("databases.neo4j.io")
49-
5042
@staticmethod
5143
def from_env() -> DbmsConnectionInfo:
5244
"""
@@ -56,10 +48,12 @@ def from_env() -> DbmsConnectionInfo:
5648
- NEO4J_USERNAME
5749
- NEO4J_PASSWORD
5850
- NEO4J_DATABASE
51+
- AURA_INSTANCEID
5952
"""
6053
uri = os.environ["NEO4J_URI"]
6154
username = os.environ.get("NEO4J_USERNAME", "neo4j")
6255
password = os.environ["NEO4J_PASSWORD"]
6356
database = os.environ.get("NEO4J_DATABASE")
57+
aura_instance_id = os.environ.get("AURA_INSTANCEID")
6458

65-
return DbmsConnectionInfo(uri, username, password, database)
59+
return DbmsConnectionInfo(uri, username, password, database, aura_instance_id=aura_instance_id)

graphdatascience/session/dedicated_sessions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77

88
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
9+
from graphdatascience.query_runner.db_environment_resolver import DbEnvironmentResolver
910
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1011
from graphdatascience.session.algorithm_category import AlgorithmCategory
1112
from graphdatascience.session.aura_api import AuraApi
@@ -61,7 +62,6 @@ def get_or_create(
6162
db_connection: DbmsConnectionInfo | None = None,
6263
ttl: timedelta | None = None,
6364
cloud_location: CloudLocation | None = None,
64-
aura_instance_id: str | None = None,
6565
timeout: int | None = None,
6666
neo4j_driver_options: dict[str, Any] | None = None,
6767
arrow_client_options: dict[str, Any] | None = None,
@@ -75,17 +75,21 @@ def get_or_create(
7575
else:
7676
db_runner = self._create_db_runner(db_connection, neo4j_driver_options)
7777

78-
aura_instance_id = AuraApi.extract_id(db_connection.uri) if not aura_instance_id else aura_instance_id
78+
aura_instance_id = (
79+
db_connection.aura_instance_id
80+
if db_connection.aura_instance_id
81+
else AuraApi.extract_id(db_connection.uri)
82+
)
83+
if not aura_instance_id and DbEnvironmentResolver.hosted_in_aura(db_runner):
84+
raise ValueError(
85+
f"Could not derive Aura instance id from the URI `{db_connection.uri}`. Please specify the `aura_instance_id` in the `db_connection` argument."
86+
)
87+
7988
aura_db_instance = self._aura_api.list_instance(aura_instance_id)
8089

8190
if aura_db_instance is None:
8291
if not cloud_location:
83-
if db_connection.hosted_in_aura():
84-
raise ValueError(
85-
f"Could not derive Aura instance id from the URI `{db_connection.uri}`. Please provide the instance id via the `aura_instance_id` argument, or specify a cloud location if the DBMS is self-managed."
86-
)
87-
else:
88-
raise ValueError("cloud_location must be provided for sessions against a self-managed DB.")
92+
raise ValueError("cloud_location must be provided for sessions against a self-managed DB.")
8993

9094
session_details = self._get_or_create_self_managed_session(
9195
session_name, memory.value, cloud_location, ttl

graphdatascience/session/gds_sessions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def get_or_create(
107107
timeout: int | None = None,
108108
neo4j_driver_config: dict[str, Any] | None = None,
109109
arrow_client_options: dict[str, Any] | None = None,
110-
aura_instance_id: str | None = None,
111110
) -> AuraGraphDataScience:
112111
"""
113112
Retrieves an existing session with the given session name and database connection,
@@ -125,7 +124,6 @@ def get_or_create(
125124
timeout (int | None): Optional timeout (in seconds) when waiting for session to become ready. If unset the method will wait forever. If set and session does not become ready an exception will be raised. It is user responsibility to ensure resource gets cleaned up in this situation.
126125
neo4j_driver_config (dict[str, Any] | None): Optional configuration for the Neo4j driver to the Neo4j DBMS. Only relevant if `db_connection` is specified..
127126
arrow_client_options (dict[str, Any] | None): Optional configuration for the Arrow Flight client.
128-
aura_instance_id (str | None): The Aura instance id. Required if the database is in Aura but its instance id cannot be inferred from the connection information.
129127
Returns:
130128
AuraGraphDataScience: The session.
131129
"""
@@ -134,7 +132,6 @@ def get_or_create(
134132
memory,
135133
db_connection=db_connection,
136134
ttl=ttl,
137-
aura_instance_id=aura_instance_id,
138135
cloud_location=cloud_location,
139136
timeout=timeout,
140137
neo4j_driver_options=neo4j_driver_config,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import pytest
2+
3+
from graphdatascience.query_runner.db_environment_resolver import DbEnvironmentResolver
4+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
5+
6+
7+
@pytest.mark.only_on_aura
8+
def test_hosted_in_aura_aura_dbms(aura_runner: Neo4jQueryRunner) -> None:
9+
assert DbEnvironmentResolver.hosted_in_aura(aura_runner)
10+
11+
12+
@pytest.mark.skip_on_aura
13+
def test_hosted_in_aura_self_managed_dbms(runner: Neo4jQueryRunner) -> None:
14+
assert not DbEnvironmentResolver.hosted_in_aura(runner)

graphdatascience/tests/unit/session/test_dbms_connection_info.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,3 @@ def test_dbms_connection_info_fail_on_auth_and_username() -> None:
3939
)
4040
else:
4141
assert False, "Expected ValueError was not raised"
42-
43-
44-
def test_dbms_connection_info_hosted_in_aura() -> None:
45-
assert DbmsConnectionInfo(uri="bolt://something.databases.neo4j.io").hosted_in_aura()
46-
assert DbmsConnectionInfo(uri="bolt://something.databases.neo4j.io:7474").hosted_in_aura()
47-
48-
assert not DbmsConnectionInfo(uri="bolt://something.neo4j.com").hosted_in_aura()

graphdatascience/tests/unit/test_dedicated_sessions.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def test_create_standalone_session(mocker: MockerFixture, aura_api: AuraApi) ->
431431
sessions = DedicatedSessions(aura_api)
432432

433433
patch_construct_client(mocker)
434-
patch_neo4j_query_runner(mocker)
434+
patch_neo4j_query_runner(mocker, hosted_in_aura=False)
435435

436436
ttl = timedelta(hours=42)
437437

@@ -526,6 +526,22 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None:
526526
assert [i.name for i in sessions.list()] == ["my-session"]
527527

528528

529+
def test_get_or_create_with_explicit_aura_instance_id(mocker: MockerFixture, aura_api: AuraApi) -> None:
530+
db = _setup_db_instance(aura_api)
531+
sessions = DedicatedSessions(aura_api)
532+
patch_construct_client(mocker)
533+
patch_neo4j_query_runner(mocker)
534+
535+
sessions.get_or_create(
536+
"my-session",
537+
SessionMemory.m_8GB,
538+
DbmsConnectionInfo(
539+
"neo4j+s://foo.bar", "dbuser", "db_pw", aura_instance_id=db.id
540+
), # not part of list instances result
541+
cloud_location=None,
542+
)
543+
544+
529545
def test_get_or_create_expired_session(mocker: MockerFixture, aura_api: AuraApi) -> None:
530546
db = _setup_db_instance(aura_api)
531547

@@ -601,7 +617,7 @@ def test_get_or_create_for_auradb_with_cloud_location(mocker: MockerFixture, aur
601617

602618
def test_get_or_create_for_without_cloud_location(mocker: MockerFixture, aura_api: AuraApi) -> None:
603619
sessions = DedicatedSessions(aura_api)
604-
patch_neo4j_query_runner(mocker)
620+
patch_neo4j_query_runner(mocker, hosted_in_aura=False)
605621

606622
with pytest.raises(
607623
ValueError, match=re.escape("cloud_location must be provided for sessions against a self-managed DB.")
@@ -621,7 +637,7 @@ def test_get_or_create_for_non_derivable_aura_instance_id(mocker: MockerFixture,
621637
with pytest.raises(
622638
ValueError,
623639
match=re.escape(
624-
"Could not derive Aura instance id from the URI `neo4j+s://06cba79f.databases.neo4j.io`. Please provide the instance id via the `aura_instance_id` argument, or specify a cloud location if the DBMS is self-managed."
640+
"Could not derive Aura instance id from the URI `neo4j+s://06cba79f.databases.neo4j.io`. Please specify the `aura_instance_id` in the `db_connection` argument."
625641
),
626642
):
627643
sessions.get_or_create(
@@ -808,12 +824,16 @@ def _setup_db_instance(aura_api: AuraApi) -> InstanceCreateDetails:
808824
return aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "leipzig-1")
809825

810826

811-
def patch_neo4j_query_runner(mocker: MockerFixture) -> None:
827+
def patch_neo4j_query_runner(mocker: MockerFixture, hosted_in_aura: bool = True) -> None:
812828
mocker.patch(
813829
"graphdatascience.query_runner.neo4j_query_runner.Neo4jQueryRunner.create_for_db",
814830
lambda *args, **kwargs: kwargs,
815831
)
816832
mocker.patch("graphdatascience.session.dedicated_sessions.DedicatedSessions._validate_db_connection")
833+
mocker.patch(
834+
"graphdatascience.query_runner.db_environment_resolver.DbEnvironmentResolver.hosted_in_aura",
835+
lambda *args, **kwargs: hosted_in_aura,
836+
)
817837

818838

819839
def patch_construct_client(mocker: MockerFixture) -> None:

0 commit comments

Comments
 (0)