Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:

env:
CARGO_TERM_COLOR: always
MINIMUM_NOIR_VERSION: v0.36.0
MINIMUM_NOIR_VERSION: v1.0.0-beta.4

jobs:
noir-version-list:
Expand Down
2 changes: 1 addition & 1 deletion Nargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ authors = [""]
compiler_version = ">=0.36.0"

[dependencies]
sort = { tag = "v0.2.3", git = "https://github.com/noir-lang/noir_sort" }
sort = { tag = "v0.3.0", git = "https://github.com/noir-lang/noir_sort" }
75 changes: 35 additions & 40 deletions src/lib.nr
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
mod mut_sparse_array;
use dep::sort::sort_advanced;

unconstrained fn __sort_field_as_u32(lhs: Field, rhs: Field) -> bool {
unconstrained fn __sort(lhs: u32, rhs: u32) -> bool {
// lhs.lt(rhs)
lhs as u32 < rhs as u32
lhs < rhs
}

fn assert_sorted(lhs: Field, rhs: Field) {
let result = (rhs - lhs - 1);
result.assert_max_bit_size::<32>();
fn assert_sorted(lhs: u32, rhs: u32) {
assert(lhs < rhs);
}

/**
Expand All @@ -24,10 +23,10 @@ fn assert_sorted(lhs: Field, rhs: Field) {
**/
struct MutSparseArrayBase<let N: u32, T, ComparisonFuncs> {
values: [T; N + 3],
keys: [Field; N + 2],
linked_keys: [Field; N + 2],
tail_ptr: Field,
maximum: Field,
keys: [u32; N + 2],
linked_keys: [u32; N + 2],
tail_ptr: u32,
maximum: u32,
}

struct U32RangeTraits {}
Expand All @@ -47,9 +46,9 @@ pub struct MutSparseArray<let N: u32, T> {
* 2. values[0] is an empty object. when calling `get(idx)`, if `idx` is not in `keys` we will return `values[0]`
**/
pub struct SparseArray<let N: u32, T> {
keys: [Field; N + 2],
keys: [u32; N + 2],
values: [T; N + 3],
maximum: Field, // can be up to 2^32
maximum: u32, // can be up to 2^32 - 1
}
impl<let N: u32, T> SparseArray<N, T>
where
Expand All @@ -59,15 +58,16 @@ where
/**
* @brief construct a SparseArray
**/
pub(crate) fn create(_keys: [Field; N], _values: [T; N], size: Field) -> Self {
pub(crate) fn create(_keys: [u32; N], _values: [T; N], size: u32) -> Self {
assert(size >= 1);
let _maximum = size - 1;
let mut r: Self =
SparseArray { keys: [0; N + 2], values: [T::default(); N + 3], maximum: _maximum };

// for any valid index, we want to ensure the following is satified:
// self.keys[X] <= index <= self.keys[X+1]
// this requires us to sort hte keys, and insert a startpoint and endpoint
let sorted_keys = sort_advanced(_keys, __sort_field_as_u32, assert_sorted);
let sorted_keys = sort_advanced(_keys, __sort, assert_sorted);

// insert start and endpoints
r.keys[0] = 0;
Expand Down Expand Up @@ -103,45 +103,41 @@ where
// because `self.keys` is sorted, we can simply validate that
// sorted_keys.sorted[0] < 2^32
// sorted_keys.sorted[N-1] < maximum
sorted_keys.sorted[0].assert_max_bit_size::<32>();
_maximum.assert_max_bit_size::<32>();
(_maximum - sorted_keys.sorted[N - 1]).assert_max_bit_size::<32>();
assert(_maximum >= sorted_keys.sorted[N - 1]);
r
}

/**
* @brief determine whether `target` is present in `self.keys`
* @details if `found == false`, `self.keys[found_index] < target < self.keys[found_index + 1]`
**/
unconstrained fn search_for_key(self, target: Field) -> (Field, Field) {
unconstrained fn search_for_key(self, target: u32) -> (bool, u32) {
let mut found = false;
let mut found_index = 0;
let mut found_index: u32 = 0;
let mut previous_less_than_or_equal_to_target = false;
for i in 0..N + 2 {
// if target = 0xffffffff we need to be able to add 1 here, so use u64
let current_less_than_or_equal_to_target = self.keys[i] as u64 <= target as u64;
if (self.keys[i] == target) {
found = true;
found_index = i as Field;
found_index = i;
break;
}
if (previous_less_than_or_equal_to_target & !current_less_than_or_equal_to_target) {
found_index = i as Field - 1;
found_index = i - 1;
break;
}
previous_less_than_or_equal_to_target = current_less_than_or_equal_to_target;
}
(found as Field, found_index)
(found, found_index)
}

/**
* @brief return element `idx` from the sparse array
* @details cost is 14.5 gates per lookup
**/
fn get(self, idx: Field) -> T {
fn get(self, idx: u32) -> T {
let (found, found_index) = unsafe { self.search_for_key(idx) };
// bool check. 0.25 gates cheaper than a raw `bool` type. need to fix at some point
assert(found * found == found);

// OK! So we have the following cases to check
// 1. if `found` then `self.keys[found_index] == idx`
Expand All @@ -152,15 +148,13 @@ where
// combine the two into the following single statement:
// `self.keys[found_index] + 1 - found <= idx <= self.keys[found_index + 1 - found] - 1 + found
let lhs = self.keys[found_index];
let rhs = self.keys[found_index + 1 - found];
let lhs_condition = idx - lhs - 1 + found;
let rhs_condition = rhs - 1 + found - idx;
lhs_condition.assert_max_bit_size::<32>();
rhs_condition.assert_max_bit_size::<32>();
let rhs = self.keys[found_index + 1 - found as u32];
assert(lhs + 1 - found as u32 <= idx);
assert(idx <= rhs + found as u32 - 1);

// self.keys[i] maps to self.values[i+1]
// however...if we did not find a non-sparse entry, we want to return self.values[0] (the default value)
let value_index = (found_index + 1) * found;
let value_index = (found_index + 1) * found as u32;
self.values[value_index]
}
}
Expand All @@ -179,7 +173,7 @@ mod test {

for i in 0..100 {
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
assert(example.get(i as Field) == 0);
assert(example.get(i) == 0);
}
}
}
Expand All @@ -188,34 +182,35 @@ mod test {
fn test_sparse_lookup_boundary_cases() {
// what about when keys[0] = 0 and keys[N-1] = 2^32 - 1?
let example = SparseArray::create(
[0, 99999, 7, 0xffffffff],
[0, 99999, 7, 0xfffffffe],
[123, 101112, 789, 456],
0x100000000,
0xffffffff,
);

assert(example.get(0) == 123);
assert(example.get(99999) == 101112);
assert(example.get(7) == 789);
assert(example.get(0xffffffff) == 456);
assert(example.get(0xfffffffe) == 0);
assert(example.get(0xfffffffe) == 456);
assert(example.get(0xfffffffd) == 0);
}

#[test(should_fail_with = "call to assert_max_bit_size")]
#[test(should_fail)]
fn test_sparse_lookup_overflow() {
let example = SparseArray::create([1, 5, 7, 99999], [123, 456, 789, 101112], 100000);

assert(example.get(100000) == 0);
}

/**
#[test(should_fail_with = "call to assert_max_bit_size")]
fn test_sparse_lookup_boundary_case_overflow() {
let example =
SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0x100000000);

assert(example.get(0x100000000) == 0);
}

#[test(should_fail_with = "call to assert_max_bit_size")]
**/
#[test(should_fail)]
fn test_sparse_lookup_key_exceeds_maximum() {
let example =
SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0xffffffff);
Expand All @@ -236,7 +231,7 @@ mod test {

for i in 0..100 {
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
assert(example.get(i as Field) == 0);
assert(example.get(i) == 0);
}
}
}
Expand Down Expand Up @@ -272,7 +267,7 @@ mod test {
assert(example.get(99) == values[1]);
for i in 0..100 {
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
assert(example.get(i as Field) == F::default());
assert(example.get(i) == F::default());
}
}
}
Expand Down
Loading