@@ -4,22 +4,14 @@ use pyo3::{prelude::*, pyclass};
44use crate :: {
55 exceptions:: rust_errors:: { RustPSQLDriverError , RustPSQLDriverPyResult } ,
66 query_result:: PSQLDriverPyQueryResult ,
7- value_converter:: { convert_parameters, PythonDTO } ,
87} ;
98
109use super :: {
11- connection:: Connection ,
1210 cursor:: Cursor ,
1311 transaction_options:: { IsolationLevel , ReadVariant } ,
1412} ;
15- use crate :: common:: ObjectQueryTrait ;
16- use std:: {
17- borrow:: Borrow ,
18- collections:: HashSet ,
19- ops:: { Deref , DerefMut } ,
20- sync:: Arc ,
21- } ;
22-
13+ use crate :: common:: { ObjectQueryTrait , TransactionObjectTrait } ;
14+ use std:: { borrow:: Borrow , collections:: HashSet , sync:: Arc } ;
2315// use super::connection::RustConnection;
2416
2517// /// Transaction for internal use only.
@@ -1253,7 +1245,6 @@ use std::{
12531245// }
12541246
12551247#[ pyclass]
1256- #[ derive( Clone ) ]
12571248pub struct Transaction {
12581249 pub db_client : Arc < Object > ,
12591250 is_started : bool ,
@@ -1277,9 +1268,40 @@ impl Transaction {
12771268 slf
12781269 }
12791270
1280- async fn __aenter__ < ' a > ( & mut self ) -> RustPSQLDriverPyResult < Transaction > {
1281- self . begin ( ) . await ?;
1282- Ok ( self . clone ( ) )
1271+ async fn __aenter__ < ' a > ( slf : Py < Self > ) -> RustPSQLDriverPyResult < Py < Self > > {
1272+ let ( is_started, is_done, isolation_level, read_variant, deferrable, db_client) =
1273+ pyo3:: Python :: with_gil ( |gil| {
1274+ let self_ = slf. borrow ( gil) ;
1275+ (
1276+ self_. is_started ,
1277+ self_. is_done ,
1278+ self_. isolation_level ,
1279+ self_. read_variant ,
1280+ self_. deferrable ,
1281+ self_. db_client . clone ( ) ,
1282+ )
1283+ } ) ;
1284+
1285+ if is_started {
1286+ return Err ( RustPSQLDriverError :: DataBaseTransactionError (
1287+ "Transaction is already started" . into ( ) ,
1288+ ) ) ;
1289+ }
1290+
1291+ if is_done {
1292+ return Err ( RustPSQLDriverError :: DataBaseTransactionError (
1293+ "Transaction is already committed or rolled back" . into ( ) ,
1294+ ) ) ;
1295+ }
1296+ db_client
1297+ . start_transaction ( isolation_level, read_variant, deferrable)
1298+ . await ?;
1299+
1300+ Python :: with_gil ( |gil| {
1301+ let mut self_ = slf. borrow_mut ( gil) ;
1302+ self_. is_started = true ;
1303+ } ) ;
1304+ Ok ( slf)
12831305 }
12841306
12851307 #[ allow( clippy:: needless_pass_by_value) ]
@@ -1289,21 +1311,30 @@ impl Transaction {
12891311 exception : Py < PyAny > ,
12901312 _traceback : Py < PyAny > ,
12911313 ) -> RustPSQLDriverPyResult < ( ) > {
1292- let ( is_exception_none, py_err, mut transaction) = pyo3:: Python :: with_gil ( |gil| {
1293- (
1294- exception. is_none ( gil) ,
1295- PyErr :: from_value_bound ( exception. into_bound ( gil) ) ,
1296- slf. borrow_mut ( gil) . clone ( ) ,
1297- )
1298- } ) ;
1299-
1300- if is_exception_none {
1301- transaction. commit ( ) . await ?;
1314+ let ( is_transaction_ready, is_exception_none, py_err, db_client) =
1315+ pyo3:: Python :: with_gil ( |gil| {
1316+ let self_ = slf. borrow ( gil) ;
1317+ (
1318+ self_. check_is_transaction_ready ( ) ,
1319+ exception. is_none ( gil) ,
1320+ PyErr :: from_value_bound ( exception. into_bound ( gil) ) ,
1321+ self_. db_client . clone ( ) ,
1322+ )
1323+ } ) ;
1324+ is_transaction_ready?;
1325+ let exit_result = if is_exception_none {
1326+ db_client. commit ( ) . await ?;
13021327 Ok ( ( ) )
13031328 } else {
1304- transaction . rollback ( ) . await ?;
1329+ db_client . rollback ( ) . await ?;
13051330 Err ( RustPSQLDriverError :: PyError ( py_err) )
1306- }
1331+ } ;
1332+
1333+ pyo3:: Python :: with_gil ( |gil| {
1334+ let mut self_ = slf. borrow_mut ( gil) ;
1335+ self_. is_done = true ;
1336+ } ) ;
1337+ exit_result
13071338 }
13081339
13091340 /// Commit the transaction.
@@ -1318,9 +1349,8 @@ impl Transaction {
13181349 /// 3) Cannot execute `COMMIT` command
13191350 pub async fn commit ( & mut self ) -> RustPSQLDriverPyResult < ( ) > {
13201351 self . check_is_transaction_ready ( ) ?;
1321- self . db_client . batch_execute ( "COMMIT;" ) . await ?;
1352+ self . db_client . commit ( ) . await ?;
13221353 self . is_done = true ;
1323-
13241354 Ok ( ( ) )
13251355 }
13261356
@@ -1351,15 +1381,17 @@ impl Transaction {
13511381 /// 1) Cannot convert python parameters
13521382 /// 2) Cannot execute querystring.
13531383 pub async fn execute (
1354- self_ : pyo3 :: Py < Self > ,
1384+ slf : Py < Self > ,
13551385 querystring : String ,
13561386 parameters : Option < pyo3:: Py < PyAny > > ,
13571387 prepared : Option < bool > ,
13581388 ) -> RustPSQLDriverPyResult < PSQLDriverPyQueryResult > {
1359- let transaction = pyo3:: Python :: with_gil ( |gil| self_. borrow ( gil) . clone ( ) ) ;
1360- transaction. check_is_transaction_ready ( ) ?;
1361- transaction
1362- . db_client
1389+ let ( is_transaction_ready, db_client) = pyo3:: Python :: with_gil ( |gil| {
1390+ let self_ = slf. borrow ( gil) ;
1391+ ( self_. check_is_transaction_ready ( ) , self_. db_client . clone ( ) )
1392+ } ) ;
1393+ is_transaction_ready?;
1394+ db_client
13631395 . psqlpy_query ( querystring, parameters, prepared)
13641396 . await
13651397 }
@@ -1463,7 +1495,6 @@ impl Transaction {
14631495
14641496 Ok ( ( ) )
14651497 }
1466-
14671498 fn check_is_transaction_ready ( & self ) -> RustPSQLDriverPyResult < ( ) > {
14681499 if !self . is_started {
14691500 return Err ( RustPSQLDriverError :: DataBaseTransactionError (
@@ -1475,7 +1506,6 @@ impl Transaction {
14751506 "Transaction is already committed or rolled back" . into ( ) ,
14761507 ) ) ;
14771508 }
1478-
14791509 Ok ( ( ) )
14801510 }
14811511}
0 commit comments