1414
1515from __future__ import annotations
1616
17- from abc import abstractmethod
1817from html import unescape
1918from re import DOTALL , findall , search
20- from typing import TYPE_CHECKING , List , Protocol
21- from urllib .parse import urlencode , urlparse
19+ from typing import TYPE_CHECKING , List
20+ from urllib .parse import urlencode
2221
23- from aws_advanced_python_wrapper .utils .iamutils import IamAuthUtils , TokenInfo
22+ from aws_advanced_python_wrapper .credentials_provider_factory import (
23+ CredentialsProviderFactory , SamlCredentialsProviderFactory )
24+ from aws_advanced_python_wrapper .utils .iam_utils import IamAuthUtils , TokenInfo
25+ from aws_advanced_python_wrapper .utils .saml_utils import SamlUtils
2426
2527if TYPE_CHECKING :
2628 from boto3 import Session
3234from datetime import datetime , timedelta
3335from typing import Callable , Dict , Optional , Set
3436
35- import boto3
3637import requests
3738
3839from aws_advanced_python_wrapper .errors import AwsWrapperError
@@ -58,6 +59,10 @@ def __init__(self, plugin_service: PluginService, credentials_provider_factory:
5859 self ._credentials_provider_factory = credentials_provider_factory
5960 self ._session = session
6061
62+ telemetry_factory = self ._plugin_service .get_telemetry_factory ()
63+ self ._fetch_token_counter = telemetry_factory .create_counter ("federated.fetch_token.count" )
64+ self ._cache_size_gauge = telemetry_factory .create_gauge ("federated.token_cache.size" , lambda : len (FederatedAuthPlugin ._token_cache ))
65+
6166 @property
6267 def subscribed_methods (self ) -> Set [str ]:
6368 return self ._SUBSCRIBED_METHODS
@@ -73,14 +78,15 @@ def connect(
7378 return self ._connect (host_info , props , connect_func )
7479
7580 def _connect (self , host_info : HostInfo , props : Properties , connect_func : Callable ) -> Connection :
76- self . _check_idp_credentials_with_fallback (props )
81+ SamlUtils . check_idp_credentials_with_fallback (props )
7782
7883 host = IamAuthUtils .get_iam_host (props , host_info )
7984 port = IamAuthUtils .get_port (props , host_info , self ._plugin_service .database_dialect .default_port )
80- region : str = self . _get_rds_region ( host , props )
85+ region : str = IamAuthUtils . get_rds_region ( self . _rds_utils , host , props , self . _session )
8186
82- cache_key : str = self ._get_cache_key (
83- WrapperProperties .DB_USER .get (props ),
87+ user = WrapperProperties .DB_USER .get (props )
88+ cache_key : str = IamAuthUtils .get_cache_key (
89+ user ,
8490 host ,
8591 port ,
8692 region
@@ -89,17 +95,17 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
8995 token_info = FederatedAuthPlugin ._token_cache .get (cache_key )
9096
9197 if token_info is not None and not token_info .is_expired ():
92- logger .debug ("IamAuthPlugin.UseCachedIamToken " , token_info .token )
98+ logger .debug ("FederatedAuthPlugin.UseCachedToken " , token_info .token )
9399 self ._plugin_service .driver_dialect .set_password (props , token_info .token )
94100 else :
95- self ._update_authentication_token (host_info , props , region , cache_key )
101+ self ._update_authentication_token (host_info , props , user , region , cache_key )
96102
97- WrapperProperties .USER .set (props , WrapperProperties .DB_USER .get (props ))
103+ WrapperProperties .USER .set (props , WrapperProperties .DB_USER .get (props ))
98104
99105 try :
100106 return connect_func ()
101107 except Exception :
102- self ._update_authentication_token (host_info , props , region , cache_key )
108+ self ._update_authentication_token (host_info , props , user , region , cache_key )
103109
104110 try :
105111 return connect_func ()
@@ -121,77 +127,25 @@ def force_connect(
121127 def _update_authentication_token (self ,
122128 host_info : HostInfo ,
123129 props : Properties ,
130+ user : Optional [str ],
124131 region : str ,
125132 cache_key : str ) -> None :
126133 token_expiration_sec : int = WrapperProperties .IAM_TOKEN_EXPIRATION .get_int (props )
127134 token_expiry : datetime = datetime .now () + timedelta (seconds = token_expiration_sec )
128135 port : int = IamAuthUtils .get_port (props , host_info , self ._plugin_service .database_dialect .default_port )
129136 credentials : Optional [Dict [str , str ]] = self ._credentials_provider_factory .get_aws_credentials (region , props )
130137
131- token : str = self ._generate_authentication_token (props , host_info .host , port , region , credentials )
132- logger .debug ("IamAuthPlugin.GeneratedNewIamToken" , token )
138+ self ._fetch_token_counter .inc ()
139+ token : str = IamAuthUtils .generate_authentication_token (
140+ self ._plugin_service ,
141+ user ,
142+ host_info .host ,
143+ port ,
144+ region ,
145+ credentials ,
146+ self ._session )
133147 WrapperProperties .PASSWORD .set (props , token )
134- FederatedAuthPlugin ._token_cache [token ] = TokenInfo (token , token_expiry )
135-
136- def _get_rds_region (self , hostname : Optional [str ], props : Properties ) -> str :
137- rds_region = WrapperProperties .IAM_REGION .get (props )
138- if rds_region is None or rds_region == "" :
139- rds_region = self ._rds_utils .get_rds_region (hostname )
140-
141- if not rds_region :
142- error_message = "RdsUtils.UnsupportedHostname"
143- logger .debug (error_message , hostname )
144- raise AwsWrapperError (Messages .get_formatted (error_message , hostname ))
145-
146- session = self ._session if self ._session else boto3 .Session ()
147- if rds_region not in session .get_available_regions ("rds" ):
148- error_message = "AwsSdk.UnsupportedRegion"
149- logger .debug (error_message , rds_region )
150- raise AwsWrapperError (Messages .get_formatted (error_message , rds_region ))
151-
152- return rds_region
153-
154- def _generate_authentication_token (self ,
155- props : Properties ,
156- host_name : Optional [str ],
157- port : Optional [int ],
158- region : Optional [str ],
159- credentials : Optional [Dict [str , str ]]) -> str :
160- session = self ._session if self ._session else boto3 .Session ()
161-
162- if credentials is not None :
163- client = session .client (
164- 'rds' ,
165- region_name = region ,
166- aws_access_key_id = credentials .get ('AccessKeyId' ),
167- aws_secret_access_key = credentials .get ('SecretAccessKey' ),
168- aws_session_token = credentials .get ('SessionToken' )
169- )
170- else :
171- client = session .client (
172- 'rds' ,
173- region_name = region
174- )
175-
176- user = WrapperProperties .USER .get (props )
177- token = client .generate_db_auth_token (
178- DBHostname = host_name ,
179- Port = port ,
180- DBUsername = user
181- )
182-
183- client .close ()
184-
185- return token
186-
187- def _get_cache_key (self , user : Optional [str ], hostname : Optional [str ], port : int , region : Optional [str ]) -> str :
188- return f"{ region } :{ hostname } :{ port } :{ user } "
189-
190- def _check_idp_credentials_with_fallback (self , props : Properties ) -> None :
191- if WrapperProperties .IDP_USERNAME .get (props ) is None :
192- WrapperProperties .IDP_USERNAME .set (props , WrapperProperties .USER .name )
193- if WrapperProperties .IDP_PASSWORD .get (props ) is None :
194- WrapperProperties .IDP_PASSWORD .set (props , WrapperProperties .PASSWORD .name )
148+ FederatedAuthPlugin ._token_cache [cache_key ] = TokenInfo (token , token_expiry )
195149
196150
197151class FederatedAuthPluginFactory (PluginFactory ):
@@ -205,35 +159,6 @@ def get_credentials_provider_factory(self, plugin_service: PluginService, props:
205159 raise AwsWrapperError (Messages .get_formatted ("FederatedAuthPluginFactory.UnsupportedIdp" , idp_name ))
206160
207161
208- class CredentialsProviderFactory (Protocol ):
209- @abstractmethod
210- def get_aws_credentials (self , region : str , props : Properties ) -> Optional [Dict [str , str ]]:
211- ...
212-
213-
214- class SamlCredentialsProviderFactory (CredentialsProviderFactory ):
215-
216- def get_aws_credentials (self , region : str , props : Properties ) -> Optional [Dict [str , str ]]:
217- saml_assertion : str = self .get_saml_assertion (props )
218- session = boto3 .Session ()
219-
220- sts_client = session .client (
221- 'sts' ,
222- region_name = region
223- )
224-
225- response : Dict [str , Dict [str , str ]] = sts_client .assume_role_with_saml (
226- RoleArn = WrapperProperties .IAM_ROLE_ARN .get (props ),
227- PrincipalArn = WrapperProperties .IAM_IDP_ARN .get (props ),
228- SAMLAssertion = saml_assertion
229- )
230-
231- return response .get ('Credentials' )
232-
233- def get_saml_assertion (self , props : Properties ):
234- ...
235-
236-
237162class AdfsCredentialsProviderFactory (SamlCredentialsProviderFactory ):
238163 _INPUT_TAG_PATTERN = r"<input(.+?)/>"
239164 _FORM_ACTION_PATTERN = r"<form.*?action=\"([^\"]+)\""
@@ -274,32 +199,22 @@ def get_saml_assertion(self, props: Properties):
274199
275200 def _get_sign_in_page_body (self , url : str , props : Properties ) -> str :
276201 logger .debug ("AdfsCredentialsProviderFactory.SignOnPageUrl" , url )
277- self . _validate_url (url )
202+ SamlUtils . validate_url (url )
278203 r = requests .get (url ,
279204 verify = WrapperProperties .SSL_SECURE .get_bool (props ),
280205 timeout = WrapperProperties .HTTP_REQUEST_TIMEOUT .get_int (props ))
281206
282- # Check HTTP Status Code is 2xx Success
283- if r .status_code / 100 != 2 :
284- error_message = "AdfsCredentialsProviderFactory.SignOnPageRequestFailed"
285- logger .debug (error_message , r .status_code , r .reason , r .text )
286- raise AwsWrapperError (Messages .get_formatted (error_message , r .status_code , r .reason , r .text ))
287-
207+ SamlUtils .validate_response (r )
288208 return r .text
289209
290210 def _post_form_action_body (self , uri : str , parameters : Dict [str , str ], props : Properties ) -> str :
291211 logger .debug ("AdfsCredentialsProviderFactory.SignOnPagePostActionUrl" , uri )
292- self . _validate_url (uri )
212+ SamlUtils . validate_url (uri )
293213 r = requests .post (uri , data = urlencode (parameters ),
294214 verify = WrapperProperties .SSL_SECURE .get_bool (props ),
295215 timeout = WrapperProperties .HTTP_REQUEST_TIMEOUT .get_int (props ))
296216 # Check HTTP Status Code is 2xx Success
297- if r .status_code / 100 != 2 :
298- error_message = "AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed"
299- logger .debug (error_message , r .status_code , r .reason , r .text )
300- raise AwsWrapperError (
301- Messages .get_formatted (error_message , r .status_code , r .reason , r .text ))
302-
217+ SamlUtils .validate_response (r )
303218 return r .text
304219
305220 def _get_sign_in_page_url (self , props ) -> str :
@@ -308,7 +223,7 @@ def _get_sign_in_page_url(self, props) -> str:
308223 relaying_party_id = WrapperProperties .RELAYING_PARTY_ID .get (props )
309224 url = f"https://{ idp_endpoint } :{ idp_port } /adfs/ls/IdpInitiatedSignOn.aspx?loginToRp={ relaying_party_id } "
310225 if idp_endpoint is None or relaying_party_id is None :
311- error_message = "AdfsCredentialsProviderFactory .InvalidHttpsUrl"
226+ error_message = "SamlUtils .InvalidHttpsUrl"
312227 logger .debug (error_message , url )
313228 raise AwsWrapperError (Messages .get_formatted (error_message , url ))
314229
@@ -319,7 +234,7 @@ def _get_form_action_url(self, props: Properties, action: str) -> str:
319234 idp_port = WrapperProperties .IDP_PORT .get (props )
320235 url = f"https://{ idp_endpoint } :{ idp_port } { action } "
321236 if idp_endpoint is None :
322- error_message = "AdfsCredentialsProviderFactory .InvalidHttpsUrl"
237+ error_message = "SamlUtils .InvalidHttpsUrl"
323238 logger .debug (error_message , url )
324239 raise AwsWrapperError (
325240 Messages .get_formatted (error_message , url ))
@@ -373,10 +288,3 @@ def _get_form_action_from_html_body(self, body: str) -> str:
373288 return unescape (match .group (1 ))
374289
375290 return ""
376-
377- def _validate_url (self , url : str ) -> None :
378- result = urlparse (url )
379- if not result .scheme or not search (self ._HTTPS_URL_PATTERN , url ):
380- error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
381- logger .debug (error_message , url )
382- raise AwsWrapperError (Messages .get_formatted (error_message , url ))
0 commit comments