|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from psqlpy import PSQLPool |
| 3 | +from psqlpy import Cursor, PSQLPool |
4 | 4 | from psqlpy.exceptions import DBTransactionError |
5 | 5 |
|
6 | 6 |
|
@@ -93,3 +93,77 @@ async def test_transaction_savepoint( |
93 | 93 | assert not len(result.result()) |
94 | 94 |
|
95 | 95 | await transaction.commit() |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.anyio |
| 99 | +async def test_transaction_rollback( |
| 100 | + psql_pool: PSQLPool, |
| 101 | + table_name: str, |
| 102 | +) -> None: |
| 103 | + """Test that ROLLBACK works correctly.""" |
| 104 | + connection = await psql_pool.connection() |
| 105 | + transaction = connection.transaction() |
| 106 | + await transaction.begin() |
| 107 | + |
| 108 | + test_name = "test_name" |
| 109 | + await transaction.execute( |
| 110 | + f"INSERT INTO {table_name} VALUES ($1, $2)", |
| 111 | + parameters=[100, test_name], |
| 112 | + ) |
| 113 | + |
| 114 | + result = await transaction.execute( |
| 115 | + f"SELECT * FROM {table_name} WHERE name = $1", |
| 116 | + parameters=[test_name], |
| 117 | + ) |
| 118 | + assert result.result() |
| 119 | + |
| 120 | + await transaction.rollback() |
| 121 | + |
| 122 | + with pytest.raises(expected_exception=DBTransactionError): |
| 123 | + await transaction.execute( |
| 124 | + f"SELECT * FROM {table_name} WHERE name = $1", |
| 125 | + parameters=[test_name], |
| 126 | + ) |
| 127 | + |
| 128 | + result_from_conn = await psql_pool.execute( |
| 129 | + f"INSERT INTO {table_name} VALUES ($1, $2)", |
| 130 | + parameters=[100, test_name], |
| 131 | + ) |
| 132 | + |
| 133 | + assert not (result_from_conn.result()) |
| 134 | + |
| 135 | + |
| 136 | +@pytest.mark.anyio |
| 137 | +async def test_transaction_release_savepoint( |
| 138 | + psql_pool: PSQLPool, |
| 139 | +) -> None: |
| 140 | + """Test that it is possible to acquire and release savepoint.""" |
| 141 | + connection = await psql_pool.connection() |
| 142 | + transaction = connection.transaction() |
| 143 | + await transaction.begin() |
| 144 | + |
| 145 | + sp_name_1 = "sp1" |
| 146 | + sp_name_2 = "sp2" |
| 147 | + |
| 148 | + await transaction.savepoint(sp_name_1) |
| 149 | + |
| 150 | + with pytest.raises(expected_exception=DBTransactionError): |
| 151 | + await transaction.savepoint(sp_name_1) |
| 152 | + |
| 153 | + await transaction.savepoint(sp_name_2) |
| 154 | + |
| 155 | + await transaction.release_savepoint(sp_name_1) |
| 156 | + await transaction.savepoint(sp_name_1) |
| 157 | + |
| 158 | + |
| 159 | +@pytest.mark.anyio |
| 160 | +async def test_transaction_cursor( |
| 161 | + psql_pool: PSQLPool, |
| 162 | + table_name: str, |
| 163 | +) -> None: |
| 164 | + """Test that transaction can create cursor.""" |
| 165 | + connection = await psql_pool.connection() |
| 166 | + async with connection.transaction() as transaction: |
| 167 | + cursor = await transaction.cursor(f"SELECT * FROM {table_name}") |
| 168 | + |
| 169 | + assert isinstance(cursor, Cursor) |
0 commit comments