Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -1065,15 +1065,19 @@ struct TableIterator : IP, C... {
: IP{policy},
C(columnData[framework::has_type_at_v<C>(all_columns{})])...
{
bind();
if (this->size() != 0) {
bind();
}
}

TableIterator(arrow::ChunkedArray* columnData[sizeof...(C)], IP&& policy)
requires(has_index<C...>)
: IP{policy},
C(columnData[framework::has_type_at_v<C>(all_columns{})])...
{
bind();
if (this->size() != 0) {
bind();
}
// In case we have an index column might need to constrain the actual
// number of rows in the view to the range provided by the index.
// FIXME: we should really understand what happens to an index when we
Expand All @@ -1086,14 +1090,18 @@ struct TableIterator : IP, C... {
: IP{static_cast<IP const&>(other)},
C(static_cast<C const&>(other))...
{
bind();
if (this->size() != 0) {
bind();
}
}

TableIterator& operator=(TableIterator other)
{
IP::operator=(static_cast<IP const&>(other));
(void(static_cast<C&>(*this) = static_cast<C>(other)), ...);
bind();
if (this->size() != 0) {
bind();
}
return *this;
}

Expand All @@ -1102,7 +1110,9 @@ struct TableIterator : IP, C... {
: IP{static_cast<IP const&>(other)},
C(static_cast<C const&>(other))...
{
bind();
if (this->size() != 0) {
bind();
}
}

TableIterator& operator++()
Expand Down Expand Up @@ -1543,18 +1553,22 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
uint64_t offset = 0;
auto out = container.getSliceFor(value, table->asArrowTable(), offset);
auto t = typename T::self_t({out}, offset);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
if (t.tableSize() != 0) {
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
}
return t;
}

template <soa::is_filtered_table T>
auto doSliceByHelper(T const* table, gsl::span<const int64_t> const& selection)
{
auto t = soa::Filtered<typename T::base_t>({table->asArrowTable()}, selection);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
t.intersectWithSelection(table->getSelectedRows()); // intersect filters
if (t.tableSize() != 0) {
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
t.intersectWithSelection(table->getSelectedRows()); // intersect filters
}
return t;
}

Expand All @@ -1563,8 +1577,10 @@ template <soa::is_table T>
auto doSliceByHelper(T const* table, gsl::span<const int64_t> const& selection)
{
auto t = soa::Filtered<T>({table->asArrowTable()}, selection);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
if (t.tableSize() != 0) {
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
}
return t;
}

Expand All @@ -1588,12 +1604,16 @@ auto prepareFilteredSlice(T const* table, std::shared_ptr<arrow::Table> slice, u
{
if (offset >= static_cast<uint64_t>(table->tableSize())) {
Filtered<typename T::base_t> fresult{{{slice}}, SelectionVector{}, 0};
table->copyIndexBindings(fresult);
if (fresult.tableSize() != 0) {
table->copyIndexBindings(fresult);
}
return fresult;
}
auto slicedSelection = sliceSelection(table->getSelectedRows(), slice->num_rows(), offset);
Filtered<typename T::base_t> fresult{{{slice}}, std::move(slicedSelection), offset};
table->copyIndexBindings(fresult);
if (fresult.tableSize() != 0) {
table->copyIndexBindings(fresult);
}
return fresult;
}

Expand All @@ -1617,7 +1637,9 @@ auto doSliceByCached(T const* table, framework::expressions::BindingNode const&
auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey<T>(node.name), node.name});
auto [offset, count] = localCache.getSliceFor(value);
auto t = typename T::self_t({table->asArrowTable()->Slice(static_cast<uint64_t>(offset), count)}, static_cast<uint64_t>(offset));
table->copyIndexBindings(t);
if (t.tableSize() != 0) {
table->copyIndexBindings(t);
}
return t;
}

Expand All @@ -1636,12 +1658,16 @@ auto doSliceByCachedUnsorted(T const* table, framework::expressions::BindingNode
auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey<T>(node.name), node.name});
if constexpr (soa::is_filtered_table<T>) {
auto t = typename T::self_t({table->asArrowTable()}, localCache.getSliceFor(value));
t.intersectWithSelection(table->getSelectedRows());
table->copyIndexBindings(t);
if (t.tableSize() != 0) {
t.intersectWithSelection(table->getSelectedRows());
table->copyIndexBindings(t);
}
return t;
} else {
auto t = Filtered<T>({table->asArrowTable()}, localCache.getSliceFor(value));
table->copyIndexBindings(t);
if (t.tableSize() != 0) {
table->copyIndexBindings(t);
}
return t;
}
}
Expand Down Expand Up @@ -3209,12 +3235,16 @@ struct JoinFull : Table<o2::aod::Hash<"JOIN"_h>, D, o2::aod::Hash<"JOIN"_h>, Ts.
JoinFull(std::shared_ptr<arrow::Table>&& table, uint64_t offset = 0)
: base{std::move(table), offset}
{
bindInternalIndicesTo(this);
if (this->tableSize() != 0) {
bindInternalIndicesTo(this);
}
}
JoinFull(std::vector<std::shared_ptr<arrow::Table>>&& tables, uint64_t offset = 0)
: base{ArrowHelpers::joinTables(std::move(tables), std::span{base::originalLabels}), offset}
{
bindInternalIndicesTo(this);
if (this->tableSize() != 0) {
bindInternalIndicesTo(this);
}
}
using base::bindExternalIndices;
using base::bindInternalIndicesTo;
Expand Down
2 changes: 0 additions & 2 deletions Framework/Core/include/Framework/ArrowTableSlicingCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ struct ArrowTableSlicingCache {
constexpr static ServiceKind service_kind = ServiceKind::Stream;

Cache bindingsKeys;
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;
std::vector<std::vector<int64_t>> offsets;
std::vector<std::vector<int64_t>> sizes;

Expand Down
80 changes: 35 additions & 45 deletions Framework/Core/src/ArrowTableSlicingCache.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@ void updatePairList(Cache& list, std::string const& binding, std::string const&

std::pair<int64_t, int64_t> SliceInfoPtr::getSliceFor(int value) const
{
int64_t offset = 0;
if (offsets.empty()) {
return {offset, 0};
}
if ((size_t)value >= offsets.size()) {
return {offset, 0};
return {0, 0};
}

return {offsets[value], sizes[value]};
Expand Down Expand Up @@ -68,8 +64,6 @@ ArrowTableSlicingCache::ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorte
: bindingsKeys{bsks},
bindingsKeysUnsorted{bsksUnsorted}
{
values.resize(bindingsKeys.size());
counts.resize(bindingsKeys.size());
offsets.resize(bindingsKeys.size());
sizes.resize(bindingsKeys.size());

Expand All @@ -81,10 +75,6 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted)
{
bindingsKeys = bsks;
bindingsKeysUnsorted = bsksUnsorted;
values.clear();
values.resize(bindingsKeys.size());
counts.clear();
counts.resize(bindingsKeys.size());
offsets.clear();
offsets.resize(bindingsKeys.size());
sizes.clear();
Expand All @@ -97,8 +87,6 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted)

arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table)
{
values[pos].reset();
counts[pos].reset();
offsets[pos].clear();
sizes[pos].clear();
if (table->num_rows() == 0) {
Expand All @@ -109,41 +97,50 @@ arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr<
throw runtime_error_f("Disabled cache %s/%s update requested", b.c_str(), k.c_str());
}
validateOrder(bindingsKeys[pos], table);
arrow::Datum value_counts;
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
ARROW_ASSIGN_OR_RAISE(value_counts,
arrow::compute::CallFunction("value_counts", {table->GetColumnByName(bindingsKeys[pos].key)},
&options));
auto pair = static_cast<arrow::StructArray>(value_counts.array());
values[pos].reset();
counts[pos].reset();
values[pos] = std::make_shared<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
counts[pos] = std::make_shared<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());

int maxValue = -1;
for (auto i = values[pos]->length() - 1; i >= 0; --i) {
if (values[pos]->Value(i) < 0) {
continue;
} else {
maxValue = values[pos]->Value(i);
auto column = table->GetColumnByName(k);

// starting from the end, find the first positive value, in a sorted column it is the largest index
for (auto iChunk = column->num_chunks() - 1; iChunk >= 0; --iChunk) {
auto chunk = static_cast<arrow::NumericArray<arrow::Int32Type>>(column->chunk(iChunk)->data());
for (auto iElement = chunk.length() - 1; iElement >= 0; --iElement) {
auto value = chunk.Value(iElement);
if (value < 0) {
continue;
} else {
maxValue = value;
break;
}
}
if (maxValue >= 0) {
break;
}
}

offsets[pos].resize(maxValue + 1);
sizes[pos].resize(maxValue + 1);
std::fill(offsets[pos].begin(), offsets[pos].end(), 0);
std::fill(sizes[pos].begin(), sizes[pos].end(), 0);
int64_t offset = 0;
for (auto i = 0U; i < values[pos]->length(); ++i) {
auto value = values[pos]->Value(i);
auto count = counts[pos]->Value(i);
if (value >= 0) {
offsets[pos][value] = offset;
sizes[pos][value] = count;

// loop over the index and collect size/offset
int lastValue = std::numeric_limits<int>::max();
int globalRow = 0;
for (auto iChunk = 0; iChunk < column->num_chunks(); ++iChunk) {
auto chunk = static_cast<arrow::NumericArray<arrow::Int32Type>>(column->chunk(iChunk)->data());
for (auto iElement = 0; iElement < chunk.length(); ++iElement) {
auto v = chunk.Value(iElement);
if (v >= 0) {
if (v == lastValue) {
++sizes[pos][v];
} else {
lastValue = v;
++sizes[pos][v];
offsets[pos][v] = globalRow;
}
}
++globalRow;
}
offset += count;
}

return arrow::Status::OK();
}

Expand Down Expand Up @@ -238,13 +235,6 @@ SliceInfoUnsortedPtr ArrowTableSlicingCache::getCacheUnsortedFor(const Entry& bi

SliceInfoPtr ArrowTableSlicingCache::getCacheForPos(int pos) const
{
if (values[pos] == nullptr && counts[pos] == nullptr) {
return {
{}, //
{} //
};
}

return {
gsl::span{offsets[pos].data(), offsets[pos].size()}, //
gsl::span(sizes[pos].data(), sizes[pos].size()) //
Expand Down