From 92894b5426a48516f6e446fccd6c2c8c6ca17499 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 14:09:10 +0800 Subject: [PATCH 01/31] Refactor schema, config, dataframe, and expression classes to use RwLock and Mutex for interior mutability --- src/common/schema.rs | 115 +++++++++++++++++++++++++++-------- src/config.rs | 39 +++++++----- src/dataframe.rs | 33 ++++++---- src/expr/conditional_expr.rs | 67 ++++++++++++++++---- src/expr/literal.rs | 6 +- src/functions.rs | 8 +-- src/record_batch.rs | 8 +-- 7 files changed, 201 insertions(+), 75 deletions(-) diff --git a/src/common/schema.rs b/src/common/schema.rs index 752c39bde..d32e401e9 100644 --- a/src/common/schema.rs +++ b/src/common/schema.rs @@ -16,7 +16,7 @@ // under the License. use std::fmt::{self, Display, Formatter}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::{any::Any, borrow::Cow}; use arrow::datatypes::Schema; @@ -25,6 +25,7 @@ use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::Constraints; use datafusion::datasource::TableType; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource}; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use datafusion::logical_expr::utils::split_conjunction; @@ -33,17 +34,13 @@ use crate::sql::logical::PyLogicalPlan; use super::{data_type::DataTypeMap, function::SqlFunction}; -#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)] +#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass, frozen)] #[derive(Debug, Clone)] pub struct SqlSchema { - #[pyo3(get, set)] - pub name: String, - #[pyo3(get, set)] - pub tables: Vec, - #[pyo3(get, set)] - pub views: Vec, - #[pyo3(get, set)] - pub functions: Vec, + name: Arc>, + tables: Arc>>, + views: Arc>>, + functions: Arc>>, } #[pyclass(name = "SqlTable", module = "datafusion.common", subclass)] @@ -104,28 +101,98 @@ impl SqlSchema { #[new] pub fn new(schema_name: &str) -> Self { Self { - name: schema_name.to_owned(), - tables: Vec::new(), - views: Vec::new(), - functions: Vec::new(), + name: Arc::new(RwLock::new(schema_name.to_owned())), + tables: Arc::new(RwLock::new(Vec::new())), + views: Arc::new(RwLock::new(Vec::new())), + functions: Arc::new(RwLock::new(Vec::new())), } } + #[getter] + fn name(&self) -> PyResult { + Ok(self + .name + .read() + .map_err(|_| PyRuntimeError::new_err("failed to read schema name"))? + .clone()) + } + + #[setter] + fn set_name(&self, value: String) -> PyResult<()> { + *self + .name + .write() + .map_err(|_| PyRuntimeError::new_err("failed to write schema name"))? = value; + Ok(()) + } + + #[getter] + fn tables(&self) -> PyResult> { + Ok(self + .tables + .read() + .map_err(|_| PyRuntimeError::new_err("failed to read schema tables"))? + .clone()) + } + + #[setter] + fn set_tables(&self, tables: Vec) -> PyResult<()> { + *self + .tables + .write() + .map_err(|_| PyRuntimeError::new_err("failed to write schema tables"))? = tables; + Ok(()) + } + + #[getter] + fn views(&self) -> PyResult> { + Ok(self + .views + .read() + .map_err(|_| PyRuntimeError::new_err("failed to read schema views"))? + .clone()) + } + + #[setter] + fn set_views(&self, views: Vec) -> PyResult<()> { + *self + .views + .write() + .map_err(|_| PyRuntimeError::new_err("failed to write schema views"))? = views; + Ok(()) + } + + #[getter] + fn functions(&self) -> PyResult> { + Ok(self + .functions + .read() + .map_err(|_| PyRuntimeError::new_err("failed to read schema functions"))? + .clone()) + } + + #[setter] + fn set_functions(&self, functions: Vec) -> PyResult<()> { + *self + .functions + .write() + .map_err(|_| PyRuntimeError::new_err("failed to write schema functions"))? = functions; + Ok(()) + } + pub fn table_by_name(&self, table_name: &str) -> Option { - for tbl in &self.tables { - if tbl.name.eq(table_name) { - return Some(tbl.clone()); - } - } - None + let tables = self.tables.read().expect("failed to read schema tables"); + tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned() } - pub fn add_table(&mut self, table: SqlTable) { - self.tables.push(table); + pub fn add_table(&self, table: SqlTable) { + let mut tables = self.tables.write().expect("failed to write schema tables"); + tables.push(table); } - pub fn drop_table(&mut self, table_name: String) { - self.tables.retain(|x| !x.name.eq(&table_name)); + pub fn drop_table(&self, table_name: String) { + let mut tables = self.tables.write().expect("failed to write schema tables"); + tables.retain(|x| !x.name.eq(&table_name)); } } diff --git a/src/config.rs b/src/config.rs index 20f22196c..487375dca 100644 --- a/src/config.rs +++ b/src/config.rs @@ -15,18 +15,20 @@ // specific language governing permissions and limitations // under the License. +use std::sync::{Arc, RwLock}; + use pyo3::prelude::*; use pyo3::types::*; use datafusion::config::ConfigOptions; -use crate::errors::PyDataFusionResult; +use crate::errors::{PyDataFusionError, PyDataFusionResult}; use crate::utils::py_obj_to_scalar_value; -#[pyclass(name = "Config", module = "datafusion", subclass)] +#[pyclass(name = "Config", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub(crate) struct PyConfig { - config: ConfigOptions, + config: Arc>, } #[pymethods] @@ -34,7 +36,7 @@ impl PyConfig { #[new] fn py_new() -> Self { Self { - config: ConfigOptions::new(), + config: Arc::new(RwLock::new(ConfigOptions::new())), } } @@ -42,13 +44,16 @@ impl PyConfig { #[staticmethod] pub fn from_env() -> PyDataFusionResult { Ok(Self { - config: ConfigOptions::from_env()?, + config: Arc::new(RwLock::new(ConfigOptions::from_env()?)), }) } /// Get a configuration option - pub fn get<'py>(&mut self, key: &str, py: Python<'py>) -> PyResult> { - let options = self.config.to_owned(); + pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult> { + let options = self + .config + .read() + .map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?; for entry in options.entries() { if entry.key == key { return Ok(entry.value.into_pyobject(py)?); @@ -58,25 +63,31 @@ impl PyConfig { } /// Set a configuration option - pub fn set(&mut self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> { + pub fn set(&self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> { let scalar_value = py_obj_to_scalar_value(py, value)?; - self.config.set(key, scalar_value.to_string().as_str())?; + let mut options = self + .config + .write() + .map_err(|_| PyDataFusionError::Common("failed to lock configuration".to_string()))?; + options.set(key, scalar_value.to_string().as_str())?; Ok(()) } /// Get all configuration options - pub fn get_all(&mut self, py: Python) -> PyResult { + pub fn get_all(&self, py: Python) -> PyResult { let dict = PyDict::new(py); - let options = self.config.to_owned(); + let options = self + .config + .read() + .map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?; for entry in options.entries() { dict.set_item(entry.key, entry.value.clone().into_pyobject(py)?)?; } Ok(dict.into()) } - fn __repr__(&mut self, py: Python) -> PyResult { - let dict = self.get_all(py); - match dict { + fn __repr__(&self, py: Python) -> PyResult { + match self.get_all(py) { Ok(result) => Ok(format!("Config({result})")), Err(err) => Ok(format!("Error: {:?}", err.to_string())), } diff --git a/src/dataframe.rs b/src/dataframe.rs index 5882acf76..6f679c9b2 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -17,7 +17,7 @@ use std::collections::HashMap; use std::ffi::CString; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow::compute::can_cast_types; @@ -284,13 +284,13 @@ impl PyParquetColumnOptions { /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. -#[pyclass(name = "DataFrame", module = "datafusion", subclass)] +#[pyclass(name = "DataFrame", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub struct PyDataFrame { df: Arc, // In IPython environment cache batches between __repr__ and _repr_html_ calls. - batches: Option<(Vec, bool)>, + batches: Arc, bool)>>>, } impl PyDataFrame { @@ -298,16 +298,24 @@ impl PyDataFrame { pub fn new(df: DataFrame) -> Self { Self { df: Arc::new(df), - batches: None, + batches: Arc::new(Mutex::new(None)), } } - fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult { + fn prepare_repr_string(&self, py: Python, as_html: bool) -> PyDataFusionResult { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; - let should_cache = *is_ipython_env(py) && self.batches.is_none(); - let (batches, has_more) = match self.batches.take() { + let (cached_batches, should_cache) = { + let mut cache = self.batches.lock().map_err(|_| { + PyDataFusionError::Common("failed to lock DataFrame display cache".to_string()) + })?; + let should_cache = *is_ipython_env(py) && cache.is_none(); + let batches = cache.take(); + (batches, should_cache) + }; + + let (batches, has_more) = match cached_batches { Some(b) => b, None => wait_for_future( py, @@ -346,7 +354,10 @@ impl PyDataFrame { let html_str: String = html_result.extract()?; if should_cache { - self.batches = Some((batches, has_more)); + let mut cache = self.batches.lock().map_err(|_| { + PyDataFusionError::Common("failed to lock DataFrame display cache".to_string()) + })?; + *cache = Some((batches.clone(), has_more)); } Ok(html_str) @@ -376,7 +387,7 @@ impl PyDataFrame { } } - fn __repr__(&mut self, py: Python) -> PyDataFusionResult { + fn __repr__(&self, py: Python) -> PyDataFusionResult { self.prepare_repr_string(py, false) } @@ -411,7 +422,7 @@ impl PyDataFrame { Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) } - fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult { + fn _repr_html_(&self, py: Python) -> PyDataFusionResult { self.prepare_repr_string(py, true) } @@ -874,7 +885,7 @@ impl PyDataFrame { #[pyo3(signature = (requested_schema=None))] fn __arrow_c_stream__<'py>( - &'py mut self, + &'py self, py: Python<'py>, requested_schema: Option>, ) -> PyDataFusionResult> { diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index fe3af2e25..c89f5b47d 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -15,40 +15,81 @@ // specific language governing permissions and limitations // under the License. -use crate::{errors::PyDataFusionResult, expr::PyExpr}; +use std::sync::{Arc, Mutex}; + +use crate::{ + errors::{PyDataFusionError, PyDataFusionResult}, + expr::PyExpr, +}; use datafusion::logical_expr::conditional_expressions::CaseBuilder; use pyo3::prelude::*; -#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass)] +#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] +#[derive(Clone)] pub struct PyCaseBuilder { - pub case_builder: CaseBuilder, + case_builder: Arc>>, } impl From for CaseBuilder { fn from(case_builder: PyCaseBuilder) -> Self { - case_builder.case_builder + case_builder + .case_builder + .lock() + .expect("Case builder mutex poisoned") + .take() + .expect("CaseBuilder has already been consumed") } } impl From for PyCaseBuilder { fn from(case_builder: CaseBuilder) -> PyCaseBuilder { - PyCaseBuilder { case_builder } + PyCaseBuilder { + case_builder: Arc::new(Mutex::new(Some(case_builder))), + } + } +} + +impl PyCaseBuilder { + fn lock_case_builder( + &self, + ) -> PyDataFusionResult>> { + self.case_builder + .lock() + .map_err(|_| PyDataFusionError::Common("failed to lock CaseBuilder".to_string())) + } + + fn take_case_builder(&self) -> PyDataFusionResult { + let mut guard = self.lock_case_builder()?; + guard.take().ok_or_else(|| { + PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) + }) + } + + fn store_case_builder(&self, builder: CaseBuilder) -> PyDataFusionResult<()> { + let mut guard = self.lock_case_builder()?; + *guard = Some(builder); + Ok(()) } } #[pymethods] impl PyCaseBuilder { - fn when(&mut self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { - PyCaseBuilder { - case_builder: self.case_builder.when(when.expr, then.expr), - } + fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { + let mut builder = self.take_case_builder()?; + let next_builder = builder.when(when.expr, then.expr); + self.store_case_builder(next_builder)?; + Ok(self.clone()) } - fn otherwise(&mut self, else_expr: PyExpr) -> PyDataFusionResult { - Ok(self.case_builder.otherwise(else_expr.expr)?.clone().into()) + fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { + let mut builder = self.take_case_builder()?; + let expr = builder.otherwise(else_expr.expr)?; + Ok(expr.clone().into()) } - fn end(&mut self) -> PyDataFusionResult { - Ok(self.case_builder.end()?.clone().into()) + fn end(&self) -> PyDataFusionResult { + let builder = self.take_case_builder()?; + let expr = builder.end()?; + Ok(expr.clone().into()) } } diff --git a/src/expr/literal.rs b/src/expr/literal.rs index 561242c9c..8a589b55a 100644 --- a/src/expr/literal.rs +++ b/src/expr/literal.rs @@ -19,7 +19,7 @@ use crate::errors::PyDataFusionError; use datafusion::{common::ScalarValue, logical_expr::expr::FieldMetadata}; use pyo3::{prelude::*, IntoPyObjectExt}; -#[pyclass(name = "Literal", module = "datafusion.expr", subclass)] +#[pyclass(name = "Literal", module = "datafusion.expr", subclass, frozen)] #[derive(Clone)] pub struct PyLiteral { pub value: ScalarValue, @@ -71,7 +71,7 @@ impl PyLiteral { extract_scalar_value!(self, Float64) } - pub fn value_decimal128(&mut self) -> PyResult<(Option, u8, i8)> { + pub fn value_decimal128(&self) -> PyResult<(Option, u8, i8)> { match &self.value { ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)), other => Err(unexpected_literal_value(other)), @@ -122,7 +122,7 @@ impl PyLiteral { extract_scalar_value!(self, Time64Nanosecond) } - pub fn value_timestamp(&mut self) -> PyResult<(Option, Option)> { + pub fn value_timestamp(&self) -> PyResult<(Option, Option)> { match &self.value { ScalarValue::TimestampNanosecond(iv, tz) | ScalarValue::TimestampMicrosecond(iv, tz) diff --git a/src/functions.rs b/src/functions.rs index e92cf053f..0f9fdf698 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -230,17 +230,13 @@ fn col(name: &str) -> PyResult { /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn case(expr: PyExpr) -> PyResult { - Ok(PyCaseBuilder { - case_builder: datafusion::logical_expr::case(expr.expr), - }) + Ok(datafusion::logical_expr::case(expr.expr).into()) } /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn when(when: PyExpr, then: PyExpr) -> PyResult { - Ok(PyCaseBuilder { - case_builder: datafusion::logical_expr::when(when.expr, then.expr), - }) + Ok(datafusion::logical_expr::when(when.expr, then.expr).into()) } /// Helper function to find the appropriate window function. diff --git a/src/record_batch.rs b/src/record_batch.rs index a85f05423..c3658cf4b 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -28,7 +28,7 @@ use pyo3::prelude::*; use pyo3::{pyclass, pymethods, PyObject, PyResult, Python}; use tokio::sync::Mutex; -#[pyclass(name = "RecordBatch", module = "datafusion", subclass)] +#[pyclass(name = "RecordBatch", module = "datafusion", subclass, frozen)] pub struct PyRecordBatch { batch: RecordBatch, } @@ -46,7 +46,7 @@ impl From for PyRecordBatch { } } -#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)] +#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass, frozen)] pub struct PyRecordBatchStream { stream: Arc>, } @@ -61,12 +61,12 @@ impl PyRecordBatchStream { #[pymethods] impl PyRecordBatchStream { - fn next(&mut self, py: Python) -> PyResult { + fn next(&self, py: Python) -> PyResult { let stream = self.stream.clone(); wait_for_future(py, next_stream(stream, true))? } - fn __next__(&mut self, py: Python) -> PyResult { + fn __next__(&self, py: Python) -> PyResult { self.next(py) } From 7030cec32c60ff16a2575f45b8c5f717aabbb086 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 15:10:01 +0800 Subject: [PATCH 02/31] Add error handling to CaseBuilder methods to preserve builder state --- python/tests/test_expr.py | 18 ++++++++++++++++++ src/expr/conditional_expr.rs | 18 ++++++++++++++---- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 810d419cf..481319b01 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -200,6 +200,24 @@ def traverse_logical_plan(plan): assert not variant.negated() +def test_case_builder_error_preserves_builder_state(): + case_builder = functions.when(lit(True), lit(1)) + + with pytest.raises(Exception) as exc_info: + case_builder.otherwise(lit("bad")) + + err_msg = str(exc_info.value) + assert "multiple data types" in err_msg + assert "CaseBuilder has already been consumed" not in err_msg + + with pytest.raises(Exception) as exc_info: + case_builder.end() + + err_msg = str(exc_info.value) + assert "multiple data types" in err_msg + assert "CaseBuilder has already been consumed" not in err_msg + + def test_expr_getitem() -> None: ctx = SessionContext() data = { diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index c89f5b47d..d851dc8e7 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -83,13 +83,23 @@ impl PyCaseBuilder { fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { let mut builder = self.take_case_builder()?; - let expr = builder.otherwise(else_expr.expr)?; - Ok(expr.clone().into()) + match builder.otherwise(else_expr.expr) { + Ok(expr) => Ok(expr.clone().into()), + Err(err) => { + self.store_case_builder(builder)?; + Err(err.into()) + } + } } fn end(&self) -> PyDataFusionResult { let builder = self.take_case_builder()?; - let expr = builder.end()?; - Ok(expr.clone().into()) + match builder.end() { + Ok(expr) => Ok(expr.clone().into()), + Err(err) => { + self.store_case_builder(builder)?; + Err(err.into()) + } + } } } From dba5c6ad621cd93459cd6281357ce68673cbfe26 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 17:42:30 +0800 Subject: [PATCH 03/31] Refactor to use parking_lot for interior mutability in schema, config, dataframe, and conditional expression modules --- Cargo.lock | 1 + Cargo.toml | 1 + src/common/schema.rs | 55 +++++++++--------------------------- src/config.rs | 20 ++++--------- src/dataframe.rs | 12 ++++---- src/expr/conditional_expr.rs | 43 ++++++++++++---------------- 6 files changed, 45 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf67256c3..2b62a69dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1594,6 +1594,7 @@ dependencies = [ "log", "mimalloc", "object_store", + "parking_lot", "prost", "prost-types", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index f1d1a0236..2c48bdd5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ futures = "0.3" object_store = { version = "0.12.3", features = ["aws", "gcp", "azure", "http"] } url = "2" log = "0.4.27" +parking_lot = "0.12" [build-dependencies] prost-types = "0.13.1" # keep in line with `datafusion-substrait` diff --git a/src/common/schema.rs b/src/common/schema.rs index d32e401e9..71dbc56d1 100644 --- a/src/common/schema.rs +++ b/src/common/schema.rs @@ -16,7 +16,7 @@ // under the License. use std::fmt::{self, Display, Formatter}; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::{any::Any, borrow::Cow}; use arrow::datatypes::Schema; @@ -25,7 +25,6 @@ use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::Constraints; use datafusion::datasource::TableType; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource}; -use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use datafusion::logical_expr::utils::split_conjunction; @@ -34,6 +33,8 @@ use crate::sql::logical::PyLogicalPlan; use super::{data_type::DataTypeMap, function::SqlFunction}; +use parking_lot::RwLock; + #[pyclass(name = "SqlSchema", module = "datafusion.common", subclass, frozen)] #[derive(Debug, Clone)] pub struct SqlSchema { @@ -110,88 +111,60 @@ impl SqlSchema { #[getter] fn name(&self) -> PyResult { - Ok(self - .name - .read() - .map_err(|_| PyRuntimeError::new_err("failed to read schema name"))? - .clone()) + Ok(self.name.read().clone()) } #[setter] fn set_name(&self, value: String) -> PyResult<()> { - *self - .name - .write() - .map_err(|_| PyRuntimeError::new_err("failed to write schema name"))? = value; + *self.name.write() = value; Ok(()) } #[getter] fn tables(&self) -> PyResult> { - Ok(self - .tables - .read() - .map_err(|_| PyRuntimeError::new_err("failed to read schema tables"))? - .clone()) + Ok(self.tables.read().clone()) } #[setter] fn set_tables(&self, tables: Vec) -> PyResult<()> { - *self - .tables - .write() - .map_err(|_| PyRuntimeError::new_err("failed to write schema tables"))? = tables; + *self.tables.write() = tables; Ok(()) } #[getter] fn views(&self) -> PyResult> { - Ok(self - .views - .read() - .map_err(|_| PyRuntimeError::new_err("failed to read schema views"))? - .clone()) + Ok(self.views.read().clone()) } #[setter] fn set_views(&self, views: Vec) -> PyResult<()> { - *self - .views - .write() - .map_err(|_| PyRuntimeError::new_err("failed to write schema views"))? = views; + *self.views.write() = views; Ok(()) } #[getter] fn functions(&self) -> PyResult> { - Ok(self - .functions - .read() - .map_err(|_| PyRuntimeError::new_err("failed to read schema functions"))? - .clone()) + Ok(self.functions.read().clone()) } #[setter] fn set_functions(&self, functions: Vec) -> PyResult<()> { - *self - .functions - .write() - .map_err(|_| PyRuntimeError::new_err("failed to write schema functions"))? = functions; + *self.functions.write() = functions; Ok(()) } pub fn table_by_name(&self, table_name: &str) -> Option { - let tables = self.tables.read().expect("failed to read schema tables"); + let tables = self.tables.read(); tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned() } pub fn add_table(&self, table: SqlTable) { - let mut tables = self.tables.write().expect("failed to write schema tables"); + let mut tables = self.tables.write(); tables.push(table); } pub fn drop_table(&self, table_name: String) { - let mut tables = self.tables.write().expect("failed to write schema tables"); + let mut tables = self.tables.write(); tables.retain(|x| !x.name.eq(&table_name)); } } diff --git a/src/config.rs b/src/config.rs index 487375dca..dcaab1066 100644 --- a/src/config.rs +++ b/src/config.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use pyo3::prelude::*; use pyo3::types::*; use datafusion::config::ConfigOptions; -use crate::errors::{PyDataFusionError, PyDataFusionResult}; +use crate::errors::PyDataFusionResult; use crate::utils::py_obj_to_scalar_value; +use parking_lot::RwLock; #[pyclass(name = "Config", module = "datafusion", subclass, frozen)] #[derive(Clone)] @@ -50,10 +51,7 @@ impl PyConfig { /// Get a configuration option pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult> { - let options = self - .config - .read() - .map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?; + let options = self.config.read(); for entry in options.entries() { if entry.key == key { return Ok(entry.value.into_pyobject(py)?); @@ -65,10 +63,7 @@ impl PyConfig { /// Set a configuration option pub fn set(&self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> { let scalar_value = py_obj_to_scalar_value(py, value)?; - let mut options = self - .config - .write() - .map_err(|_| PyDataFusionError::Common("failed to lock configuration".to_string()))?; + let mut options = self.config.write(); options.set(key, scalar_value.to_string().as_str())?; Ok(()) } @@ -76,10 +71,7 @@ impl PyConfig { /// Get all configuration options pub fn get_all(&self, py: Python) -> PyResult { let dict = PyDict::new(py); - let options = self - .config - .read() - .map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?; + let options = self.config.read(); for entry in options.entries() { dict.set_item(entry.key, entry.value.clone().into_pyobject(py)?)?; } diff --git a/src/dataframe.rs b/src/dataframe.rs index 6f679c9b2..50d72c6a4 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -17,7 +17,7 @@ use std::collections::HashMap; use std::ffi::CString; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow::compute::can_cast_types; @@ -58,6 +58,8 @@ use crate::{ expr::{sort_expr::PySortExpr, PyExpr}, }; +use parking_lot::Mutex; + // https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 // - we have not decided on the table_provider approach yet // this is an interim implementation @@ -307,9 +309,7 @@ impl PyDataFrame { let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; let (cached_batches, should_cache) = { - let mut cache = self.batches.lock().map_err(|_| { - PyDataFusionError::Common("failed to lock DataFrame display cache".to_string()) - })?; + let mut cache = self.batches.lock(); let should_cache = *is_ipython_env(py) && cache.is_none(); let batches = cache.take(); (batches, should_cache) @@ -354,9 +354,7 @@ impl PyDataFrame { let html_str: String = html_result.extract()?; if should_cache { - let mut cache = self.batches.lock().map_err(|_| { - PyDataFusionError::Common("failed to lock DataFrame display cache".to_string()) - })?; + let mut cache = self.batches.lock(); *cache = Some((batches.clone(), has_more)); } diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index d851dc8e7..816c75bf2 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use crate::{ errors::{PyDataFusionError, PyDataFusionResult}, @@ -24,23 +24,14 @@ use crate::{ use datafusion::logical_expr::conditional_expressions::CaseBuilder; use pyo3::prelude::*; +use parking_lot::{Mutex, MutexGuard}; + #[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] #[derive(Clone)] pub struct PyCaseBuilder { case_builder: Arc>>, } -impl From for CaseBuilder { - fn from(case_builder: PyCaseBuilder) -> Self { - case_builder - .case_builder - .lock() - .expect("Case builder mutex poisoned") - .take() - .expect("CaseBuilder has already been consumed") - } -} - impl From for PyCaseBuilder { fn from(case_builder: CaseBuilder) -> PyCaseBuilder { PyCaseBuilder { @@ -50,25 +41,27 @@ impl From for PyCaseBuilder { } impl PyCaseBuilder { - fn lock_case_builder( - &self, - ) -> PyDataFusionResult>> { - self.case_builder - .lock() - .map_err(|_| PyDataFusionError::Common("failed to lock CaseBuilder".to_string())) + fn lock_case_builder(&self) -> MutexGuard<'_, Option> { + self.case_builder.lock() } fn take_case_builder(&self) -> PyDataFusionResult { - let mut guard = self.lock_case_builder()?; + let mut guard = self.lock_case_builder(); guard.take().ok_or_else(|| { PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) }) } - fn store_case_builder(&self, builder: CaseBuilder) -> PyDataFusionResult<()> { - let mut guard = self.lock_case_builder()?; + fn store_case_builder(&self, builder: CaseBuilder) { + let mut guard = self.lock_case_builder(); *guard = Some(builder); - Ok(()) + } + + pub fn into_case_builder(self) -> PyDataFusionResult { + let mut guard = self.case_builder.lock(); + guard.take().ok_or_else(|| { + PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) + }) } } @@ -77,7 +70,7 @@ impl PyCaseBuilder { fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { let mut builder = self.take_case_builder()?; let next_builder = builder.when(when.expr, then.expr); - self.store_case_builder(next_builder)?; + self.store_case_builder(next_builder); Ok(self.clone()) } @@ -86,7 +79,7 @@ impl PyCaseBuilder { match builder.otherwise(else_expr.expr) { Ok(expr) => Ok(expr.clone().into()), Err(err) => { - self.store_case_builder(builder)?; + self.store_case_builder(builder); Err(err.into()) } } @@ -97,7 +90,7 @@ impl PyCaseBuilder { match builder.end() { Ok(expr) => Ok(expr.clone().into()), Err(err) => { - self.store_case_builder(builder)?; + self.store_case_builder(builder); Err(err.into()) } } From cfc9f2cf64e17b949feca8a8ce60899c4c0d4e22 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 18:24:26 +0800 Subject: [PATCH 04/31] Add concurrency tests for SqlSchema, Config, and DataFrame --- python/tests/test_concurrency.py | 102 +++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 python/tests/test_concurrency.py diff --git a/python/tests/test_concurrency.py b/python/tests/test_concurrency.py new file mode 100644 index 000000000..92f7c386b --- /dev/null +++ b/python/tests/test_concurrency.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +import pyarrow as pa + +from datafusion import Config, SessionContext, col, lit +from datafusion.common import SqlSchema +from datafusion import functions as f + + +def _run_in_threads(fn, count: int = 8) -> None: + with ThreadPoolExecutor(max_workers=count) as executor: + futures = [executor.submit(fn, i) for i in range(count)] + for future in futures: + # Propagate any exception raised in the worker thread. + future.result() + + +def test_concurrent_access_to_shared_structures() -> None: + """Exercise SqlSchema, Config, and DataFrame concurrently.""" + + schema = SqlSchema("concurrency") + config = Config() + ctx = SessionContext() + + batch = pa.record_batch([pa.array([1, 2, 3], type=pa.int32())], names=["value"]) + df = ctx.create_dataframe([[batch]]) + + config_key = "datafusion.execution.batch_size" + expected_rows = batch.num_rows + + def worker(index: int) -> None: + schema.name = f"concurrency-{index}" + assert schema.name.startswith("concurrency-") + # Exercise getters that use internal locks. + assert isinstance(schema.tables, list) + assert isinstance(schema.views, list) + assert isinstance(schema.functions, list) + + config.set(config_key, str(1024 + index)) + assert config.get(config_key) is not None + # Access the full config map to stress lock usage. + assert config_key in config.get_all() + + batches = df.collect() + assert sum(batch.num_rows for batch in batches) == expected_rows + + _run_in_threads(worker, count=12) + + +def test_case_builder_reuse_from_multiple_threads() -> None: + """Ensure the case builder can be safely reused across threads.""" + + ctx = SessionContext() + values = pa.array([0, 1, 2, 3, 4], type=pa.int32()) + df = ctx.create_dataframe([[pa.record_batch([values], names=["value"])]]) + + base_builder = f.case(col("value")) + + def add_case(i: int) -> None: + base_builder.when(lit(i), lit(f"value-{i}")) + + _run_in_threads(add_case, count=8) + + with ThreadPoolExecutor(max_workers=2) as executor: + otherwise_future = executor.submit(base_builder.otherwise, lit("default")) + case_expr = otherwise_future.result() + + result = df.select(case_expr.alias("label")).collect() + assert sum(batch.num_rows for batch in result) == len(values) + + predicate_builder = f.when(col("value") == lit(0), lit("zero")) + + def add_predicate(i: int) -> None: + predicate_builder.when(col("value") == lit(i + 1), lit(f"value-{i + 1}")) + + _run_in_threads(add_predicate, count=4) + + with ThreadPoolExecutor(max_workers=2) as executor: + end_future = executor.submit(predicate_builder.end) + predicate_expr = end_future.result() + + result = df.select(predicate_expr.alias("label")).collect() + assert sum(batch.num_rows for batch in result) == len(values) From 03a1022a0d6077f83dfc4b5a6e1f0526387b5d01 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 19:29:54 +0800 Subject: [PATCH 05/31] Add tests for CaseBuilder to ensure builder state is preserved on success --- python/tests/test_expr.py | 23 +++++++++++++++++++++++ src/expr/conditional_expr.rs | 10 ++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 481319b01..2de536df9 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -218,6 +218,29 @@ def test_case_builder_error_preserves_builder_state(): assert "CaseBuilder has already been consumed" not in err_msg +def test_case_builder_success_preserves_builder_state(): + ctx = SessionContext() + df = ctx.from_pydict({"flag": [False]}, name="tbl") + + case_builder = functions.when(col("flag"), lit("true")) + + expr_default_one = case_builder.otherwise(lit("default-1")).alias("result") + result_one = df.select(expr_default_one).collect() + assert result_one[0].column(0).to_pylist() == ["default-1"] + + expr_default_two = case_builder.otherwise(lit("default-2")).alias("result") + result_two = df.select(expr_default_two).collect() + assert result_two[0].column(0).to_pylist() == ["default-2"] + + expr_end_one = case_builder.end().alias("result") + end_one = df.select(expr_end_one).collect() + assert end_one[0].column(0).to_pylist() == ["default-2"] + + expr_end_two = case_builder.end().alias("result") + end_two = df.select(expr_end_two).collect() + assert end_two[0].column(0).to_pylist() == ["default-2"] + + def test_expr_getitem() -> None: ctx = SessionContext() data = { diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index 816c75bf2..8987ad0d9 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -77,7 +77,10 @@ impl PyCaseBuilder { fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { let mut builder = self.take_case_builder()?; match builder.otherwise(else_expr.expr) { - Ok(expr) => Ok(expr.clone().into()), + Ok(expr) => { + self.store_case_builder(builder); + Ok(expr.clone().into()) + } Err(err) => { self.store_case_builder(builder); Err(err.into()) @@ -88,7 +91,10 @@ impl PyCaseBuilder { fn end(&self) -> PyDataFusionResult { let builder = self.take_case_builder()?; match builder.end() { - Ok(expr) => Ok(expr.clone().into()), + Ok(expr) => { + self.store_case_builder(builder); + Ok(expr.clone().into()) + } Err(err) => { self.store_case_builder(builder); Err(err.into()) From d6cdfe3e034ffedfade84933c0fdfbffbff72d56 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 19:43:15 +0800 Subject: [PATCH 06/31] Add test for independent handles in CaseBuilder to verify behavior --- python/tests/test_expr.py | 36 ++++++++++++++++++++++++++++++++++++ src/expr/conditional_expr.rs | 6 +++--- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 2de536df9..b4aa77cef 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -241,6 +241,42 @@ def test_case_builder_success_preserves_builder_state(): assert end_two[0].column(0).to_pylist() == ["default-2"] +def test_case_builder_when_handles_are_independent(): + ctx = SessionContext() + df = ctx.from_pydict( + { + "flag": [True, False, False, False], + "value": [1, 15, 25, 5], + }, + name="tbl", + ) + + base_builder = functions.when(col("flag"), lit("flag-true")) + + first_builder = base_builder.when(col("value") > lit(10), lit("gt10")) + second_builder = base_builder.when(col("value") > lit(20), lit("gt20")) + + first_builder = first_builder.when(lit(True), lit("final-one")) + + expr_first = first_builder.otherwise(lit("fallback-one")).alias("first") + expr_second = second_builder.otherwise(lit("fallback-two")).alias("second") + + result = df.select(expr_first, expr_second).collect()[0] + + assert result.column(0).to_pylist() == [ + "flag-true", + "gt10", + "gt10", + "final-one", + ] + assert result.column(1).to_pylist() == [ + "flag-true", + "fallback-two", + "gt20", + "fallback-two", + ] + + def test_expr_getitem() -> None: ctx = SessionContext() data = { diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index 8987ad0d9..d104b0bbb 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -68,10 +68,10 @@ impl PyCaseBuilder { #[pymethods] impl PyCaseBuilder { fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let mut builder = self.take_case_builder()?; + let builder = self.take_case_builder()?; let next_builder = builder.when(when.expr, then.expr); - self.store_case_builder(next_builder); - Ok(self.clone()) + self.store_case_builder(next_builder.clone()); + Ok(next_builder.into()) } fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { From 17559374232f5210038859eaf2b219c4816c7fd0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 21:19:05 +0800 Subject: [PATCH 07/31] Fix CaseBuilder to preserve state correctly in when() method --- python/tests/test_expr.py | 4 ++-- src/expr/conditional_expr.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index b4aa77cef..6be9be4bc 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -271,8 +271,8 @@ def test_case_builder_when_handles_are_independent(): ] assert result.column(1).to_pylist() == [ "flag-true", - "fallback-two", - "gt20", + "gt10", + "gt10", "fallback-two", ] diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index d104b0bbb..c2c5a2a38 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -68,9 +68,9 @@ impl PyCaseBuilder { #[pymethods] impl PyCaseBuilder { fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let builder = self.take_case_builder()?; + let mut builder = self.take_case_builder()?; let next_builder = builder.when(when.expr, then.expr); - self.store_case_builder(next_builder.clone()); + self.store_case_builder(builder); Ok(next_builder.into()) } From b6ce4aee4b9ba9eed7bf52e2edcafe94b5f91c88 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 22:10:33 +0800 Subject: [PATCH 08/31] Refactor to use named constant for boolean literals in test_expr.py --- python/tests/test_concurrency.py | 3 +-- python/tests/test_expr.py | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tests/test_concurrency.py b/python/tests/test_concurrency.py index 92f7c386b..65fc5cb2c 100644 --- a/python/tests/test_concurrency.py +++ b/python/tests/test_concurrency.py @@ -20,10 +20,9 @@ from concurrent.futures import ThreadPoolExecutor import pyarrow as pa - from datafusion import Config, SessionContext, col, lit -from datafusion.common import SqlSchema from datafusion import functions as f +from datafusion.common import SqlSchema def _run_in_threads(fn, count: int = 8) -> None: diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 6be9be4bc..1d4d11942 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -20,6 +20,10 @@ import pyarrow as pa import pytest + +# Avoid passing boolean literals positionally (FBT003). Use a named constant +# so linters don't see a bare True/False literal in a function call. +_TRUE = True from datafusion import ( SessionContext, col, @@ -201,7 +205,7 @@ def traverse_logical_plan(plan): def test_case_builder_error_preserves_builder_state(): - case_builder = functions.when(lit(True), lit(1)) + case_builder = functions.when(lit(_TRUE), lit(1)) with pytest.raises(Exception) as exc_info: case_builder.otherwise(lit("bad")) @@ -256,7 +260,7 @@ def test_case_builder_when_handles_are_independent(): first_builder = base_builder.when(col("value") > lit(10), lit("gt10")) second_builder = base_builder.when(col("value") > lit(20), lit("gt20")) - first_builder = first_builder.when(lit(True), lit("final-one")) + first_builder = first_builder.when(lit(_TRUE), lit("final-one")) expr_first = first_builder.otherwise(lit("fallback-one")).alias("first") expr_second = second_builder.otherwise(lit("fallback-two")).alias("second") From fd504a1ec146022f7307da16110b9c0db86b90e0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 22:12:28 +0800 Subject: [PATCH 09/31] fix ruff errors --- python/tests/test_expr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 1d4d11942..e925b0ee2 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -20,10 +20,6 @@ import pyarrow as pa import pytest - -# Avoid passing boolean literals positionally (FBT003). Use a named constant -# so linters don't see a bare True/False literal in a function call. -_TRUE = True from datafusion import ( SessionContext, col, @@ -57,6 +53,10 @@ ensure_expr_list, ) +# Avoid passing boolean literals positionally (FBT003). Use a named constant +# so linters don't see a bare True/False literal in a function call. +_TRUE = True + @pytest.fixture def test_ctx(): From 4b01772bf47bf34a47d8b5977239e216f090a0b8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 26 Sep 2025 22:27:44 +0800 Subject: [PATCH 10/31] Refactor to introduce type aliases for cached batches in dataframe.rs --- src/dataframe.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 50d72c6a4..7a167187b 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -60,6 +60,11 @@ use crate::{ use parking_lot::Mutex; +// Type aliases to simplify very complex types used in this file and +// avoid compiler complaints about deeply nested types in struct fields. +type CachedBatches = Option<(Vec, bool)>; +type SharedCachedBatches = Arc>; + // https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 // - we have not decided on the table_provider approach yet // this is an interim implementation @@ -292,7 +297,7 @@ pub struct PyDataFrame { df: Arc, // In IPython environment cache batches between __repr__ and _repr_html_ calls. - batches: Arc, bool)>>>, + batches: SharedCachedBatches, } impl PyDataFrame { From 1c734100d3c58766643f3d74d100bda8fce3b6b2 Mon Sep 17 00:00:00 2001 From: ntjohnson1 <24689722+ntjohnson1@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:08:57 -0400 Subject: [PATCH 11/31] Cherry pick from #1252 --- src/common/data_type.rs | 17 ++++++++++++----- src/common/df_schema.rs | 2 +- src/common/function.rs | 2 +- src/common/schema.rs | 8 ++++---- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 4d7743397..3cbe31332 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -37,7 +37,7 @@ impl From for ScalarValue { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "RexType", module = "datafusion.common")] pub enum RexType { Alias, Literal, @@ -56,6 +56,7 @@ pub enum RexType { /// and manageable location. Therefore this structure exists /// to map those types and provide a simple place for developers /// to map types from one system to another. +// TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now #[derive(Debug, Clone)] #[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)] pub struct DataTypeMap { @@ -577,7 +578,7 @@ impl DataTypeMap { /// Since `DataType` exists in another package we cannot make that happen here so we wrap /// `DataType` as `PyDataType` This exists solely to satisfy those constraints. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(name = "DataType", module = "datafusion.common")] +#[pyclass(frozen, name = "DataType", module = "datafusion.common")] pub struct PyDataType { pub data_type: DataType, } @@ -635,7 +636,7 @@ impl From for PyDataType { /// Represents the possible Python types that can be mapped to the SQL types #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "PythonType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "PythonType", module = "datafusion.common")] pub enum PythonType { Array, Bool, @@ -655,7 +656,7 @@ pub enum PythonType { #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "SqlType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "SqlType", module = "datafusion.common")] pub enum SqlType { ANY, ARRAY, @@ -713,7 +714,13 @@ pub enum SqlType { #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "NullTreatment", module = "datafusion.common")] +#[pyclass( + frozen, + eq, + eq_int, + name = "NullTreatment", + module = "datafusion.common" +)] pub enum NullTreatment { IGNORE_NULLS, RESPECT_NULLS, diff --git a/src/common/df_schema.rs b/src/common/df_schema.rs index 4e1d84060..eb62469cf 100644 --- a/src/common/df_schema.rs +++ b/src/common/df_schema.rs @@ -21,7 +21,7 @@ use datafusion::common::DFSchema; use pyo3::prelude::*; #[derive(Debug, Clone)] -#[pyclass(name = "DFSchema", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "DFSchema", module = "datafusion.common", subclass)] pub struct PyDFSchema { schema: Arc, } diff --git a/src/common/function.rs b/src/common/function.rs index a8d752f16..bc6f23160 100644 --- a/src/common/function.rs +++ b/src/common/function.rs @@ -22,7 +22,7 @@ use pyo3::prelude::*; use super::data_type::PyDataType; -#[pyclass(name = "SqlFunction", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "SqlFunction", module = "datafusion.common", subclass)] #[derive(Debug, Clone)] pub struct SqlFunction { pub name: String, diff --git a/src/common/schema.rs b/src/common/schema.rs index 71dbc56d1..14ab630d3 100644 --- a/src/common/schema.rs +++ b/src/common/schema.rs @@ -248,7 +248,7 @@ fn is_supported_push_down_expr(_expr: &Expr) -> bool { true } -#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "SqlStatistics", module = "datafusion.common", subclass)] #[derive(Debug, Clone)] pub struct SqlStatistics { row_count: f64, @@ -267,7 +267,7 @@ impl SqlStatistics { } } -#[pyclass(name = "Constraints", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Constraints", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyConstraints { pub constraints: Constraints, @@ -292,7 +292,7 @@ impl Display for PyConstraints { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "TableType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "TableType", module = "datafusion.common")] pub enum PyTableType { Base, View, @@ -319,7 +319,7 @@ impl From for PyTableType { } } -#[pyclass(name = "TableSource", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "TableSource", module = "datafusion.common", subclass)] #[derive(Clone)] pub struct PyTableSource { pub table_source: Arc, From d5914c22de2688a55cc9b22cdeaf19022c6b4b53 Mon Sep 17 00:00:00 2001 From: ntjohnson1 <24689722+ntjohnson1@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:31:25 -0400 Subject: [PATCH 12/31] Add most expr - cherry pick from #1252 --- src/expr/aggregate.rs | 2 +- src/expr/aggregate_expr.rs | 7 ++++++- src/expr/alias.rs | 2 +- src/expr/analyze.rs | 2 +- src/expr/between.rs | 2 +- src/expr/binary_expr.rs | 2 +- src/expr/bool_expr.rs | 2 +- src/expr/case.rs | 2 +- src/expr/cast.rs | 2 +- src/expr/column.rs | 2 +- src/expr/copy_to.rs | 4 ++-- src/expr/create_catalog.rs | 2 +- src/expr/create_catalog_schema.rs | 7 ++++++- src/expr/create_external_table.rs | 7 ++++++- src/expr/create_function.rs | 18 ++++++++++++---- src/expr/create_index.rs | 2 +- src/expr/create_memory_table.rs | 7 ++++++- src/expr/create_view.rs | 2 +- src/expr/describe_table.rs | 2 +- src/expr/distinct.rs | 2 +- src/expr/dml.rs | 2 +- src/expr/drop_catalog_schema.rs | 7 ++++++- src/expr/drop_function.rs | 2 +- src/expr/drop_table.rs | 2 +- src/expr/drop_view.rs | 2 +- src/expr/empty_relation.rs | 2 +- src/expr/exists.rs | 2 +- src/expr/explain.rs | 2 +- src/expr/extension.rs | 2 +- src/expr/filter.rs | 2 +- src/expr/grouping_set.rs | 2 +- src/expr/in_list.rs | 2 +- src/expr/in_subquery.rs | 2 +- src/expr/indexed_field.rs | 2 +- src/expr/join.rs | 6 +++--- src/expr/like.rs | 6 +++--- src/expr/limit.rs | 2 +- src/expr/placeholder.rs | 2 +- src/expr/projection.rs | 2 +- src/expr/recursive_query.rs | 2 +- src/expr/repartition.rs | 4 ++-- src/expr/scalar_subquery.rs | 2 +- src/expr/scalar_variable.rs | 2 +- src/expr/signature.rs | 2 +- src/expr/sort.rs | 2 +- src/expr/sort_expr.rs | 2 +- src/expr/statement.rs | 34 +++++++++++++++++++++++-------- src/expr/subquery.rs | 2 +- src/expr/subquery_alias.rs | 2 +- src/expr/table_scan.rs | 2 +- src/expr/union.rs | 2 +- src/expr/unnest.rs | 2 +- src/expr/unnest_expr.rs | 2 +- src/expr/values.rs | 2 +- src/expr/window.rs | 11 +++++++--- src/sql/logical.rs | 2 +- src/unparser/dialect.rs | 2 +- src/unparser/mod.rs | 2 +- 58 files changed, 134 insertions(+), 76 deletions(-) diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index fd4393271..4af7c755a 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -28,7 +28,7 @@ use crate::errors::py_type_err; use crate::expr::PyExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Aggregate", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Aggregate", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyAggregate { aggregate: Aggregate, diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs index 7c5d3d31f..72ba0638f 100644 --- a/src/expr/aggregate_expr.rs +++ b/src/expr/aggregate_expr.rs @@ -20,7 +20,12 @@ use datafusion::logical_expr::expr::AggregateFunction; use pyo3::prelude::*; use std::fmt::{Display, Formatter}; -#[pyclass(name = "AggregateFunction", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "AggregateFunction", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyAggregateFunction { aggr: AggregateFunction, diff --git a/src/expr/alias.rs b/src/expr/alias.rs index 40746f200..588c00fdf 100644 --- a/src/expr/alias.rs +++ b/src/expr/alias.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use datafusion::logical_expr::expr::Alias; -#[pyclass(name = "Alias", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Alias", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyAlias { alias: Alias, diff --git a/src/expr/analyze.rs b/src/expr/analyze.rs index e8081e95b..c7caeebc8 100644 --- a/src/expr/analyze.rs +++ b/src/expr/analyze.rs @@ -23,7 +23,7 @@ use super::logical_node::LogicalNode; use crate::common::df_schema::PyDFSchema; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Analyze", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Analyze", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyAnalyze { analyze: Analyze, diff --git a/src/expr/between.rs b/src/expr/between.rs index 817f1baae..1f61599a3 100644 --- a/src/expr/between.rs +++ b/src/expr/between.rs @@ -20,7 +20,7 @@ use datafusion::logical_expr::expr::Between; use pyo3::prelude::*; use std::fmt::{self, Display, Formatter}; -#[pyclass(name = "Between", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Between", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyBetween { between: Between, diff --git a/src/expr/binary_expr.rs b/src/expr/binary_expr.rs index 740299211..94379583c 100644 --- a/src/expr/binary_expr.rs +++ b/src/expr/binary_expr.rs @@ -19,7 +19,7 @@ use crate::expr::PyExpr; use datafusion::logical_expr::BinaryExpr; use pyo3::prelude::*; -#[pyclass(name = "BinaryExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "BinaryExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyBinaryExpr { expr: BinaryExpr, diff --git a/src/expr/bool_expr.rs b/src/expr/bool_expr.rs index e67e25d74..22eabdb88 100644 --- a/src/expr/bool_expr.rs +++ b/src/expr/bool_expr.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use super::PyExpr; -#[pyclass(name = "Not", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Not", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyNot { expr: Expr, diff --git a/src/expr/case.rs b/src/expr/case.rs index 92e28ba56..1a7369826 100644 --- a/src/expr/case.rs +++ b/src/expr/case.rs @@ -19,7 +19,7 @@ use crate::expr::PyExpr; use datafusion::logical_expr::Case; use pyo3::prelude::*; -#[pyclass(name = "Case", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Case", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCase { case: Case, diff --git a/src/expr/cast.rs b/src/expr/cast.rs index b8faea634..03e2b8476 100644 --- a/src/expr/cast.rs +++ b/src/expr/cast.rs @@ -19,7 +19,7 @@ use crate::{common::data_type::PyDataType, expr::PyExpr}; use datafusion::logical_expr::{Cast, TryCast}; use pyo3::prelude::*; -#[pyclass(name = "Cast", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Cast", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCast { cast: Cast, diff --git a/src/expr/column.rs b/src/expr/column.rs index 50f316f1c..300079481 100644 --- a/src/expr/column.rs +++ b/src/expr/column.rs @@ -18,7 +18,7 @@ use datafusion::common::Column; use pyo3::prelude::*; -#[pyclass(name = "Column", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Column", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyColumn { pub col: Column, diff --git a/src/expr/copy_to.rs b/src/expr/copy_to.rs index c2f7c61d4..422ab77f4 100644 --- a/src/expr/copy_to.rs +++ b/src/expr/copy_to.rs @@ -28,7 +28,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "CopyTo", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CopyTo", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCopyTo { copy: CopyTo, @@ -114,7 +114,7 @@ impl PyCopyTo { } } -#[pyclass(name = "FileType", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "FileType", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyFileType { file_type: Arc, diff --git a/src/expr/create_catalog.rs b/src/expr/create_catalog.rs index d2d2ee8f6..361387894 100644 --- a/src/expr/create_catalog.rs +++ b/src/expr/create_catalog.rs @@ -27,7 +27,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateCatalog", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateCatalog", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateCatalog { create: CreateCatalog, diff --git a/src/expr/create_catalog_schema.rs b/src/expr/create_catalog_schema.rs index e794962f5..cb3be2d30 100644 --- a/src/expr/create_catalog_schema.rs +++ b/src/expr/create_catalog_schema.rs @@ -27,7 +27,12 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateCatalogSchema", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateCatalogSchema", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateCatalogSchema { create: CreateCatalogSchema, diff --git a/src/expr/create_external_table.rs b/src/expr/create_external_table.rs index 3e35af006..920d0d613 100644 --- a/src/expr/create_external_table.rs +++ b/src/expr/create_external_table.rs @@ -29,7 +29,12 @@ use crate::common::df_schema::PyDFSchema; use super::{logical_node::LogicalNode, sort_expr::PySortExpr}; -#[pyclass(name = "CreateExternalTable", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateExternalTable", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateExternalTable { create: CreateExternalTable, diff --git a/src/expr/create_function.rs b/src/expr/create_function.rs index c02ceebb1..1b663b466 100644 --- a/src/expr/create_function.rs +++ b/src/expr/create_function.rs @@ -30,7 +30,7 @@ use super::PyExpr; use crate::common::{data_type::PyDataType, df_schema::PyDFSchema}; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "CreateFunction", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateFunction", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateFunction { create: CreateFunction, @@ -54,21 +54,31 @@ impl Display for PyCreateFunction { } } -#[pyclass(name = "OperateFunctionArg", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "OperateFunctionArg", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyOperateFunctionArg { arg: OperateFunctionArg, } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "Volatility", module = "datafusion.expr")] +#[pyclass(frozen, eq, eq_int, name = "Volatility", module = "datafusion.expr")] pub enum PyVolatility { Immutable, Stable, Volatile, } -#[pyclass(name = "CreateFunctionBody", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateFunctionBody", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateFunctionBody { body: CreateFunctionBody, diff --git a/src/expr/create_index.rs b/src/expr/create_index.rs index 0f4b5011a..7b68df305 100644 --- a/src/expr/create_index.rs +++ b/src/expr/create_index.rs @@ -27,7 +27,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, sort_expr::PySortExpr}; -#[pyclass(name = "CreateIndex", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateIndex", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateIndex { create: CreateIndex, diff --git a/src/expr/create_memory_table.rs b/src/expr/create_memory_table.rs index 37f4d3420..15aaa810b 100644 --- a/src/expr/create_memory_table.rs +++ b/src/expr/create_memory_table.rs @@ -24,7 +24,12 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateMemoryTable", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateMemoryTable", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateMemoryTable { create: CreateMemoryTable, diff --git a/src/expr/create_view.rs b/src/expr/create_view.rs index 718e404d0..49b3b6199 100644 --- a/src/expr/create_view.rs +++ b/src/expr/create_view.rs @@ -24,7 +24,7 @@ use crate::{errors::py_type_err, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateView", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateView", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateView { create: CreateView, diff --git a/src/expr/describe_table.rs b/src/expr/describe_table.rs index 6c48f3c77..315026fef 100644 --- a/src/expr/describe_table.rs +++ b/src/expr/describe_table.rs @@ -28,7 +28,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "DescribeTable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DescribeTable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDescribeTable { describe: DescribeTable, diff --git a/src/expr/distinct.rs b/src/expr/distinct.rs index 889e7099d..5770b849d 100644 --- a/src/expr/distinct.rs +++ b/src/expr/distinct.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "Distinct", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Distinct", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDistinct { distinct: Distinct, diff --git a/src/expr/dml.rs b/src/expr/dml.rs index 251e336cc..4437a9de9 100644 --- a/src/expr/dml.rs +++ b/src/expr/dml.rs @@ -24,7 +24,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "DmlStatement", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DmlStatement", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDmlStatement { dml: DmlStatement, diff --git a/src/expr/drop_catalog_schema.rs b/src/expr/drop_catalog_schema.rs index b4a4c521c..7008bcd24 100644 --- a/src/expr/drop_catalog_schema.rs +++ b/src/expr/drop_catalog_schema.rs @@ -28,7 +28,12 @@ use crate::common::df_schema::PyDFSchema; use super::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "DropCatalogSchema", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "DropCatalogSchema", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyDropCatalogSchema { drop: DropCatalogSchema, diff --git a/src/expr/drop_function.rs b/src/expr/drop_function.rs index fca9eb94b..42ad3e1fe 100644 --- a/src/expr/drop_function.rs +++ b/src/expr/drop_function.rs @@ -27,7 +27,7 @@ use super::logical_node::LogicalNode; use crate::common::df_schema::PyDFSchema; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "DropFunction", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DropFunction", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDropFunction { drop: DropFunction, diff --git a/src/expr/drop_table.rs b/src/expr/drop_table.rs index 3f442539a..6ff4f01c4 100644 --- a/src/expr/drop_table.rs +++ b/src/expr/drop_table.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "DropTable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DropTable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDropTable { drop: DropTable, diff --git a/src/expr/drop_view.rs b/src/expr/drop_view.rs index 6196c8bb5..b2aff4e9b 100644 --- a/src/expr/drop_view.rs +++ b/src/expr/drop_view.rs @@ -28,7 +28,7 @@ use crate::common::df_schema::PyDFSchema; use super::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "DropView", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DropView", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDropView { drop: DropView, diff --git a/src/expr/empty_relation.rs b/src/expr/empty_relation.rs index 758213423..797a8c02d 100644 --- a/src/expr/empty_relation.rs +++ b/src/expr/empty_relation.rs @@ -22,7 +22,7 @@ use std::fmt::{self, Display, Formatter}; use super::logical_node::LogicalNode; -#[pyclass(name = "EmptyRelation", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "EmptyRelation", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyEmptyRelation { empty: EmptyRelation, diff --git a/src/expr/exists.rs b/src/expr/exists.rs index 693357836..392bfcb9e 100644 --- a/src/expr/exists.rs +++ b/src/expr/exists.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use super::subquery::PySubquery; -#[pyclass(name = "Exists", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Exists", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExists { exists: Exists, diff --git a/src/expr/explain.rs b/src/expr/explain.rs index fc02fe2b5..71b7b2c13 100644 --- a/src/expr/explain.rs +++ b/src/expr/explain.rs @@ -24,7 +24,7 @@ use crate::{common::df_schema::PyDFSchema, errors::py_type_err, sql::logical::Py use super::logical_node::LogicalNode; -#[pyclass(name = "Explain", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Explain", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExplain { explain: Explain, diff --git a/src/expr/extension.rs b/src/expr/extension.rs index 1e3fbb199..7d913ff8c 100644 --- a/src/expr/extension.rs +++ b/src/expr/extension.rs @@ -22,7 +22,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "Extension", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Extension", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExtension { pub node: Extension, diff --git a/src/expr/filter.rs b/src/expr/filter.rs index 4fcb600cd..76338d139 100644 --- a/src/expr/filter.rs +++ b/src/expr/filter.rs @@ -24,7 +24,7 @@ use crate::expr::logical_node::LogicalNode; use crate::expr::PyExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Filter", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Filter", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyFilter { filter: Filter, diff --git a/src/expr/grouping_set.rs b/src/expr/grouping_set.rs index 63a1c0b50..107dd9370 100644 --- a/src/expr/grouping_set.rs +++ b/src/expr/grouping_set.rs @@ -18,7 +18,7 @@ use datafusion::logical_expr::GroupingSet; use pyo3::prelude::*; -#[pyclass(name = "GroupingSet", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "GroupingSet", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyGroupingSet { grouping_set: GroupingSet, diff --git a/src/expr/in_list.rs b/src/expr/in_list.rs index 5dfd8d8eb..e2e6d7832 100644 --- a/src/expr/in_list.rs +++ b/src/expr/in_list.rs @@ -19,7 +19,7 @@ use crate::expr::PyExpr; use datafusion::logical_expr::expr::InList; use pyo3::prelude::*; -#[pyclass(name = "InList", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "InList", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyInList { in_list: InList, diff --git a/src/expr/in_subquery.rs b/src/expr/in_subquery.rs index 306b68a6e..6d4a38e49 100644 --- a/src/expr/in_subquery.rs +++ b/src/expr/in_subquery.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use super::{subquery::PySubquery, PyExpr}; -#[pyclass(name = "InSubquery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "InSubquery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyInSubquery { in_subquery: InSubquery, diff --git a/src/expr/indexed_field.rs b/src/expr/indexed_field.rs index a22dc6b27..1dfa0ed2f 100644 --- a/src/expr/indexed_field.rs +++ b/src/expr/indexed_field.rs @@ -22,7 +22,7 @@ use std::fmt::{Display, Formatter}; use super::literal::PyLiteral; -#[pyclass(name = "GetIndexedField", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "GetIndexedField", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyGetIndexedField { indexed_field: GetIndexedField, diff --git a/src/expr/join.rs b/src/expr/join.rs index 7b7e0d9dd..3fde874d5 100644 --- a/src/expr/join.rs +++ b/src/expr/join.rs @@ -25,7 +25,7 @@ use crate::expr::{logical_node::LogicalNode, PyExpr}; use crate::sql::logical::PyLogicalPlan; #[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[pyclass(name = "JoinType", module = "datafusion.expr")] +#[pyclass(frozen, name = "JoinType", module = "datafusion.expr")] pub struct PyJoinType { join_type: JoinType, } @@ -60,7 +60,7 @@ impl Display for PyJoinType { } #[derive(Debug, Clone, Copy)] -#[pyclass(name = "JoinConstraint", module = "datafusion.expr")] +#[pyclass(frozen, name = "JoinConstraint", module = "datafusion.expr")] pub struct PyJoinConstraint { join_constraint: JoinConstraint, } @@ -87,7 +87,7 @@ impl PyJoinConstraint { } } -#[pyclass(name = "Join", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Join", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyJoin { join: Join, diff --git a/src/expr/like.rs b/src/expr/like.rs index f180f5d4c..0a36dcd92 100644 --- a/src/expr/like.rs +++ b/src/expr/like.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use crate::expr::PyExpr; -#[pyclass(name = "Like", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Like", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyLike { like: Like, @@ -79,7 +79,7 @@ impl PyLike { } } -#[pyclass(name = "ILike", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ILike", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyILike { like: Like, @@ -137,7 +137,7 @@ impl PyILike { } } -#[pyclass(name = "SimilarTo", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SimilarTo", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySimilarTo { like: Like, diff --git a/src/expr/limit.rs b/src/expr/limit.rs index 92552814e..cf6971fb3 100644 --- a/src/expr/limit.rs +++ b/src/expr/limit.rs @@ -23,7 +23,7 @@ use crate::common::df_schema::PyDFSchema; use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Limit", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Limit", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyLimit { limit: Limit, diff --git a/src/expr/placeholder.rs b/src/expr/placeholder.rs index 4ac2c47e3..268263d41 100644 --- a/src/expr/placeholder.rs +++ b/src/expr/placeholder.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use crate::common::data_type::PyDataType; -#[pyclass(name = "Placeholder", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Placeholder", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyPlaceholder { placeholder: Placeholder, diff --git a/src/expr/projection.rs b/src/expr/projection.rs index b5a9ef34a..b2d5db79b 100644 --- a/src/expr/projection.rs +++ b/src/expr/projection.rs @@ -25,7 +25,7 @@ use crate::expr::logical_node::LogicalNode; use crate::expr::PyExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Projection", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Projection", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyProjection { pub projection: Projection, diff --git a/src/expr/recursive_query.rs b/src/expr/recursive_query.rs index 2517b7417..fe047315e 100644 --- a/src/expr/recursive_query.rs +++ b/src/expr/recursive_query.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "RecursiveQuery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "RecursiveQuery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyRecursiveQuery { query: RecursiveQuery, diff --git a/src/expr/repartition.rs b/src/expr/repartition.rs index 48b5e7041..ee6d1dc45 100644 --- a/src/expr/repartition.rs +++ b/src/expr/repartition.rs @@ -24,13 +24,13 @@ use crate::{errors::py_type_err, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, PyExpr}; -#[pyclass(name = "Repartition", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Repartition", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyRepartition { repartition: Repartition, } -#[pyclass(name = "Partitioning", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Partitioning", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyPartitioning { partitioning: Partitioning, diff --git a/src/expr/scalar_subquery.rs b/src/expr/scalar_subquery.rs index 9d35f28a9..e58d66e19 100644 --- a/src/expr/scalar_subquery.rs +++ b/src/expr/scalar_subquery.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use super::subquery::PySubquery; -#[pyclass(name = "ScalarSubquery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ScalarSubquery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyScalarSubquery { subquery: Subquery, diff --git a/src/expr/scalar_variable.rs b/src/expr/scalar_variable.rs index 7b50ba241..f3c128a4c 100644 --- a/src/expr/scalar_variable.rs +++ b/src/expr/scalar_variable.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use crate::common::data_type::PyDataType; -#[pyclass(name = "ScalarVariable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ScalarVariable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyScalarVariable { data_type: DataType, diff --git a/src/expr/signature.rs b/src/expr/signature.rs index e85763555..e2c23dce9 100644 --- a/src/expr/signature.rs +++ b/src/expr/signature.rs @@ -19,7 +19,7 @@ use datafusion::logical_expr::{TypeSignature, Volatility}; use pyo3::prelude::*; #[allow(dead_code)] -#[pyclass(name = "Signature", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Signature", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySignature { type_signature: TypeSignature, diff --git a/src/expr/sort.rs b/src/expr/sort.rs index 79a8aee50..d5ea07fdd 100644 --- a/src/expr/sort.rs +++ b/src/expr/sort.rs @@ -25,7 +25,7 @@ use crate::expr::logical_node::LogicalNode; use crate::expr::sort_expr::PySortExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Sort", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Sort", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySort { sort: Sort, diff --git a/src/expr/sort_expr.rs b/src/expr/sort_expr.rs index e2df6b963..3f279027e 100644 --- a/src/expr/sort_expr.rs +++ b/src/expr/sort_expr.rs @@ -20,7 +20,7 @@ use datafusion::logical_expr::SortExpr; use pyo3::prelude::*; use std::fmt::{self, Display, Formatter}; -#[pyclass(name = "SortExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SortExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySortExpr { pub(crate) sort: SortExpr, diff --git a/src/expr/statement.rs b/src/expr/statement.rs index 83774cda1..1ea4f9f7f 100644 --- a/src/expr/statement.rs +++ b/src/expr/statement.rs @@ -25,7 +25,12 @@ use crate::{common::data_type::PyDataType, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, PyExpr}; -#[pyclass(name = "TransactionStart", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "TransactionStart", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyTransactionStart { transaction_start: TransactionStart, @@ -56,7 +61,13 @@ impl LogicalNode for PyTransactionStart { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "TransactionAccessMode", module = "datafusion.expr")] +#[pyclass( + frozen, + eq, + eq_int, + name = "TransactionAccessMode", + module = "datafusion.expr" +)] pub enum PyTransactionAccessMode { ReadOnly, ReadWrite, @@ -84,6 +95,7 @@ impl TryFrom for TransactionAccessMode { #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass( + frozen, eq, eq_int, name = "TransactionIsolationLevel", @@ -161,7 +173,7 @@ impl PyTransactionStart { } } -#[pyclass(name = "TransactionEnd", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "TransactionEnd", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyTransactionEnd { transaction_end: TransactionEnd, @@ -192,7 +204,13 @@ impl LogicalNode for PyTransactionEnd { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "TransactionConclusion", module = "datafusion.expr")] +#[pyclass( + frozen, + eq, + eq_int, + name = "TransactionConclusion", + module = "datafusion.expr" +)] pub enum PyTransactionConclusion { Commit, Rollback, @@ -236,7 +254,7 @@ impl PyTransactionEnd { } } -#[pyclass(name = "SetVariable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SetVariable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySetVariable { set_variable: SetVariable, @@ -284,7 +302,7 @@ impl PySetVariable { } } -#[pyclass(name = "Prepare", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Prepare", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyPrepare { prepare: Prepare, @@ -352,7 +370,7 @@ impl PyPrepare { } } -#[pyclass(name = "Execute", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Execute", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExecute { execute: Execute, @@ -409,7 +427,7 @@ impl PyExecute { } } -#[pyclass(name = "Deallocate", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Deallocate", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDeallocate { deallocate: Deallocate, diff --git a/src/expr/subquery.rs b/src/expr/subquery.rs index 77f56f9a9..785cf7d1a 100644 --- a/src/expr/subquery.rs +++ b/src/expr/subquery.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "Subquery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Subquery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySubquery { subquery: Subquery, diff --git a/src/expr/subquery_alias.rs b/src/expr/subquery_alias.rs index 3302e7f23..ab1229bfe 100644 --- a/src/expr/subquery_alias.rs +++ b/src/expr/subquery_alias.rs @@ -24,7 +24,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "SubqueryAlias", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SubqueryAlias", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySubqueryAlias { subquery_alias: SubqueryAlias, diff --git a/src/expr/table_scan.rs b/src/expr/table_scan.rs index 329964687..34a140df3 100644 --- a/src/expr/table_scan.rs +++ b/src/expr/table_scan.rs @@ -24,7 +24,7 @@ use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; use crate::{common::df_schema::PyDFSchema, expr::PyExpr}; -#[pyclass(name = "TableScan", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "TableScan", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyTableScan { table_scan: TableScan, diff --git a/src/expr/union.rs b/src/expr/union.rs index e0b221398..b7b589650 100644 --- a/src/expr/union.rs +++ b/src/expr/union.rs @@ -23,7 +23,7 @@ use crate::common::df_schema::PyDFSchema; use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Union", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Union", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyUnion { union_: Union, diff --git a/src/expr/unnest.rs b/src/expr/unnest.rs index c8833347f..7ed7919b1 100644 --- a/src/expr/unnest.rs +++ b/src/expr/unnest.rs @@ -23,7 +23,7 @@ use crate::common::df_schema::PyDFSchema; use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Unnest", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Unnest", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyUnnest { unnest_: Unnest, diff --git a/src/expr/unnest_expr.rs b/src/expr/unnest_expr.rs index 634186ed8..2cdf46a59 100644 --- a/src/expr/unnest_expr.rs +++ b/src/expr/unnest_expr.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use super::PyExpr; -#[pyclass(name = "UnnestExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "UnnestExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyUnnestExpr { unnest: Unnest, diff --git a/src/expr/values.rs b/src/expr/values.rs index fb2692230..63d94ce00 100644 --- a/src/expr/values.rs +++ b/src/expr/values.rs @@ -25,7 +25,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, PyExpr}; -#[pyclass(name = "Values", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Values", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyValues { values: Values, diff --git a/src/expr/window.rs b/src/expr/window.rs index 77ecb71aa..2723007ec 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -30,13 +30,13 @@ use std::fmt::{self, Display, Formatter}; use super::py_expr_list; -#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "WindowExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyWindowExpr { window: Window, } -#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "WindowFrame", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyWindowFrame { window_frame: WindowFrame, @@ -54,7 +54,12 @@ impl From for PyWindowFrame { } } -#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "WindowFrameBound", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyWindowFrameBound { frame_bound: WindowFrameBound, diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 97d320470..47ea39fdc 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -63,7 +63,7 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes}; use crate::expr::logical_node::LogicalNode; -#[pyclass(name = "LogicalPlan", module = "datafusion", subclass)] +#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyLogicalPlan { pub(crate) plan: Arc, diff --git a/src/unparser/dialect.rs b/src/unparser/dialect.rs index caeef9949..5df0a0c2e 100644 --- a/src/unparser/dialect.rs +++ b/src/unparser/dialect.rs @@ -22,7 +22,7 @@ use datafusion::sql::unparser::dialect::{ }; use pyo3::prelude::*; -#[pyclass(name = "Dialect", module = "datafusion.unparser", subclass)] +#[pyclass(frozen, name = "Dialect", module = "datafusion.unparser", subclass)] #[derive(Clone)] pub struct PyDialect { pub dialect: Arc, diff --git a/src/unparser/mod.rs b/src/unparser/mod.rs index b4b0fed10..f234345a7 100644 --- a/src/unparser/mod.rs +++ b/src/unparser/mod.rs @@ -25,7 +25,7 @@ use pyo3::{exceptions::PyValueError, prelude::*}; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Unparser", module = "datafusion.unparser", subclass)] +#[pyclass(frozen, name = "Unparser", module = "datafusion.unparser", subclass)] #[derive(Clone)] pub struct PyUnparser { dialect: Arc, From fe3ad12af546c0014e6dfaa26823f6707e1950a2 Mon Sep 17 00:00:00 2001 From: ntjohnson1 <24689722+ntjohnson1@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:44:54 -0400 Subject: [PATCH 13/31] Add source root - cherry pick #1252 --- src/catalog.rs | 6 +++--- src/config.rs | 5 ++++- src/context.rs | 6 +++--- src/dataframe.rs | 6 +++--- src/expr.rs | 4 ++-- src/physical_plan.rs | 2 +- src/store.rs | 15 ++++++++++----- src/substrait.rs | 8 ++++---- src/udaf.rs | 2 +- src/udf.rs | 2 +- src/udtf.rs | 2 +- src/udwf.rs | 2 +- 12 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 17d4ec3b8..b5fa3da72 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -36,19 +36,19 @@ use std::any::Any; use std::collections::HashSet; use std::sync::Arc; -#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] +#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)] #[derive(Clone)] pub struct PyCatalog { pub catalog: Arc, } -#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)] #[derive(Clone)] pub struct PySchema { pub schema: Arc, } -#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] +#[pyclass(frozen, name = "RawTable", module = "datafusion.catalog", subclass)] #[derive(Clone)] pub struct PyTable { pub table: Arc, diff --git a/src/config.rs b/src/config.rs index dcaab1066..bd4ff24a8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,7 @@ -// Licensed to the Apache Software Foundation (ASF) under one +// Licensed to the Apache use parking_lot::RwLock; + +#[pyclass(name = "Config", module = "datafusion", subclass, frozen)] +#[derive(Clone)]Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file diff --git a/src/context.rs b/src/context.rs index 0ccb03261..e3f978ee1 100644 --- a/src/context.rs +++ b/src/context.rs @@ -77,7 +77,7 @@ use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; /// Configuration options for a SessionContext -#[pyclass(name = "SessionConfig", module = "datafusion", subclass)] +#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)] #[derive(Clone, Default)] pub struct PySessionConfig { pub config: SessionConfig, @@ -170,7 +170,7 @@ impl PySessionConfig { } /// Runtime options for a SessionContext -#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)] +#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)] #[derive(Clone)] pub struct PyRuntimeEnvBuilder { pub builder: RuntimeEnvBuilder, @@ -257,7 +257,7 @@ impl PyRuntimeEnvBuilder { } /// `PySQLOptions` allows you to specify options to the sql execution. -#[pyclass(name = "SQLOptions", module = "datafusion", subclass)] +#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)] #[derive(Clone)] pub struct PySQLOptions { pub options: SQLOptions, diff --git a/src/dataframe.rs b/src/dataframe.rs index 7a167187b..5c0b1c385 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -68,7 +68,7 @@ type SharedCachedBatches = Arc>; // https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 // - we have not decided on the table_provider approach yet // this is an interim implementation -#[pyclass(name = "TableProvider", module = "datafusion")] +#[pyclass(frozen, name = "TableProvider", module = "datafusion")] pub struct PyTableProvider { provider: Arc, } @@ -195,7 +195,7 @@ fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult< } /// Python mapping of `ParquetOptions` (includes just the writer-related options). -#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ParquetWriterOptions", module = "datafusion", subclass)] #[derive(Clone, Default)] pub struct PyParquetWriterOptions { options: ParquetOptions, @@ -256,7 +256,7 @@ impl PyParquetWriterOptions { } /// Python mapping of `ParquetColumnOptions`. -#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ParquetColumnOptions", module = "datafusion", subclass)] #[derive(Clone, Default)] pub struct PyParquetColumnOptions { options: ParquetColumnOptions, diff --git a/src/expr.rs b/src/expr.rs index e2c53025c..c9eddaa2d 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -115,7 +115,7 @@ pub mod window; use sort_expr::{to_sort_expressions, PySortExpr}; /// A PyExpr that can be used on a DataFrame -#[pyclass(name = "RawExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "RawExpr", module = "datafusion.expr", subclass)] #[derive(Debug, Clone)] pub struct PyExpr { pub expr: Expr, @@ -637,7 +637,7 @@ impl PyExpr { } } -#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ExprFuncBuilder", module = "datafusion.expr", subclass)] #[derive(Debug, Clone)] pub struct PyExprFuncBuilder { pub builder: ExprFuncBuilder, diff --git a/src/physical_plan.rs b/src/physical_plan.rs index 49db643e1..4994b0114 100644 --- a/src/physical_plan.rs +++ b/src/physical_plan.rs @@ -24,7 +24,7 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes}; use crate::{context::PySessionContext, errors::PyDataFusionResult}; -#[pyclass(name = "ExecutionPlan", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ExecutionPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyExecutionPlan { pub plan: Arc, diff --git a/src/store.rs b/src/store.rs index 1e5fab472..998681854 100644 --- a/src/store.rs +++ b/src/store.rs @@ -36,7 +36,12 @@ pub enum StorageContexts { HTTP(PyHttpContext), } -#[pyclass(name = "LocalFileSystem", module = "datafusion.store", subclass)] +#[pyclass( + frozen, + name = "LocalFileSystem", + module = "datafusion.store", + subclass +)] #[derive(Debug, Clone)] pub struct PyLocalFileSystemContext { pub inner: Arc, @@ -62,7 +67,7 @@ impl PyLocalFileSystemContext { } } -#[pyclass(name = "MicrosoftAzure", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "MicrosoftAzure", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyMicrosoftAzureContext { pub inner: Arc, @@ -134,7 +139,7 @@ impl PyMicrosoftAzureContext { } } -#[pyclass(name = "GoogleCloud", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "GoogleCloud", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyGoogleCloudContext { pub inner: Arc, @@ -164,7 +169,7 @@ impl PyGoogleCloudContext { } } -#[pyclass(name = "AmazonS3", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "AmazonS3", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyAmazonS3Context { pub inner: Arc, @@ -223,7 +228,7 @@ impl PyAmazonS3Context { } } -#[pyclass(name = "Http", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "Http", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyHttpContext { pub url: String, diff --git a/src/substrait.rs b/src/substrait.rs index f1936b05e..291892cf8 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -27,7 +27,7 @@ use datafusion_substrait::serializer; use datafusion_substrait::substrait::proto::Plan; use prost::Message; -#[pyclass(name = "Plan", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Plan", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PyPlan { pub plan: Plan, @@ -59,7 +59,7 @@ impl From for PyPlan { /// A PySubstraitSerializer is a representation of a Serializer that is capable of both serializing /// a `LogicalPlan` instance to Substrait Protobuf bytes and also deserialize Substrait Protobuf bytes /// to a valid `LogicalPlan` instance. -#[pyclass(name = "Serde", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Serde", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitSerializer; @@ -112,7 +112,7 @@ impl PySubstraitSerializer { } } -#[pyclass(name = "Producer", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Producer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitProducer; @@ -129,7 +129,7 @@ impl PySubstraitProducer { } } -#[pyclass(name = "Consumer", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Consumer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitConsumer; diff --git a/src/udaf.rs b/src/udaf.rs index 78f4e2b0c..eab4581df 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -155,7 +155,7 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction { } /// Represents an AggregateUDF -#[pyclass(name = "AggregateUDF", module = "datafusion", subclass)] +#[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyAggregateUDF { pub(crate) function: AggregateUDF, diff --git a/src/udf.rs b/src/udf.rs index de1e3f18c..a9249d6c8 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -81,7 +81,7 @@ fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation { } /// Represents a PyScalarUDF -#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyScalarUDF { pub(crate) function: ScalarUDF, diff --git a/src/udtf.rs b/src/udtf.rs index db16d6c05..55f306b17 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -31,7 +31,7 @@ use pyo3::exceptions::PyNotImplementedError; use pyo3::types::{PyCapsule, PyTuple}; /// Represents a user defined table function -#[pyclass(name = "TableFunction", module = "datafusion")] +#[pyclass(frozen, name = "TableFunction", module = "datafusion")] #[derive(Debug, Clone)] pub struct PyTableFunction { pub(crate) name: String, diff --git a/src/udwf.rs b/src/udwf.rs index 70a66e38f..ceeaa0ef1 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -210,7 +210,7 @@ pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFac } /// Represents an WindowUDF -#[pyclass(name = "WindowUDF", module = "datafusion", subclass)] +#[pyclass(frozen, name = "WindowUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyWindowUDF { pub(crate) function: WindowUDF, From 509850e77b8680b985b61d34a4aa49969bb338c2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 28 Sep 2025 19:55:47 +0800 Subject: [PATCH 14/31] Fix license comment formatting in config.rs --- src/config.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/config.rs b/src/config.rs index bd4ff24a8..e45efc066 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,4 @@ -// Licensed to the Apache use parking_lot::RwLock; - -#[pyclass(name = "Config", module = "datafusion", subclass, frozen)] -#[derive(Clone)]Foundation (ASF) under one +// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -28,7 +25,6 @@ use datafusion::config::ConfigOptions; use crate::errors::PyDataFusionResult; use crate::utils::py_obj_to_scalar_value; use parking_lot::RwLock; - #[pyclass(name = "Config", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub(crate) struct PyConfig { From c95e8b18211fea4cc4fa256199beda92c8adacd4 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 12:50:28 +0800 Subject: [PATCH 15/31] Refactor caching logic to use a local variable for IPython environment check --- src/dataframe.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 5c0b1c385..555a8500d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -313,9 +313,11 @@ impl PyDataFrame { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; + let is_ipython = *is_ipython_env(py); + let (cached_batches, should_cache) = { let mut cache = self.batches.lock(); - let should_cache = *is_ipython_env(py) && cache.is_none(); + let should_cache = is_ipython && cache.is_none(); let batches = cache.take(); (batches, should_cache) }; From 799e8fb52bfb47b66828f831dd63f18f51b63c55 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 14:02:00 +0800 Subject: [PATCH 16/31] Add test for ensuring exposed pyclasses default to frozen --- python/tests/test_pyclass_frozen.py | 76 +++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 python/tests/test_pyclass_frozen.py diff --git a/python/tests/test_pyclass_frozen.py b/python/tests/test_pyclass_frozen.py new file mode 100644 index 000000000..e677b3727 --- /dev/null +++ b/python/tests/test_pyclass_frozen.py @@ -0,0 +1,76 @@ +"""Ensure exposed pyclasses default to frozen.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator + +PYCLASS_RE = re.compile(r"#\[\s*pyclass\s*(?:\((?P.*?)\))?\s*\]", re.DOTALL) +ARG_STRING_RE = re.compile(r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P[^\"]+)\"") +STRUCT_NAME_RE = re.compile(r"\b(?:pub\s+)?(?:struct|enum)\s+(?P[A-Za-z_][A-Za-z0-9_]*)") + + +@dataclass +class PyClass: + module: str + name: str + frozen: bool + source: Path + + +def iter_pyclasses(root: Path) -> Iterator[PyClass]: + for path in root.rglob("*.rs"): + text = path.read_text(encoding="utf8") + for match in PYCLASS_RE.finditer(text): + args = match.group("args") or "" + frozen = re.search(r"\bfrozen\b", args) is not None + + module = None + name = None + for arg_match in ARG_STRING_RE.finditer(args): + key = arg_match.group("key") + value = arg_match.group("value") + if key == "module": + module = value + elif key == "name": + name = value + + remainder = text[match.end() :] + struct_match = STRUCT_NAME_RE.search(remainder) + struct_name = struct_match.group("name") if struct_match else None + + yield PyClass( + module=module or "datafusion", + name=name or struct_name or "", + frozen=frozen, + source=path, + ) + + +def test_pyclasses_are_frozen() -> None: + allowlist = { + # NOTE: Any new exceptions must include a justification comment in the Rust source + # and, ideally, a follow-up issue to remove the exemption. + ("datafusion.common", "SqlTable"), + ("datafusion.common", "SqlView"), + ("datafusion.common", "DataTypeMap"), + ("datafusion.expr", "TryCast"), + ("datafusion.expr", "WriteOp"), + } + + unfrozen = [ + pyclass + for pyclass in iter_pyclasses(Path("src")) + if not pyclass.frozen and (pyclass.module, pyclass.name) not in allowlist + ] + + assert not unfrozen, ( + "Found pyclasses missing `frozen`; add them to the allowlist only with a " + "justification comment and follow-up plan:\n" + + "\n".join( + f"- {pyclass.module}.{pyclass.name} (defined in {pyclass.source})" + for pyclass in unfrozen + ) + ) From 6de60bc230c543de396225434671181adcda4dd2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 14:02:29 +0800 Subject: [PATCH 17/31] Add PyO3 class mutability guidelines reference to contributor guide --- docs/source/contributor-guide/ffi.rst | 61 +++++++++++++++++++ .../source/contributor-guide/introduction.rst | 4 ++ 2 files changed, 65 insertions(+) diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst index e201db71e..e8a0398b8 100644 --- a/docs/source/contributor-guide/ffi.rst +++ b/docs/source/contributor-guide/ffi.rst @@ -137,6 +137,67 @@ and you want to create a sharable FFI counterpart, you could write: let my_provider = MyTableProvider::default(); let ffi_provider = FFI_TableProvider::new(Arc::new(my_provider), false, None); +.. _ffi_pyclass_mutability: + +PyO3 class mutability guidelines +-------------------------------- + +PyO3 bindings should present immutable wrappers whenever a struct stores shared or +interior-mutable state. In practice this means that any ``#[pyclass]`` containing an +``Arc>`` or similar synchronized primitive must opt into ``#[pyclass(frozen)]`` +unless there is a compelling reason not to. + +The :mod:`datafusion` configuration helpers illustrate the preferred pattern. The +``PyConfig`` class in :file:`src/config.rs` stores an ``Arc>`` and is +explicitly frozen so callers interact with configuration state through provided methods +instead of mutating the container directly: + +.. code-block:: rust + + #[pyclass(name = "Config", module = "datafusion", subclass, frozen)] + #[derive(Clone)] + pub(crate) struct PyConfig { + config: Arc>, + } + +The same approach applies to execution contexts. ``PySessionContext`` in +:file:`src/context.rs` stays frozen even though it shares mutable state internally via +``SessionContext``. This ensures PyO3 tracks borrows correctly while Python-facing APIs +clone the inner ``SessionContext`` or return new wrappers instead of mutating the +existing instance in place: + +.. code-block:: rust + + #[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)] + #[derive(Clone)] + pub struct PySessionContext { + pub ctx: SessionContext, + } + +Occasionally a type must remain mutable—for example when PyO3 attribute setters need to +update fields directly. In these rare cases add an inline justification so reviewers and +future contributors understand why ``frozen`` is unsafe to enable. ``DataTypeMap`` in +:file:`src/common/data_type.rs` includes such a comment because PyO3 still needs to track +field updates: + +.. code-block:: rust + + // TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now + #[derive(Debug, Clone)] + #[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)] + pub struct DataTypeMap { + #[pyo3(get, set)] + pub arrow_type: PyDataType, + #[pyo3(get, set)] + pub python_type: PythonType, + #[pyo3(get, set)] + pub sql_type: SqlType, + } + +When reviewers encounter a mutable ``#[pyclass]`` without a comment, they should request +an explanation or ask that ``frozen`` be added. Keeping these wrappers frozen by default +helps avoid subtle bugs stemming from PyO3's interior mutability tracking. + If you were interfacing with a library that provided the above ``FFI_TableProvider`` and you needed to turn it back into an ``TableProvider``, you can turn it into a ``ForeignTableProvider`` with implements the ``TableProvider`` trait. diff --git a/docs/source/contributor-guide/introduction.rst b/docs/source/contributor-guide/introduction.rst index 6cb05c62d..33c2b274c 100644 --- a/docs/source/contributor-guide/introduction.rst +++ b/docs/source/contributor-guide/introduction.rst @@ -26,6 +26,10 @@ We welcome and encourage contributions of all kinds, such as: In addition to submitting new PRs, we have a healthy tradition of community members reviewing each other’s PRs. Doing so is a great way to help the community as well as get more familiar with Rust and the relevant codebases. +Before opening a pull request that touches PyO3 bindings, please review the +:ref:`PyO3 class mutability guidelines ` so you can flag missing +``#[pyclass(frozen)]`` annotations during development and review. + How to develop -------------- From b213bd4ce6e78d3c19ad66737838fa56b3b438d1 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 14:36:09 +0800 Subject: [PATCH 18/31] Mark boolean expression classes as frozen for immutability --- src/expr/bool_expr.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/expr/bool_expr.rs b/src/expr/bool_expr.rs index 22eabdb88..0d2b051e6 100644 --- a/src/expr/bool_expr.rs +++ b/src/expr/bool_expr.rs @@ -51,7 +51,7 @@ impl PyNot { } } -#[pyclass(name = "IsNotNull", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotNull", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotNull { expr: Expr, @@ -81,7 +81,7 @@ impl PyIsNotNull { } } -#[pyclass(name = "IsNull", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNull", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNull { expr: Expr, @@ -111,7 +111,7 @@ impl PyIsNull { } } -#[pyclass(name = "IsTrue", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsTrue", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsTrue { expr: Expr, @@ -141,7 +141,7 @@ impl PyIsTrue { } } -#[pyclass(name = "IsFalse", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsFalse", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsFalse { expr: Expr, @@ -171,7 +171,7 @@ impl PyIsFalse { } } -#[pyclass(name = "IsUnknown", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsUnknown", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsUnknown { expr: Expr, @@ -201,7 +201,7 @@ impl PyIsUnknown { } } -#[pyclass(name = "IsNotTrue", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotTrue", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotTrue { expr: Expr, @@ -231,7 +231,7 @@ impl PyIsNotTrue { } } -#[pyclass(name = "IsNotFalse", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotFalse", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotFalse { expr: Expr, @@ -261,7 +261,7 @@ impl PyIsNotFalse { } } -#[pyclass(name = "IsNotUnknown", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotUnknown", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotUnknown { expr: Expr, @@ -291,7 +291,7 @@ impl PyIsNotUnknown { } } -#[pyclass(name = "Negative", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Negative", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyNegative { expr: Expr, From 64faca278ca9079602d6ff8fb06b0eef08acefc8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 15:14:34 +0800 Subject: [PATCH 19/31] Refactor PyCaseBuilder methods to eliminate redundant take/store logic --- src/expr/conditional_expr.rs | 53 ++++++++++++------------------------ 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index c2c5a2a38..07a64dbd3 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -45,18 +45,6 @@ impl PyCaseBuilder { self.case_builder.lock() } - fn take_case_builder(&self) -> PyDataFusionResult { - let mut guard = self.lock_case_builder(); - guard.take().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - }) - } - - fn store_case_builder(&self, builder: CaseBuilder) { - let mut guard = self.lock_case_builder(); - *guard = Some(builder); - } - pub fn into_case_builder(self) -> PyDataFusionResult { let mut guard = self.case_builder.lock(); guard.take().ok_or_else(|| { @@ -68,37 +56,30 @@ impl PyCaseBuilder { #[pymethods] impl PyCaseBuilder { fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let mut builder = self.take_case_builder()?; + let mut guard = self.lock_case_builder(); + let builder = guard.as_mut().ok_or_else(|| { + PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) + })?; let next_builder = builder.when(when.expr, then.expr); - self.store_case_builder(builder); Ok(next_builder.into()) } fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { - let mut builder = self.take_case_builder()?; - match builder.otherwise(else_expr.expr) { - Ok(expr) => { - self.store_case_builder(builder); - Ok(expr.clone().into()) - } - Err(err) => { - self.store_case_builder(builder); - Err(err.into()) - } - } + let mut guard = self.lock_case_builder(); + let builder = guard.as_mut().ok_or_else(|| { + PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) + })?; + builder + .otherwise(else_expr.expr) + .map(|expr| expr.into()) + .map_err(Into::into) } fn end(&self) -> PyDataFusionResult { - let builder = self.take_case_builder()?; - match builder.end() { - Ok(expr) => { - self.store_case_builder(builder); - Ok(expr.clone().into()) - } - Err(err) => { - self.store_case_builder(builder); - Err(err.into()) - } - } + let mut guard = self.lock_case_builder(); + let builder = guard.as_mut().ok_or_else(|| { + PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) + })?; + builder.end().map(|expr| expr.into()).map_err(Into::into) } } From 5caec09e886d762eebc4b8439f25081a3e6f00ea Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 16:57:48 +0800 Subject: [PATCH 20/31] Refactor PyConfig methods to improve readability by encapsulating configuration reads --- src/config.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/config.rs b/src/config.rs index e45efc066..74d87de01 100644 --- a/src/config.rs +++ b/src/config.rs @@ -50,8 +50,12 @@ impl PyConfig { /// Get a configuration option pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult> { - let options = self.config.read(); - for entry in options.entries() { + let entries = { + let options = self.config.read(); + options.entries() + }; + + for entry in entries { if entry.key == key { return Ok(entry.value.into_pyobject(py)?); } @@ -69,10 +73,14 @@ impl PyConfig { /// Get all configuration options pub fn get_all(&self, py: Python) -> PyResult { + let entries = { + let options = self.config.read(); + options.entries() + }; + let dict = PyDict::new(py); - let options = self.config.read(); - for entry in options.entries() { - dict.set_item(entry.key, entry.value.clone().into_pyobject(py)?)?; + for entry in entries { + dict.set_item(entry.key, entry.value.into_pyobject(py)?)?; } Ok(dict.into()) } From a905154da94a39fcb14dd348a1e540de75d78abf Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 18:06:12 +0800 Subject: [PATCH 21/31] Resolve patch apply conflicts for CaseBuilder concurrency improvements - Added CaseBuilderHandle guard that keeps the underlying CaseBuilder alive while holding the mutex and restores it on drop - Updated when, otherwise, and end methods to operate through the guard and consume the builder explicitly - This prevents transient None states during concurrent access and improves thread safety --- python/tests/test_expr.py | 21 ++++++++++ src/expr/conditional_expr.rs | 80 +++++++++++++++++++++++++----------- 2 files changed, 76 insertions(+), 25 deletions(-) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index e925b0ee2..619b78603 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -16,6 +16,7 @@ # under the License. import re +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone import pyarrow as pa @@ -281,6 +282,26 @@ def test_case_builder_when_handles_are_independent(): ] +def test_case_builder_when_thread_safe(): + case_builder = functions.when(lit(_TRUE), lit(1)) + + def build_expr(value: int) -> bool: + builder = case_builder.when(lit(_TRUE), lit(value)) + builder.otherwise(lit(value)) + return True + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(build_expr, idx) for idx in range(16)] + results = [future.result() for future in futures] + + assert all(results) + + # Ensure the shared builder remains usable after concurrent `when` calls. + follow_up_builder = case_builder.when(lit(_TRUE), lit(42)) + assert isinstance(follow_up_builder, type(case_builder)) + follow_up_builder.otherwise(lit(7)) + + def test_expr_getitem() -> None: ctx = SessionContext() data = { diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index 07a64dbd3..ea677944e 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -22,9 +22,46 @@ use crate::{ expr::PyExpr, }; use datafusion::logical_expr::conditional_expressions::CaseBuilder; +use parking_lot::{Mutex, MutexGuard}; use pyo3::prelude::*; -use parking_lot::{Mutex, MutexGuard}; +struct CaseBuilderHandle<'a> { + guard: MutexGuard<'a, Option>, + builder: Option, +} + +impl<'a> CaseBuilderHandle<'a> { + fn new(mut guard: MutexGuard<'a, Option>) -> PyDataFusionResult { + let builder = guard.take().ok_or_else(|| { + PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) + })?; + + Ok(Self { + guard, + builder: Some(builder), + }) + } + + fn builder_mut(&mut self) -> &mut CaseBuilder { + self.builder + .as_mut() + .expect("builder should be present while handle is alive") + } + + fn into_inner(mut self) -> CaseBuilder { + self.builder + .take() + .expect("builder should be present when consuming handle") + } +} + +impl Drop for CaseBuilderHandle<'_> { + fn drop(&mut self) { + if let Some(builder) = self.builder.take() { + *self.guard = Some(builder); + } + } +} #[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] #[derive(Clone)] @@ -41,45 +78,38 @@ impl From for PyCaseBuilder { } impl PyCaseBuilder { - fn lock_case_builder(&self) -> MutexGuard<'_, Option> { - self.case_builder.lock() + fn case_builder_handle(&self) -> PyDataFusionResult> { + let guard = self.case_builder.lock(); + CaseBuilderHandle::new(guard) } pub fn into_case_builder(self) -> PyDataFusionResult { - let mut guard = self.case_builder.lock(); - guard.take().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - }) + let guard = self.case_builder.lock(); + CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner) } } #[pymethods] impl PyCaseBuilder { fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let mut guard = self.lock_case_builder(); - let builder = guard.as_mut().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - })?; - let next_builder = builder.when(when.expr, then.expr); + let mut handle = self.case_builder_handle()?; + let next_builder = handle.builder_mut().when(when.expr, then.expr); Ok(next_builder.into()) } fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { - let mut guard = self.lock_case_builder(); - let builder = guard.as_mut().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - })?; - builder - .otherwise(else_expr.expr) - .map(|expr| expr.into()) - .map_err(Into::into) + let mut handle = self.case_builder_handle()?; + match handle.builder_mut().otherwise(else_expr.expr) { + Ok(expr) => Ok(expr.clone().into()), + Err(err) => Err(err.into()), + } } fn end(&self) -> PyDataFusionResult { - let mut guard = self.lock_case_builder(); - let builder = guard.as_mut().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - })?; - builder.end().map(|expr| expr.into()).map_err(Into::into) + let mut handle = self.case_builder_handle()?; + match handle.builder_mut().end() { + Ok(expr) => Ok(expr.clone().into()), + Err(err) => Err(err.into()), + } } } From 428839df63a085f1b63c850e886d03dcc29addab Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 18:13:43 +0800 Subject: [PATCH 22/31] Resolve Config optimization conflicts for improved read/write concurrency - Released Config read guard before converting values to Python objects in get and get_all - Ensures locks are held only while collecting scalar entries, not during expensive Python object conversion - Added regression test that runs Config.get_all and Config.set concurrently to guard against read/write contention regressions - Improves overall performance by reducing lock contention in multi-threaded scenarios --- python/tests/test_concurrency.py | 24 ++++++++++++++++++++++++ src/config.rs | 27 +++++++++++++++++---------- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/python/tests/test_concurrency.py b/python/tests/test_concurrency.py index 65fc5cb2c..f443f63f7 100644 --- a/python/tests/test_concurrency.py +++ b/python/tests/test_concurrency.py @@ -65,6 +65,30 @@ def worker(index: int) -> None: _run_in_threads(worker, count=12) +def test_config_set_during_get_all() -> None: + """Ensure config writes proceed while another thread reads all entries.""" + + config = Config() + key = "datafusion.execution.batch_size" + + def reader() -> None: + for _ in range(200): + # get_all should not hold the lock while converting to Python objects + config.get_all() + + def writer() -> None: + for index in range(200): + config.set(key, str(1024 + index)) + + with ThreadPoolExecutor(max_workers=2) as executor: + reader_future = executor.submit(reader) + writer_future = executor.submit(writer) + reader_future.result(timeout=10) + writer_future.result(timeout=10) + + assert config.get(key) is not None + + def test_case_builder_reuse_from_multiple_threads() -> None: """Ensure the case builder can be safely reused across threads.""" diff --git a/src/config.rs b/src/config.rs index 74d87de01..3a2060776 100644 --- a/src/config.rs +++ b/src/config.rs @@ -50,17 +50,20 @@ impl PyConfig { /// Get a configuration option pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult> { - let entries = { + let value = { let options = self.config.read(); - options.entries() + options + .entries() + .iter() + .find(|entry| entry.key == key) + .map(|entry| entry.value.clone()) }; - for entry in entries { - if entry.key == key { - return Ok(entry.value.into_pyobject(py)?); - } + if let Some(value) = value { + Ok(value.into_pyobject(py)?) + } else { + Ok(None::.into_pyobject(py)?) } - Ok(None::.into_pyobject(py)?) } /// Set a configuration option @@ -75,12 +78,16 @@ impl PyConfig { pub fn get_all(&self, py: Python) -> PyResult { let entries = { let options = self.config.read(); - options.entries() + options + .entries() + .into_iter() + .map(|entry| (entry.key.to_string(), entry.value.clone())) + .collect::>() }; let dict = PyDict::new(py); - for entry in entries { - dict.set_item(entry.key, entry.value.into_pyobject(py)?)?; + for (key, value) in entries { + dict.set_item(key, value.into_pyobject(py)?)?; } Ok(dict.into()) } From 34a60785eb2db1dcf585a0f10949547cc205f1b0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 29 Sep 2025 19:28:08 +0800 Subject: [PATCH 23/31] Refactor PyConfig get methods for improved readability and performance --- src/config.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/config.rs b/src/config.rs index 3a2060776..1726e5d9b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -50,19 +50,17 @@ impl PyConfig { /// Get a configuration option pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult> { - let value = { + let value: Option> = { let options = self.config.read(); options .entries() - .iter() - .find(|entry| entry.key == key) - .map(|entry| entry.value.clone()) + .into_iter() + .find_map(|entry| (entry.key == key).then_some(entry.value.clone())) }; - if let Some(value) = value { - Ok(value.into_pyobject(py)?) - } else { - Ok(None::.into_pyobject(py)?) + match value { + Some(value) => Ok(value.into_pyobject(py)?), + None => Ok(None::.into_pyobject(py)?), } } @@ -76,13 +74,13 @@ impl PyConfig { /// Get all configuration options pub fn get_all(&self, py: Python) -> PyResult { - let entries = { + let entries: Vec<(String, Option)> = { let options = self.config.read(); options .entries() .into_iter() - .map(|entry| (entry.key.to_string(), entry.value.clone())) - .collect::>() + .map(|entry| (entry.key.clone(), entry.value.clone())) + .collect() }; let dict = PyDict::new(py); From 09d9ab8225fa5b31311807fd64f295bd122f4639 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 1 Oct 2025 21:26:59 +0800 Subject: [PATCH 24/31] Refactor test_expr.py to replace positional boolean literals with named constants for improved linting compliance --- pyproject.toml | 3 +++ python/tests/test_expr.py | 14 +++++--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index edecc4588..69d31ec9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,9 @@ convention = "google" [tool.ruff.lint.pycodestyle] max-doc-length = 88 +[tool.ruff.lint.flake8-boolean-trap] +extend-allowed-calls = ["lit", "datafusion.lit"] + # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] "python/tests/*" = [ diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 619b78603..0d459f5d4 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -54,10 +54,6 @@ ensure_expr_list, ) -# Avoid passing boolean literals positionally (FBT003). Use a named constant -# so linters don't see a bare True/False literal in a function call. -_TRUE = True - @pytest.fixture def test_ctx(): @@ -206,7 +202,7 @@ def traverse_logical_plan(plan): def test_case_builder_error_preserves_builder_state(): - case_builder = functions.when(lit(_TRUE), lit(1)) + case_builder = functions.when(lit(True), lit(1)) with pytest.raises(Exception) as exc_info: case_builder.otherwise(lit("bad")) @@ -261,7 +257,7 @@ def test_case_builder_when_handles_are_independent(): first_builder = base_builder.when(col("value") > lit(10), lit("gt10")) second_builder = base_builder.when(col("value") > lit(20), lit("gt20")) - first_builder = first_builder.when(lit(_TRUE), lit("final-one")) + first_builder = first_builder.when(lit(True), lit("final-one")) expr_first = first_builder.otherwise(lit("fallback-one")).alias("first") expr_second = second_builder.otherwise(lit("fallback-two")).alias("second") @@ -283,10 +279,10 @@ def test_case_builder_when_handles_are_independent(): def test_case_builder_when_thread_safe(): - case_builder = functions.when(lit(_TRUE), lit(1)) + case_builder = functions.when(lit(True), lit(1)) def build_expr(value: int) -> bool: - builder = case_builder.when(lit(_TRUE), lit(value)) + builder = case_builder.when(lit(True), lit(value)) builder.otherwise(lit(value)) return True @@ -297,7 +293,7 @@ def build_expr(value: int) -> bool: assert all(results) # Ensure the shared builder remains usable after concurrent `when` calls. - follow_up_builder = case_builder.when(lit(_TRUE), lit(42)) + follow_up_builder = case_builder.when(lit(True), lit(42)) assert isinstance(follow_up_builder, type(case_builder)) follow_up_builder.otherwise(lit(7)) From 2df2f5f654d7b18fa3163a57b1aa380e2f0cee06 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 1 Oct 2025 21:50:31 +0800 Subject: [PATCH 25/31] fix ruff errors --- python/tests/test_pyclass_frozen.py | 33 +++++++++++++++++++---------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/python/tests/test_pyclass_frozen.py b/python/tests/test_pyclass_frozen.py index e677b3727..4280204b4 100644 --- a/python/tests/test_pyclass_frozen.py +++ b/python/tests/test_pyclass_frozen.py @@ -7,9 +7,17 @@ from pathlib import Path from typing import Iterator -PYCLASS_RE = re.compile(r"#\[\s*pyclass\s*(?:\((?P.*?)\))?\s*\]", re.DOTALL) -ARG_STRING_RE = re.compile(r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P[^\"]+)\"") -STRUCT_NAME_RE = re.compile(r"\b(?:pub\s+)?(?:struct|enum)\s+(?P[A-Za-z_][A-Za-z0-9_]*)") +PYCLASS_RE = re.compile( + r"#\[\s*pyclass\s*(?:\((?P.*?)\))?\s*\]", + re.DOTALL, +) +ARG_STRING_RE = re.compile( + r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P[^\"]+)\"", +) +STRUCT_NAME_RE = re.compile( + r"\b(?:pub\s+)?(?:struct|enum)\s+" + r"(?P[A-Za-z_][A-Za-z0-9_]*)", +) @dataclass @@ -51,8 +59,9 @@ def iter_pyclasses(root: Path) -> Iterator[PyClass]: def test_pyclasses_are_frozen() -> None: allowlist = { - # NOTE: Any new exceptions must include a justification comment in the Rust source - # and, ideally, a follow-up issue to remove the exemption. + # NOTE: Any new exceptions must include a justification comment + # in the Rust source and, ideally, a follow-up issue to remove + # the exemption. ("datafusion.common", "SqlTable"), ("datafusion.common", "SqlView"), ("datafusion.common", "DataTypeMap"), @@ -66,11 +75,13 @@ def test_pyclasses_are_frozen() -> None: if not pyclass.frozen and (pyclass.module, pyclass.name) not in allowlist ] - assert not unfrozen, ( - "Found pyclasses missing `frozen`; add them to the allowlist only with a " - "justification comment and follow-up plan:\n" + - "\n".join( - f"- {pyclass.module}.{pyclass.name} (defined in {pyclass.source})" + if unfrozen: + msg = ( + "Found pyclasses missing `frozen`; add them to the allowlist only " + "with a justification comment and follow-up plan:\n" + ) + msg += "\n".join( + (f"- {pyclass.module}.{pyclass.name} (defined in {pyclass.source})") for pyclass in unfrozen ) - ) + assert not unfrozen, msg From 2c76271da0eb6de9a4dfbb4958064be9280641b8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 1 Oct 2025 21:51:57 +0800 Subject: [PATCH 26/31] Add license header to test_pyclass_frozen.py for compliance --- python/tests/test_pyclass_frozen.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tests/test_pyclass_frozen.py b/python/tests/test_pyclass_frozen.py index 4280204b4..189ea8dec 100644 --- a/python/tests/test_pyclass_frozen.py +++ b/python/tests/test_pyclass_frozen.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Ensure exposed pyclasses default to frozen.""" from __future__ import annotations From 8a52e23e9a5ceee2b5f06abfe6834dc597951cba Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 4 Oct 2025 10:50:33 -0400 Subject: [PATCH 27/31] Alternate approach to case expression --- src/expr/conditional_expr.rs | 100 ++++++++++------------------------- 1 file changed, 28 insertions(+), 72 deletions(-) diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index ea677944e..4b9fa57e3 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -15,101 +15,57 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use crate::{ - errors::{PyDataFusionError, PyDataFusionResult}, - expr::PyExpr, -}; +use crate::{errors::PyDataFusionResult, expr::PyExpr}; +use datafusion::common::{exec_err, DataFusionError}; use datafusion::logical_expr::conditional_expressions::CaseBuilder; -use parking_lot::{Mutex, MutexGuard}; +use datafusion::prelude::Expr; use pyo3::prelude::*; -struct CaseBuilderHandle<'a> { - guard: MutexGuard<'a, Option>, - builder: Option, -} - -impl<'a> CaseBuilderHandle<'a> { - fn new(mut guard: MutexGuard<'a, Option>) -> PyDataFusionResult { - let builder = guard.take().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - })?; - - Ok(Self { - guard, - builder: Some(builder), - }) - } - - fn builder_mut(&mut self) -> &mut CaseBuilder { - self.builder - .as_mut() - .expect("builder should be present while handle is alive") - } - - fn into_inner(mut self) -> CaseBuilder { - self.builder - .take() - .expect("builder should be present when consuming handle") - } -} - -impl Drop for CaseBuilderHandle<'_> { - fn drop(&mut self) { - if let Some(builder) = self.builder.take() { - *self.guard = Some(builder); - } - } -} - #[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] -#[derive(Clone)] pub struct PyCaseBuilder { - case_builder: Arc>>, + case_builder: CaseBuilder, } impl From for PyCaseBuilder { fn from(case_builder: CaseBuilder) -> PyCaseBuilder { - PyCaseBuilder { - case_builder: Arc::new(Mutex::new(Some(case_builder))), - } + PyCaseBuilder { case_builder } } } -impl PyCaseBuilder { - fn case_builder_handle(&self) -> PyDataFusionResult> { - let guard = self.case_builder.lock(); - CaseBuilderHandle::new(guard) - } +// TODO(tsaucer) upstream make CaseBuilder impl Clone +fn builder_clone(case_builder: &CaseBuilder) -> Result { + let Expr::Case(case) = case_builder.end()? else { + return exec_err!("CaseBuilder returned an invalid expression"); + }; - pub fn into_case_builder(self) -> PyDataFusionResult { - let guard = self.case_builder.lock(); - CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner) - } + let (when_expr, then_expr) = case + .when_then_expr + .iter() + .map(|(w, t)| (w.as_ref().to_owned(), t.as_ref().to_owned())) + .unzip(); + + Ok(CaseBuilder::new( + case.expr, + when_expr, + then_expr, + case.else_expr, + )) } #[pymethods] impl PyCaseBuilder { fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - let next_builder = handle.builder_mut().when(when.expr, then.expr); - Ok(next_builder.into()) + let case_builder = builder_clone(&self.case_builder)?.when(when.expr, then.expr); + Ok(PyCaseBuilder { case_builder }) } fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - match handle.builder_mut().otherwise(else_expr.expr) { - Ok(expr) => Ok(expr.clone().into()), - Err(err) => Err(err.into()), - } + Ok(builder_clone(&self.case_builder)? + .otherwise(else_expr.expr)? + .into()) } fn end(&self) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - match handle.builder_mut().end() { - Ok(expr) => Ok(expr.clone().into()), - Err(err) => Err(err.into()), - } + Ok(builder_clone(&self.case_builder)?.end()?.into()) } } From 1b97b41735d625563e515417f62fe33f0312bbb9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 4 Oct 2025 16:42:28 -0400 Subject: [PATCH 28/31] Replace case builter with keeping the expressions and then applying as required --- src/expr/conditional_expr.rs | 77 ++++++++++++++++++++---------------- src/functions.rs | 4 +- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index 4b9fa57e3..bf9ae9287 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -16,56 +16,65 @@ // under the License. use crate::{errors::PyDataFusionResult, expr::PyExpr}; -use datafusion::common::{exec_err, DataFusionError}; use datafusion::logical_expr::conditional_expressions::CaseBuilder; use datafusion::prelude::Expr; use pyo3::prelude::*; +// TODO(tsaucer) replace this all with CaseBuilder after it implements Clone +#[derive(Clone, Debug)] #[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] pub struct PyCaseBuilder { - case_builder: CaseBuilder, + expr: Option, + when: Vec, + then: Vec, } -impl From for PyCaseBuilder { - fn from(case_builder: CaseBuilder) -> PyCaseBuilder { - PyCaseBuilder { case_builder } +#[pymethods] +impl PyCaseBuilder { + #[new] + pub fn new(expr: Option) -> Self { + Self { + expr: expr.map(Into::into), + when: vec![], + then: vec![], + } } -} - -// TODO(tsaucer) upstream make CaseBuilder impl Clone -fn builder_clone(case_builder: &CaseBuilder) -> Result { - let Expr::Case(case) = case_builder.end()? else { - return exec_err!("CaseBuilder returned an invalid expression"); - }; - let (when_expr, then_expr) = case - .when_then_expr - .iter() - .map(|(w, t)| (w.as_ref().to_owned(), t.as_ref().to_owned())) - .unzip(); + pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { + println!("when called {self:?}"); + let mut case_builder = self.clone(); + case_builder.when.push(when.into()); + case_builder.then.push(then.into()); - Ok(CaseBuilder::new( - case.expr, - when_expr, - then_expr, - case.else_expr, - )) -} - -#[pymethods] -impl PyCaseBuilder { - fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let case_builder = builder_clone(&self.case_builder)?.when(when.expr, then.expr); - Ok(PyCaseBuilder { case_builder }) + case_builder } fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { - Ok(builder_clone(&self.case_builder)? - .otherwise(else_expr.expr)? - .into()) + println!("otherwise called {self:?}"); + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + Some(Box::new(else_expr.into())), + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) } fn end(&self) -> PyDataFusionResult { - Ok(builder_clone(&self.case_builder)?.end()?.into()) + println!("end called {self:?}"); + + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + None, + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) } } diff --git a/src/functions.rs b/src/functions.rs index 0f9fdf698..5956b67cf 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -230,13 +230,13 @@ fn col(name: &str) -> PyResult { /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn case(expr: PyExpr) -> PyResult { - Ok(datafusion::logical_expr::case(expr.expr).into()) + Ok(PyCaseBuilder::new(Some(expr))) } /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn when(when: PyExpr, then: PyExpr) -> PyResult { - Ok(datafusion::logical_expr::when(when.expr, then.expr).into()) + Ok(PyCaseBuilder::new(None).when(when, then)) } /// Helper function to find the appropriate window function. From fc27bd508e768e6fa4dba795cc2338502787dbe0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 4 Oct 2025 16:43:34 -0400 Subject: [PATCH 29/31] Update unit tests --- python/tests/test_concurrency.py | 3 ++- python/tests/test_expr.py | 15 +++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/python/tests/test_concurrency.py b/python/tests/test_concurrency.py index f443f63f7..f790f9473 100644 --- a/python/tests/test_concurrency.py +++ b/python/tests/test_concurrency.py @@ -99,7 +99,8 @@ def test_case_builder_reuse_from_multiple_threads() -> None: base_builder = f.case(col("value")) def add_case(i: int) -> None: - base_builder.when(lit(i), lit(f"value-{i}")) + nonlocal base_builder + base_builder = base_builder.when(lit(i), lit(f"value-{i}")) _run_in_threads(add_case, count=8) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 0d459f5d4..7847826ac 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -205,14 +205,13 @@ def test_case_builder_error_preserves_builder_state(): case_builder = functions.when(lit(True), lit(1)) with pytest.raises(Exception) as exc_info: - case_builder.otherwise(lit("bad")) + _ = case_builder.otherwise(lit("bad")) err_msg = str(exc_info.value) assert "multiple data types" in err_msg assert "CaseBuilder has already been consumed" not in err_msg - with pytest.raises(Exception) as exc_info: - case_builder.end() + _ = case_builder.end() err_msg = str(exc_info.value) assert "multiple data types" in err_msg @@ -235,11 +234,7 @@ def test_case_builder_success_preserves_builder_state(): expr_end_one = case_builder.end().alias("result") end_one = df.select(expr_end_one).collect() - assert end_one[0].column(0).to_pylist() == ["default-2"] - - expr_end_two = case_builder.end().alias("result") - end_two = df.select(expr_end_two).collect() - assert end_two[0].column(0).to_pylist() == ["default-2"] + assert end_one[0].column(0).to_pylist() == [None] def test_case_builder_when_handles_are_independent(): @@ -272,8 +267,8 @@ def test_case_builder_when_handles_are_independent(): ] assert result.column(1).to_pylist() == [ "flag-true", - "gt10", - "gt10", + "fallback-two", + "gt20", "fallback-two", ] From d247d6479ac627e0877471b4ac8a706d289142eb Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 6 Oct 2025 16:11:11 +0800 Subject: [PATCH 30/31] Refactor case and when functions to utilize PyCaseBuilder for improved clarity and functionality --- python/tests/test_concurrency.py | 3 +- python/tests/test_expr.py | 15 ++-- src/expr/conditional_expr.rs | 123 +++++++++++-------------------- src/functions.rs | 4 +- 4 files changed, 53 insertions(+), 92 deletions(-) diff --git a/python/tests/test_concurrency.py b/python/tests/test_concurrency.py index f443f63f7..f790f9473 100644 --- a/python/tests/test_concurrency.py +++ b/python/tests/test_concurrency.py @@ -99,7 +99,8 @@ def test_case_builder_reuse_from_multiple_threads() -> None: base_builder = f.case(col("value")) def add_case(i: int) -> None: - base_builder.when(lit(i), lit(f"value-{i}")) + nonlocal base_builder + base_builder = base_builder.when(lit(i), lit(f"value-{i}")) _run_in_threads(add_case, count=8) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 0d459f5d4..7847826ac 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -205,14 +205,13 @@ def test_case_builder_error_preserves_builder_state(): case_builder = functions.when(lit(True), lit(1)) with pytest.raises(Exception) as exc_info: - case_builder.otherwise(lit("bad")) + _ = case_builder.otherwise(lit("bad")) err_msg = str(exc_info.value) assert "multiple data types" in err_msg assert "CaseBuilder has already been consumed" not in err_msg - with pytest.raises(Exception) as exc_info: - case_builder.end() + _ = case_builder.end() err_msg = str(exc_info.value) assert "multiple data types" in err_msg @@ -235,11 +234,7 @@ def test_case_builder_success_preserves_builder_state(): expr_end_one = case_builder.end().alias("result") end_one = df.select(expr_end_one).collect() - assert end_one[0].column(0).to_pylist() == ["default-2"] - - expr_end_two = case_builder.end().alias("result") - end_two = df.select(expr_end_two).collect() - assert end_two[0].column(0).to_pylist() == ["default-2"] + assert end_one[0].column(0).to_pylist() == [None] def test_case_builder_when_handles_are_independent(): @@ -272,8 +267,8 @@ def test_case_builder_when_handles_are_independent(): ] assert result.column(1).to_pylist() == [ "flag-true", - "gt10", - "gt10", + "fallback-two", + "gt20", "fallback-two", ] diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index ea677944e..bf9ae9287 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -15,101 +15,66 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use crate::{ - errors::{PyDataFusionError, PyDataFusionResult}, - expr::PyExpr, -}; +use crate::{errors::PyDataFusionResult, expr::PyExpr}; use datafusion::logical_expr::conditional_expressions::CaseBuilder; -use parking_lot::{Mutex, MutexGuard}; +use datafusion::prelude::Expr; use pyo3::prelude::*; -struct CaseBuilderHandle<'a> { - guard: MutexGuard<'a, Option>, - builder: Option, -} - -impl<'a> CaseBuilderHandle<'a> { - fn new(mut guard: MutexGuard<'a, Option>) -> PyDataFusionResult { - let builder = guard.take().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - })?; - - Ok(Self { - guard, - builder: Some(builder), - }) - } - - fn builder_mut(&mut self) -> &mut CaseBuilder { - self.builder - .as_mut() - .expect("builder should be present while handle is alive") - } - - fn into_inner(mut self) -> CaseBuilder { - self.builder - .take() - .expect("builder should be present when consuming handle") - } -} - -impl Drop for CaseBuilderHandle<'_> { - fn drop(&mut self) { - if let Some(builder) = self.builder.take() { - *self.guard = Some(builder); - } - } -} - +// TODO(tsaucer) replace this all with CaseBuilder after it implements Clone +#[derive(Clone, Debug)] #[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] -#[derive(Clone)] pub struct PyCaseBuilder { - case_builder: Arc>>, -} - -impl From for PyCaseBuilder { - fn from(case_builder: CaseBuilder) -> PyCaseBuilder { - PyCaseBuilder { - case_builder: Arc::new(Mutex::new(Some(case_builder))), - } - } + expr: Option, + when: Vec, + then: Vec, } +#[pymethods] impl PyCaseBuilder { - fn case_builder_handle(&self) -> PyDataFusionResult> { - let guard = self.case_builder.lock(); - CaseBuilderHandle::new(guard) + #[new] + pub fn new(expr: Option) -> Self { + Self { + expr: expr.map(Into::into), + when: vec![], + then: vec![], + } } - pub fn into_case_builder(self) -> PyDataFusionResult { - let guard = self.case_builder.lock(); - CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner) - } -} + pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { + println!("when called {self:?}"); + let mut case_builder = self.clone(); + case_builder.when.push(when.into()); + case_builder.then.push(then.into()); -#[pymethods] -impl PyCaseBuilder { - fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - let next_builder = handle.builder_mut().when(when.expr, then.expr); - Ok(next_builder.into()) + case_builder } fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - match handle.builder_mut().otherwise(else_expr.expr) { - Ok(expr) => Ok(expr.clone().into()), - Err(err) => Err(err.into()), - } + println!("otherwise called {self:?}"); + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + Some(Box::new(else_expr.into())), + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) } fn end(&self) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - match handle.builder_mut().end() { - Ok(expr) => Ok(expr.clone().into()), - Err(err) => Err(err.into()), - } + println!("end called {self:?}"); + + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + None, + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) } } diff --git a/src/functions.rs b/src/functions.rs index 0f9fdf698..5956b67cf 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -230,13 +230,13 @@ fn col(name: &str) -> PyResult { /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn case(expr: PyExpr) -> PyResult { - Ok(datafusion::logical_expr::case(expr.expr).into()) + Ok(PyCaseBuilder::new(Some(expr))) } /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn when(when: PyExpr, then: PyExpr) -> PyResult { - Ok(datafusion::logical_expr::when(when.expr, then.expr).into()) + Ok(PyCaseBuilder::new(None).when(when, then)) } /// Helper function to find the appropriate window function. From 9536e02ac58877372c82d7163a0370d0b94f08dd Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Oct 2025 07:05:22 -0400 Subject: [PATCH 31/31] Update src/expr/conditional_expr.rs --- src/expr/conditional_expr.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index bf9ae9287..27297295b 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -41,7 +41,6 @@ impl PyCaseBuilder { } pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { - println!("when called {self:?}"); let mut case_builder = self.clone(); case_builder.when.push(when.into()); case_builder.then.push(then.into());