Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 2 additions & 49 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -2157,61 +2157,14 @@ void emptyColumnLabel();

namespace row_helpers
{
template <soa::is_persistent_column... Cs>
std::array<arrow::ChunkedArray*, sizeof...(Cs)> getArrowColumns(arrow::Table* table, framework::pack<Cs...>)
{
return std::array<arrow::ChunkedArray*, sizeof...(Cs)>{o2::soa::getIndexFromLabel(table, Cs::columnLabel())...};
}

template <soa::is_persistent_column... Cs>
std::array<std::shared_ptr<arrow::Array>, sizeof...(Cs)> getChunks(arrow::Table* table, framework::pack<Cs...>, uint64_t ci)
{
return std::array<std::shared_ptr<arrow::Array>, sizeof...(Cs)>{o2::soa::getIndexFromLabel(table, Cs::columnLabel())->chunk(ci)...};
}

template <typename T, soa::is_persistent_column C>
typename C::type getSingleRowData(arrow::Table* table, T& rowIterator, uint64_t ci = std::numeric_limits<uint64_t>::max(), uint64_t ai = std::numeric_limits<uint64_t>::max(), uint64_t globalIndex = std::numeric_limits<uint64_t>::max())
{
if (ci == std::numeric_limits<uint64_t>::max() || ai == std::numeric_limits<uint64_t>::max()) {
auto colIterator = static_cast<C>(rowIterator).getIterator();
ci = colIterator.mCurrentChunk;
ai = *(colIterator.mCurrentPos) - colIterator.mFirstIndex;
}
return std::static_pointer_cast<o2::soa::arrow_array_for_t<typename C::type>>(o2::soa::getIndexFromLabel(table, C::columnLabel())->chunk(ci))->raw_values()[ai];
}

template <typename T, soa::is_dynamic_column C>
typename C::type getSingleRowData(arrow::Table*, T& rowIterator, uint64_t ci = std::numeric_limits<uint64_t>::max(), uint64_t ai = std::numeric_limits<uint64_t>::max(), uint64_t globalIndex = std::numeric_limits<uint64_t>::max())
{
if (globalIndex != std::numeric_limits<uint64_t>::max() && globalIndex != *std::get<0>(rowIterator.getIndices())) {
rowIterator.setCursor(globalIndex);
}
return rowIterator.template getDynamicColumn<C>();
}

template <typename T, soa::is_index_column C>
typename C::type getSingleRowData(arrow::Table*, T& rowIterator, uint64_t ci = std::numeric_limits<uint64_t>::max(), uint64_t ai = std::numeric_limits<uint64_t>::max(), uint64_t globalIndex = std::numeric_limits<uint64_t>::max())
{
if (globalIndex != std::numeric_limits<uint64_t>::max() && globalIndex != *std::get<0>(rowIterator.getIndices())) {
rowIterator.setCursor(globalIndex);
}
return rowIterator.template getId<C>();
}

template <typename T, typename... Cs>
std::tuple<typename Cs::type...> getRowData(arrow::Table* table, T rowIterator, uint64_t ci = std::numeric_limits<uint64_t>::max(), uint64_t ai = std::numeric_limits<uint64_t>::max(), uint64_t globalIndex = std::numeric_limits<uint64_t>::max())
{
return std::make_tuple(getSingleRowData<T, Cs>(table, rowIterator, ci, ai, globalIndex)...);
}

namespace
{
template <typename R, typename T, typename C>
R getColumnValue(const T& rowIterator)
{
return static_cast<R>(static_cast<C>(rowIterator).get());
}

namespace
{
template <typename R, typename T>
using ColumnGetterFunction = R (*)(const T&);

Expand Down
73 changes: 5 additions & 68 deletions Framework/Core/include/Framework/ASoAHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,76 +76,13 @@ void dataSizeVariesBetweenColumns();
template <template <typename... Cs> typename BP, typename T, typename... Cs>
std::vector<BinningIndex> groupTable(const T& table, const BP<Cs...>& binningPolicy, int minCatSize, int outsider)
{
arrow::Table* arrowTable = table.asArrowTable().get();
auto rowIterator = table.begin();

uint64_t ind = 0;
uint64_t selInd = 0;
gsl::span<int64_t const> selectedRows;
std::vector<BinningIndex> groupedIndices;

// Separate check to account for Filtered size different from arrow table
if (table.size() == 0) {
return groupedIndices;
}

if constexpr (soa::is_filtered_table<T>) {
selectedRows = table.getSelectedRows(); // vector<int64_t>
}

auto persistentColumns = typename BP<Cs...>::persistent_columns_t{};
constexpr auto persistentColumnsCount = pack_size(persistentColumns);
auto arrowColumns = o2::soa::row_helpers::getArrowColumns(arrowTable, persistentColumns);
auto chunksCount = arrowColumns[0]->num_chunks();
for (int i = 1; i < persistentColumnsCount; i++) {
if (arrowColumns[i]->num_chunks() != chunksCount) {
dataSizeVariesBetweenColumns();
}
}

for (uint64_t ci = 0; ci < chunksCount; ++ci) {
auto chunks = o2::soa::row_helpers::getChunks(arrowTable, persistentColumns, ci);
auto chunkLength = std::get<0>(chunks)->length();
for_<persistentColumnsCount - 1>([&chunks, &chunkLength](auto i) {
if (std::get<i.value + 1>(chunks)->length() != chunkLength) {
dataSizeVariesBetweenColumns();
}
});

if constexpr (soa::is_filtered_table<T>) {
if (selectedRows[ind] >= selInd + chunkLength) {
selInd += chunkLength;
continue; // Go to the next chunk, no value selected in this chunk
}
}

uint64_t ai = 0;
while (ai < chunkLength) {
if constexpr (soa::is_filtered_table<T>) {
ai += selectedRows[ind] - selInd;
selInd = selectedRows[ind];
}

auto values = binningPolicy.getBinningValues(rowIterator, arrowTable, ci, ai, ind);
auto val = binningPolicy.getBin(values);
if (val != outsider) {
groupedIndices.emplace_back(val, ind);
}
ind++;

if constexpr (soa::is_filtered_table<T>) {
if (ind >= selectedRows.size()) {
break;
}
} else {
ai++;
}
}

if constexpr (soa::is_filtered_table<T>) {
if (ind == selectedRows.size()) {
break;
}
for (auto rowIterator : table) {
auto values = binningPolicy.getBinningValues(rowIterator);
auto val = binningPolicy.getBin(values);
if (val != outsider) {
groupedIndices.emplace_back(val, *std::get<1>(rowIterator.getIndices()));
}
}

Expand Down
40 changes: 23 additions & 17 deletions Framework/Core/include/Framework/BinningPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,28 +241,28 @@ struct FlexibleBinningPolicy<std::tuple<Ls...>, Ts...> : BinningPolicyBase<sizeo
}

template <typename T, typename T2>
auto getBinningValue(T& rowIterator, arrow::Table* table, uint64_t ci = -1, uint64_t ai = -1, uint64_t globalIndex = -1) const
auto getBinningValue(T& rowIterator, uint64_t globalIndex = -1) const
{
if (globalIndex != -1) {
rowIterator.setCursor(globalIndex);
}
if constexpr (has_type<T2>(pack<Ls...>{})) {
if (globalIndex != -1) {
rowIterator.setCursor(globalIndex);
}
return std::get<T2>(mBinningFunctions)(rowIterator);
} else {
return soa::row_helpers::getSingleRowData<T, T2>(table, rowIterator, ci, ai, globalIndex);
return soa::row_helpers::getColumnValue<typename T2::type, T, T2>(rowIterator);
}
}

template <typename T>
auto getBinningValues(T& rowIterator, arrow::Table* table, uint64_t ci = -1, uint64_t ai = -1, uint64_t globalIndex = -1) const
auto getBinningValues(T& rowIterator, uint64_t globalIndex = -1) const
{
return std::make_tuple(getBinningValue<T, Ts>(rowIterator, table, ci, ai, globalIndex)...);
return std::make_tuple(getBinningValue<T, Ts>(rowIterator, globalIndex)...);
}

template <typename T>
auto getBinningValues(typename T::iterator rowIterator, T& table, uint64_t ci = -1, uint64_t ai = -1, uint64_t globalIndex = -1) const
auto getBinningValues(typename T::iterator rowIterator, T& table, uint64_t globalIndex = -1) const
{
return getBinningValues(rowIterator, table.asArrowTable().get(), ci, ai, globalIndex);
return getBinningValues(rowIterator, globalIndex);
}

template <typename... T2s>
Expand All @@ -284,15 +284,18 @@ struct ColumnBinningPolicy : BinningPolicyBase<sizeof...(Ts)> {
}

template <typename T>
auto getBinningValues(T& rowIterator, arrow::Table* table, uint64_t ci = -1, uint64_t ai = -1, uint64_t globalIndex = -1) const
auto getBinningValues(T& rowIterator, uint64_t globalIndex = -1) const
{
return std::make_tuple(soa::row_helpers::getSingleRowData<T, Ts>(table, rowIterator, ci, ai, globalIndex)...);
if (globalIndex != -1) {
rowIterator.setCursor(globalIndex);
}
return std::make_tuple(soa::row_helpers::getColumnValue<typename Ts::type, T, Ts>(rowIterator)...);
}

template <typename T>
auto getBinningValues(typename T::iterator rowIterator, T& table, uint64_t ci = -1, uint64_t ai = -1, uint64_t globalIndex = -1) const
auto getBinningValues(typename T::iterator rowIterator, T& table, uint64_t globalIndex = -1) const
{
return getBinningValues(rowIterator, table.asArrowTable().get(), ci, ai, globalIndex);
return getBinningValues(rowIterator, globalIndex);
}

int getBin(std::tuple<typename Ts::type...> const& data) const
Expand All @@ -309,15 +312,18 @@ struct NoBinningPolicy {
NoBinningPolicy() = default;

template <typename T>
auto getBinningValues(T& rowIterator, arrow::Table* table, uint64_t ci = -1, uint64_t ai = -1, uint64_t globalIndex = -1) const
auto getBinningValues(T& rowIterator, uint64_t globalIndex = -1) const
{
return std::make_tuple(soa::row_helpers::getSingleRowData<T, C>(table, rowIterator, ci, ai, globalIndex));
if (globalIndex != -1) {
rowIterator.setCursor(globalIndex);
}
return std::make_tuple(soa::row_helpers::getColumnValue<typename C::type, T, C>(rowIterator));
}

template <typename T>
auto getBinningValues(typename T::iterator rowIterator, T& table, uint64_t ci = -1, uint64_t ai = -1, uint64_t globalIndex = -1) const
auto getBinningValues(typename T::iterator rowIterator, T& table, uint64_t globalIndex = -1) const
{
return getBinningValues(rowIterator, table.asArrowTable().get(), ci, ai, globalIndex);
return getBinningValues(rowIterator, globalIndex);
}

int getBin(std::tuple<typename C::type> const& data) const
Expand Down