Skip to content

Commit 3dbda4b

Browse files
authored
fix(builtin): limit recursion depth (#870)
Add builtin.MaxDepth (default 10k) to prevent stack overflows when processing deeply nested or cyclic structures in builtin functions. The functions flatten, min, max, mean, and median now return a "recursion depth exceeded" error instead of crashing the runtime. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
1 parent ad49544 commit 3dbda4b

File tree

3 files changed

+134
-14
lines changed

3 files changed

+134
-14
lines changed

builtin/builtin.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package builtin
33
import (
44
"encoding/base64"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"reflect"
89
"sort"
@@ -16,6 +17,10 @@ import (
1617
var (
1718
Index map[string]int
1819
Names []string
20+
21+
// MaxDepth limits the recursion depth for nested structures.
22+
MaxDepth = 10000
23+
ErrorMaxDepth = errors.New("recursion depth exceeded")
1924
)
2025

2126
func init() {
@@ -377,7 +382,7 @@ var Builtins = []*Function{
377382
{
378383
Name: "max",
379384
Func: func(args ...any) (any, error) {
380-
return minMax("max", runtime.Less, args...)
385+
return minMax("max", runtime.Less, 0, args...)
381386
},
382387
Validate: func(args []reflect.Type) (reflect.Type, error) {
383388
return validateAggregateFunc("max", args)
@@ -386,7 +391,7 @@ var Builtins = []*Function{
386391
{
387392
Name: "min",
388393
Func: func(args ...any) (any, error) {
389-
return minMax("min", runtime.More, args...)
394+
return minMax("min", runtime.More, 0, args...)
390395
},
391396
Validate: func(args []reflect.Type) (reflect.Type, error) {
392397
return validateAggregateFunc("min", args)
@@ -395,7 +400,7 @@ var Builtins = []*Function{
395400
{
396401
Name: "mean",
397402
Func: func(args ...any) (any, error) {
398-
count, sum, err := mean(args...)
403+
count, sum, err := mean(0, args...)
399404
if err != nil {
400405
return nil, err
401406
}
@@ -411,7 +416,7 @@ var Builtins = []*Function{
411416
{
412417
Name: "median",
413418
Func: func(args ...any) (any, error) {
414-
values, err := median(args...)
419+
values, err := median(0, args...)
415420
if err != nil {
416421
return nil, err
417422
}
@@ -940,7 +945,10 @@ var Builtins = []*Function{
940945
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
941946
return nil, size, fmt.Errorf("cannot flatten %s", v.Kind())
942947
}
943-
ret := flatten(v)
948+
ret, err := flatten(v, 0)
949+
if err != nil {
950+
return nil, 0, err
951+
}
944952
size = uint(len(ret))
945953
return ret, size, nil
946954
},

builtin/builtin_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) {
722722
})
723723
}
724724
}
725+
726+
func TestBuiltin_flatten_recursion(t *testing.T) {
727+
var s []any
728+
s = append(s, &s) // s contains a pointer to itself
729+
730+
env := map[string]any{
731+
"arr": s,
732+
}
733+
734+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
735+
require.NoError(t, err)
736+
737+
_, err = expr.Run(program, env)
738+
require.Error(t, err)
739+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
740+
}
741+
742+
func TestBuiltin_flatten_recursion_slice(t *testing.T) {
743+
s := make([]any, 1)
744+
s[0] = s
745+
746+
env := map[string]any{
747+
"arr": s,
748+
}
749+
750+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
751+
require.NoError(t, err)
752+
753+
_, err = expr.Run(program, env)
754+
require.Error(t, err)
755+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
756+
}
757+
758+
func TestBuiltin_numerical_recursion(t *testing.T) {
759+
s := make([]any, 1)
760+
s[0] = s
761+
762+
env := map[string]any{
763+
"arr": s,
764+
}
765+
766+
tests := []string{
767+
"max(arr)",
768+
"min(arr)",
769+
"mean(arr)",
770+
"median(arr)",
771+
}
772+
773+
for _, input := range tests {
774+
t.Run(input, func(t *testing.T) {
775+
program, err := expr.Compile(input, expr.Env(env))
776+
require.NoError(t, err)
777+
778+
_, err = expr.Run(program, env)
779+
require.Error(t, err)
780+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
781+
})
782+
}
783+
}
784+
785+
func TestBuiltin_recursion_custom_max_depth(t *testing.T) {
786+
originalMaxDepth := builtin.MaxDepth
787+
defer func() {
788+
builtin.MaxDepth = originalMaxDepth
789+
}()
790+
791+
// Set a small depth limit
792+
builtin.MaxDepth = 2
793+
794+
// Create a deeply nested array (depth 5)
795+
// [1, [2, [3, [4, [5]]]]]
796+
arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}}
797+
798+
env := map[string]any{
799+
"arr": arr,
800+
}
801+
802+
t.Run("flatten exceeds max depth", func(t *testing.T) {
803+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
804+
require.NoError(t, err)
805+
806+
_, err = expr.Run(program, env)
807+
require.Error(t, err)
808+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
809+
})
810+
811+
t.Run("flatten within max depth", func(t *testing.T) {
812+
// Depth 2: [1, [2]]
813+
shallowArr := []any{1, []any{2}}
814+
envShallow := map[string]any{"arr": shallowArr}
815+
program, err := expr.Compile("flatten(arr)", expr.Env(envShallow))
816+
require.NoError(t, err)
817+
818+
_, err = expr.Run(program, envShallow)
819+
require.NoError(t, err)
820+
})
821+
}

builtin/lib.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,18 @@ func String(arg any) any {
253253
return fmt.Sprintf("%v", arg)
254254
}
255255

256-
func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
256+
func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) {
257+
if depth > MaxDepth {
258+
return nil, ErrorMaxDepth
259+
}
257260
var val any
258261
for _, arg := range args {
259262
rv := reflect.ValueOf(arg)
260263
switch rv.Kind() {
261264
case reflect.Array, reflect.Slice:
262265
size := rv.Len()
263266
for i := 0; i < size; i++ {
264-
elemVal, err := minMax(name, fn, rv.Index(i).Interface())
267+
elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface())
265268
if err != nil {
266269
return nil, err
267270
}
@@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
294297
return val, nil
295298
}
296299

297-
func mean(args ...any) (int, float64, error) {
300+
func mean(depth int, args ...any) (int, float64, error) {
301+
if depth > MaxDepth {
302+
return 0, 0, ErrorMaxDepth
303+
}
298304
var total float64
299305
var count int
300306

@@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) {
304310
case reflect.Array, reflect.Slice:
305311
size := rv.Len()
306312
for i := 0; i < size; i++ {
307-
elemCount, elemSum, err := mean(rv.Index(i).Interface())
313+
elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface())
308314
if err != nil {
309315
return 0, 0, err
310316
}
@@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) {
327333
return count, total, nil
328334
}
329335

330-
func median(args ...any) ([]float64, error) {
336+
func median(depth int, args ...any) ([]float64, error) {
337+
if depth > MaxDepth {
338+
return nil, ErrorMaxDepth
339+
}
331340
var values []float64
332341

333342
for _, arg := range args {
@@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) {
336345
case reflect.Array, reflect.Slice:
337346
size := rv.Len()
338347
for i := 0; i < size; i++ {
339-
elems, err := median(rv.Index(i).Interface())
348+
elems, err := median(depth+1, rv.Index(i).Interface())
340349
if err != nil {
341350
return nil, err
342351
}
@@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) {
355364
return values, nil
356365
}
357366

358-
func flatten(arg reflect.Value) []any {
367+
func flatten(arg reflect.Value, depth int) ([]any, error) {
368+
if depth > MaxDepth {
369+
return nil, ErrorMaxDepth
370+
}
359371
ret := []any{}
360372
for i := 0; i < arg.Len(); i++ {
361373
v := deref.Value(arg.Index(i))
362374
if v.Kind() == reflect.Array || v.Kind() == reflect.Slice {
363-
x := flatten(v)
375+
x, err := flatten(v, depth+1)
376+
if err != nil {
377+
return nil, err
378+
}
364379
ret = append(ret, x...)
365380
} else {
366381
ret = append(ret, v.Interface())
367382
}
368383
}
369-
return ret
384+
return ret, nil
370385
}
371386

372387
func get(params ...any) (out any, err error) {

0 commit comments

Comments
 (0)