Skip to content

Commit 5e59a35

Browse files
committed
fix static check
1 parent 6cb65ec commit 5e59a35

33 files changed

+533
-504
lines changed

aws_advanced_python_wrapper/driver_dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def ping(self, conn: Connection) -> bool:
165165
return True
166166
except Exception:
167167
return False
168-
168+
169169
def get_driver_module(self) -> ModuleType:
170170
raise UnsupportedOperationError(
171171
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "get_driver_module"))

aws_advanced_python_wrapper/driver_dialect_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from aws_advanced_python_wrapper.errors import AwsWrapperError
2525
from aws_advanced_python_wrapper.utils.log import Logger
2626
from aws_advanced_python_wrapper.utils.messages import Messages
27-
from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperties
27+
from aws_advanced_python_wrapper.utils.properties import (Properties,
28+
WrapperProperties)
2829
from aws_advanced_python_wrapper.utils.utils import Utils
2930

3031
logger = Logger(__name__)

aws_advanced_python_wrapper/mysql_driver_dialect.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from __future__ import annotations
1615

1716
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Set
1817

18+
import mysql.connector
19+
1920
if TYPE_CHECKING:
2021
from aws_advanced_python_wrapper.hostinfo import HostInfo
2122
from aws_advanced_python_wrapper.pep249 import Connection
@@ -29,11 +30,12 @@
2930
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
3031
from aws_advanced_python_wrapper.utils.decorators import timeout
3132
from aws_advanced_python_wrapper.utils.messages import Messages
32-
from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties
33+
from aws_advanced_python_wrapper.utils.properties import (Properties,
34+
PropertiesUtils,
35+
WrapperProperties)
3336

3437
CMYSQL_ENABLED = False
3538

36-
import mysql.connector
3739
from mysql.connector import MySQLConnection # noqa: E402
3840
from mysql.connector.cursor import MySQLCursor # noqa: E402
3941

aws_advanced_python_wrapper/pg_driver_dialect.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
2929
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
3030
from aws_advanced_python_wrapper.utils.messages import Messages
31-
from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils, WrapperProperties
31+
from aws_advanced_python_wrapper.utils.properties import (Properties,
32+
PropertiesUtils,
33+
WrapperProperties)
3234

3335

3436
class PgDriverDialect(DriverDialect):

aws_advanced_python_wrapper/sql_alchemy_connection_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool:
8888
if self._accept_url_func:
8989
return self._accept_url_func(host_info, props)
9090
url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host)
91-
return RdsUrlType.RDS_INSTANCE == url_type or RdsUrlType.RDS_WRITER_CLUSTER
91+
return url_type in (RdsUrlType.RDS_INSTANCE, RdsUrlType.RDS_WRITER_CLUSTER)
9292

9393
def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
9494
return strategy == SqlAlchemyPooledConnectionProvider._LEAST_CONNECTIONS or strategy in self._accepted_strategies

aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
class SqlAlchemyDriverDialect(DriverDialect):
3333
_driver_name: str = "SQLAlchemy"
3434
TARGET_DRIVER_CODE: str = "sqlalchemy"
35+
_underlying_driver_dialect = None
3536

3637
def __init__(self, underlying_driver: DriverDialect, props: Properties):
3738
super().__init__(props)
@@ -126,6 +127,6 @@ def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
126127
return None
127128

128129
return self._underlying_driver.transfer_session_state(from_driver_conn, to_driver_conn)
129-
130+
130131
def get_driver_module(self) -> ModuleType:
131132
return self._underlying_driver.get_driver_module()

aws_advanced_python_wrapper/tortoise/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
DB_LOOKUP["aws-mysql"] = {
1919
"engine": "aws_advanced_python_wrapper.tortoise.backend.mysql",
2020
"vmap": {
21-
"path": "database",
22-
"hostname": "host",
23-
"port": "port",
24-
"username": "user",
25-
"password": "password",
26-
},
21+
"path": "database",
22+
"hostname": "host",
23+
"port": "port",
24+
"username": "user",
25+
"password": "password",
26+
},
2727
"defaults": {"port": 3306, "charset": "utf8mb4", "sql_mode": "STRICT_TRANS_TABLES"},
2828
"cast": {
2929
"minsize": int,

aws_advanced_python_wrapper/tortoise/backend/base/client.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import asyncio
16-
import mysql.connector
1718
from contextlib import asynccontextmanager
18-
from typing import Any, Callable, Generic
19+
from typing import Any, Callable, Dict, Generic, cast
1920

20-
from tortoise.backends.base.client import BaseDBAsyncClient, T_conn, TransactionalDBClient, TransactionContext
21+
import mysql.connector
22+
from tortoise.backends.base.client import (BaseDBAsyncClient, T_conn,
23+
TransactionalDBClient,
24+
TransactionContext)
2125
from tortoise.connection import connections
2226
from tortoise.exceptions import TransactionManagementError
2327

@@ -26,55 +30,55 @@
2630

2731
class AwsWrapperAsyncConnector:
2832
"""Class for creating and closing AWS wrapper connections."""
29-
33+
3034
@staticmethod
31-
async def ConnectWithAwsWrapper(connect_func: Callable, **kwargs) -> AwsWrapperConnection:
35+
async def connect_with_aws_wrapper(connect_func: Callable, **kwargs) -> AwsConnectionAsyncWrapper:
3236
"""Create an AWS wrapper connection with async cursor support."""
3337
connection = await asyncio.to_thread(
3438
AwsWrapperConnection.connect, connect_func, **kwargs
3539
)
3640
return AwsConnectionAsyncWrapper(connection)
37-
41+
3842
@staticmethod
39-
async def CloseAwsWrapper(connection: AwsWrapperConnection) -> None:
43+
async def close_aws_wrapper(connection: AwsWrapperConnection) -> None:
4044
"""Close an AWS wrapper connection asynchronously."""
4145
await asyncio.to_thread(connection.close)
4246

4347

4448
class AwsCursorAsyncWrapper:
4549
"""Wraps sync AwsCursor cursor with async support."""
46-
50+
4751
def __init__(self, sync_cursor):
4852
self._cursor = sync_cursor
49-
53+
5054
async def execute(self, query, params=None):
5155
"""Execute a query asynchronously."""
5256
return await asyncio.to_thread(self._cursor.execute, query, params)
53-
57+
5458
async def executemany(self, query, params_list):
5559
"""Execute multiple queries asynchronously."""
5660
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
57-
61+
5862
async def fetchall(self):
5963
"""Fetch all results asynchronously."""
6064
return await asyncio.to_thread(self._cursor.fetchall)
61-
65+
6266
async def fetchone(self):
6367
"""Fetch one result asynchronously."""
6468
return await asyncio.to_thread(self._cursor.fetchone)
65-
69+
6670
async def close(self):
6771
"""Close cursor asynchronously."""
6872
return await asyncio.to_thread(self._cursor.close)
69-
73+
7074
def __getattr__(self, name):
7175
"""Delegate non-async attributes to the wrapped cursor."""
7276
return getattr(self._cursor, name)
7377

7478

7579
class AwsConnectionAsyncWrapper(AwsWrapperConnection):
7680
"""Wraps sync AwsConnection with async cursor support."""
77-
81+
7882
def __init__(self, connection: AwsWrapperConnection):
7983
self._wrapped_connection = connection
8084

@@ -90,40 +94,50 @@ async def cursor(self):
9094
async def rollback(self):
9195
"""Rollback the current transaction."""
9296
return await asyncio.to_thread(self._wrapped_connection.rollback)
93-
97+
9498
async def commit(self):
9599
"""Commit the current transaction."""
96100
return await asyncio.to_thread(self._wrapped_connection.commit)
97-
101+
98102
async def set_autocommit(self, value: bool):
99103
"""Set autocommit mode."""
100104
return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value)
101105

102106
def __getattr__(self, name):
103107
"""Delegate all other attributes/methods to the wrapped connection."""
104108
return getattr(self._wrapped_connection, name)
105-
109+
106110
def __del__(self):
107111
"""Delegate cleanup to wrapped connection."""
108112
if hasattr(self, '_wrapped_connection'):
109113
# Let the wrapped connection handle its own cleanup
110114
pass
111115

112116

117+
class AwsBaseDBAsyncClient(BaseDBAsyncClient):
118+
_template: Dict[str, Any]
119+
120+
121+
class AwsTransactionalDBClient(TransactionalDBClient):
122+
_template: Dict[str, Any]
123+
_parent: AwsBaseDBAsyncClient
124+
pass
125+
126+
113127
class TortoiseAwsClientConnectionWrapper(Generic[T_conn]):
114128
"""Manages acquiring from and releasing connections to a pool."""
115129

116130
__slots__ = ("client", "connection", "connect_func", "with_db")
117131

118132
def __init__(
119-
self,
120-
client: BaseDBAsyncClient,
121-
connect_func: Callable,
133+
self,
134+
client: AwsBaseDBAsyncClient,
135+
connect_func: Callable,
122136
with_db: bool = True
123137
) -> None:
124138
self.connect_func = connect_func
125139
self.client = client
126-
self.connection: T_conn | None = None
140+
self.connection: AwsConnectionAsyncWrapper | None = None
127141
self.with_db = with_db
128142

129143
async def ensure_connection(self) -> None:
@@ -133,22 +147,22 @@ async def ensure_connection(self) -> None:
133147
async def __aenter__(self) -> T_conn:
134148
"""Acquire connection from pool."""
135149
await self.ensure_connection()
136-
self.connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(self.connect_func, **self.client._template)
137-
return self.connection
150+
self.connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(self.connect_func, **self.client._template)
151+
return cast("T_conn", self.connection)
138152

139153
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
140154
"""Close connection and release back to pool."""
141155
if self.connection:
142-
await AwsWrapperAsyncConnector.CloseAwsWrapper(self.connection)
156+
await AwsWrapperAsyncConnector.close_aws_wrapper(self.connection)
143157

144158

145159
class TortoiseAwsClientTransactionContext(TransactionContext):
146160
"""Transaction context that uses a pool to acquire connections."""
147161

148162
__slots__ = ("client", "connection_name", "token")
149163

150-
def __init__(self, client: TransactionalDBClient) -> None:
151-
self.client = client
164+
def __init__(self, client: AwsTransactionalDBClient) -> None:
165+
self.client: AwsTransactionalDBClient = client
152166
self.connection_name = client.connection_name
153167

154168
async def ensure_connection(self) -> None:
@@ -158,13 +172,13 @@ async def ensure_connection(self) -> None:
158172
async def __aenter__(self) -> TransactionalDBClient:
159173
"""Enter transaction context."""
160174
await self.ensure_connection()
161-
175+
162176
# Set the context variable so the current task sees a TransactionWrapper connection
163177
self.token = connections.set(self.connection_name, self.client)
164-
178+
165179
# Create connection and begin transaction
166-
self.client._connection = await AwsWrapperAsyncConnector.ConnectWithAwsWrapper(
167-
mysql.connector.Connect,
180+
self.client._connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(
181+
mysql.connector.Connect,
168182
**self.client._parent._template
169183
)
170184
await self.client.begin()
@@ -181,5 +195,5 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
181195
else:
182196
await self.client.commit()
183197
finally:
184-
await AwsWrapperAsyncConnector.CloseAwsWrapper(self.client._connection)
198+
await AwsWrapperAsyncConnector.close_aws_wrapper(self.client._connection)
185199
connections.reset(self.token)

0 commit comments

Comments
 (0)