@@ -4,20 +4,11 @@ use pyo3::{prelude::*, pyclass};
44use crate :: {
55 exceptions:: rust_errors:: { RustPSQLDriverError , RustPSQLDriverPyResult } ,
66 query_result:: PSQLDriverPyQueryResult ,
7- value_converter:: { convert_parameters, PythonDTO } ,
87} ;
98
10- use super :: {
11- connection:: Connection ,
12- transaction_options:: { IsolationLevel , ReadVariant } ,
13- } ;
14- use crate :: common:: BaseDataBaseQuery ;
15- use std:: {
16- borrow:: Borrow ,
17- collections:: HashSet ,
18- ops:: { Deref , DerefMut } ,
19- sync:: Arc ,
20- } ;
9+ use super :: transaction_options:: { IsolationLevel , ReadVariant } ;
10+ use crate :: common:: { BaseDataBaseQuery , BaseTransactionQuery } ;
11+ use std:: { collections:: HashSet , sync:: Arc } ;
2112
2213// use super::connection::RustConnection;
2314
@@ -1252,7 +1243,6 @@ use std::{
12521243// }
12531244
12541245#[ pyclass]
1255- #[ derive( Clone ) ]
12561246pub struct Transaction {
12571247 pub db_client : Arc < Object > ,
12581248 is_started : bool ,
@@ -1276,9 +1266,28 @@ impl Transaction {
12761266 slf
12771267 }
12781268
1279- async fn __aenter__ < ' a > ( & mut self ) -> RustPSQLDriverPyResult < Transaction > {
1280- self . begin ( ) . await ?;
1281- Ok ( self . clone ( ) )
1269+ async fn __aenter__ < ' a > ( slf : Py < Self > ) -> RustPSQLDriverPyResult < Py < Self > > {
1270+ let ( is_transaction_ready, isolation_level, read_variant, deferrable, db_client) =
1271+ pyo3:: Python :: with_gil ( |gil| {
1272+ let self_ = slf. borrow ( gil) ;
1273+ (
1274+ self_. check_is_transaction_ready ( ) ,
1275+ self_. isolation_level ,
1276+ self_. read_variant ,
1277+ self_. deferrable ,
1278+ self_. db_client . clone ( ) ,
1279+ )
1280+ } ) ;
1281+ is_transaction_ready?;
1282+ db_client
1283+ . start_transaction ( isolation_level, read_variant, deferrable)
1284+ . await ?;
1285+
1286+ Python :: with_gil ( |gil| {
1287+ let mut self_ = slf. borrow_mut ( gil) ;
1288+ self_. is_started = true ;
1289+ } ) ;
1290+ Ok ( slf)
12821291 }
12831292
12841293 #[ allow( clippy:: needless_pass_by_value) ]
@@ -1288,21 +1297,30 @@ impl Transaction {
12881297 exception : Py < PyAny > ,
12891298 _traceback : Py < PyAny > ,
12901299 ) -> RustPSQLDriverPyResult < ( ) > {
1291- let ( is_exception_none, py_err, mut transaction) = pyo3:: Python :: with_gil ( |gil| {
1292- (
1293- exception. is_none ( gil) ,
1294- PyErr :: from_value_bound ( exception. into_bound ( gil) ) ,
1295- slf. borrow_mut ( gil) . clone ( ) ,
1296- )
1297- } ) ;
1298-
1299- if is_exception_none {
1300- transaction. commit ( ) . await ?;
1300+ let ( is_transaction_ready, is_exception_none, py_err, db_client) =
1301+ pyo3:: Python :: with_gil ( |gil| {
1302+ let self_ = slf. borrow ( gil) ;
1303+ (
1304+ self_. check_is_transaction_ready ( ) ,
1305+ exception. is_none ( gil) ,
1306+ PyErr :: from_value_bound ( exception. into_bound ( gil) ) ,
1307+ self_. db_client . clone ( ) ,
1308+ )
1309+ } ) ;
1310+ is_transaction_ready?;
1311+ let exit_result = if is_exception_none {
1312+ db_client. commit ( ) . await ?;
13011313 Ok ( ( ) )
13021314 } else {
1303- transaction . rollback ( ) . await ?;
1315+ db_client . rollback ( ) . await ?;
13041316 Err ( RustPSQLDriverError :: PyError ( py_err) )
1305- }
1317+ } ;
1318+
1319+ pyo3:: Python :: with_gil ( |gil| {
1320+ let mut self_ = slf. borrow_mut ( gil) ;
1321+ self_. is_done = true ;
1322+ } ) ;
1323+ exit_result
13061324 }
13071325
13081326 /// Commit the transaction.
@@ -1317,9 +1335,8 @@ impl Transaction {
13171335 /// 3) Cannot execute `COMMIT` command
13181336 pub async fn commit ( & mut self ) -> RustPSQLDriverPyResult < ( ) > {
13191337 self . check_is_transaction_ready ( ) ?;
1320- self . db_client . batch_execute ( "COMMIT;" ) . await ?;
1338+ self . db_client . commit ( ) . await ?;
13211339 self . is_done = true ;
1322-
13231340 Ok ( ( ) )
13241341 }
13251342
@@ -1350,15 +1367,17 @@ impl Transaction {
13501367 /// 1) Cannot convert python parameters
13511368 /// 2) Cannot execute querystring.
13521369 pub async fn execute (
1353- self_ : pyo3 :: Py < Self > ,
1370+ slf : Py < Self > ,
13541371 querystring : String ,
13551372 parameters : Option < pyo3:: Py < PyAny > > ,
13561373 prepared : Option < bool > ,
13571374 ) -> RustPSQLDriverPyResult < PSQLDriverPyQueryResult > {
1358- let transaction = pyo3:: Python :: with_gil ( |gil| self_. borrow ( gil) . clone ( ) ) ;
1359- transaction. check_is_transaction_ready ( ) ?;
1360- transaction
1361- . db_client
1375+ let ( is_transaction_ready, db_client) = pyo3:: Python :: with_gil ( |gil| {
1376+ let self_ = slf. borrow ( gil) ;
1377+ ( self_. check_is_transaction_ready ( ) , self_. db_client . clone ( ) )
1378+ } ) ;
1379+ is_transaction_ready?;
1380+ db_client
13621381 . psqlpy_query ( querystring, parameters, prepared)
13631382 . await
13641383 }
@@ -1443,7 +1462,6 @@ impl Transaction {
14431462
14441463 Ok ( ( ) )
14451464 }
1446-
14471465 fn check_is_transaction_ready ( & self ) -> RustPSQLDriverPyResult < ( ) > {
14481466 if !self . is_started {
14491467 return Err ( RustPSQLDriverError :: DataBaseTransactionError (
@@ -1455,7 +1473,6 @@ impl Transaction {
14551473 "Transaction is already committed or rolled back" . into ( ) ,
14561474 ) ) ;
14571475 }
1458-
14591476 Ok ( ( ) )
14601477 }
14611478}
0 commit comments