diff --git a/python/tests/test_concurrency.py b/python/tests/test_concurrency.py index f443f63f7..f790f9473 100644 --- a/python/tests/test_concurrency.py +++ b/python/tests/test_concurrency.py @@ -99,7 +99,8 @@ def test_case_builder_reuse_from_multiple_threads() -> None: base_builder = f.case(col("value")) def add_case(i: int) -> None: - base_builder.when(lit(i), lit(f"value-{i}")) + nonlocal base_builder + base_builder = base_builder.when(lit(i), lit(f"value-{i}")) _run_in_threads(add_case, count=8) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 0d459f5d4..7847826ac 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -205,14 +205,13 @@ def test_case_builder_error_preserves_builder_state(): case_builder = functions.when(lit(True), lit(1)) with pytest.raises(Exception) as exc_info: - case_builder.otherwise(lit("bad")) + _ = case_builder.otherwise(lit("bad")) err_msg = str(exc_info.value) assert "multiple data types" in err_msg assert "CaseBuilder has already been consumed" not in err_msg - with pytest.raises(Exception) as exc_info: - case_builder.end() + _ = case_builder.end() err_msg = str(exc_info.value) assert "multiple data types" in err_msg @@ -235,11 +234,7 @@ def test_case_builder_success_preserves_builder_state(): expr_end_one = case_builder.end().alias("result") end_one = df.select(expr_end_one).collect() - assert end_one[0].column(0).to_pylist() == ["default-2"] - - expr_end_two = case_builder.end().alias("result") - end_two = df.select(expr_end_two).collect() - assert end_two[0].column(0).to_pylist() == ["default-2"] + assert end_one[0].column(0).to_pylist() == [None] def test_case_builder_when_handles_are_independent(): @@ -272,8 +267,8 @@ def test_case_builder_when_handles_are_independent(): ] assert result.column(1).to_pylist() == [ "flag-true", - "gt10", - "gt10", + "fallback-two", + "gt20", "fallback-two", ] diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index ea677944e..27297295b 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -15,101 +15,65 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use crate::{ - errors::{PyDataFusionError, PyDataFusionResult}, - expr::PyExpr, -}; +use crate::{errors::PyDataFusionResult, expr::PyExpr}; use datafusion::logical_expr::conditional_expressions::CaseBuilder; -use parking_lot::{Mutex, MutexGuard}; +use datafusion::prelude::Expr; use pyo3::prelude::*; -struct CaseBuilderHandle<'a> { - guard: MutexGuard<'a, Option>, - builder: Option, -} - -impl<'a> CaseBuilderHandle<'a> { - fn new(mut guard: MutexGuard<'a, Option>) -> PyDataFusionResult { - let builder = guard.take().ok_or_else(|| { - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) - })?; - - Ok(Self { - guard, - builder: Some(builder), - }) - } - - fn builder_mut(&mut self) -> &mut CaseBuilder { - self.builder - .as_mut() - .expect("builder should be present while handle is alive") - } - - fn into_inner(mut self) -> CaseBuilder { - self.builder - .take() - .expect("builder should be present when consuming handle") - } -} - -impl Drop for CaseBuilderHandle<'_> { - fn drop(&mut self) { - if let Some(builder) = self.builder.take() { - *self.guard = Some(builder); - } - } -} - +// TODO(tsaucer) replace this all with CaseBuilder after it implements Clone +#[derive(Clone, Debug)] #[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] -#[derive(Clone)] pub struct PyCaseBuilder { - case_builder: Arc>>, -} - -impl From for PyCaseBuilder { - fn from(case_builder: CaseBuilder) -> PyCaseBuilder { - PyCaseBuilder { - case_builder: Arc::new(Mutex::new(Some(case_builder))), - } - } + expr: Option, + when: Vec, + then: Vec, } +#[pymethods] impl PyCaseBuilder { - fn case_builder_handle(&self) -> PyDataFusionResult> { - let guard = self.case_builder.lock(); - CaseBuilderHandle::new(guard) + #[new] + pub fn new(expr: Option) -> Self { + Self { + expr: expr.map(Into::into), + when: vec![], + then: vec![], + } } - pub fn into_case_builder(self) -> PyDataFusionResult { - let guard = self.case_builder.lock(); - CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner) - } -} + pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { + let mut case_builder = self.clone(); + case_builder.when.push(when.into()); + case_builder.then.push(then.into()); -#[pymethods] -impl PyCaseBuilder { - fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - let next_builder = handle.builder_mut().when(when.expr, then.expr); - Ok(next_builder.into()) + case_builder } fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - match handle.builder_mut().otherwise(else_expr.expr) { - Ok(expr) => Ok(expr.clone().into()), - Err(err) => Err(err.into()), - } + println!("otherwise called {self:?}"); + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + Some(Box::new(else_expr.into())), + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) } fn end(&self) -> PyDataFusionResult { - let mut handle = self.case_builder_handle()?; - match handle.builder_mut().end() { - Ok(expr) => Ok(expr.clone().into()), - Err(err) => Err(err.into()), - } + println!("end called {self:?}"); + + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + None, + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) } } diff --git a/src/functions.rs b/src/functions.rs index 0f9fdf698..5956b67cf 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -230,13 +230,13 @@ fn col(name: &str) -> PyResult { /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn case(expr: PyExpr) -> PyResult { - Ok(datafusion::logical_expr::case(expr.expr).into()) + Ok(PyCaseBuilder::new(Some(expr))) } /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn when(when: PyExpr, then: PyExpr) -> PyResult { - Ok(datafusion::logical_expr::when(when.expr, then.expr).into()) + Ok(PyCaseBuilder::new(None).when(when, then)) } /// Helper function to find the appropriate window function.