Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def test_case_builder_reuse_from_multiple_threads() -> None:
base_builder = f.case(col("value"))

def add_case(i: int) -> None:
base_builder.when(lit(i), lit(f"value-{i}"))
nonlocal base_builder
base_builder = base_builder.when(lit(i), lit(f"value-{i}"))

_run_in_threads(add_case, count=8)

Expand Down
15 changes: 5 additions & 10 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,13 @@ def test_case_builder_error_preserves_builder_state():
case_builder = functions.when(lit(True), lit(1))

with pytest.raises(Exception) as exc_info:
case_builder.otherwise(lit("bad"))
_ = case_builder.otherwise(lit("bad"))

err_msg = str(exc_info.value)
assert "multiple data types" in err_msg
assert "CaseBuilder has already been consumed" not in err_msg

with pytest.raises(Exception) as exc_info:
case_builder.end()
_ = case_builder.end()

err_msg = str(exc_info.value)
assert "multiple data types" in err_msg
Expand All @@ -235,11 +234,7 @@ def test_case_builder_success_preserves_builder_state():

expr_end_one = case_builder.end().alias("result")
end_one = df.select(expr_end_one).collect()
assert end_one[0].column(0).to_pylist() == ["default-2"]

expr_end_two = case_builder.end().alias("result")
end_two = df.select(expr_end_two).collect()
assert end_two[0].column(0).to_pylist() == ["default-2"]
assert end_one[0].column(0).to_pylist() == [None]


def test_case_builder_when_handles_are_independent():
Expand Down Expand Up @@ -272,8 +267,8 @@ def test_case_builder_when_handles_are_independent():
]
assert result.column(1).to_pylist() == [
"flag-true",
"gt10",
"gt10",
"fallback-two",
"gt20",
"fallback-two",
]

Expand Down
122 changes: 43 additions & 79 deletions src/expr/conditional_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,101 +15,65 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::{
errors::{PyDataFusionError, PyDataFusionResult},
expr::PyExpr,
};
use crate::{errors::PyDataFusionResult, expr::PyExpr};
use datafusion::logical_expr::conditional_expressions::CaseBuilder;
use parking_lot::{Mutex, MutexGuard};
use datafusion::prelude::Expr;
use pyo3::prelude::*;

struct CaseBuilderHandle<'a> {
guard: MutexGuard<'a, Option<CaseBuilder>>,
builder: Option<CaseBuilder>,
}

impl<'a> CaseBuilderHandle<'a> {
fn new(mut guard: MutexGuard<'a, Option<CaseBuilder>>) -> PyDataFusionResult<Self> {
let builder = guard.take().ok_or_else(|| {
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
})?;

Ok(Self {
guard,
builder: Some(builder),
})
}

fn builder_mut(&mut self) -> &mut CaseBuilder {
self.builder
.as_mut()
.expect("builder should be present while handle is alive")
}

fn into_inner(mut self) -> CaseBuilder {
self.builder
.take()
.expect("builder should be present when consuming handle")
}
}

impl Drop for CaseBuilderHandle<'_> {
fn drop(&mut self) {
if let Some(builder) = self.builder.take() {
*self.guard = Some(builder);
}
}
}

// TODO(tsaucer) replace this all with CaseBuilder after it implements Clone
#[derive(Clone, Debug)]
#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)]
#[derive(Clone)]
pub struct PyCaseBuilder {
case_builder: Arc<Mutex<Option<CaseBuilder>>>,
}

impl From<CaseBuilder> for PyCaseBuilder {
fn from(case_builder: CaseBuilder) -> PyCaseBuilder {
PyCaseBuilder {
case_builder: Arc::new(Mutex::new(Some(case_builder))),
}
}
expr: Option<Expr>,
when: Vec<Expr>,
then: Vec<Expr>,
}

#[pymethods]
impl PyCaseBuilder {
fn case_builder_handle(&self) -> PyDataFusionResult<CaseBuilderHandle<'_>> {
let guard = self.case_builder.lock();
CaseBuilderHandle::new(guard)
#[new]
pub fn new(expr: Option<PyExpr>) -> Self {
Self {
expr: expr.map(Into::into),
when: vec![],
then: vec![],
}
}

pub fn into_case_builder(self) -> PyDataFusionResult<CaseBuilder> {
let guard = self.case_builder.lock();
CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner)
}
}
pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder {
let mut case_builder = self.clone();
case_builder.when.push(when.into());
case_builder.then.push(then.into());

#[pymethods]
impl PyCaseBuilder {
fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult<PyCaseBuilder> {
let mut handle = self.case_builder_handle()?;
let next_builder = handle.builder_mut().when(when.expr, then.expr);
Ok(next_builder.into())
case_builder
}

fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> {
let mut handle = self.case_builder_handle()?;
match handle.builder_mut().otherwise(else_expr.expr) {
Ok(expr) => Ok(expr.clone().into()),
Err(err) => Err(err.into()),
}
println!("otherwise called {self:?}");
let case_builder = CaseBuilder::new(
self.expr.clone().map(Box::new),
self.when.clone(),
self.then.clone(),
Some(Box::new(else_expr.into())),
);

let expr = case_builder.end()?;

Ok(expr.into())
}

fn end(&self) -> PyDataFusionResult<PyExpr> {
let mut handle = self.case_builder_handle()?;
match handle.builder_mut().end() {
Ok(expr) => Ok(expr.clone().into()),
Err(err) => Err(err.into()),
}
println!("end called {self:?}");

let case_builder = CaseBuilder::new(
self.expr.clone().map(Box::new),
self.when.clone(),
self.then.clone(),
None,
);

let expr = case_builder.end()?;

Ok(expr.into())
}
}
4 changes: 2 additions & 2 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ fn col(name: &str) -> PyResult<PyExpr> {
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
#[pyfunction]
fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
Ok(datafusion::logical_expr::case(expr.expr).into())
Ok(PyCaseBuilder::new(Some(expr)))
}

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
#[pyfunction]
fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> {
Ok(datafusion::logical_expr::when(when.expr, then.expr).into())
Ok(PyCaseBuilder::new(None).when(when, then))
}

/// Helper function to find the appropriate window function.
Expand Down