diff --git a/Framework/Core/include/Framework/ExpressionHelpers.h b/Framework/Core/include/Framework/ExpressionHelpers.h index b531a39519272..f881abf7b0e6c 100644 --- a/Framework/Core/include/Framework/ExpressionHelpers.h +++ b/Framework/Core/include/Framework/ExpressionHelpers.h @@ -75,18 +75,6 @@ struct ColumnOperationSpec { result.type = type; } }; - -/// helper struct used to parse trees -struct NodeRecord { - /// pointer to the actual tree node - Node* node_ptr = nullptr; - size_t index = 0; - explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {} - bool operator!=(NodeRecord const& rhs) - { - return this->node_ptr != rhs.node_ptr; - } -}; } // namespace o2::framework::expressions #endif // O2_FRAMEWORK_EXPRESSIONS_HELPERS_H_ diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index 1d2883418de71..af89e56f85835 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -41,6 +41,7 @@ class Projector; #include #include #include +#include namespace gandiva { using Selection = std::shared_ptr; @@ -114,6 +115,8 @@ struct LiteralNode { { } + LiteralNode(LiteralNode const& other) = default; + using var_t = LiteralValue::stored_type; var_t value; atype::type type = atype::NA; @@ -132,6 +135,7 @@ struct BindingNode { /// An expression tree node corresponding to binary or unary operation struct OpNode { OpNode(BasicOp op_) : op{op_} {} + OpNode(OpNode const& other) = default; BasicOp op; }; @@ -147,6 +151,8 @@ struct PlaceholderNode : LiteralNode { } } + PlaceholderNode(PlaceholderNode const& other) = default; + void reset(InitContext& context) { value = retrieve(context, name.data()); @@ -156,6 +162,28 @@ struct PlaceholderNode : LiteralNode { LiteralNode::var_t (*retrieve)(InitContext&, char const*); }; +/// A placeholder node for parameters taken from an array +struct ParameterNode : LiteralNode { + ParameterNode(int index_ = -1) + : LiteralNode((float)0), + index{index_} + { + } + + ParameterNode(ParameterNode const&) = default; + + template + void reset(T value_, int index_ = -1) + { + (*static_cast(this)) = LiteralNode(value_); + if (index_ > 0) { + index = index_; + } + } + + int index; +}; + /// A conditional node struct ConditionalNode { }; @@ -178,6 +206,10 @@ struct Node { { } + Node(ParameterNode&& p) : self{std::forward(p)}, left{nullptr}, right{nullptr}, condition{nullptr} + { + } + Node(ConditionalNode op, Node&& then_, Node&& else_, Node&& condition_) : self{op}, left{std::make_unique(std::forward(then_))}, @@ -196,16 +228,70 @@ struct Node { right{nullptr}, condition{nullptr} {} + Node(Node const& other) + : self{other.self}, + index{other.index} + { + if (other.left != nullptr) { + left = std::make_unique(*other.left); + } + if (other.right != nullptr) { + right = std::make_unique(*other.right); + } + if (other.condition != nullptr) { + condition = std::make_unique(*other.condition); + } + } + /// variant with possible nodes - using self_t = std::variant; + using self_t = std::variant; self_t self; size_t index = 0; /// pointers to children - std::unique_ptr left; - std::unique_ptr right; - std::unique_ptr condition; + std::unique_ptr left = nullptr; + std::unique_ptr right = nullptr; + std::unique_ptr condition = nullptr; +}; + +/// helper struct used to parse trees +struct NodeRecord { + /// pointer to the actual tree node + Node* node_ptr = nullptr; + size_t index = 0; + explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {} + bool operator!=(NodeRecord const& rhs) + { + return this->node_ptr != rhs.node_ptr; + } }; +/// Tree-walker helper +template +void walk(Node* head, L const& pred) +{ + std::stack path; + path.emplace(head, 0); + while (!path.empty()) { + auto& top = path.top(); + pred(top.node_ptr); + + auto* leftp = top.node_ptr->left.get(); + auto* rightp = top.node_ptr->right.get(); + auto* condp = top.node_ptr->condition.get(); + path.pop(); + + if (leftp != nullptr) { + path.emplace(leftp, 0); + } + if (rightp != nullptr) { + path.emplace(rightp, 0); + } + if (condp != nullptr) { + path.emplace(condp, 0); + } + } +} + /// overloaded operators to build the tree from an expression #define BINARY_OP_NODES(_operator_, _operation_) \ @@ -402,6 +488,43 @@ inline Node ifnode(Node&& condition_, Configurable const& then_, Configurabl return Node{ConditionalNode{}, PlaceholderNode{then_}, PlaceholderNode{else_}, std::forward(condition_)}; } +/// Parameters +inline Node par(int index) +{ + return Node{ParameterNode{index}}; +} + +/// binned functional +template +inline Node binned(std::vector const& binning, std::vector const& parameters, Node&& binned, Node&& pexp, Node&& out) +{ + int bins = binning.size() - 1; + const auto binned_copy = binned; + const auto out_copy = out; + auto root = ifnode(Node{binned_copy} < binning[0], Node{out_copy}, LiteralNode{-1}); + auto* current = &root; + for (auto i = 0; i < bins; ++i) { + current->right = std::make_unique(ifnode(Node{binned_copy} < binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1})); + current = current->right.get(); + } + current->right = std::make_unique(out); + return root; +} + +template +Node updateParameters(Node const& pexp, int bins, std::vector const& parameters, int bin) +{ + Node result{pexp}; + auto updateParameter = [&bins, ¶meters, &bin](Node* node) { + if (node->self.index() == 5) { + auto* n = std::get_if<5>(&node->self); + n->reset(parameters[n->index * bins + bin]); + } + }; + walk(&result, updateParameter); + return result; +} + /// A struct, containing the root of the expression tree struct Filter { Filter() = default; diff --git a/Framework/Core/src/Expressions.cxx b/Framework/Core/src/Expressions.cxx index 45bb120b6eb74..6f646515b7837 100644 --- a/Framework/Core/src/Expressions.cxx +++ b/Framework/Core/src/Expressions.cxx @@ -118,6 +118,13 @@ struct PlaceholderNodeHelper { return DatumSpec{node.value, node.type}; } }; + +struct ParameterNodeHelper { + DatumSpec operator()(ParameterNode const& node) const + { + return DatumSpec{node.value, node.type}; + } +}; } // namespace std::shared_ptr concreteArrowType(atype::type type) @@ -189,37 +196,13 @@ std::ostream& operator<<(std::ostream& os, DatumSpec const& spec) void updatePlaceholders(Filter& filter, InitContext& context) { - std::stack path; - - // insert the top node into stack - path.emplace(filter.node.get(), 0); - auto updateNode = [&](Node* node) { if (node->self.index() == 3) { std::get_if<3>(&node->self)->reset(context); } }; - // while the stack is not empty - while (!path.empty()) { - auto& top = path.top(); - updateNode(top.node_ptr); - - auto* leftp = top.node_ptr->left.get(); - auto* rightp = top.node_ptr->right.get(); - auto* condp = top.node_ptr->condition.get(); - path.pop(); - - if (leftp != nullptr) { - path.emplace(leftp, 0); - } - if (rightp != nullptr) { - path.emplace(rightp, 0); - } - if (condp != nullptr) { - path.emplace(condp, 0); - } - } + expressions::walk(filter.node.get(), updateNode); } const char* stringType(atype::type t) @@ -267,6 +250,7 @@ Operations createOperations(Filter const& expression) [lh = LiteralNodeHelper{}](LiteralNode const& node) { return lh(node); }, [bh = BindingNodeHelper{}](BindingNode const& node) { return bh(node); }, [ph = PlaceholderNodeHelper{}](PlaceholderNode const& node) { return ph(node); }, + [pr = ParameterNodeHelper{}](ParameterNode const& node) { return pr(node); }, [](auto&&) { return DatumSpec{}; }}, node->self); }; diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index 8b08a9a38aa63..2296b5dcbfbc4 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -12,7 +12,6 @@ #include "Framework/Configurable.h" #include "Framework/ExpressionHelpers.h" #include "Framework/AnalysisDataModel.h" -#include "Framework/AODReaderHelpers.h" #include #include @@ -283,3 +282,29 @@ TEST_CASE("TestConditionalExpressions") auto gandiva_filter2 = createFilter(schema2, gandiva_condition2); REQUIRE(gandiva_tree2->ToString() == "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }"); } + +TEST_CASE("TestBinnedExpressions") +{ + std::vector bins{0.5, 1.5, 2.5, 3.5, 4.5}; + std::vector params{1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3, 4.0, 4.1, 4.2, 4.3}; + Projector p = binned(bins, params, o2::aod::track::pt, par(0) * o2::aod::track::x + par(1) * o2::aod::track::y + par(2) * o2::aod::track::z + par(3) * o2::aod::track::phi, LiteralNode{0.f}); + auto pspecs = createOperations(p); + auto schema = std::make_shared(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()}); + auto tree = createExpressionTree(pspecs, schema); + REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool less_than((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 2 raw(40000000), (float) fY)), float multiply((const float) 3 raw(40400000), (float) fZ)), float multiply((const float) 4 raw(40800000), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 3.1 raw(40466666), (float) fZ)), float multiply((const float) 4.1 raw(40833333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 1.2 raw(3f99999a), (float) fX), float multiply((const float) 2.2 raw(400ccccd), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 4.2 raw(40866666), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 1.3 raw(3fa66666), (float) fX), float multiply((const float) 2.3 raw(40133333), (float) fY)), float multiply((const float) 3.3 raw(40533333), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }"); + + std::vector binning{0, o2::constants::math::PIHalf, o2::constants::math::PI, o2::constants::math::PI + o2::constants::math::PIHalf, o2::constants::math::TwoPI}; + std::vector parameters{1.0, 1.1, 1.2, 1.3, // par 0 + 2.0, 2.1, 2.2, 2.3, // par 1 + 3.0, 3.1, 3.2, 3.3, // par 2 + 4.0, 4.1, 4.2, 4.3}; // par 3 + + Projector p2 = binned((std::vector)binning, + (std::vector)parameters, + o2::aod::track::phi, par(0) * o2::aod::track::x * o2::aod::track::x + par(1) * o2::aod::track::y * o2::aod::track::y + par(2) * o2::aod::track::z * o2::aod::track::z, + LiteralNode{-1.f}); + auto p2specs = createOperations(p2); + auto schema2 = std::make_shared(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()}); + auto tree2 = createExpressionTree(p2specs, schema2); + REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 2 raw(40000000), (float) fY), (float) fY)), float multiply(float multiply((const float) 3 raw(40400000), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fX), (float) fX), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 1.3 raw(3fa66666), (float) fX), (float) fX), float multiply(float multiply((const float) 2.3 raw(40133333), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.3 raw(40533333), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }"); +}