Skip to content

Commit e8eaf9a

Browse files
authored
feat: federated auth plugin (#373)
1 parent deaa932 commit e8eaf9a

File tree

15 files changed

+1558
-106
lines changed

15 files changed

+1558
-106
lines changed

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 382 additions & 0 deletions
Large diffs are not rendered by default.

aws_advanced_python_wrapper/iam_plugin.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from typing import TYPE_CHECKING
1818

19+
from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo
20+
1921
if TYPE_CHECKING:
2022
from boto3 import Session
2123
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
@@ -39,23 +41,6 @@
3941
logger = Logger(__name__)
4042

4143

42-
class TokenInfo:
43-
@property
44-
def token(self):
45-
return self._token
46-
47-
@property
48-
def expiration(self):
49-
return self._expiration
50-
51-
def __init__(self, token: str, expiration: datetime):
52-
self._token = token
53-
self._expiration = expiration
54-
55-
def is_expired(self) -> bool:
56-
return datetime.now() > self._expiration
57-
58-
5944
class IamAuthPlugin(Plugin):
6045
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
6146
# Leave 30 second buffer to prevent time-of-check to time-of-use errors
@@ -86,10 +71,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
8671
if not WrapperProperties.USER.get(props):
8772
raise AwsWrapperError(Messages.get_formatted("IamPlugin.IsNoneOrEmpty", WrapperProperties.USER.name))
8873

89-
host = WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host
74+
host = IamAuthUtils.get_iam_host(props, host_info)
9075
region = WrapperProperties.IAM_REGION.get(props) \
9176
if WrapperProperties.IAM_REGION.get(props) else self._get_rds_region(host)
92-
port = self._get_port(props, host_info)
77+
port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
9378
token_expiration_sec: int = WrapperProperties.IAM_EXPIRATION.get_int(props)
9479

9580
cache_key: str = self._get_cache_key(
@@ -170,27 +155,11 @@ def _generate_authentication_token(self,
170155
def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
171156
return f"{region}:{hostname}:{port}:{user}"
172157

173-
def _get_port(self, props: Properties, host_info: HostInfo) -> int:
174-
if WrapperProperties.IAM_DEFAULT_PORT.get(props):
175-
default_port: int = WrapperProperties.IAM_DEFAULT_PORT.get_int(props)
176-
if default_port > 0:
177-
return default_port
178-
else:
179-
logger.debug("IamAuthPlugin.InvalidPort", default_port)
180-
181-
if host_info.is_port_specified():
182-
return host_info.port
183-
184-
if self._plugin_service.database_dialect is not None:
185-
return self._plugin_service.database_dialect.default_port
186-
187-
raise AwsWrapperError(Messages.get("IamAuthPlugin.NoValidPorts"))
188-
189158
def _get_rds_region(self, hostname: Optional[str]) -> str:
190159
rds_region = self._rds_utils.get_rds_region(hostname) if hostname else None
191160

192161
if not rds_region:
193-
exception_message = "IamAuthPlugin.UnsupportedHostname"
162+
exception_message = "RdsUtils.UnsupportedHostname"
194163
logger.debug(exception_message, hostname)
195164
raise AwsWrapperError(Messages.get_formatted(exception_message, hostname))
196165

aws_advanced_python_wrapper/plugin_service.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
from typing import TYPE_CHECKING, ClassVar, List, Type
1818

19+
from aws_advanced_python_wrapper.federated_plugin import \
20+
FederatedAuthPluginFactory
21+
1922
if TYPE_CHECKING:
2023
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
2124
from aws_advanced_python_wrapper.driver_dialect_manager import DriverDialectManager
@@ -563,7 +566,8 @@ class PluginManager(CanReleaseResources):
563566
"stale_dns": StaleDnsPluginFactory,
564567
"connect_time": ConnectTimePluginFactory,
565568
"execute_time": ExecuteTimePluginFactory,
566-
"dev": DeveloperPluginFactory
569+
"dev": DeveloperPluginFactory,
570+
"federated_auth": FederatedAuthPluginFactory
567571
}
568572

569573
WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1
@@ -581,7 +585,8 @@ class PluginManager(CanReleaseResources):
581585
AwsSecretsManagerPluginFactory: 700,
582586
ConnectTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN,
583587
ExecuteTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN,
584-
DeveloperPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN
588+
DeveloperPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN,
589+
FederatedAuthPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN
585590
}
586591

587592
def __init__(self, container: PluginServiceManagerContainer, props: Properties):

aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
AuroraPgDialect.HasExtensionsTrue=[AuroraPgDialect] has_extensions: True
1818
AuroraPgDialect.HasTopologyTrue=[AuroraPgDialect] has_topology: True
1919

20+
AdfsCredentialsProviderFactory.FailedLogin=[AdfsCredentialsProviderFactory] Failed login. Could not obtain SAML Assertion from ADFS SignOn Page POST response: '{}'
21+
AdfsCredentialsProviderFactory.GetSamlAssertionFailed=[AdfsCredentialsProviderFactory] Failed to get SAML Assertion due to exception: '{}'
22+
AdfsCredentialsProviderFactory.InvalidHttpsUrl=[AdfsCredentialsProviderFactory] Invalid HTTPS URL: '{}'
23+
AdfsCredentialsProviderFactory.SignOnPagePostActionUrl=[AdfsCredentialsProviderFactory] ADFS SignOn Action URL: '{}'
24+
AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed=[AdfsCredentialsProviderFactory] ADFS SignOn Page POST action failed with HTTP status '{}', reason phrase '{}', and response '{}'
25+
AdfsCredentialsProviderFactory.SignOnPageRequestFailed=[AdfsCredentialsProviderFactory] ADFS SignOn Page Request Failed with HTTP status '{}', reason phrase '{}', and response '{}'
26+
AdfsCredentialsProviderFactory.SignOnPageUrl=[AdfsCredentialsProviderFactory] ADFS SignOn URL: '{}'
27+
2028
AwsSdk.UnsupportedRegion=[AwsSdk] Unsupported AWS region {}. For supported regions please read https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html
2129

2230
AwsSecretsManagerPlugin.ConnectException=[AwsSecretsManagerPlugin] Error occurred while opening a connection: {}
@@ -82,6 +90,10 @@ FailoverPlugin.TransactionResolutionUnknownError=[Failover] Transaction resoluti
8290
FailoverPlugin.UnableToConnectToReader=[Failover] Unable to establish SQL connection to the reader instance.
8391
FailoverPlugin.UnableToConnectToWriter=[Failover] Unable to establish SQL connection to the writer instance.
8492

93+
FederatedAuthPlugin.UnhandledException=[FederatedAuthPlugin] Unhandled exception: '{}'
94+
95+
FederatedAuthPluginFactory.UnsupportedIdp=[FederatedAuthPluginFactory] Unsupported Identity Provider '{}'. Please visit to the documentation for supported Identity Providers.
96+
8597
HostAvailabilityStrategy.InvalidInitialBackoffTime=[HostAvailabilityStrategy] Invalid value of {} for configuration parameter `host_availability_strategy_initial_backoff_time`. It must be an integer greater than 1.
8698
HostAvailabilityStrategy.InvalidMaxRetries=[HostAvailabilityStrategy] Invalid value of {} for configuration parameter `host_availability_strategy_max_retries`. It must be an integer greater than 1.
8799

@@ -104,7 +116,6 @@ IamAuthPlugin.GeneratedNewIamToken=[IamAuthPlugin] Generated new IAM token = {}
104116
IamAuthPlugin.InvalidPort=[IamAuthPlugin] Port number: {} is not valid. Port number should be greater than zero. Falling back to default port.
105117
IamAuthPlugin.NoValidPort=[IamAuthPlugin] Unable to determine a valid port.
106118
IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {}
107-
IamAuthPlugin.UnsupportedHostname=[IamAuthPlugin] Unsupported AWS hostname {}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected.
108119
IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {}
109120

110121
IamPlugin.IsNoneOrEmpty=[IamPlugin] Property "{}" is None or empty.
@@ -185,6 +196,8 @@ RdsTestUtility.InvalidDatabaseEngine=[RdsTestUtility] The detected database engi
185196
RdsTestUtility.MethodNotSupportedForDeployment=[RdsTestUtility] Method '{}' is not supported for the current database engine deployment: '{}'
186197
RdsTestUtility.WriterInstanceNotFound=[RdsTestUtility] Cannot find writer instance for cluster '{}'.
187198

199+
RdsUtils.UnsupportedHostname=[RdsUtils] Unsupported AWS hostname {}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected.
200+
188201
ReaderFailoverHandler.AttemptingReaderConnection=[ReaderFailoverHandler] Trying to connect to reader: '{}', with properties '{}'
189202
ReaderFailoverHandler.FailedReaderConnection=[ReaderFailoverHandler] Failed to connect to reader: '{}'
190203
ReaderFailoverHandler.InvalidTopology=[ReaderFailoverHandler] '{}' was called with an invalid (None or empty) topology.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from datetime import datetime
18+
from typing import TYPE_CHECKING
19+
20+
if TYPE_CHECKING:
21+
from aws_advanced_python_wrapper.hostinfo import HostInfo
22+
23+
from aws_advanced_python_wrapper.utils.properties import (Properties,
24+
WrapperProperties)
25+
26+
27+
class IamAuthUtils:
28+
29+
@staticmethod
30+
def get_iam_host(props: Properties, host_info: HostInfo):
31+
return WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host
32+
33+
@staticmethod
34+
def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) -> int:
35+
default_port: int = WrapperProperties.IAM_DEFAULT_PORT.get_int(props)
36+
if default_port > 0:
37+
return default_port
38+
39+
if host_info.is_port_specified():
40+
return host_info.port
41+
42+
return dialect_default_port
43+
44+
45+
class TokenInfo:
46+
@property
47+
def token(self):
48+
return self._token
49+
50+
@property
51+
def expiration(self):
52+
return self._expiration
53+
54+
def __init__(self, token: str, expiration: datetime):
55+
self._token = token
56+
self._expiration = expiration
57+
58+
def is_expired(self) -> bool:
59+
return datetime.now() > self._expiration

aws_advanced_python_wrapper/utils/properties.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,54 @@ class WrapperProperties:
234234
ROUND_ROBIN_HOST_WEIGHT_PAIRS = WrapperProperty("round_robin_host_weight_pairs",
235235
"Comma separated list of database host-weight pairs in the format of `<host>:<weight>`.",
236236
"")
237+
# Federated Auth Plugin
238+
IDP_ENDPOINT = WrapperProperty("idp_endpoint",
239+
"The hosting URL of the Identity Provider",
240+
None)
241+
242+
IDP_PORT = WrapperProperty("idp_port",
243+
"The hosting port of the Identity Provider",
244+
443)
245+
246+
RELAYING_PARTY_ID = WrapperProperty("rp_identifier",
247+
"The relaying party identifier",
248+
"urn:amazon:webservices")
249+
250+
IAM_ROLE_ARN = WrapperProperty("iam_role_arn",
251+
"The ARN of the IAM Role that is to be assumed.",
252+
None)
253+
254+
IAM_IDP_ARN = WrapperProperty("iam_idp_arn",
255+
"The ARN of the Identity Provider",
256+
None)
257+
258+
IAM_TOKEN_EXPIRATION = WrapperProperty("iam_token_expiration",
259+
"IAM token cache expiration in seconds",
260+
15 * 60 - 30)
261+
262+
IDP_USERNAME = WrapperProperty("idp_username",
263+
"The federated user name",
264+
None)
265+
266+
IDP_PASSWORD = WrapperProperty("idp_password",
267+
"The federated user password",
268+
None)
269+
270+
HTTP_REQUEST_TIMEOUT = WrapperProperty("http_request_connect_timeout",
271+
"The timeout value in seconds to send the HTTP request data used by the FederatedAuthPlugin",
272+
60)
273+
274+
SSL_SECURE = WrapperProperty("ssl_secure",
275+
"Whether the SSL session is to be secure and the server's certificates will be verified",
276+
False)
277+
278+
IDP_NAME = WrapperProperty("idp_name",
279+
"The name of the Identity Provider implementation used",
280+
"adfs")
281+
282+
DB_USER = WrapperProperty("db_user",
283+
"The database user used to access the database",
284+
None)
237285

238286

239287
class PropertiesUtils:

docs/development-guide/IntegrationTests.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ unset FILTER # Done testing the IAM tests, unset FILTER
117117

118118
| Environment Variable Name | Required | Description | Example Value |
119119
|---------------------------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|
120-
| `DB_USERNAME` | Yes | The username to access the database. | `admin` |
120+
| `DB_USER` | Yes | The username to access the database. | `admin` |
121121
| `DB_PASSWORD` | Yes | The database cluster password. | `password` |
122122
| `DB_DATABASE_NAME` | No | Name of the database that will be used by the tests. The default database name is test. | `test_db_name` |
123123
| `AURORA_CLUSTER_NAME` | Yes | The database identifier for your Aurora cluster. Must be a unique value to avoid conflicting with existing clusters. | `db-identifier` |
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import mysql.connector
16+
17+
from aws_advanced_python_wrapper import AwsWrapperConnection
18+
19+
if __name__ == "__main__":
20+
with AwsWrapperConnection.connect(
21+
mysql.connector.Connect,
22+
host="database.cluster-xyz.us-east-2.rds.amazonaws.com",
23+
database="mysql",
24+
plugins="federated_auth",
25+
idp_name="adfs",
26+
idp_endpoint="ec2amaz-ab3cdef.example.com",
27+
iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role",
28+
iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example",
29+
iam_region="us-east-2",
30+
idp_username="some_federated_username@example.com",
31+
idp_password="some_password",
32+
user="john",
33+
autocommit=True
34+
) as awsconn, awsconn.cursor() as awscursor:
35+
awscursor.execute("SELECT 1")
36+
37+
res = awscursor.fetchone()
38+
print(res)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import psycopg
16+
17+
from aws_advanced_python_wrapper import AwsWrapperConnection
18+
19+
if __name__ == "__main__":
20+
with AwsWrapperConnection.connect(
21+
psycopg.Connection.connect,
22+
host="database.cluster-xyz.us-east-2.rds.amazonaws.com",
23+
dbname="postgres",
24+
plugins="federated_auth",
25+
idp_name="adfs",
26+
idp_endpoint="ec2amaz-ab3cdef.example.com",
27+
iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role",
28+
iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example",
29+
iam_region="us-east-2",
30+
idp_username="some_federated_username@example.com",
31+
idp_password="some_password",
32+
user="john",
33+
autocommit=True
34+
) as awsconn, awsconn.cursor() as awscursor:
35+
awscursor.execute("SELECT 1")
36+
37+
res = awscursor.fetchone()
38+
print(res)

0 commit comments

Comments
 (0)