Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions Framework/Core/include/Framework/ExpressionHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,6 @@ struct ColumnOperationSpec {
result.type = type;
}
};

/// helper struct used to parse trees
struct NodeRecord {
/// pointer to the actual tree node
Node* node_ptr = nullptr;
size_t index = 0;
explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {}
bool operator!=(NodeRecord const& rhs)
{
return this->node_ptr != rhs.node_ptr;
}
};
} // namespace o2::framework::expressions

#endif // O2_FRAMEWORK_EXPRESSIONS_HELPERS_H_
131 changes: 127 additions & 4 deletions Framework/Core/include/Framework/Expressions.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Projector;
#include <string>
#include <memory>
#include <set>
#include <stack>
namespace gandiva
{
using Selection = std::shared_ptr<gandiva::SelectionVector>;
Expand Down Expand Up @@ -114,6 +115,8 @@ struct LiteralNode {
{
}

LiteralNode(LiteralNode const& other) = default;

using var_t = LiteralValue::stored_type;
var_t value;
atype::type type = atype::NA;
Expand All @@ -132,6 +135,7 @@ struct BindingNode {
/// An expression tree node corresponding to binary or unary operation
struct OpNode {
OpNode(BasicOp op_) : op{op_} {}
OpNode(OpNode const& other) = default;
BasicOp op;
};

Expand All @@ -147,6 +151,8 @@ struct PlaceholderNode : LiteralNode {
}
}

PlaceholderNode(PlaceholderNode const& other) = default;

void reset(InitContext& context)
{
value = retrieve(context, name.data());
Expand All @@ -156,6 +162,28 @@ struct PlaceholderNode : LiteralNode {
LiteralNode::var_t (*retrieve)(InitContext&, char const*);
};

/// A placeholder node for parameters taken from an array
struct ParameterNode : LiteralNode {
ParameterNode(int index_ = -1)
: LiteralNode((float)0),
index{index_}
{
}

ParameterNode(ParameterNode const&) = default;

template <typename T>
void reset(T value_, int index_ = -1)
{
(*static_cast<LiteralNode*>(this)) = LiteralNode(value_);
if (index_ > 0) {
index = index_;
}
}

int index;
};

/// A conditional node
struct ConditionalNode {
};
Expand All @@ -178,6 +206,10 @@ struct Node {
{
}

Node(ParameterNode&& p) : self{std::forward<ParameterNode>(p)}, left{nullptr}, right{nullptr}, condition{nullptr}
{
}

Node(ConditionalNode op, Node&& then_, Node&& else_, Node&& condition_)
: self{op},
left{std::make_unique<Node>(std::forward<Node>(then_))},
Expand All @@ -196,16 +228,70 @@ struct Node {
right{nullptr},
condition{nullptr} {}

Node(Node const& other)
: self{other.self},
index{other.index}
{
if (other.left != nullptr) {
left = std::make_unique<Node>(*other.left);
}
if (other.right != nullptr) {
right = std::make_unique<Node>(*other.right);
}
if (other.condition != nullptr) {
condition = std::make_unique<Node>(*other.condition);
}
}

/// variant with possible nodes
using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode>;
using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode, ParameterNode>;
self_t self;
size_t index = 0;
/// pointers to children
std::unique_ptr<Node> left;
std::unique_ptr<Node> right;
std::unique_ptr<Node> condition;
std::unique_ptr<Node> left = nullptr;
std::unique_ptr<Node> right = nullptr;
std::unique_ptr<Node> condition = nullptr;
};

/// helper struct used to parse trees
struct NodeRecord {
/// pointer to the actual tree node
Node* node_ptr = nullptr;
size_t index = 0;
explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {}
bool operator!=(NodeRecord const& rhs)
{
return this->node_ptr != rhs.node_ptr;
}
};

/// Tree-walker helper
template <typename L>
void walk(Node* head, L const& pred)
{
std::stack<NodeRecord> path;
path.emplace(head, 0);
while (!path.empty()) {
auto& top = path.top();
pred(top.node_ptr);

auto* leftp = top.node_ptr->left.get();
auto* rightp = top.node_ptr->right.get();
auto* condp = top.node_ptr->condition.get();
path.pop();

if (leftp != nullptr) {
path.emplace(leftp, 0);
}
if (rightp != nullptr) {
path.emplace(rightp, 0);
}
if (condp != nullptr) {
path.emplace(condp, 0);
}
}
}

/// overloaded operators to build the tree from an expression

#define BINARY_OP_NODES(_operator_, _operation_) \
Expand Down Expand Up @@ -402,6 +488,43 @@ inline Node ifnode(Node&& condition_, Configurable<L1> const& then_, Configurabl
return Node{ConditionalNode{}, PlaceholderNode{then_}, PlaceholderNode{else_}, std::forward<Node>(condition_)};
}

/// Parameters
inline Node par(int index)
{
return Node{ParameterNode{index}};
}

/// binned functional
template <typename T>
inline Node binned(std::vector<T> const& binning, std::vector<T> const& parameters, Node&& binned, Node&& pexp, Node&& out)
{
int bins = binning.size() - 1;
const auto binned_copy = binned;
const auto out_copy = out;
auto root = ifnode(Node{binned_copy} < binning[0], Node{out_copy}, LiteralNode{-1});
auto* current = &root;
for (auto i = 0; i < bins; ++i) {
current->right = std::make_unique<Node>(ifnode(Node{binned_copy} < binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1}));
current = current->right.get();
}
current->right = std::make_unique<Node>(out);
return root;
}

template <typename T>
Node updateParameters(Node const& pexp, int bins, std::vector<T> const& parameters, int bin)
{
Node result{pexp};
auto updateParameter = [&bins, &parameters, &bin](Node* node) {
if (node->self.index() == 5) {
auto* n = std::get_if<5>(&node->self);
n->reset(parameters[n->index * bins + bin]);
}
};
walk(&result, updateParameter);
return result;
}

/// A struct, containing the root of the expression tree
struct Filter {
Filter() = default;
Expand Down
34 changes: 9 additions & 25 deletions Framework/Core/src/Expressions.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ struct PlaceholderNodeHelper {
return DatumSpec{node.value, node.type};
}
};

struct ParameterNodeHelper {
DatumSpec operator()(ParameterNode const& node) const
{
return DatumSpec{node.value, node.type};
}
};
} // namespace

std::shared_ptr<arrow::DataType> concreteArrowType(atype::type type)
Expand Down Expand Up @@ -189,37 +196,13 @@ std::ostream& operator<<(std::ostream& os, DatumSpec const& spec)

void updatePlaceholders(Filter& filter, InitContext& context)
{
std::stack<NodeRecord> path;

// insert the top node into stack
path.emplace(filter.node.get(), 0);

auto updateNode = [&](Node* node) {
if (node->self.index() == 3) {
std::get_if<3>(&node->self)->reset(context);
}
};

// while the stack is not empty
while (!path.empty()) {
auto& top = path.top();
updateNode(top.node_ptr);

auto* leftp = top.node_ptr->left.get();
auto* rightp = top.node_ptr->right.get();
auto* condp = top.node_ptr->condition.get();
path.pop();

if (leftp != nullptr) {
path.emplace(leftp, 0);
}
if (rightp != nullptr) {
path.emplace(rightp, 0);
}
if (condp != nullptr) {
path.emplace(condp, 0);
}
}
expressions::walk(filter.node.get(), updateNode);
}

const char* stringType(atype::type t)
Expand Down Expand Up @@ -267,6 +250,7 @@ Operations createOperations(Filter const& expression)
[lh = LiteralNodeHelper{}](LiteralNode const& node) { return lh(node); },
[bh = BindingNodeHelper{}](BindingNode const& node) { return bh(node); },
[ph = PlaceholderNodeHelper{}](PlaceholderNode const& node) { return ph(node); },
[pr = ParameterNodeHelper{}](ParameterNode const& node) { return pr(node); },
[](auto&&) { return DatumSpec{}; }},
node->self);
};
Expand Down
27 changes: 26 additions & 1 deletion Framework/Core/test/test_Expressions.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "Framework/Configurable.h"
#include "Framework/ExpressionHelpers.h"
#include "Framework/AnalysisDataModel.h"
#include "Framework/AODReaderHelpers.h"
#include <catch_amalgamated.hpp>
#include <arrow/util/config.h>

Expand Down Expand Up @@ -283,3 +282,29 @@ TEST_CASE("TestConditionalExpressions")
auto gandiva_filter2 = createFilter(schema2, gandiva_condition2);
REQUIRE(gandiva_tree2->ToString() == "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }");
}

TEST_CASE("TestBinnedExpressions")
{
std::vector<float> bins{0.5, 1.5, 2.5, 3.5, 4.5};
std::vector<float> params{1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3, 4.0, 4.1, 4.2, 4.3};
Projector p = binned(bins, params, o2::aod::track::pt, par(0) * o2::aod::track::x + par(1) * o2::aod::track::y + par(2) * o2::aod::track::z + par(3) * o2::aod::track::phi, LiteralNode{0.f});
auto pspecs = createOperations(p);
auto schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()});
auto tree = createExpressionTree(pspecs, schema);
REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool less_than((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 2 raw(40000000), (float) fY)), float multiply((const float) 3 raw(40400000), (float) fZ)), float multiply((const float) 4 raw(40800000), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 3.1 raw(40466666), (float) fZ)), float multiply((const float) 4.1 raw(40833333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 1.2 raw(3f99999a), (float) fX), float multiply((const float) 2.2 raw(400ccccd), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 4.2 raw(40866666), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 1.3 raw(3fa66666), (float) fX), float multiply((const float) 2.3 raw(40133333), (float) fY)), float multiply((const float) 3.3 raw(40533333), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }");

std::vector<float> binning{0, o2::constants::math::PIHalf, o2::constants::math::PI, o2::constants::math::PI + o2::constants::math::PIHalf, o2::constants::math::TwoPI};
std::vector<float> parameters{1.0, 1.1, 1.2, 1.3, // par 0
2.0, 2.1, 2.2, 2.3, // par 1
3.0, 3.1, 3.2, 3.3, // par 2
4.0, 4.1, 4.2, 4.3}; // par 3

Projector p2 = binned((std::vector<float>)binning,
(std::vector<float>)parameters,
o2::aod::track::phi, par(0) * o2::aod::track::x * o2::aod::track::x + par(1) * o2::aod::track::y * o2::aod::track::y + par(2) * o2::aod::track::z * o2::aod::track::z,
LiteralNode{-1.f});
auto p2specs = createOperations(p2);
auto schema2 = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()});
auto tree2 = createExpressionTree(p2specs, schema2);
REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 2 raw(40000000), (float) fY), (float) fY)), float multiply(float multiply((const float) 3 raw(40400000), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fX), (float) fX), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 1.3 raw(3fa66666), (float) fX), (float) fX), float multiply(float multiply((const float) 2.3 raw(40133333), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.3 raw(40533333), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }");
}