Skip to content

Commit dc960e4

Browse files
committed
continue implementing transaction
1 parent dcf3713 commit dc960e4

File tree

2 files changed

+105
-37
lines changed

2 files changed

+105
-37
lines changed

src/common.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use deadpool_postgres::Object;
22
use pyo3::{types::PyModule, PyAny, PyResult, Python};
33

44
use crate::{
5+
driver::transaction_options::{IsolationLevel, ReadVariant},
56
exceptions::rust_errors::RustPSQLDriverPyResult,
67
query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult},
78
value_converter::{convert_parameters, PythonDTO, QueryParameter},
@@ -31,6 +32,56 @@ pub fn add_module(
3132
Ok(())
3233
}
3334

35+
pub trait BaseTransactionQuery {
36+
fn start_transaction(
37+
&self,
38+
isolation_level: Option<IsolationLevel>,
39+
read_variant: Option<ReadVariant>,
40+
defferable: Option<bool>,
41+
) -> impl std::future::Future<Output = RustPSQLDriverPyResult<()>> + Send;
42+
fn commit(&self) -> impl std::future::Future<Output = RustPSQLDriverPyResult<()>> + Send;
43+
fn rollback(&self) -> impl std::future::Future<Output = RustPSQLDriverPyResult<()>> + Send;
44+
}
45+
46+
impl BaseTransactionQuery for Object {
47+
async fn start_transaction(
48+
&self,
49+
isolation_level: Option<IsolationLevel>,
50+
read_variant: Option<ReadVariant>,
51+
deferrable: Option<bool>,
52+
) -> RustPSQLDriverPyResult<()> {
53+
let mut querystring = "START TRANSACTION".to_string();
54+
55+
if let Some(level) = isolation_level {
56+
let level = &level.to_str_level();
57+
querystring.push_str(format!(" ISOLATION LEVEL {level}").as_str());
58+
};
59+
60+
querystring.push_str(match read_variant {
61+
Some(ReadVariant::ReadOnly) => " READ ONLY",
62+
Some(ReadVariant::ReadWrite) => " READ WRITE",
63+
None => "",
64+
});
65+
66+
querystring.push_str(match deferrable {
67+
Some(true) => " DEFERRABLE",
68+
Some(false) => " NOT DEFERRABLE",
69+
None => "",
70+
});
71+
self.batch_execute(&querystring).await?;
72+
73+
Ok(())
74+
}
75+
async fn commit(&self) -> RustPSQLDriverPyResult<()> {
76+
self.batch_execute("COMMIT;").await?;
77+
Ok(())
78+
}
79+
async fn rollback(&self) -> RustPSQLDriverPyResult<()> {
80+
self.batch_execute("ROLLBACK;").await?;
81+
Ok(())
82+
}
83+
}
84+
3485
pub trait BaseDataBaseQuery {
3586
fn psqlpy_query_one(
3687
&self,

src/driver/transaction.rs

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,11 @@ use pyo3::{prelude::*, pyclass};
44
use crate::{
55
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
66
query_result::PSQLDriverPyQueryResult,
7-
value_converter::{convert_parameters, PythonDTO},
87
};
98

10-
use super::{
11-
connection::Connection,
12-
transaction_options::{IsolationLevel, ReadVariant},
13-
};
14-
use crate::common::BaseDataBaseQuery;
15-
use std::{
16-
borrow::Borrow,
17-
collections::HashSet,
18-
ops::{Deref, DerefMut},
19-
sync::Arc,
20-
};
9+
use super::transaction_options::{IsolationLevel, ReadVariant};
10+
use crate::common::{BaseDataBaseQuery, BaseTransactionQuery};
11+
use std::{collections::HashSet, sync::Arc};
2112

2213
// use super::connection::RustConnection;
2314

@@ -1252,7 +1243,6 @@ use std::{
12521243
// }
12531244

12541245
#[pyclass]
1255-
#[derive(Clone)]
12561246
pub struct Transaction {
12571247
pub db_client: Arc<Object>,
12581248
is_started: bool,
@@ -1276,9 +1266,28 @@ impl Transaction {
12761266
slf
12771267
}
12781268

1279-
async fn __aenter__<'a>(&mut self) -> RustPSQLDriverPyResult<Transaction> {
1280-
self.begin().await?;
1281-
Ok(self.clone())
1269+
async fn __aenter__<'a>(slf: Py<Self>) -> RustPSQLDriverPyResult<Py<Self>> {
1270+
let (is_transaction_ready, isolation_level, read_variant, deferrable, db_client) =
1271+
pyo3::Python::with_gil(|gil| {
1272+
let self_ = slf.borrow(gil);
1273+
(
1274+
self_.check_is_transaction_ready(),
1275+
self_.isolation_level,
1276+
self_.read_variant,
1277+
self_.deferrable,
1278+
self_.db_client.clone(),
1279+
)
1280+
});
1281+
is_transaction_ready?;
1282+
db_client
1283+
.start_transaction(isolation_level, read_variant, deferrable)
1284+
.await?;
1285+
1286+
Python::with_gil(|gil| {
1287+
let mut self_ = slf.borrow_mut(gil);
1288+
self_.is_started = true;
1289+
});
1290+
Ok(slf)
12821291
}
12831292

12841293
#[allow(clippy::needless_pass_by_value)]
@@ -1288,21 +1297,30 @@ impl Transaction {
12881297
exception: Py<PyAny>,
12891298
_traceback: Py<PyAny>,
12901299
) -> RustPSQLDriverPyResult<()> {
1291-
let (is_exception_none, py_err, mut transaction) = pyo3::Python::with_gil(|gil| {
1292-
(
1293-
exception.is_none(gil),
1294-
PyErr::from_value_bound(exception.into_bound(gil)),
1295-
slf.borrow_mut(gil).clone(),
1296-
)
1297-
});
1298-
1299-
if is_exception_none {
1300-
transaction.commit().await?;
1300+
let (is_transaction_ready, is_exception_none, py_err, db_client) =
1301+
pyo3::Python::with_gil(|gil| {
1302+
let self_ = slf.borrow(gil);
1303+
(
1304+
self_.check_is_transaction_ready(),
1305+
exception.is_none(gil),
1306+
PyErr::from_value_bound(exception.into_bound(gil)),
1307+
self_.db_client.clone(),
1308+
)
1309+
});
1310+
is_transaction_ready?;
1311+
let exit_result = if is_exception_none {
1312+
db_client.commit().await?;
13011313
Ok(())
13021314
} else {
1303-
transaction.rollback().await?;
1315+
db_client.rollback().await?;
13041316
Err(RustPSQLDriverError::PyError(py_err))
1305-
}
1317+
};
1318+
1319+
pyo3::Python::with_gil(|gil| {
1320+
let mut self_ = slf.borrow_mut(gil);
1321+
self_.is_done = true;
1322+
});
1323+
exit_result
13061324
}
13071325

13081326
/// Commit the transaction.
@@ -1317,9 +1335,8 @@ impl Transaction {
13171335
/// 3) Cannot execute `COMMIT` command
13181336
pub async fn commit(&mut self) -> RustPSQLDriverPyResult<()> {
13191337
self.check_is_transaction_ready()?;
1320-
self.db_client.batch_execute("COMMIT;").await?;
1338+
self.db_client.commit().await?;
13211339
self.is_done = true;
1322-
13231340
Ok(())
13241341
}
13251342

@@ -1350,15 +1367,17 @@ impl Transaction {
13501367
/// 1) Cannot convert python parameters
13511368
/// 2) Cannot execute querystring.
13521369
pub async fn execute(
1353-
self_: pyo3::Py<Self>,
1370+
slf: Py<Self>,
13541371
querystring: String,
13551372
parameters: Option<pyo3::Py<PyAny>>,
13561373
prepared: Option<bool>,
13571374
) -> RustPSQLDriverPyResult<PSQLDriverPyQueryResult> {
1358-
let transaction = pyo3::Python::with_gil(|gil| self_.borrow(gil).clone());
1359-
transaction.check_is_transaction_ready()?;
1360-
transaction
1361-
.db_client
1375+
let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| {
1376+
let self_ = slf.borrow(gil);
1377+
(self_.check_is_transaction_ready(), self_.db_client.clone())
1378+
});
1379+
is_transaction_ready?;
1380+
db_client
13621381
.psqlpy_query(querystring, parameters, prepared)
13631382
.await
13641383
}
@@ -1443,7 +1462,6 @@ impl Transaction {
14431462

14441463
Ok(())
14451464
}
1446-
14471465
fn check_is_transaction_ready(&self) -> RustPSQLDriverPyResult<()> {
14481466
if !self.is_started {
14491467
return Err(RustPSQLDriverError::DataBaseTransactionError(
@@ -1455,7 +1473,6 @@ impl Transaction {
14551473
"Transaction is already committed or rolled back".into(),
14561474
));
14571475
}
1458-
14591476
Ok(())
14601477
}
14611478
}

0 commit comments

Comments
 (0)