diff --git a/extensions/standard/array.cpp b/extensions/standard/array.cpp index 54f5420..e7ef51f 100644 --- a/extensions/standard/array.cpp +++ b/extensions/standard/array.cpp @@ -187,3 +187,76 @@ Value replace(const std::vector& args) { arr = result; return arr; } + + +// 拼接给定的所有列表 +Value concat(const std::vector& args) { + std::vector result; + for (const auto &i : args) { + if (!i.is_array()) { + L_ERR("Given parameter(s) have no-array element"); + return LAMINA_NULL; + } + for (const auto &j : std::get >(i.data)) { + result.push_back(j); + } + } + return Value(result); +} + +// 列表切片 +Value slice(const std::vector& args) { + check_cpp_function_argv_x(args, 3, 4); + if (!args[0].is_array()) { + L_ERR("slice() requires a list"); + return LAMINA_NULL; + } + std::vector val = std::get >(args[0].data); + std::vector result; + int begin = (int) args[1].as_number(), end = (int) args[2].as_number(), step = 1; + if (end < begin) step = -1; + if (args.size() >= 4) { + step = (int) args[3].as_number(); + } + for (int i = begin; i != end; i += step) { + int it = (i >= 0) ? i : (int(val.size()) + i); + if (it < 0 || it >= val.size()) { + break; + } + result.push_back(val[it]); + } + return Value(result); +} + +// 排序给定的列表,第二个参数表示比较器 +Value _sort(const std::vector& args) { + check_cpp_function_argv_x(args, 1, 2); + if (!args[0].is_array()) { + L_ERR("sort() requires a list"); + return LAMINA_NULL; + } + std::vector val = std::get >(args[0].data); + for (auto &i : val) { + if (!i.is_comparable()) { + L_ERR("Array has uncomparable object"); + return args[0]; // A failure, not an error. + } + } + std::function comparer; + if (args.size() >= 2) { + if (!args[1].is_lambda()) { + L_ERR("Comparer must be a lambda/function"); + return LAMINA_NULL; + } + const auto func = std::get>(args[1].data); + comparer = [&func](const Value &a, const Value &b) -> bool { + return Interpreter::call_function(func.get(), {a, b}).as_bool(); + }; + } else { + comparer = [](const Value &a, const Value &b) -> bool { + return a < b; + }; + } + sort(val.begin(), val.end(), comparer); + return Value(val); +} \ No newline at end of file diff --git a/extensions/standard/standard.hpp b/extensions/standard/standard.hpp index 69188c3..a30f14a 100644 --- a/extensions/standard/standard.hpp +++ b/extensions/standard/standard.hpp @@ -150,6 +150,15 @@ Value map(const std::vector& args); // 替换内容:需3个参数(原字符串/容器、目标值、替换值),替换所有匹配的目标值并返回新结果 Value replace(const std::vector& args); +// 拼接给定的所有列表 +Value concat(const std::vector& args); + +// 切片,可以为 [a:b] 或 [a:b:c] +Value slice(const std::vector& args); + +// 排序给定的列表,第二个参数表示比较器 +Value _sort(const std::vector& args); + // 变量表 Value vars(const std::vector& args); @@ -298,6 +307,10 @@ inline std::unordered_map register_builtins = LAMINA_FUNC("find", find), LAMINA_FUNC("map", map), LAMINA_FUNC("replace", replace), + + LAMINA_FUNC("concat", concat), + LAMINA_FUNC("slice", slice), + LAMINA_FUNC("sort", _sort), // CAS数学模块:封装符号计算相关的解析、化简、求导等功能 LAMINA_MODULE("cas", LAMINA_VERSION, { diff --git a/interpreter/eval.cpp b/interpreter/eval.cpp index 1cc6d0b..1095575 100644 --- a/interpreter/eval.cpp +++ b/interpreter/eval.cpp @@ -74,7 +74,7 @@ Value HANDLE_BINARYEXPR_ADD(Value* l, Value* r) { return Value(l->to_string() + r->to_string()); } else if (ltype & VALUE_IS_ARRAY && rtype & VALUE_IS_ARRAY) { // Vector addition - return l->vector_add(r); + return l->vector_add(*r); // 只要有一方是 Irrational 或 Symbolic,优先生成符号表达式 } else if (((ltype & VALUE_IS_IRRATIONAL) || (ltype & VALUE_IS_SYMBOLIC) || (rtype & VALUE_IS_IRRATIONAL) || (rtype & VALUE_IS_SYMBOLIC)) && (ltype & VALUE_IS_NUMERIC) && (rtype & VALUE_IS_NUMERIC)) { std::shared_ptr leftExpr = GET_SYMBOLICEXPR(l, ltype); @@ -111,6 +111,7 @@ Value HANDLE_BINARYEXPR_ADD(Value* l, Value* r) { Value HANDLE_BINARYEXPR_STR_ADD_STR(Value* l, Value* r) { std::string ls = l->to_string(); std::string rs = r->to_string(); + /* try { // Try to parse both strings as CAS expressions and combine them symbolically LaminaCAS::Parser pl(ls); @@ -257,6 +258,8 @@ Value HANDLE_BINARYEXPR_STR_ADD_STR(Value* l, Value* r) { } catch (...) { // parsing failed for one or both strings; fall back to normal concatenation } + */ + return Value(ls + rs); } Value Interpreter::eval_LiteralExpr(const LiteralExpr* node) { @@ -337,11 +340,13 @@ Value Interpreter::eval_CallExpr(const CallExpr* call) { return {}; } // User function - return Interpreter::call_function(func.get(), args, self); + return Interpreter::call_function(func.get(), args, self, left.in_module ? Value(left.in_module) : LAMINA_NULL); } if (std::holds_alternative>(left.data)) { push_frame("", " "); + push_scope(); + set_module_as(left.in_module ? Value(left.in_module) : LAMINA_NULL); Value result; std::shared_ptr func; @@ -350,9 +355,11 @@ Value Interpreter::eval_CallExpr(const CallExpr* call) { result = func->function(args); } catch (...) { pop_frame(); + pop_scope(); throw; } pop_frame(); + pop_scope(); return result; } @@ -360,7 +367,7 @@ Value Interpreter::eval_CallExpr(const CallExpr* call) { return {}; } -Value Interpreter::call_function(const LambdaDeclExpr* func, const std::vector& args, Value self) { +Value Interpreter::call_function(const LambdaDeclExpr* func, const std::vector& args, Value self, Value module) { if (func == nullptr) { std::cerr << "Error: Function at '" << func << "' is null" << std::endl; return Value(""); @@ -369,6 +376,7 @@ Value Interpreter::call_function(const LambdaDeclExpr* func, const std::vectorname, "