Skip to content

Commit f566768

Browse files
committed
Continue adding tests
1 parent dc8e844 commit f566768

File tree

4 files changed

+88
-6
lines changed

4 files changed

+88
-6
lines changed

python/psqlpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ._internal import (
22
Connection,
3+
Cursor,
34
IsolationLevel,
45
PSQLPool,
56
QueryResult,
@@ -14,4 +15,5 @@
1415
"IsolationLevel",
1516
"ReadVariant",
1617
"Connection",
18+
"Cursor",
1719
]

python/psqlpy/_internal/__init__.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import types
12
from enum import Enum
23
from typing import Any, Dict, List, Optional
34

@@ -153,6 +154,13 @@ class Transaction:
153154
`.transaction()`.
154155
"""
155156

157+
async def __aenter__(self: Self) -> Self: ...
158+
async def __aexit__(
159+
self: Self,
160+
exception_type: type[BaseException] | None,
161+
exception: BaseException | None,
162+
traceback: types.TracebackType | None,
163+
) -> None: ...
156164
async def begin(self: Self) -> None:
157165
"""Start the transaction.
158166

python/tests/test_transaction.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from psqlpy import PSQLPool
3+
from psqlpy import Cursor, PSQLPool
44
from psqlpy.exceptions import DBTransactionError
55

66

@@ -93,3 +93,77 @@ async def test_transaction_savepoint(
9393
assert not len(result.result())
9494

9595
await transaction.commit()
96+
97+
98+
@pytest.mark.anyio
99+
async def test_transaction_rollback(
100+
psql_pool: PSQLPool,
101+
table_name: str,
102+
) -> None:
103+
"""Test that ROLLBACK works correctly."""
104+
connection = await psql_pool.connection()
105+
transaction = connection.transaction()
106+
await transaction.begin()
107+
108+
test_name = "test_name"
109+
await transaction.execute(
110+
f"INSERT INTO {table_name} VALUES ($1, $2)",
111+
parameters=[100, test_name],
112+
)
113+
114+
result = await transaction.execute(
115+
f"SELECT * FROM {table_name} WHERE name = $1",
116+
parameters=[test_name],
117+
)
118+
assert result.result()
119+
120+
await transaction.rollback()
121+
122+
with pytest.raises(expected_exception=DBTransactionError):
123+
await transaction.execute(
124+
f"SELECT * FROM {table_name} WHERE name = $1",
125+
parameters=[test_name],
126+
)
127+
128+
result_from_conn = await psql_pool.execute(
129+
f"INSERT INTO {table_name} VALUES ($1, $2)",
130+
parameters=[100, test_name],
131+
)
132+
133+
assert not (result_from_conn.result())
134+
135+
136+
@pytest.mark.anyio
137+
async def test_transaction_release_savepoint(
138+
psql_pool: PSQLPool,
139+
) -> None:
140+
"""Test that it is possible to acquire and release savepoint."""
141+
connection = await psql_pool.connection()
142+
transaction = connection.transaction()
143+
await transaction.begin()
144+
145+
sp_name_1 = "sp1"
146+
sp_name_2 = "sp2"
147+
148+
await transaction.savepoint(sp_name_1)
149+
150+
with pytest.raises(expected_exception=DBTransactionError):
151+
await transaction.savepoint(sp_name_1)
152+
153+
await transaction.savepoint(sp_name_2)
154+
155+
await transaction.release_savepoint(sp_name_1)
156+
await transaction.savepoint(sp_name_1)
157+
158+
159+
@pytest.mark.anyio
160+
async def test_transaction_cursor(
161+
psql_pool: PSQLPool,
162+
table_name: str,
163+
) -> None:
164+
"""Test that transaction can create cursor."""
165+
connection = await psql_pool.connection()
166+
async with connection.transaction() as transaction:
167+
cursor = await transaction.cursor(f"SELECT * FROM {table_name}")
168+
169+
assert isinstance(cursor, Cursor)

src/driver/transaction.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,9 @@ impl RustTransaction {
396396
));
397397
};
398398

399-
let rollback_savepoint_arc = self.rollback_savepoint.clone();
400-
let is_rollback_exists = {
401-
let rollback_savepoint_guard = rollback_savepoint_arc.read().await;
402-
rollback_savepoint_guard.contains(&rollback_name)
403-
};
399+
let mut rollback_savepoint_guard = self.rollback_savepoint.write().await;
400+
let is_rollback_exists = rollback_savepoint_guard.remove(&rollback_name);
401+
404402
if !is_rollback_exists {
405403
return Err(RustPSQLDriverError::DataBaseTransactionError(
406404
"Don't have rollback with this name".into(),

0 commit comments

Comments
 (0)