Skip to content

Commit 0c49cec

Browse files
committed
Implemented execute_many in transaction
1 parent 91e818a commit 0c49cec

File tree

5 files changed

+50
-23
lines changed

5 files changed

+50
-23
lines changed

python/tests/helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import typing
2+
3+
from psqlpy import Transaction
4+
5+
6+
async def count_rows_in_test_table(
7+
table_name: str,
8+
transaction: Transaction,
9+
) -> int:
10+
query_result: typing.Final = await transaction.execute(
11+
f"SELECT COUNT(*) FROM {table_name}",
12+
)
13+
return query_result.result()[0]["count"]

python/tests/test_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from psqlpy import PSQLPool, QueryResult, Transaction
44

5+
pytestmark = pytest.mark.anyio
6+
57

6-
@pytest.mark.anyio()
78
async def test_connection_execute(
89
psql_pool: PSQLPool,
910
table_name: str,
@@ -19,7 +20,6 @@ async def test_connection_execute(
1920
assert len(conn_result.result()) == number_database_records
2021

2122

22-
@pytest.mark.anyio()
2323
async def test_connection_transaction(
2424
psql_pool: PSQLPool,
2525
) -> None:

python/tests/test_connection_pool.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from psqlpy import Connection, ConnRecyclingMethod, PSQLPool, QueryResult
44

5+
pytestmark = pytest.mark.anyio
6+
57

6-
@pytest.mark.anyio()
78
async def test_pool_dsn_startup() -> None:
89
"""Test that connection pool can startup with dsn."""
910
pg_pool = PSQLPool(
@@ -14,7 +15,6 @@ async def test_pool_dsn_startup() -> None:
1415
await pg_pool.execute("SELECT 1")
1516

1617

17-
@pytest.mark.anyio()
1818
async def test_pool_execute(
1919
psql_pool: PSQLPool,
2020
table_name: str,
@@ -32,7 +32,6 @@ async def test_pool_execute(
3232
assert len(inner_result) == number_database_records
3333

3434

35-
@pytest.mark.anyio()
3635
async def test_pool_connection(
3736
psql_pool: PSQLPool,
3837
) -> None:
@@ -41,7 +40,6 @@ async def test_pool_connection(
4140
assert isinstance(connection, Connection)
4241

4342

44-
@pytest.mark.anyio()
4543
@pytest.mark.parametrize(
4644
"conn_recycling_method",
4745
[

python/tests/test_cursor.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from psqlpy import Cursor
44

5+
pytestmark = pytest.mark.anyio
6+
57

6-
@pytest.mark.anyio()
78
async def test_cursor_fetch(
89
number_database_records: int,
910
test_cursor: Cursor,
@@ -13,7 +14,6 @@ async def test_cursor_fetch(
1314
assert len(result.result()) == number_database_records // 2
1415

1516

16-
@pytest.mark.anyio()
1717
async def test_cursor_fetch_next(
1818
test_cursor: Cursor,
1919
) -> None:
@@ -22,7 +22,6 @@ async def test_cursor_fetch_next(
2222
assert len(result.result()) == 1
2323

2424

25-
@pytest.mark.anyio()
2625
async def test_cursor_fetch_prior(
2726
test_cursor: Cursor,
2827
) -> None:
@@ -35,7 +34,6 @@ async def test_cursor_fetch_prior(
3534
assert len(result.result()) == 1
3635

3736

38-
@pytest.mark.anyio()
3937
async def test_cursor_fetch_first(
4038
test_cursor: Cursor,
4139
) -> None:
@@ -49,7 +47,6 @@ async def test_cursor_fetch_first(
4947
assert fetch_first.result() == first.result()
5048

5149

52-
@pytest.mark.anyio()
5350
async def test_cursor_fetch_last(
5451
test_cursor: Cursor,
5552
number_database_records: int,
@@ -64,7 +61,6 @@ async def test_cursor_fetch_last(
6461
assert all_res.result()[-1] == last_res.result()[0]
6562

6663

67-
@pytest.mark.anyio()
6864
async def test_cursor_fetch_absolute(
6965
test_cursor: Cursor,
7066
number_database_records: int,
@@ -85,7 +81,6 @@ async def test_cursor_fetch_absolute(
8581
assert all_res.result()[-1] == last_record.result()[0]
8682

8783

88-
@pytest.mark.anyio()
8984
async def test_cursor_fetch_relative(
9085
test_cursor: Cursor,
9186
number_database_records: int,
@@ -107,7 +102,6 @@ async def test_cursor_fetch_relative(
107102
assert not (records.result())
108103

109104

110-
@pytest.mark.anyio()
111105
async def test_cursor_fetch_forward_all(
112106
test_cursor: Cursor,
113107
number_database_records: int,
@@ -124,7 +118,6 @@ async def test_cursor_fetch_forward_all(
124118
)
125119

126120

127-
@pytest.mark.anyio()
128121
async def test_cursor_fetch_backward(
129122
test_cursor: Cursor,
130123
) -> None:
@@ -142,7 +135,6 @@ async def test_cursor_fetch_backward(
142135
assert len(must_not_be_empty.result()) == expected_number_of_results
143136

144137

145-
@pytest.mark.anyio()
146138
async def test_cursor_fetch_backward_all(
147139
test_cursor: Cursor,
148140
) -> None:

python/tests/test_transaction.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

3+
import typing
4+
35
import pytest
6+
from tests.helpers import count_rows_in_test_table
47

58
from psqlpy import Cursor, IsolationLevel, PSQLPool, ReadVariant
69
from psqlpy.exceptions import DBTransactionError, RustPSQLDriverPyBaseError
710

11+
pytestmark = pytest.mark.anyio
12+
813

9-
@pytest.mark.anyio()
1014
@pytest.mark.parametrize(
1115
("isolation_level", "deferrable", "read_variant"),
1216
[
@@ -42,7 +46,6 @@ async def test_transaction_init_parameters(
4246
assert read_variant is not ReadVariant.ReadOnly
4347

4448

45-
@pytest.mark.anyio()
4649
async def test_transaction_begin(
4750
psql_pool: PSQLPool,
4851
table_name: str,
@@ -66,7 +69,6 @@ async def test_transaction_begin(
6669
assert len(result.result()) == number_database_records
6770

6871

69-
@pytest.mark.anyio()
7072
async def test_transaction_commit(
7173
psql_pool: PSQLPool,
7274
table_name: str,
@@ -100,7 +102,6 @@ async def test_transaction_commit(
100102
assert len(result.result())
101103

102104

103-
@pytest.mark.anyio()
104105
async def test_transaction_savepoint(
105106
psql_pool: PSQLPool,
106107
table_name: str,
@@ -133,7 +134,6 @@ async def test_transaction_savepoint(
133134
await transaction.commit()
134135

135136

136-
@pytest.mark.anyio()
137137
async def test_transaction_rollback(
138138
psql_pool: PSQLPool,
139139
table_name: str,
@@ -171,7 +171,6 @@ async def test_transaction_rollback(
171171
assert not (result_from_conn.result())
172172

173173

174-
@pytest.mark.anyio()
175174
async def test_transaction_release_savepoint(
176175
psql_pool: PSQLPool,
177176
) -> None:
@@ -194,7 +193,6 @@ async def test_transaction_release_savepoint(
194193
await transaction.savepoint(sp_name_1)
195194

196195

197-
@pytest.mark.anyio()
198196
async def test_transaction_cursor(
199197
psql_pool: PSQLPool,
200198
table_name: str,
@@ -205,3 +203,29 @@ async def test_transaction_cursor(
205203
cursor = await transaction.cursor(f"SELECT * FROM {table_name}")
206204

207205
assert isinstance(cursor, Cursor)
206+
207+
208+
@pytest.mark.parametrize(
209+
("insert_values"),
210+
[
211+
[[1, "name1"], [2, "name2"]],
212+
[[10, "name1"], [20, "name2"], [30, "name3"]],
213+
[[1, "name1"]],
214+
],
215+
)
216+
async def test_transaction_execute_many(
217+
psql_pool: PSQLPool,
218+
table_name: str,
219+
number_database_records: int,
220+
insert_values: list[list[typing.Any]],
221+
) -> None:
222+
connection = await psql_pool.connection()
223+
async with connection.transaction() as transaction:
224+
await transaction.execute_many(
225+
f"INSERT INTO {table_name} VALUES ($1, $2)",
226+
insert_values,
227+
)
228+
assert await count_rows_in_test_table(
229+
table_name,
230+
transaction,
231+
) - number_database_records == len(insert_values)

0 commit comments

Comments
 (0)