From bf498622f602d4c2a376e1e333cfe075fae65852 Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Wed, 9 Apr 2025 11:22:36 +0200 Subject: [PATCH 1/6] DPL Analysis: introduce binned expression --- .../include/Framework/ExpressionHelpers.h | 12 -- .../Core/include/Framework/Expressions.h | 134 +++++++++++++++++- Framework/Core/src/Expressions.cxx | 34 ++--- Framework/Core/test/test_Expressions.cxx | 27 +++- 4 files changed, 165 insertions(+), 42 deletions(-) diff --git a/Framework/Core/include/Framework/ExpressionHelpers.h b/Framework/Core/include/Framework/ExpressionHelpers.h index b531a39519272..f881abf7b0e6c 100644 --- a/Framework/Core/include/Framework/ExpressionHelpers.h +++ b/Framework/Core/include/Framework/ExpressionHelpers.h @@ -75,18 +75,6 @@ struct ColumnOperationSpec { result.type = type; } }; - -/// helper struct used to parse trees -struct NodeRecord { - /// pointer to the actual tree node - Node* node_ptr = nullptr; - size_t index = 0; - explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {} - bool operator!=(NodeRecord const& rhs) - { - return this->node_ptr != rhs.node_ptr; - } -}; } // namespace o2::framework::expressions #endif // O2_FRAMEWORK_EXPRESSIONS_HELPERS_H_ diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index 1d2883418de71..db82b64a1e416 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -114,6 +114,8 @@ struct LiteralNode { { } + LiteralNode(LiteralNode const& other) = default; + using var_t = LiteralValue::stored_type; var_t value; atype::type type = atype::NA; @@ -132,6 +134,7 @@ struct BindingNode { /// An expression tree node corresponding to binary or unary operation struct OpNode { OpNode(BasicOp op_) : op{op_} {} + OpNode(OpNode const& other) = default; BasicOp op; }; @@ -147,6 +150,8 @@ struct PlaceholderNode : LiteralNode { } } + PlaceholderNode(PlaceholderNode const& other) = default; + void reset(InitContext& context) { value = retrieve(context, name.data()); @@ -156,6 +161,28 @@ struct PlaceholderNode : LiteralNode { LiteralNode::var_t (*retrieve)(InitContext&, char const*); }; +/// A placeholder node for parameters taken from an array +struct ParameterNode : LiteralNode { + ParameterNode(int index_ = -1) + : LiteralNode((float)0), + index{index_} + { + } + + ParameterNode(ParameterNode const&) = default; + + template + void reset(T value_, int index_ = -1) + { + (*static_cast(this)) = LiteralNode(value_); + if (index_ > 0) { + index = index_; + } + } + + int index; +}; + /// A conditional node struct ConditionalNode { }; @@ -178,6 +205,10 @@ struct Node { { } + Node(ParameterNode&& p) : self{std::forward(p)}, left{nullptr}, right{nullptr}, condition{nullptr} + { + } + Node(ConditionalNode op, Node&& then_, Node&& else_, Node&& condition_) : self{op}, left{std::make_unique(std::forward(then_))}, @@ -196,16 +227,70 @@ struct Node { right{nullptr}, condition{nullptr} {} + Node(Node const& other) + : self{other.self}, + index{other.index} + { + if (other.left != nullptr) { + left = std::make_unique(*other.left); + } + if (other.right != nullptr) { + right = std::make_unique(*other.right); + } + if (other.condition != nullptr) { + condition = std::make_unique(*other.condition); + } + } + /// variant with possible nodes - using self_t = std::variant; + using self_t = std::variant; self_t self; size_t index = 0; /// pointers to children - std::unique_ptr left; - std::unique_ptr right; - std::unique_ptr condition; + std::unique_ptr left = nullptr; + std::unique_ptr right = nullptr; + std::unique_ptr condition = nullptr; }; +/// helper struct used to parse trees +struct NodeRecord { + /// pointer to the actual tree node + Node* node_ptr = nullptr; + size_t index = 0; + explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {} + bool operator!=(NodeRecord const& rhs) + { + return this->node_ptr != rhs.node_ptr; + } +}; + +/// Tree-walker helper +template +void walk(Node* head, L const& pred) +{ + std::stack path; + path.emplace(head, 0); + while (!path.empty()) { + auto& top = path.top(); + pred(top.node_ptr); + + auto* leftp = top.node_ptr->left.get(); + auto* rightp = top.node_ptr->right.get(); + auto* condp = top.node_ptr->condition.get(); + path.pop(); + + if (leftp != nullptr) { + path.emplace(leftp, 0); + } + if (rightp != nullptr) { + path.emplace(rightp, 0); + } + if (condp != nullptr) { + path.emplace(condp, 0); + } + } +} + /// overloaded operators to build the tree from an expression #define BINARY_OP_NODES(_operator_, _operation_) \ @@ -402,6 +487,47 @@ inline Node ifnode(Node&& condition_, Configurable const& then_, Configurabl return Node{ConditionalNode{}, PlaceholderNode{then_}, PlaceholderNode{else_}, std::forward(condition_)}; } +/// Parameters +inline Node par(int index) +{ + return Node{ParameterNode{index}}; +} + +/// binned functional +template +inline Node binned(std::vector const& binning, std::vector const& parameters, Node&& binned, Node&& pexp, Node&& out) +{ + int bins = binning.size() - 1; + const auto binned_copy = binned; + const auto out_copy = out; + auto root = ifnode(Node{binned_copy} < binning[0], Node{out_copy}, LiteralNode{-1}); + root.right = std::make_unique(ifnode(Node{binned_copy} > binning[0] && Node{binned_copy} <= binning [1], updateParameters(pexp, bins, parameters, 0), LiteralNode{-1})); + auto* current = root.right.get(); + for (auto i = 1; i < bins; ++i) { + current->right = std::make_unique(ifnode(Node{binned_copy} <= binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1})); + current = current->right.get(); + } + current->right = std::make_unique(out); + + return root; +} + +template +Node updateParameters(Node const& pexp, int bins, std::vector const& parameters, int bin) +{ + Node result{pexp}; + auto updateParameter = [&bins, ¶meters, &bin](Node* node) + { + if (node->self.index() == 5) { + auto* n = std::get_if<5>(&node->self); + n->reset(parameters[bin * bins + n->index]); + } + }; + walk(&result, updateParameter); + return result; +} + + /// A struct, containing the root of the expression tree struct Filter { Filter() = default; diff --git a/Framework/Core/src/Expressions.cxx b/Framework/Core/src/Expressions.cxx index 45bb120b6eb74..ade9af1e6f6f3 100644 --- a/Framework/Core/src/Expressions.cxx +++ b/Framework/Core/src/Expressions.cxx @@ -118,6 +118,13 @@ struct PlaceholderNodeHelper { return DatumSpec{node.value, node.type}; } }; + +struct ParameterNodeHelper { + DatumSpec operator()(ParameterNode const& node) const + { + return DatumSpec{node.value, node.type}; + } +}; } // namespace std::shared_ptr concreteArrowType(atype::type type) @@ -189,37 +196,13 @@ std::ostream& operator<<(std::ostream& os, DatumSpec const& spec) void updatePlaceholders(Filter& filter, InitContext& context) { - std::stack path; - - // insert the top node into stack - path.emplace(filter.node.get(), 0); - auto updateNode = [&](Node* node) { if (node->self.index() == 3) { std::get_if<3>(&node->self)->reset(context); } }; - // while the stack is not empty - while (!path.empty()) { - auto& top = path.top(); - updateNode(top.node_ptr); - - auto* leftp = top.node_ptr->left.get(); - auto* rightp = top.node_ptr->right.get(); - auto* condp = top.node_ptr->condition.get(); - path.pop(); - - if (leftp != nullptr) { - path.emplace(leftp, 0); - } - if (rightp != nullptr) { - path.emplace(rightp, 0); - } - if (condp != nullptr) { - path.emplace(condp, 0); - } - } + expressions::walk(filter.node.get(), updateNode); } const char* stringType(atype::type t) @@ -267,6 +250,7 @@ Operations createOperations(Filter const& expression) [lh = LiteralNodeHelper{}](LiteralNode const& node) { return lh(node); }, [bh = BindingNodeHelper{}](BindingNode const& node) { return bh(node); }, [ph = PlaceholderNodeHelper{}](PlaceholderNode const& node) { return ph(node); }, + [pr = ParameterNodeHelper{}](ParameterNode const& node){ return pr(node); }, [](auto&&) { return DatumSpec{}; }}, node->self); }; diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index 8b08a9a38aa63..e65429e7ce94b 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -12,7 +12,6 @@ #include "Framework/Configurable.h" #include "Framework/ExpressionHelpers.h" #include "Framework/AnalysisDataModel.h" -#include "Framework/AODReaderHelpers.h" #include #include @@ -283,3 +282,29 @@ TEST_CASE("TestConditionalExpressions") auto gandiva_filter2 = createFilter(schema2, gandiva_condition2); REQUIRE(gandiva_tree2->ToString() == "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }"); } + +TEST_CASE("TestBinnedExpressions") +{ + std::vector bins{0.5, 1.5, 2.5, 3.5, 4.5}; + std::vector params{1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3, 4.0, 4.1, 4.2, 4.3}; + Projector p = binned(bins, params, o2::aod::track::pt, par(0) * o2::aod::track::x + par(1) * o2::aod::track::y + par(2) * o2::aod::track::z + par(3) * o2::aod::track::phi, LiteralNode{0.f}); + auto pspecs = createOperations(p); + auto schema = std::make_shared(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()}); + auto tree = createExpressionTree(pspecs, schema); + REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool greater_than((float) fPt, (const float) 0.5 raw(3f000000)) && bool less_than_or_equal_to((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 1.1 raw(3f8ccccd), (float) fY)), float multiply((const float) 1.2 raw(3f99999a), (float) fZ)), float multiply((const float) 1.3 raw(3fa66666), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 2 raw(40000000), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 2.2 raw(400ccccd), (float) fZ)), float multiply((const float) 2.3 raw(40133333), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 3 raw(40400000), (float) fX), float multiply((const float) 3.1 raw(40466666), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 3.3 raw(40533333), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 4 raw(40800000), (float) fX), float multiply((const float) 4.1 raw(40833333), (float) fY)), float multiply((const float) 4.2 raw(40866666), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }"); + + std::vector binning{0, o2::constants::math::PIHalf, o2::constants::math::PI, o2::constants::math::PI + o2::constants::math::PIHalf, o2::constants::math::TwoPI}; + std::vector parameters{1.0, 1.1, 1.2, 1.3, // par 0 + 2.0, 2.1, 2.2, 2.3, // par 1 + 3.0, 3.1, 3.2, 3.3, // par 2 + 4.0, 4.1, 4.2, 4.3}; // par 3 + + Projector p2 = binned((std::vector)binning, + (std::vector)parameters, + o2::aod::track::phi, par(0) * o2::aod::track::x * o2::aod::track::x + par(1) * o2::aod::track::y * o2::aod::track::y + par(2) * o2::aod::track::z * o2::aod::track::z, + LiteralNode{-1.f}); + auto p2specs = createOperations(p2); + auto schema2 = std::make_shared(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()}); + auto tree2 = createExpressionTree(p2specs, schema2); + REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool greater_than((float) fPhi, (const float) 0 raw(0)) && bool less_than_or_equal_to((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 2 raw(40000000), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 3 raw(40400000), (float) fX), (float) fX), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 4 raw(40800000), (float) fX), (float) fX), float multiply(float multiply((const float) 4.1 raw(40833333), (float) fY), (float) fY)), float multiply(float multiply((const float) 4.2 raw(40866666), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }"); +} From 7be5308e58f01b2728fdf0c9d189d0e53e068ad5 Mon Sep 17 00:00:00 2001 From: ALICE Action Bot Date: Fri, 11 Apr 2025 12:04:16 +0000 Subject: [PATCH 2/6] Please consider the following formatting changes --- Framework/Core/include/Framework/Expressions.h | 6 ++---- Framework/Core/src/Expressions.cxx | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index db82b64a1e416..9e3b681dab3e2 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -501,7 +501,7 @@ inline Node binned(std::vector const& binning, std::vector const& paramete const auto binned_copy = binned; const auto out_copy = out; auto root = ifnode(Node{binned_copy} < binning[0], Node{out_copy}, LiteralNode{-1}); - root.right = std::make_unique(ifnode(Node{binned_copy} > binning[0] && Node{binned_copy} <= binning [1], updateParameters(pexp, bins, parameters, 0), LiteralNode{-1})); + root.right = std::make_unique(ifnode(Node{binned_copy} > binning[0] && Node{binned_copy} <= binning[1], updateParameters(pexp, bins, parameters, 0), LiteralNode{-1})); auto* current = root.right.get(); for (auto i = 1; i < bins; ++i) { current->right = std::make_unique(ifnode(Node{binned_copy} <= binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1})); @@ -516,8 +516,7 @@ template Node updateParameters(Node const& pexp, int bins, std::vector const& parameters, int bin) { Node result{pexp}; - auto updateParameter = [&bins, ¶meters, &bin](Node* node) - { + auto updateParameter = [&bins, ¶meters, &bin](Node* node) { if (node->self.index() == 5) { auto* n = std::get_if<5>(&node->self); n->reset(parameters[bin * bins + n->index]); @@ -527,7 +526,6 @@ Node updateParameters(Node const& pexp, int bins, std::vector const& paramete return result; } - /// A struct, containing the root of the expression tree struct Filter { Filter() = default; diff --git a/Framework/Core/src/Expressions.cxx b/Framework/Core/src/Expressions.cxx index ade9af1e6f6f3..6f646515b7837 100644 --- a/Framework/Core/src/Expressions.cxx +++ b/Framework/Core/src/Expressions.cxx @@ -250,7 +250,7 @@ Operations createOperations(Filter const& expression) [lh = LiteralNodeHelper{}](LiteralNode const& node) { return lh(node); }, [bh = BindingNodeHelper{}](BindingNode const& node) { return bh(node); }, [ph = PlaceholderNodeHelper{}](PlaceholderNode const& node) { return ph(node); }, - [pr = ParameterNodeHelper{}](ParameterNode const& node){ return pr(node); }, + [pr = ParameterNodeHelper{}](ParameterNode const& node) { return pr(node); }, [](auto&&) { return DatumSpec{}; }}, node->self); }; From 95f5ecdf94a57da5c8bc833c6e786803a3170b19 Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Fri, 11 Apr 2025 23:42:02 +0200 Subject: [PATCH 3/6] Update Expressions.h --- Framework/Core/include/Framework/Expressions.h | 1 + 1 file changed, 1 insertion(+) diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index 9e3b681dab3e2..3acb68d869fe0 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -41,6 +41,7 @@ class Projector; #include #include #include +#include namespace gandiva { using Selection = std::shared_ptr; From 8d2f417e5d3d72dc59278bfca589d45ab25796b5 Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Sat, 12 Apr 2025 15:51:51 +0200 Subject: [PATCH 4/6] Update Expressions.h Fix logic: * bins are defined as [lower : upper) * separate "less than lower" node --- Framework/Core/include/Framework/Expressions.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index 3acb68d869fe0..c0dc4595592ff 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -502,14 +502,12 @@ inline Node binned(std::vector const& binning, std::vector const& paramete const auto binned_copy = binned; const auto out_copy = out; auto root = ifnode(Node{binned_copy} < binning[0], Node{out_copy}, LiteralNode{-1}); - root.right = std::make_unique(ifnode(Node{binned_copy} > binning[0] && Node{binned_copy} <= binning[1], updateParameters(pexp, bins, parameters, 0), LiteralNode{-1})); - auto* current = root.right.get(); - for (auto i = 1; i < bins; ++i) { - current->right = std::make_unique(ifnode(Node{binned_copy} <= binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1})); + auto* current = &root; + for (auto i = 0; i < bins; ++i) { + current->right = std::make_unique(ifnode(Node{binned_copy} < binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1})); current = current->right.get(); } current->right = std::make_unique(out); - return root; } From 3b508068d370ec710798a62a5afd0063ecd1b83b Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Sat, 12 Apr 2025 16:48:50 +0200 Subject: [PATCH 5/6] update test --- Framework/Core/test/test_Expressions.cxx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index e65429e7ce94b..bd4070ebf3815 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -291,7 +291,7 @@ TEST_CASE("TestBinnedExpressions") auto pspecs = createOperations(p); auto schema = std::make_shared(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()}); auto tree = createExpressionTree(pspecs, schema); - REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool greater_than((float) fPt, (const float) 0.5 raw(3f000000)) && bool less_than_or_equal_to((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 1.1 raw(3f8ccccd), (float) fY)), float multiply((const float) 1.2 raw(3f99999a), (float) fZ)), float multiply((const float) 1.3 raw(3fa66666), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 2 raw(40000000), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 2.2 raw(400ccccd), (float) fZ)), float multiply((const float) 2.3 raw(40133333), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 3 raw(40400000), (float) fX), float multiply((const float) 3.1 raw(40466666), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 3.3 raw(40533333), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 4 raw(40800000), (float) fX), float multiply((const float) 4.1 raw(40833333), (float) fY)), float multiply((const float) 4.2 raw(40866666), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }"); + REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool less_than((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 1.1 raw(3f8ccccd), (float) fY)), float multiply((const float) 1.2 raw(3f99999a), (float) fZ)), float multiply((const float) 1.3 raw(3fa66666), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 2 raw(40000000), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 2.2 raw(400ccccd), (float) fZ)), float multiply((const float) 2.3 raw(40133333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 3 raw(40400000), (float) fX), float multiply((const float) 3.1 raw(40466666), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 3.3 raw(40533333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 4 raw(40800000), (float) fX), float multiply((const float) 4.1 raw(40833333), (float) fY)), float multiply((const float) 4.2 raw(40866666), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }"); std::vector binning{0, o2::constants::math::PIHalf, o2::constants::math::PI, o2::constants::math::PI + o2::constants::math::PIHalf, o2::constants::math::TwoPI}; std::vector parameters{1.0, 1.1, 1.2, 1.3, // par 0 @@ -306,5 +306,5 @@ TEST_CASE("TestBinnedExpressions") auto p2specs = createOperations(p2); auto schema2 = std::make_shared(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()}); auto tree2 = createExpressionTree(p2specs, schema2); - REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool greater_than((float) fPhi, (const float) 0 raw(0)) && bool less_than_or_equal_to((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 2 raw(40000000), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 3 raw(40400000), (float) fX), (float) fX), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 4 raw(40800000), (float) fX), (float) fX), float multiply(float multiply((const float) 4.1 raw(40833333), (float) fY), (float) fY)), float multiply(float multiply((const float) 4.2 raw(40866666), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }"); + REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 2 raw(40000000), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 3 raw(40400000), (float) fX), (float) fX), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 4 raw(40800000), (float) fX), (float) fX), float multiply(float multiply((const float) 4.1 raw(40833333), (float) fY), (float) fY)), float multiply(float multiply((const float) 4.2 raw(40866666), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }"); } From 6a081b10eb0607756e69eacc7af7fcb0baf4790f Mon Sep 17 00:00:00 2001 From: Anton Alkin Date: Mon, 14 Apr 2025 10:33:29 +0200 Subject: [PATCH 6/6] fix parameter addressing --- Framework/Core/include/Framework/Expressions.h | 2 +- Framework/Core/test/test_Expressions.cxx | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Framework/Core/include/Framework/Expressions.h b/Framework/Core/include/Framework/Expressions.h index c0dc4595592ff..af89e56f85835 100644 --- a/Framework/Core/include/Framework/Expressions.h +++ b/Framework/Core/include/Framework/Expressions.h @@ -518,7 +518,7 @@ Node updateParameters(Node const& pexp, int bins, std::vector const& paramete auto updateParameter = [&bins, ¶meters, &bin](Node* node) { if (node->self.index() == 5) { auto* n = std::get_if<5>(&node->self); - n->reset(parameters[bin * bins + n->index]); + n->reset(parameters[n->index * bins + bin]); } }; walk(&result, updateParameter); diff --git a/Framework/Core/test/test_Expressions.cxx b/Framework/Core/test/test_Expressions.cxx index bd4070ebf3815..2296b5dcbfbc4 100644 --- a/Framework/Core/test/test_Expressions.cxx +++ b/Framework/Core/test/test_Expressions.cxx @@ -291,7 +291,7 @@ TEST_CASE("TestBinnedExpressions") auto pspecs = createOperations(p); auto schema = std::make_shared(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()}); auto tree = createExpressionTree(pspecs, schema); - REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool less_than((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 1.1 raw(3f8ccccd), (float) fY)), float multiply((const float) 1.2 raw(3f99999a), (float) fZ)), float multiply((const float) 1.3 raw(3fa66666), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 2 raw(40000000), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 2.2 raw(400ccccd), (float) fZ)), float multiply((const float) 2.3 raw(40133333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 3 raw(40400000), (float) fX), float multiply((const float) 3.1 raw(40466666), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 3.3 raw(40533333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 4 raw(40800000), (float) fX), float multiply((const float) 4.1 raw(40833333), (float) fY)), float multiply((const float) 4.2 raw(40866666), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }"); + REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool less_than((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 2 raw(40000000), (float) fY)), float multiply((const float) 3 raw(40400000), (float) fZ)), float multiply((const float) 4 raw(40800000), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 3.1 raw(40466666), (float) fZ)), float multiply((const float) 4.1 raw(40833333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 1.2 raw(3f99999a), (float) fX), float multiply((const float) 2.2 raw(400ccccd), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 4.2 raw(40866666), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 1.3 raw(3fa66666), (float) fX), float multiply((const float) 2.3 raw(40133333), (float) fY)), float multiply((const float) 3.3 raw(40533333), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }"); std::vector binning{0, o2::constants::math::PIHalf, o2::constants::math::PI, o2::constants::math::PI + o2::constants::math::PIHalf, o2::constants::math::TwoPI}; std::vector parameters{1.0, 1.1, 1.2, 1.3, // par 0 @@ -306,5 +306,5 @@ TEST_CASE("TestBinnedExpressions") auto p2specs = createOperations(p2); auto schema2 = std::make_shared(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()}); auto tree2 = createExpressionTree(p2specs, schema2); - REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 2 raw(40000000), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 3 raw(40400000), (float) fX), (float) fX), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 4 raw(40800000), (float) fX), (float) fX), float multiply(float multiply((const float) 4.1 raw(40833333), (float) fY), (float) fY)), float multiply(float multiply((const float) 4.2 raw(40866666), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }"); + REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 2 raw(40000000), (float) fY), (float) fY)), float multiply(float multiply((const float) 3 raw(40400000), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fX), (float) fX), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 1.3 raw(3fa66666), (float) fX), (float) fX), float multiply(float multiply((const float) 2.3 raw(40133333), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.3 raw(40533333), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }"); }