Skip to content

Commit 3b33c18

Browse files
authored
fix: rollback transaction on is_dialect exceptions (#341)
1 parent 73934b9 commit 3b33c18

File tree

4 files changed

+61
-53
lines changed

4 files changed

+61
-53
lines changed

aws_advanced_python_wrapper/database_dialect.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
130130
...
131131

132132
@abstractmethod
133-
def is_dialect(self, conn: Connection) -> bool:
133+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
134134
...
135135

136136
@abstractmethod
@@ -184,7 +184,8 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
184184
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
185185
return MysqlDatabaseDialect._DIALECT_UPDATE_CANDIDATES
186186

187-
def is_dialect(self, conn: Connection) -> bool:
187+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
188+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
188189
try:
189190
with closing(conn.cursor()) as cursor:
190191
cursor.execute(self.server_version_query)
@@ -193,7 +194,8 @@ def is_dialect(self, conn: Connection) -> bool:
193194
if "mysql" in column_value.lower():
194195
return True
195196
except Exception:
196-
pass
197+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
198+
conn.rollback()
197199

198200
return False
199201

@@ -232,16 +234,16 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
232234
"aws_advanced_python_wrapper.utils.pg_exception_handler.SingleAzPgExceptionHandler")
233235
return PgDatabaseDialect._exception_handler
234236

235-
def is_dialect(self, conn: Connection) -> bool:
237+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
238+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
236239
try:
237240
with closing(conn.cursor()) as cursor:
238241
cursor.execute('SELECT 1 FROM pg_proc LIMIT 1')
239242
if cursor.fetchone() is not None:
240243
return True
241244
except Exception:
242-
# Executing the select statements will start a transaction, if the queries failed due to invalid syntax,
243-
# the transaction will be aborted, no further commands can be executed. We need to call rollback here.
244-
conn.rollback()
245+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
246+
conn.rollback()
245247

246248
return False
247249

@@ -255,7 +257,8 @@ def prepare_conn_props(self, props: Properties):
255257
class RdsMysqlDialect(MysqlDatabaseDialect):
256258
_DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_MYSQL, DialectCode.MULTI_AZ_MYSQL)
257259

258-
def is_dialect(self, conn: Connection) -> bool:
260+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
261+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
259262
try:
260263
with closing(conn.cursor()) as cursor:
261264
cursor.execute(self.server_version_query)
@@ -264,7 +267,8 @@ def is_dialect(self, conn: Connection) -> bool:
264267
if "source distribution" in column_value.lower():
265268
return True
266269
except Exception:
267-
pass
270+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
271+
conn.rollback()
268272

269273
return False
270274

@@ -280,8 +284,9 @@ class RdsPgDialect(PgDatabaseDialect):
280284
"WHERE name='rds.extensions'")
281285
_DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_PG, DialectCode.MULTI_AZ_PG)
282286

283-
def is_dialect(self, conn: Connection) -> bool:
284-
if not super().is_dialect(conn):
287+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
288+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
289+
if not super().is_dialect(conn, driver_dialect):
285290
return False
286291

287292
try:
@@ -296,9 +301,8 @@ def is_dialect(self, conn: Connection) -> bool:
296301
return True
297302

298303
except Exception:
299-
# Executing the select statements will start a transaction, if the queries failed due to invalid syntax,
300-
# the transaction will be aborted, no further commands can be executed. We need to call rollback here.
301-
conn.rollback()
304+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
305+
conn.rollback()
302306
return False
303307

304308
@property
@@ -320,15 +324,17 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect):
320324
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
321325
return AuroraMysqlDialect._DIALECT_UPDATE_CANDIDATES
322326

323-
def is_dialect(self, conn: Connection) -> bool:
327+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
328+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
324329
try:
325330
with closing(conn.cursor()) as cursor:
326331
cursor.execute("SHOW VARIABLES LIKE 'aurora_version'")
327332
# If variable with such a name is presented then it means it's an Aurora cluster
328333
if cursor.fetchone() is not None:
329334
return True
330335
except Exception:
331-
pass
336+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
337+
conn.rollback()
332338

333339
return False
334340

@@ -358,13 +364,14 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
358364
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
359365
return AuroraPgDialect._DIALECT_UPDATE_CANDIDATES
360366

361-
def is_dialect(self, conn: Connection) -> bool:
362-
if not super().is_dialect(conn):
367+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
368+
if not super().is_dialect(conn, driver_dialect):
363369
return False
364370

365371
has_extensions: bool = False
366372
has_topology: bool = False
367373

374+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
368375
try:
369376
with closing(conn.cursor()) as cursor:
370377
cursor.execute(self._EXTENSIONS_QUERY)
@@ -381,9 +388,8 @@ def is_dialect(self, conn: Connection) -> bool:
381388

382389
return has_extensions and has_topology
383390
except Exception:
384-
# Executing the select statements will start a transaction, if the queries failed due to invalid syntax,
385-
# the transaction will be aborted, no further commands can be executed. We need to call rollback here.
386-
conn.rollback()
391+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
392+
conn.rollback()
387393

388394
return False
389395

@@ -402,15 +408,17 @@ class MultiAzMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect):
402408
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
403409
return None
404410

405-
def is_dialect(self, conn: Connection) -> bool:
411+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
412+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
406413
try:
407414
with closing(conn.cursor()) as cursor:
408415
cursor.execute(MultiAzMysqlDialect._TOPOLOGY_QUERY)
409416
records = cursor.fetchall()
410417
if records is not None and len(records) > 0:
411418
return True
412419
except Exception:
413-
pass
420+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
421+
conn.rollback()
414422

415423
return False
416424

@@ -459,14 +467,16 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
459467
"aws_advanced_python_wrapper.utils.pg_exception_handler.MultiAzPgExceptionHandler")
460468
return MultiAzPgDialect._exception_handler
461469

462-
def is_dialect(self, conn: Connection) -> bool:
470+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
471+
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
463472
try:
464473
with closing(conn.cursor()) as cursor:
465474
cursor.execute(MultiAzPgDialect._WRITER_HOST_QUERY)
466475
if cursor.fetchone() is not None:
467476
return True
468477
except Exception:
469-
pass
478+
if not initial_transaction_status and driver_dialect.is_in_transaction(conn):
479+
conn.rollback()
470480

471481
return False
472482

@@ -511,7 +521,7 @@ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
511521
def exception_handler(self) -> Optional[ExceptionHandler]:
512522
return None
513523

514-
def is_dialect(self, conn: Connection) -> bool:
524+
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
515525
return False
516526

517527
def get_host_list_provider_supplier(self) -> Callable:
@@ -662,7 +672,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne
662672
timeout_sec,
663673
driver_dialect,
664674
conn)(dialect_candidate.is_dialect)
665-
is_dialect = cursor_execute_func_with_timeout(conn)
675+
is_dialect = cursor_execute_func_with_timeout(conn, driver_dialect)
666676
except TimeoutError as e:
667677
raise QueryTimeoutError("DatabaseDialectManager.QueryForDialectTimeout") from e
668678

aws_advanced_python_wrapper/pg_driver_dialect.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ class PgDriverDialect(DriverDialect):
3838
TARGET_DRIVER_CODE: str = "psycopg"
3939

4040
# https://www.psycopg.org/psycopg3/docs/api/pq.html#psycopg.pq.TransactionStatus
41-
PSYCOPG_ACTIVE_TRANSACTION_STATUS = 1
42-
PSYCOPG_IN_TRANSACTION_STATUS = 2
41+
PSYCOPG_IDLE_TRANSACTION_STATUS = 0
4342

4443
_dialect_code: str = DriverDialectCodes.PSYCOPG
4544
_network_bound_methods: Set[str] = {
@@ -78,7 +77,7 @@ def abort_connection(self, conn: Connection):
7877
def is_in_transaction(self, conn: Connection) -> bool:
7978
if isinstance(conn, psycopg.Connection):
8079
status: int = conn.info.transaction_status
81-
return status == self.PSYCOPG_ACTIVE_TRANSACTION_STATUS or status == self.PSYCOPG_IN_TRANSACTION_STATUS
80+
return status != self.PSYCOPG_IDLE_TRANSACTION_STATUS
8281

8382
raise UnsupportedOperationError(Messages.get_formatted(
8483
"DriverDialect.UnsupportedOperationError",

tests/integration/container/test_basic_connectivity.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def test_proxied_wrapper_connection_failed(
110110
# That is expected exception. Test pass.
111111
assert True
112112

113-
@pytest.mark.skip("Currently failing, will be fixed in another PR")
114113
@enable_on_num_instances(min_instances=2)
115114
@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.MULTI_AZ])
116115
@enable_on_features([TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED])

tests/unit/test_dialect.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,96 +99,96 @@ def mock_default_behavior(mock_conn, mock_cursor, mock_fetchone_row):
9999
mock_cursor.fetchone.return_value = mock_fetchone_row
100100

101101

102-
def test_pg_is_dialect(mock_conn, mock_cursor, mock_session, pg_dialect):
103-
assert pg_dialect.is_dialect(mock_conn)
102+
def test_pg_is_dialect(mock_conn, mock_cursor, mock_session, pg_dialect, mock_driver_dialect):
103+
assert pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
104104

105105
mock_cursor.fetchone.return_value = None
106-
assert not pg_dialect.is_dialect(mock_conn)
106+
assert not pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
107107

108108

109-
def test_mysql_is_dialect(mock_conn, mock_cursor, mock_session, mysql_dialect):
109+
def test_mysql_is_dialect(mock_conn, mock_cursor, mock_session, mysql_dialect, mock_driver_dialect):
110110
records = [("some_value", "some_value"), ("some_value", "mysql")]
111111
mock_cursor.__iter__.return_value = records
112112

113-
assert mysql_dialect.is_dialect(mock_conn)
113+
assert mysql_dialect.is_dialect(mock_conn, mock_driver_dialect)
114114

115115
records = [("some_value", "some_value"), ("some_value", "some_value")]
116116
mock_cursor.__iter__.return_value = records
117117

118-
assert not mysql_dialect.is_dialect(mock_conn)
118+
assert not mysql_dialect.is_dialect(mock_conn, mock_driver_dialect)
119119

120120

121121
@patch('aws_advanced_python_wrapper.database_dialect.super')
122-
def test_rds_mysql_is_dialect(mock_super, mock_cursor, mock_conn, rds_mysql_dialect):
122+
def test_rds_mysql_is_dialect(mock_super, mock_cursor, mock_conn, rds_mysql_dialect, mock_driver_dialect):
123123
mock_super().is_dialect.return_value = True
124124

125125
records = [("some_value", "some_value"), ("some_value", "source distribution")]
126126
mock_cursor.__iter__.return_value = records
127127

128-
assert rds_mysql_dialect.is_dialect(mock_conn)
128+
assert rds_mysql_dialect.is_dialect(mock_conn, mock_driver_dialect)
129129

130130
records = [("some_value", "some_value"), ("some_value", "some_value")]
131131
mock_cursor.__iter__.return_value = records
132132

133-
assert not rds_mysql_dialect.is_dialect(mock_conn)
133+
assert not rds_mysql_dialect.is_dialect(mock_conn, mock_driver_dialect)
134134

135135
mock_super().is_dialect.return_value = False
136136

137-
assert not rds_mysql_dialect.is_dialect(mock_conn)
137+
assert not rds_mysql_dialect.is_dialect(mock_conn, mock_driver_dialect)
138138

139139

140-
def test_aurora_mysql_is_dialect(mock_conn, mock_cursor):
140+
def test_aurora_mysql_is_dialect(mock_conn, mock_cursor, mock_driver_dialect):
141141
mock_conn.cursor.return_value = mock_cursor
142142
mock_cursor.fetchone.return_value = None
143143

144144
dialect = AuroraMysqlDialect()
145-
assert dialect.is_dialect(mock_conn) is False
145+
assert dialect.is_dialect(mock_conn, mock_driver_dialect) is False
146146

147147
mock_cursor.fetchone.return_value = ('aurora_version', '3.0.0')
148-
assert dialect.is_dialect(mock_conn) is True
148+
assert dialect.is_dialect(mock_conn, mock_driver_dialect) is True
149149

150150

151151
@patch('aws_advanced_python_wrapper.database_dialect.super')
152-
def test_aurora_pg_is_dialect(mock_super, mock_conn, mock_cursor):
152+
def test_aurora_pg_is_dialect(mock_super, mock_conn, mock_cursor, mock_driver_dialect):
153153
aurora_pg_dialect = AuroraPgDialect()
154154
mock_conn.cursor.return_value = mock_cursor
155155
mock_super().is_dialect.return_value = True
156156

157157
records = [("aurora_stat_utils", "aurora_stat_utils"), ("some_value", "source distribution")]
158158
mock_cursor.__iter__.return_value = records
159159

160-
assert aurora_pg_dialect.is_dialect(mock_conn)
160+
assert aurora_pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
161161

162162
mock_cursor.fetchone.return_value = None
163163

164-
assert not aurora_pg_dialect.is_dialect(mock_conn)
164+
assert not aurora_pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
165165

166166
mock_cursor.fetchone.return_value = records
167167
mock_super().is_dialect.return_value = False
168168

169-
assert not aurora_pg_dialect.is_dialect(mock_conn)
169+
assert not aurora_pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
170170

171171

172172
@patch('aws_advanced_python_wrapper.database_dialect.super')
173-
def test_rds_pg_is_dialect(mock_super, mock_cursor, mock_conn, rds_pg_dialect):
173+
def test_rds_pg_is_dialect(mock_super, mock_cursor, mock_conn, rds_pg_dialect, mock_driver_dialect):
174174
mock_super().is_dialect.return_value = True
175175

176176
mock_cursor.__iter__.return_value = [(False, False), (True, False)]
177177

178-
assert rds_pg_dialect.is_dialect(mock_conn)
178+
assert rds_pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
179179

180180
mock_cursor.__iter__.return_value = [(False, False), (False, False)]
181181

182-
assert not rds_pg_dialect.is_dialect(mock_conn)
182+
assert not rds_pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
183183

184184
mock_cursor.__iter__.return_value = []
185185

186-
assert not rds_pg_dialect.is_dialect(mock_conn)
186+
assert not rds_pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
187187

188188
mock_cursor.fetchone.return_value = [(False, False), (True, False)]
189189
mock_super().is_dialect.return_value = False
190190

191-
assert not rds_pg_dialect.is_dialect(mock_conn)
191+
assert not rds_pg_dialect.is_dialect(mock_conn, mock_driver_dialect)
192192

193193

194194
def test_get_dialect_custom_dialect(mock_custom_dialect, mock_driver_dialect):

0 commit comments

Comments
 (0)