Skip to content

Commit 2546889

Browse files
authored
feat: okta support (#555)
1 parent 0e9b21e commit 2546889

18 files changed

+904
-290
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING, Dict, Optional, Protocol
18+
19+
import boto3
20+
21+
if TYPE_CHECKING:
22+
from aws_advanced_python_wrapper.utils.properties import Properties
23+
24+
from abc import abstractmethod
25+
26+
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
27+
28+
29+
class CredentialsProviderFactory(Protocol):
30+
@abstractmethod
31+
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
32+
...
33+
34+
35+
class SamlCredentialsProviderFactory(CredentialsProviderFactory):
36+
37+
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
38+
saml_assertion: str = self.get_saml_assertion(props)
39+
session = boto3.Session()
40+
41+
sts_client = session.client(
42+
'sts',
43+
region_name=region
44+
)
45+
46+
response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
47+
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),
48+
PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props),
49+
SAMLAssertion=saml_assertion,
50+
)
51+
52+
return response.get('Credentials')
53+
54+
def get_saml_assertion(self, props: Properties):
55+
...

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 36 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
from __future__ import annotations
1616

17-
from abc import abstractmethod
1817
from html import unescape
1918
from re import DOTALL, findall, search
20-
from typing import TYPE_CHECKING, List, Protocol
21-
from urllib.parse import urlencode, urlparse
19+
from typing import TYPE_CHECKING, List
20+
from urllib.parse import urlencode
2221

23-
from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo
22+
from aws_advanced_python_wrapper.credentials_provider_factory import (
23+
CredentialsProviderFactory, SamlCredentialsProviderFactory)
24+
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
25+
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils
2426

2527
if TYPE_CHECKING:
2628
from boto3 import Session
@@ -32,7 +34,6 @@
3234
from datetime import datetime, timedelta
3335
from typing import Callable, Dict, Optional, Set
3436

35-
import boto3
3637
import requests
3738

3839
from aws_advanced_python_wrapper.errors import AwsWrapperError
@@ -58,6 +59,10 @@ def __init__(self, plugin_service: PluginService, credentials_provider_factory:
5859
self._credentials_provider_factory = credentials_provider_factory
5960
self._session = session
6061

62+
telemetry_factory = self._plugin_service.get_telemetry_factory()
63+
self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count")
64+
self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache))
65+
6166
@property
6267
def subscribed_methods(self) -> Set[str]:
6368
return self._SUBSCRIBED_METHODS
@@ -73,14 +78,15 @@ def connect(
7378
return self._connect(host_info, props, connect_func)
7479

7580
def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
76-
self._check_idp_credentials_with_fallback(props)
81+
SamlUtils.check_idp_credentials_with_fallback(props)
7782

7883
host = IamAuthUtils.get_iam_host(props, host_info)
7984
port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
80-
region: str = self._get_rds_region(host, props)
85+
region: str = IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session)
8186

82-
cache_key: str = self._get_cache_key(
83-
WrapperProperties.DB_USER.get(props),
87+
user = WrapperProperties.DB_USER.get(props)
88+
cache_key: str = IamAuthUtils.get_cache_key(
89+
user,
8490
host,
8591
port,
8692
region
@@ -89,17 +95,17 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
8995
token_info = FederatedAuthPlugin._token_cache.get(cache_key)
9096

9197
if token_info is not None and not token_info.is_expired():
92-
logger.debug("IamAuthPlugin.UseCachedIamToken", token_info.token)
98+
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
9399
self._plugin_service.driver_dialect.set_password(props, token_info.token)
94100
else:
95-
self._update_authentication_token(host_info, props, region, cache_key)
101+
self._update_authentication_token(host_info, props, user, region, cache_key)
96102

97-
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
103+
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
98104

99105
try:
100106
return connect_func()
101107
except Exception:
102-
self._update_authentication_token(host_info, props, region, cache_key)
108+
self._update_authentication_token(host_info, props, user, region, cache_key)
103109

104110
try:
105111
return connect_func()
@@ -121,77 +127,25 @@ def force_connect(
121127
def _update_authentication_token(self,
122128
host_info: HostInfo,
123129
props: Properties,
130+
user: Optional[str],
124131
region: str,
125132
cache_key: str) -> None:
126133
token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props)
127134
token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec)
128135
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
129136
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)
130137

131-
token: str = self._generate_authentication_token(props, host_info.host, port, region, credentials)
132-
logger.debug("IamAuthPlugin.GeneratedNewIamToken", token)
138+
self._fetch_token_counter.inc()
139+
token: str = IamAuthUtils.generate_authentication_token(
140+
self._plugin_service,
141+
user,
142+
host_info.host,
143+
port,
144+
region,
145+
credentials,
146+
self._session)
133147
WrapperProperties.PASSWORD.set(props, token)
134-
FederatedAuthPlugin._token_cache[token] = TokenInfo(token, token_expiry)
135-
136-
def _get_rds_region(self, hostname: Optional[str], props: Properties) -> str:
137-
rds_region = WrapperProperties.IAM_REGION.get(props)
138-
if rds_region is None or rds_region == "":
139-
rds_region = self._rds_utils.get_rds_region(hostname)
140-
141-
if not rds_region:
142-
error_message = "RdsUtils.UnsupportedHostname"
143-
logger.debug(error_message, hostname)
144-
raise AwsWrapperError(Messages.get_formatted(error_message, hostname))
145-
146-
session = self._session if self._session else boto3.Session()
147-
if rds_region not in session.get_available_regions("rds"):
148-
error_message = "AwsSdk.UnsupportedRegion"
149-
logger.debug(error_message, rds_region)
150-
raise AwsWrapperError(Messages.get_formatted(error_message, rds_region))
151-
152-
return rds_region
153-
154-
def _generate_authentication_token(self,
155-
props: Properties,
156-
host_name: Optional[str],
157-
port: Optional[int],
158-
region: Optional[str],
159-
credentials: Optional[Dict[str, str]]) -> str:
160-
session = self._session if self._session else boto3.Session()
161-
162-
if credentials is not None:
163-
client = session.client(
164-
'rds',
165-
region_name=region,
166-
aws_access_key_id=credentials.get('AccessKeyId'),
167-
aws_secret_access_key=credentials.get('SecretAccessKey'),
168-
aws_session_token=credentials.get('SessionToken')
169-
)
170-
else:
171-
client = session.client(
172-
'rds',
173-
region_name=region
174-
)
175-
176-
user = WrapperProperties.USER.get(props)
177-
token = client.generate_db_auth_token(
178-
DBHostname=host_name,
179-
Port=port,
180-
DBUsername=user
181-
)
182-
183-
client.close()
184-
185-
return token
186-
187-
def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
188-
return f"{region}:{hostname}:{port}:{user}"
189-
190-
def _check_idp_credentials_with_fallback(self, props: Properties) -> None:
191-
if WrapperProperties.IDP_USERNAME.get(props) is None:
192-
WrapperProperties.IDP_USERNAME.set(props, WrapperProperties.USER.name)
193-
if WrapperProperties.IDP_PASSWORD.get(props) is None:
194-
WrapperProperties.IDP_PASSWORD.set(props, WrapperProperties.PASSWORD.name)
148+
FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
195149

196150

197151
class FederatedAuthPluginFactory(PluginFactory):
@@ -205,35 +159,6 @@ def get_credentials_provider_factory(self, plugin_service: PluginService, props:
205159
raise AwsWrapperError(Messages.get_formatted("FederatedAuthPluginFactory.UnsupportedIdp", idp_name))
206160

207161

208-
class CredentialsProviderFactory(Protocol):
209-
@abstractmethod
210-
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
211-
...
212-
213-
214-
class SamlCredentialsProviderFactory(CredentialsProviderFactory):
215-
216-
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
217-
saml_assertion: str = self.get_saml_assertion(props)
218-
session = boto3.Session()
219-
220-
sts_client = session.client(
221-
'sts',
222-
region_name=region
223-
)
224-
225-
response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
226-
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),
227-
PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props),
228-
SAMLAssertion=saml_assertion
229-
)
230-
231-
return response.get('Credentials')
232-
233-
def get_saml_assertion(self, props: Properties):
234-
...
235-
236-
237162
class AdfsCredentialsProviderFactory(SamlCredentialsProviderFactory):
238163
_INPUT_TAG_PATTERN = r"<input(.+?)/>"
239164
_FORM_ACTION_PATTERN = r"<form.*?action=\"([^\"]+)\""
@@ -274,32 +199,22 @@ def get_saml_assertion(self, props: Properties):
274199

275200
def _get_sign_in_page_body(self, url: str, props: Properties) -> str:
276201
logger.debug("AdfsCredentialsProviderFactory.SignOnPageUrl", url)
277-
self._validate_url(url)
202+
SamlUtils.validate_url(url)
278203
r = requests.get(url,
279204
verify=WrapperProperties.SSL_SECURE.get_bool(props),
280205
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))
281206

282-
# Check HTTP Status Code is 2xx Success
283-
if r.status_code / 100 != 2:
284-
error_message = "AdfsCredentialsProviderFactory.SignOnPageRequestFailed"
285-
logger.debug(error_message, r.status_code, r.reason, r.text)
286-
raise AwsWrapperError(Messages.get_formatted(error_message, r.status_code, r.reason, r.text))
287-
207+
SamlUtils.validate_response(r)
288208
return r.text
289209

290210
def _post_form_action_body(self, uri: str, parameters: Dict[str, str], props: Properties) -> str:
291211
logger.debug("AdfsCredentialsProviderFactory.SignOnPagePostActionUrl", uri)
292-
self._validate_url(uri)
212+
SamlUtils.validate_url(uri)
293213
r = requests.post(uri, data=urlencode(parameters),
294214
verify=WrapperProperties.SSL_SECURE.get_bool(props),
295215
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))
296216
# Check HTTP Status Code is 2xx Success
297-
if r.status_code / 100 != 2:
298-
error_message = "AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed"
299-
logger.debug(error_message, r.status_code, r.reason, r.text)
300-
raise AwsWrapperError(
301-
Messages.get_formatted(error_message, r.status_code, r.reason, r.text))
302-
217+
SamlUtils.validate_response(r)
303218
return r.text
304219

305220
def _get_sign_in_page_url(self, props) -> str:
@@ -308,7 +223,7 @@ def _get_sign_in_page_url(self, props) -> str:
308223
relaying_party_id = WrapperProperties.RELAYING_PARTY_ID.get(props)
309224
url = f"https://{idp_endpoint}:{idp_port}/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp={relaying_party_id}"
310225
if idp_endpoint is None or relaying_party_id is None:
311-
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
226+
error_message = "SamlUtils.InvalidHttpsUrl"
312227
logger.debug(error_message, url)
313228
raise AwsWrapperError(Messages.get_formatted(error_message, url))
314229

@@ -319,7 +234,7 @@ def _get_form_action_url(self, props: Properties, action: str) -> str:
319234
idp_port = WrapperProperties.IDP_PORT.get(props)
320235
url = f"https://{idp_endpoint}:{idp_port}{action}"
321236
if idp_endpoint is None:
322-
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
237+
error_message = "SamlUtils.InvalidHttpsUrl"
323238
logger.debug(error_message, url)
324239
raise AwsWrapperError(
325240
Messages.get_formatted(error_message, url))
@@ -373,10 +288,3 @@ def _get_form_action_from_html_body(self, body: str) -> str:
373288
return unescape(match.group(1))
374289

375290
return ""
376-
377-
def _validate_url(self, url: str) -> None:
378-
result = urlparse(url)
379-
if not result.scheme or not search(self._HTTPS_URL_PATTERN, url):
380-
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
381-
logger.debug(error_message, url)
382-
raise AwsWrapperError(Messages.get_formatted(error_message, url))

0 commit comments

Comments
 (0)