Skip to content

Commit 388db0e

Browse files
authored
Update aggregate probe to be locked only if skipping aggregation (#18766)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> If we always set the probe to locked, we let go of the ability to skip partial aggregation during processing a future record batch when the ratio does meet the required threshold to skip. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Lock the aggregation probe only if we are skipping not when just the number of rows threshold is met. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> This change should be covered by existing aggregate tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> No
1 parent 8dac8f1 commit 388db0e

File tree

1 file changed

+150
-1
lines changed

1 file changed

+150
-1
lines changed

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ impl SkipAggregationProbe {
194194
if self.input_rows >= self.probe_rows_threshold {
195195
self.should_skip = self.num_groups as f64 / self.input_rows as f64
196196
>= self.probe_ratio_threshold;
197-
self.is_locked = true;
197+
// Set is_locked to true only if we have decided to skip, otherwise we can try to skip
198+
// during processing the next record_batch.
199+
self.is_locked = self.should_skip;
198200
}
199201
}
200202

@@ -1280,6 +1282,7 @@ impl GroupedHashAggregateStream {
12801282
#[cfg(test)]
12811283
mod tests {
12821284
use super::*;
1285+
use crate::execution_plan::ExecutionPlan;
12831286
use crate::test::TestMemoryExec;
12841287
use arrow::array::{Int32Array, Int64Array};
12851288
use arrow::datatypes::{DataType, Field, Schema};
@@ -1395,4 +1398,150 @@ mod tests {
13951398

13961399
Ok(())
13971400
}
1401+
1402+
#[tokio::test]
1403+
async fn test_skip_aggregation_probe_not_locked_until_skip() -> Result<()> {
1404+
// Test that the probe is not locked until we actually decide to skip.
1405+
// This allows us to continue evaluating the skip condition across multiple batches.
1406+
//
1407+
// Scenario:
1408+
// - Batch 1: Hits rows threshold but NOT ratio threshold (low cardinality) -> don't skip
1409+
// - Batch 2: Now hits ratio threshold (high cardinality) -> skip
1410+
//
1411+
// Without the fix, the probe would be locked after batch 1, preventing the skip
1412+
// decision from being made on batch 2.
1413+
1414+
let schema = Arc::new(Schema::new(vec![
1415+
Field::new("group_col", DataType::Int32, false),
1416+
Field::new("value_col", DataType::Int32, false),
1417+
]));
1418+
1419+
// Configure thresholds:
1420+
// - probe_rows_threshold: 100 rows
1421+
// - probe_ratio_threshold: 0.8 (80%)
1422+
let probe_rows_threshold = 100;
1423+
let probe_ratio_threshold = 0.8;
1424+
1425+
// Batch 1: 100 rows with only 10 unique groups
1426+
// Ratio: 10/100 = 0.1 (10%) < 0.8 -> should NOT skip
1427+
// This will hit the rows threshold but not the ratio threshold
1428+
let batch1_rows = 100;
1429+
let batch1_groups = 10;
1430+
let mut group_ids_batch1 = Vec::new();
1431+
for i in 0..batch1_rows {
1432+
group_ids_batch1.push((i % batch1_groups) as i32);
1433+
}
1434+
let values_batch1: Vec<i32> = vec![1; batch1_rows];
1435+
1436+
let batch1 = RecordBatch::try_new(
1437+
Arc::clone(&schema),
1438+
vec![
1439+
Arc::new(Int32Array::from(group_ids_batch1)),
1440+
Arc::new(Int32Array::from(values_batch1)),
1441+
],
1442+
)?;
1443+
1444+
// Batch 2: 350 rows with 350 unique NEW groups (starting from group 10)
1445+
// After batch 2, total: 450 rows, 360 groups
1446+
// Ratio: 360/450 = 0.8 (80%) >= 0.8 -> SHOULD decide to skip
1447+
let batch2_rows = 350;
1448+
let batch2_groups = 350;
1449+
let group_ids_batch2: Vec<i32> = (batch1_groups..(batch1_groups + batch2_groups))
1450+
.map(|x| x as i32)
1451+
.collect();
1452+
let values_batch2: Vec<i32> = vec![1; batch2_rows];
1453+
1454+
let batch2 = RecordBatch::try_new(
1455+
Arc::clone(&schema),
1456+
vec![
1457+
Arc::new(Int32Array::from(group_ids_batch2)),
1458+
Arc::new(Int32Array::from(values_batch2)),
1459+
],
1460+
)?;
1461+
1462+
// Batch 3: This batch should be skipped since we decided to skip after batch 2
1463+
// 100 rows with 100 unique groups (continuing from where batch 2 left off)
1464+
let batch3_rows = 100;
1465+
let batch3_groups = 100;
1466+
let batch3_start_group = batch1_groups + batch2_groups;
1467+
let group_ids_batch3: Vec<i32> = (batch3_start_group
1468+
..(batch3_start_group + batch3_groups))
1469+
.map(|x| x as i32)
1470+
.collect();
1471+
let values_batch3: Vec<i32> = vec![1; batch3_rows];
1472+
1473+
let batch3 = RecordBatch::try_new(
1474+
Arc::clone(&schema),
1475+
vec![
1476+
Arc::new(Int32Array::from(group_ids_batch3)),
1477+
Arc::new(Int32Array::from(values_batch3)),
1478+
],
1479+
)?;
1480+
1481+
let input_partitions = vec![vec![batch1, batch2, batch3]];
1482+
1483+
let runtime = RuntimeEnvBuilder::default().build_arc()?;
1484+
let mut task_ctx = TaskContext::default().with_runtime(runtime);
1485+
1486+
// Configure skip aggregation settings
1487+
let mut session_config = task_ctx.session_config().clone();
1488+
session_config = session_config.set(
1489+
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
1490+
&datafusion_common::ScalarValue::UInt64(Some(probe_rows_threshold)),
1491+
);
1492+
session_config = session_config.set(
1493+
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
1494+
&datafusion_common::ScalarValue::Float64(Some(probe_ratio_threshold)),
1495+
);
1496+
task_ctx = task_ctx.with_session_config(session_config);
1497+
let task_ctx = Arc::new(task_ctx);
1498+
1499+
// Create aggregate: COUNT(*) GROUP BY group_col
1500+
let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())];
1501+
let aggr_expr = vec![Arc::new(
1502+
AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?])
1503+
.schema(Arc::clone(&schema))
1504+
.alias("count_value")
1505+
.build()?,
1506+
)];
1507+
1508+
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
1509+
let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec)));
1510+
1511+
// Use Partial mode
1512+
let aggregate_exec = AggregateExec::try_new(
1513+
AggregateMode::Partial,
1514+
PhysicalGroupBy::new_single(group_expr),
1515+
aggr_expr,
1516+
vec![None],
1517+
exec,
1518+
Arc::clone(&schema),
1519+
)?;
1520+
1521+
// Execute and collect results
1522+
let mut stream =
1523+
GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?;
1524+
let mut results = Vec::new();
1525+
1526+
while let Some(result) = stream.next().await {
1527+
let batch = result?;
1528+
results.push(batch);
1529+
}
1530+
1531+
// Check that skip aggregation actually happened
1532+
// The key metric is skipped_aggregation_rows
1533+
let metrics = aggregate_exec.metrics().unwrap();
1534+
let skipped_rows = metrics
1535+
.sum_by_name("skipped_aggregation_rows")
1536+
.map(|m| m.as_usize())
1537+
.unwrap_or(0);
1538+
1539+
// We expect batch 3's rows to be skipped (100 rows)
1540+
assert_eq!(
1541+
skipped_rows, batch3_rows,
1542+
"Expected batch 3's rows ({batch3_rows}) to be skipped",
1543+
);
1544+
1545+
Ok(())
1546+
}
13981547
}

0 commit comments

Comments
 (0)