@@ -130,7 +130,7 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
130130 ...
131131
132132 @abstractmethod
133- def is_dialect (self , conn : Connection ) -> bool :
133+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
134134 ...
135135
136136 @abstractmethod
@@ -184,7 +184,8 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
184184 def dialect_update_candidates (self ) -> Optional [Tuple [DialectCode , ...]]:
185185 return MysqlDatabaseDialect ._DIALECT_UPDATE_CANDIDATES
186186
187- def is_dialect (self , conn : Connection ) -> bool :
187+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
188+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
188189 try :
189190 with closing (conn .cursor ()) as cursor :
190191 cursor .execute (self .server_version_query )
@@ -193,7 +194,8 @@ def is_dialect(self, conn: Connection) -> bool:
193194 if "mysql" in column_value .lower ():
194195 return True
195196 except Exception :
196- pass
197+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
198+ conn .rollback ()
197199
198200 return False
199201
@@ -232,16 +234,16 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
232234 "aws_advanced_python_wrapper.utils.pg_exception_handler.SingleAzPgExceptionHandler" )
233235 return PgDatabaseDialect ._exception_handler
234236
235- def is_dialect (self , conn : Connection ) -> bool :
237+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
238+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
236239 try :
237240 with closing (conn .cursor ()) as cursor :
238241 cursor .execute ('SELECT 1 FROM pg_proc LIMIT 1' )
239242 if cursor .fetchone () is not None :
240243 return True
241244 except Exception :
242- # Executing the select statements will start a transaction, if the queries failed due to invalid syntax,
243- # the transaction will be aborted, no further commands can be executed. We need to call rollback here.
244- conn .rollback ()
245+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
246+ conn .rollback ()
245247
246248 return False
247249
@@ -255,7 +257,8 @@ def prepare_conn_props(self, props: Properties):
255257class RdsMysqlDialect (MysqlDatabaseDialect ):
256258 _DIALECT_UPDATE_CANDIDATES = (DialectCode .AURORA_MYSQL , DialectCode .MULTI_AZ_MYSQL )
257259
258- def is_dialect (self , conn : Connection ) -> bool :
260+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
261+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
259262 try :
260263 with closing (conn .cursor ()) as cursor :
261264 cursor .execute (self .server_version_query )
@@ -264,7 +267,8 @@ def is_dialect(self, conn: Connection) -> bool:
264267 if "source distribution" in column_value .lower ():
265268 return True
266269 except Exception :
267- pass
270+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
271+ conn .rollback ()
268272
269273 return False
270274
@@ -280,8 +284,9 @@ class RdsPgDialect(PgDatabaseDialect):
280284 "WHERE name='rds.extensions'" )
281285 _DIALECT_UPDATE_CANDIDATES = (DialectCode .AURORA_PG , DialectCode .MULTI_AZ_PG )
282286
283- def is_dialect (self , conn : Connection ) -> bool :
284- if not super ().is_dialect (conn ):
287+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
288+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
289+ if not super ().is_dialect (conn , driver_dialect ):
285290 return False
286291
287292 try :
@@ -296,9 +301,8 @@ def is_dialect(self, conn: Connection) -> bool:
296301 return True
297302
298303 except Exception :
299- # Executing the select statements will start a transaction, if the queries failed due to invalid syntax,
300- # the transaction will be aborted, no further commands can be executed. We need to call rollback here.
301- conn .rollback ()
304+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
305+ conn .rollback ()
302306 return False
303307
304308 @property
@@ -320,15 +324,17 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect):
320324 def dialect_update_candidates (self ) -> Optional [Tuple [DialectCode , ...]]:
321325 return AuroraMysqlDialect ._DIALECT_UPDATE_CANDIDATES
322326
323- def is_dialect (self , conn : Connection ) -> bool :
327+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
328+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
324329 try :
325330 with closing (conn .cursor ()) as cursor :
326331 cursor .execute ("SHOW VARIABLES LIKE 'aurora_version'" )
327332 # If variable with such a name is presented then it means it's an Aurora cluster
328333 if cursor .fetchone () is not None :
329334 return True
330335 except Exception :
331- pass
336+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
337+ conn .rollback ()
332338
333339 return False
334340
@@ -358,13 +364,14 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
358364 def dialect_update_candidates (self ) -> Optional [Tuple [DialectCode , ...]]:
359365 return AuroraPgDialect ._DIALECT_UPDATE_CANDIDATES
360366
361- def is_dialect (self , conn : Connection ) -> bool :
362- if not super ().is_dialect (conn ):
367+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
368+ if not super ().is_dialect (conn , driver_dialect ):
363369 return False
364370
365371 has_extensions : bool = False
366372 has_topology : bool = False
367373
374+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
368375 try :
369376 with closing (conn .cursor ()) as cursor :
370377 cursor .execute (self ._EXTENSIONS_QUERY )
@@ -381,9 +388,8 @@ def is_dialect(self, conn: Connection) -> bool:
381388
382389 return has_extensions and has_topology
383390 except Exception :
384- # Executing the select statements will start a transaction, if the queries failed due to invalid syntax,
385- # the transaction will be aborted, no further commands can be executed. We need to call rollback here.
386- conn .rollback ()
391+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
392+ conn .rollback ()
387393
388394 return False
389395
@@ -402,15 +408,17 @@ class MultiAzMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect):
402408 def dialect_update_candidates (self ) -> Optional [Tuple [DialectCode , ...]]:
403409 return None
404410
405- def is_dialect (self , conn : Connection ) -> bool :
411+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
412+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
406413 try :
407414 with closing (conn .cursor ()) as cursor :
408415 cursor .execute (MultiAzMysqlDialect ._TOPOLOGY_QUERY )
409416 records = cursor .fetchall ()
410417 if records is not None and len (records ) > 0 :
411418 return True
412419 except Exception :
413- pass
420+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
421+ conn .rollback ()
414422
415423 return False
416424
@@ -459,14 +467,16 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
459467 "aws_advanced_python_wrapper.utils.pg_exception_handler.MultiAzPgExceptionHandler" )
460468 return MultiAzPgDialect ._exception_handler
461469
462- def is_dialect (self , conn : Connection ) -> bool :
470+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
471+ initial_transaction_status : bool = driver_dialect .is_in_transaction (conn )
463472 try :
464473 with closing (conn .cursor ()) as cursor :
465474 cursor .execute (MultiAzPgDialect ._WRITER_HOST_QUERY )
466475 if cursor .fetchone () is not None :
467476 return True
468477 except Exception :
469- pass
478+ if not initial_transaction_status and driver_dialect .is_in_transaction (conn ):
479+ conn .rollback ()
470480
471481 return False
472482
@@ -511,7 +521,7 @@ def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
511521 def exception_handler (self ) -> Optional [ExceptionHandler ]:
512522 return None
513523
514- def is_dialect (self , conn : Connection ) -> bool :
524+ def is_dialect (self , conn : Connection , driver_dialect : DriverDialect ) -> bool :
515525 return False
516526
517527 def get_host_list_provider_supplier (self ) -> Callable :
@@ -662,7 +672,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne
662672 timeout_sec ,
663673 driver_dialect ,
664674 conn )(dialect_candidate .is_dialect )
665- is_dialect = cursor_execute_func_with_timeout (conn )
675+ is_dialect = cursor_execute_func_with_timeout (conn , driver_dialect )
666676 except TimeoutError as e :
667677 raise QueryTimeoutError ("DatabaseDialectManager.QueryForDialectTimeout" ) from e
668678
0 commit comments