Skip to content

Commit 7a7096a

Browse files
authored
refactor: try reduce aggregate hash index cost on hot path (#19072)
* improve * test: add a test to ensure cover all slots * chore: clean
1 parent 6693d34 commit 7a7096a

File tree

2 files changed

+76
-22
lines changed

2 files changed

+76
-22
lines changed

src/query/expression/src/aggregate/aggregate_hashtable.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ impl AggregateHashTable {
9292
arena: Arc<Bump>,
9393
need_init_entry: bool,
9494
) -> Self {
95+
debug_assert!(capacity.is_power_of_two());
9596
let entries = if need_init_entry {
9697
vec![Entry::default(); capacity]
9798
} else {
@@ -110,6 +111,7 @@ impl AggregateHashTable {
110111
entries,
111112
count: 0,
112113
capacity,
114+
capacity_mask: capacity - 1,
113115
},
114116
config,
115117
}

src/query/expression/src/aggregate/hash_index.rs

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,56 @@ pub(super) struct HashIndex {
2222
pub entries: Vec<Entry>,
2323
pub count: usize,
2424
pub capacity: usize,
25+
pub capacity_mask: usize,
26+
}
27+
28+
const INCREMENT_BITS: usize = 5;
29+
30+
/// Derive an odd probing step from the high bits of the hash so the walk spans all slots.
31+
///
32+
/// this will generate a step in the range [1, 2^INCREMENT_BITS) based on hash and always odd.
33+
#[inline(always)]
34+
fn step(hash: u64) -> usize {
35+
((hash >> (64 - INCREMENT_BITS)) as usize) | 1
36+
}
37+
38+
/// Move to the next slot with wrap-around using the power-of-two capacity mask.
39+
///
40+
/// soundness: capacity is always a power of two, so mask is capacity - 1
41+
#[inline(always)]
42+
fn next_slot(slot: usize, hash: u64, mask: usize) -> usize {
43+
(slot + step(hash)) & mask
44+
}
45+
46+
#[inline(always)]
47+
fn init_slot(hash: u64, capacity_mask: usize) -> usize {
48+
hash as usize & capacity_mask
2549
}
2650

2751
impl HashIndex {
2852
pub fn with_capacity(capacity: usize) -> Self {
53+
debug_assert!(capacity.is_power_of_two());
54+
let capacity_mask = capacity - 1;
2955
Self {
3056
entries: vec![Entry::default(); capacity],
3157
count: 0,
3258
capacity,
59+
capacity_mask,
3360
}
3461
}
3562

36-
fn init_slot(&self, hash: u64) -> usize {
37-
hash as usize & (self.capacity - 1)
38-
}
39-
40-
fn find_or_insert(&mut self, mut slot: usize, salt: u16) -> (usize, bool) {
63+
fn find_or_insert(&mut self, mut slot: usize, hash: u64) -> (usize, bool) {
64+
let salt = Entry::hash_to_salt(hash);
4165
let entries = self.entries.as_mut_slice();
4266
loop {
43-
let entry = &mut entries[slot];
67+
debug_assert!(entries.get(slot).is_some());
68+
// SAFETY: slot is always in range
69+
let entry = unsafe { entries.get_unchecked_mut(slot) };
4470
if entry.is_occupied() {
4571
if entry.get_salt() == salt {
4672
return (slot, false);
4773
} else {
48-
slot += 1;
49-
if slot >= self.capacity {
50-
slot = 0;
51-
}
74+
slot = next_slot(slot, hash, self.capacity_mask);
5275
continue;
5376
}
5477
} else {
@@ -59,13 +82,10 @@ impl HashIndex {
5982
}
6083

6184
pub fn probe_slot(&mut self, hash: u64) -> usize {
62-
let mut slot = self.init_slot(hash);
6385
let entries = self.entries.as_mut_slice();
86+
let mut slot = init_slot(hash, self.capacity_mask);
6487
while entries[slot].is_occupied() {
65-
slot += 1;
66-
if slot >= self.capacity {
67-
slot = 0;
68-
}
88+
slot = next_slot(slot, hash, self.capacity_mask);
6989
}
7090
slot as _
7191
}
@@ -159,8 +179,9 @@ impl HashIndex {
159179
slots.extend(
160180
state.group_hashes[..row_count]
161181
.iter()
162-
.map(|hash| self.init_slot(*hash)),
182+
.map(|hash| init_slot(*hash, self.capacity_mask)),
163183
);
184+
let capacity_mask = self.capacity_mask;
164185

165186
let mut new_group_count = 0;
166187
let mut remaining_entries = row_count;
@@ -176,7 +197,7 @@ impl HashIndex {
176197
let hash = state.group_hashes[row];
177198

178199
let is_new;
179-
(*slot, is_new) = self.find_or_insert(*slot, Entry::hash_to_salt(hash));
200+
(*slot, is_new) = self.find_or_insert(*slot, hash);
180201

181202
if is_new {
182203
state.empty_vector[new_entry_count] = row;
@@ -217,13 +238,11 @@ impl HashIndex {
217238
no_match_count = adapter.compare(state, need_compare_count, no_match_count);
218239
}
219240

220-
// 5. Linear probing, just increase iter_times
241+
// 5. Linear probing with hash-derived step
221242
for row in state.no_match_vector[..no_match_count].iter().copied() {
222243
let slot = &mut slots[row];
223-
*slot += 1;
224-
if *slot >= self.capacity {
225-
*slot = 0;
226-
}
244+
let hash = state.group_hashes[row];
245+
*slot = next_slot(*slot, hash, capacity_mask);
227246
}
228247
remaining_entries = no_match_count;
229248
}
@@ -262,6 +281,7 @@ impl<'a> TableAdapter for AdapterImpl<'a> {
262281
#[cfg(test)]
263282
mod tests {
264283
use std::collections::HashMap;
284+
use std::collections::HashSet;
265285

266286
use super::*;
267287
use crate::ProbeState;
@@ -405,6 +425,38 @@ mod tests {
405425
}
406426
}
407427

428+
#[test]
429+
fn test_probe_walk_covers_full_capacity() {
430+
// This test make sure that we can always cover all slots in the table
431+
let capacity = 16;
432+
let capacity_mask = capacity - 1;
433+
434+
for high_bits in 0u64..(1 << INCREMENT_BITS) {
435+
let hash = high_bits << (64 - INCREMENT_BITS);
436+
let mut slot = init_slot(hash, capacity_mask);
437+
let mut visited = HashSet::with_capacity(capacity);
438+
439+
for _ in 0..capacity {
440+
assert!(
441+
visited.insert(slot),
442+
"hash {hash:#x} revisited slot {slot} before covering the table"
443+
);
444+
slot = next_slot(slot, hash, capacity_mask);
445+
}
446+
447+
assert_eq!(
448+
capacity,
449+
visited.len(),
450+
"hash {hash:#x} failed to cover every slot for capacity {capacity}"
451+
);
452+
assert_eq!(
453+
init_slot(hash, capacity_mask),
454+
slot,
455+
"hash {hash:#x} walk did not return to its start after {capacity} steps"
456+
);
457+
}
458+
}
459+
408460
#[test]
409461
fn test_hash_index() {
410462
TestCase {

0 commit comments

Comments
 (0)