Skip to content

Commit 76ad8e2

Browse files
committed
Added cursor first implementation
1 parent 77790e2 commit 76ad8e2

File tree

10 files changed

+220
-3
lines changed

10 files changed

+220
-3
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,4 @@ jobs:
112112
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
113113
with:
114114
command: upload
115-
args: --non-interactive --skip-existing *
115+
args: --non-interactive --skip-existing *

Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ byteorder = "1.5.0"
2525
postgres-protocol = "0.6.6"
2626
chrono = "0.4.33"
2727
chrono-tz = "0.8.5"
28-
uuid = "1.7.0"
28+
uuid = { version = "1.7.0", features = ["v4"] }
2929
serde_json = "1.0.113"

python/psql_rust_driver/_internal/__init__.pyi

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,34 @@ class IsolationLevel(Enum):
2121
Serializable = 4
2222

2323

24+
class Cursor:
25+
"""Represent opened cursor in a transaction.
26+
27+
It can be used as an asynchronous iterator.
28+
"""
29+
30+
async def fetch(
31+
self: Self,
32+
fetch_number: int | None = None,
33+
) -> QueryResult:
34+
"""Fetch next <fetch_number> rows.
35+
36+
By default fetches 10 next rows.
37+
38+
### Parameters:
39+
- `fetch_number`: how many rows need to fetch.
40+
41+
### Returns:
42+
result as `QueryResult`.
43+
"""
44+
45+
def __aiter__(self: Self) -> Self:
46+
...
47+
48+
async def __anext__(self: Self) -> QueryResult:
49+
...
50+
51+
2452
class Transaction:
2553
"""Single connection for executing queries.
2654
@@ -207,6 +235,50 @@ class Transaction:
207235
```
208236
"""
209237

238+
async def cursor(
239+
self: Self,
240+
querystring: str,
241+
parameters: List[Any] | None = None,
242+
fetch_number: int | None = None,
243+
) -> Cursor:
244+
"""Create new cursor object.
245+
246+
Cursor can be used as an asynchronous iterator.
247+
248+
### Parameters:
249+
- `querystring`: querystring to execute.
250+
- `parameters`: list of parameters to pass in the query.
251+
- `fetch_number`: how many rows need to fetch.
252+
253+
### Returns:
254+
new initialized cursor.
255+
256+
### Example:
257+
```python
258+
import asyncio
259+
260+
from psql_rust_driver import PSQLPool, QueryResult
261+
262+
263+
async def main() -> None:
264+
db_pool = PSQLPool()
265+
await db_pool.startup()
266+
267+
connection = await db_pool.connection()
268+
transaction = await connection.transaction()
269+
270+
cursor = await transaction.cursor(
271+
querystring="SELECT * FROM users WHERE username = $1",
272+
parameters=["Some_Username"],
273+
fetch_number=5,
274+
)
275+
276+
async for fetched_result in cursor:
277+
dict_result: List[Dict[Any, Any]] = fetched_result.result()
278+
... # do something with this result.
279+
```
280+
"""
281+
210282

211283
class Connection:
212284
"""Connection from Database Connection Pool.
@@ -228,6 +300,9 @@ class Connection:
228300
- `querystring`: querystring to execute.
229301
- `parameters`: list of parameters to pass in the query.
230302
303+
### Returns:
304+
query result as `QueryResult`
305+
231306
### Example:
232307
```python
233308
import asyncio

src/driver/connection.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ impl RustConnection {
6262
is_done: Arc::new(tokio::sync::RwLock::new(false)),
6363
rollback_savepoint: Arc::new(tokio::sync::RwLock::new(HashSet::new())),
6464
isolation_level: isolation_level,
65+
cursor_num: Default::default(),
6566
};
6667

6768
Transaction {

src/driver/connection_pool.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ impl RustPSQLPool {
136136
is_done: Arc::new(tokio::sync::RwLock::new(false)),
137137
rollback_savepoint: Arc::new(tokio::sync::RwLock::new(HashSet::new())),
138138
isolation_level: isolation_level,
139+
cursor_num: Default::default(),
139140
};
140141

141142
Ok(Transaction {

src/driver/cursor.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use deadpool_postgres::Object;
2+
use pyo3::{exceptions::PyStopAsyncIteration, pyclass, pymethods, PyAny, PyObject, PyRef, Python};
3+
use std::sync::Arc;
4+
5+
use crate::{
6+
common::rustengine_future, exceptions::rust_errors::RustPSQLDriverPyResult,
7+
query_result::PSQLDriverPyQueryResult,
8+
};
9+
10+
#[pyclass]
11+
pub struct Cursor {
12+
db_client: Arc<tokio::sync::RwLock<Object>>,
13+
cursor_name: String,
14+
fetch_number: usize,
15+
}
16+
17+
impl Cursor {
18+
pub fn new(
19+
db_client: Arc<tokio::sync::RwLock<Object>>,
20+
cursor_name: String,
21+
fetch_number: usize,
22+
) -> Self {
23+
return Cursor {
24+
db_client,
25+
cursor_name,
26+
fetch_number,
27+
};
28+
}
29+
}
30+
31+
#[pymethods]
32+
impl Cursor {
33+
pub fn fetch<'a>(
34+
&'a self,
35+
py: Python<'a>,
36+
fetch_number: Option<usize>,
37+
) -> RustPSQLDriverPyResult<&PyAny> {
38+
let db_client_arc = self.db_client.clone();
39+
let cursor_name = self.cursor_name.clone();
40+
let fetch_number = match fetch_number {
41+
Some(usize) => usize,
42+
None => self.fetch_number.clone(),
43+
};
44+
45+
rustengine_future(py, async move {
46+
let db_client_guard = db_client_arc.read().await;
47+
let result = db_client_guard
48+
.query(
49+
format!("FETCH {fetch_number} FROM {cursor_name}").as_str(),
50+
&[],
51+
)
52+
.await?;
53+
Ok(PSQLDriverPyQueryResult::new(result))
54+
})
55+
}
56+
57+
#[must_use]
58+
pub fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
59+
slf
60+
}
61+
62+
pub fn __anext__(&self, py: Python<'_>) -> RustPSQLDriverPyResult<Option<PyObject>> {
63+
let db_client_arc = self.db_client.clone();
64+
let cursor_name = self.cursor_name.clone();
65+
let fetch_number = self.fetch_number.clone();
66+
67+
let future = rustengine_future(py, async move {
68+
let db_client_guard = db_client_arc.read().await;
69+
let result = db_client_guard
70+
.query(
71+
format!("FETCH {fetch_number} FROM {cursor_name}").as_str(),
72+
&[],
73+
)
74+
.await?;
75+
76+
if result.len() == 0 {
77+
return Err(PyStopAsyncIteration::new_err("Error").into());
78+
};
79+
80+
Ok(PSQLDriverPyQueryResult::new(result))
81+
});
82+
83+
Ok(Some(future?.into()))
84+
}
85+
}

src/driver/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod connection;
22
pub mod connection_pool;
3+
pub mod cursor;
34
pub mod transaction;
45
pub mod transaction_options;

src/driver/transaction.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{
1010
value_converter::{convert_parameters, PythonDTO},
1111
};
1212

13-
use super::transaction_options::IsolationLevel;
13+
use super::{cursor::Cursor, transaction_options::IsolationLevel};
1414

1515
/// Transaction for internal use only.
1616
///
@@ -22,6 +22,7 @@ pub struct RustTransaction {
2222
pub rollback_savepoint: Arc<tokio::sync::RwLock<HashSet<String>>>,
2323

2424
pub isolation_level: Option<IsolationLevel>,
25+
pub cursor_num: usize,
2526
}
2627

2728
impl RustTransaction {
@@ -393,6 +394,34 @@ impl RustTransaction {
393394

394395
Ok(())
395396
}
397+
398+
pub async fn inner_cursor<'a>(
399+
&'a mut self,
400+
querystring: String,
401+
parameters: Vec<PythonDTO>,
402+
fetch_number: usize,
403+
) -> RustPSQLDriverPyResult<Cursor> {
404+
let db_client_arc = self.db_client.clone();
405+
let db_client_arc2 = self.db_client.clone();
406+
let db_client_guard = db_client_arc.read().await;
407+
408+
let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(parameters.len());
409+
for param in parameters.iter() {
410+
vec_parameters.push(param);
411+
}
412+
413+
let cursor_name = format!("cur{}", self.cursor_num);
414+
db_client_guard
415+
.execute(
416+
&format!("DECLARE {} CURSOR FOR {querystring}", cursor_name),
417+
&vec_parameters.into_boxed_slice(),
418+
)
419+
.await?;
420+
421+
self.cursor_num = self.cursor_num + 1;
422+
423+
Ok(Cursor::new(db_client_arc2, cursor_name, fetch_number))
424+
}
396425
}
397426

398427
#[pyclass()]
@@ -619,4 +648,25 @@ impl Transaction {
619648
Ok(())
620649
})
621650
}
651+
652+
pub fn cursor<'a>(
653+
&'a self,
654+
py: Python<'a>,
655+
querystring: String,
656+
parameters: Option<&'a PyAny>,
657+
fetch_number: Option<usize>,
658+
) -> RustPSQLDriverPyResult<&PyAny> {
659+
let transaction_arc = self.transaction.clone();
660+
let mut params: Vec<PythonDTO> = vec![];
661+
if let Some(parameters) = parameters {
662+
params = convert_parameters(parameters)?
663+
}
664+
665+
rustengine_future(py, async move {
666+
let mut transaction_guard = transaction_arc.write().await;
667+
Ok(transaction_guard
668+
.inner_cursor(querystring, params, fetch_number.unwrap_or(10))
669+
.await?)
670+
})
671+
}
622672
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use pyo3::{pymodule, types::PyModule, PyResult, Python};
1515
fn psql_rust_engine(py: Python<'_>, pymod: &PyModule) -> PyResult<()> {
1616
pymod.add_class::<driver::connection_pool::PSQLPool>()?;
1717
pymod.add_class::<driver::transaction::Transaction>()?;
18+
pymod.add_class::<driver::cursor::Cursor>()?;
1819
pymod.add_class::<driver::transaction_options::IsolationLevel>()?;
1920
pymod.add_class::<query_result::PSQLDriverPyQueryResult>()?;
2021
add_module(py, pymod, "extra_types", extra_types_module)?;

0 commit comments

Comments
 (0)