Skip to content

Commit dc8e844

Browse files
committed
Continue adding tests
1 parent 390d458 commit dc8e844

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

python/psqlpy/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
"DBPoolConfigurationError",
1919
"UUIDValueConvertError",
2020
"CursorError",
21+
"DBTransactionError",
2122
]

python/tests/test_transaction.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import pytest
2+
3+
from psqlpy import PSQLPool
4+
from psqlpy.exceptions import DBTransactionError
5+
6+
7+
@pytest.mark.anyio
8+
async def test_transaction_begin(
9+
psql_pool: PSQLPool,
10+
table_name: str,
11+
number_database_records: int,
12+
) -> None:
13+
"""Test that transaction must be started with `begin()` method."""
14+
connection = await psql_pool.connection()
15+
transaction = connection.transaction()
16+
17+
with pytest.raises(expected_exception=DBTransactionError):
18+
await transaction.execute(
19+
f"SELECT * FROM {table_name}",
20+
)
21+
22+
await transaction.begin()
23+
24+
result = await transaction.execute(
25+
f"SELECT * FROM {table_name}",
26+
)
27+
28+
assert len(result.result()) == number_database_records
29+
30+
31+
@pytest.mark.anyio
32+
async def test_transaction_commit(
33+
psql_pool: PSQLPool,
34+
table_name: str,
35+
) -> None:
36+
"""Test that transaction commit command."""
37+
connection = await psql_pool.connection()
38+
transaction = connection.transaction()
39+
await transaction.begin()
40+
41+
test_name: str = "test_name"
42+
await transaction.execute(
43+
f"INSERT INTO {table_name} VALUES ($1, $2)",
44+
parameters=[100, test_name],
45+
)
46+
47+
# Make request from other connection, it mustn't know
48+
# about new INSERT data before commit.
49+
result = await psql_pool.execute(
50+
f"SELECT * FROM {table_name} WHERE name = $1",
51+
parameters=[test_name],
52+
)
53+
assert not result.result()
54+
55+
await transaction.commit()
56+
57+
result = await psql_pool.execute(
58+
f"SELECT * FROM {table_name} WHERE name = $1",
59+
parameters=[test_name],
60+
)
61+
62+
assert len(result.result())
63+
64+
65+
@pytest.mark.anyio
66+
async def test_transaction_savepoint(
67+
psql_pool: PSQLPool,
68+
table_name: str,
69+
) -> None:
70+
"""Test that it's possible to rollback to savepoint."""
71+
connection = await psql_pool.connection()
72+
transaction = connection.transaction()
73+
await transaction.begin()
74+
75+
test_name = "test_name"
76+
savepoint_name = "sp1"
77+
await transaction.savepoint(savepoint_name=savepoint_name)
78+
await transaction.execute(
79+
f"INSERT INTO {table_name} VALUES ($1, $2)",
80+
parameters=[100, test_name],
81+
)
82+
result = await transaction.execute(
83+
f"SELECT * FROM {table_name} WHERE name = $1",
84+
parameters=[test_name],
85+
)
86+
assert result.result()
87+
88+
await transaction.rollback_to(savepoint_name=savepoint_name)
89+
result = await psql_pool.execute(
90+
f"SELECT * FROM {table_name} WHERE name = $1",
91+
parameters=[test_name],
92+
)
93+
assert not len(result.result())
94+
95+
await transaction.commit()

0 commit comments

Comments
 (0)