|
16 | 16 |
|
17 | 17 | from typing import TYPE_CHECKING |
18 | 18 |
|
| 19 | +from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo |
| 20 | + |
19 | 21 | if TYPE_CHECKING: |
20 | 22 | from boto3 import Session |
21 | 23 | from aws_advanced_python_wrapper.driver_dialect import DriverDialect |
|
39 | 41 | logger = Logger(__name__) |
40 | 42 |
|
41 | 43 |
|
42 | | -class TokenInfo: |
43 | | - @property |
44 | | - def token(self): |
45 | | - return self._token |
46 | | - |
47 | | - @property |
48 | | - def expiration(self): |
49 | | - return self._expiration |
50 | | - |
51 | | - def __init__(self, token: str, expiration: datetime): |
52 | | - self._token = token |
53 | | - self._expiration = expiration |
54 | | - |
55 | | - def is_expired(self) -> bool: |
56 | | - return datetime.now() > self._expiration |
57 | | - |
58 | | - |
59 | 44 | class IamAuthPlugin(Plugin): |
60 | 45 | _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} |
61 | 46 | # Leave 30 second buffer to prevent time-of-check to time-of-use errors |
@@ -86,10 +71,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl |
86 | 71 | if not WrapperProperties.USER.get(props): |
87 | 72 | raise AwsWrapperError(Messages.get_formatted("IamPlugin.IsNoneOrEmpty", WrapperProperties.USER.name)) |
88 | 73 |
|
89 | | - host = WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host |
| 74 | + host = IamAuthUtils.get_iam_host(props, host_info) |
90 | 75 | region = WrapperProperties.IAM_REGION.get(props) \ |
91 | 76 | if WrapperProperties.IAM_REGION.get(props) else self._get_rds_region(host) |
92 | | - port = self._get_port(props, host_info) |
| 77 | + port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) |
93 | 78 | token_expiration_sec: int = WrapperProperties.IAM_EXPIRATION.get_int(props) |
94 | 79 |
|
95 | 80 | cache_key: str = self._get_cache_key( |
@@ -170,27 +155,11 @@ def _generate_authentication_token(self, |
170 | 155 | def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: |
171 | 156 | return f"{region}:{hostname}:{port}:{user}" |
172 | 157 |
|
173 | | - def _get_port(self, props: Properties, host_info: HostInfo) -> int: |
174 | | - if WrapperProperties.IAM_DEFAULT_PORT.get(props): |
175 | | - default_port: int = WrapperProperties.IAM_DEFAULT_PORT.get_int(props) |
176 | | - if default_port > 0: |
177 | | - return default_port |
178 | | - else: |
179 | | - logger.debug("IamAuthPlugin.InvalidPort", default_port) |
180 | | - |
181 | | - if host_info.is_port_specified(): |
182 | | - return host_info.port |
183 | | - |
184 | | - if self._plugin_service.database_dialect is not None: |
185 | | - return self._plugin_service.database_dialect.default_port |
186 | | - |
187 | | - raise AwsWrapperError(Messages.get("IamAuthPlugin.NoValidPorts")) |
188 | | - |
189 | 158 | def _get_rds_region(self, hostname: Optional[str]) -> str: |
190 | 159 | rds_region = self._rds_utils.get_rds_region(hostname) if hostname else None |
191 | 160 |
|
192 | 161 | if not rds_region: |
193 | | - exception_message = "IamAuthPlugin.UnsupportedHostname" |
| 162 | + exception_message = "RdsUtils.UnsupportedHostname" |
194 | 163 | logger.debug(exception_message, hostname) |
195 | 164 | raise AwsWrapperError(Messages.get_formatted(exception_message, hostname)) |
196 | 165 |
|
|
0 commit comments