Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,13 @@ func coalesceErrs(errs ...error) error {
}

var ErrInvalidOperation = errors.New("invalid operation")

type ErrInvalidFunctionArg struct {
Index int
Expected string
Got string
}

func (e *ErrInvalidFunctionArg) Error() string {
return fmt.Sprintf("arg %d: expected %s, got %s", e.Index, e.Expected, e.Got)
}
95 changes: 95 additions & 0 deletions functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package rulekit

import (
"fmt"
"strings"
)

type FunctionValue struct {
fn string
args *ArrayValue
}

func (f *FunctionValue) Eval(ctx *Ctx) Result {
fn, ok := StdlibFuncs[f.fn]
if !ok {
return Result{
Error: fmt.Errorf("unknown function %q", f.fn),
EvaluatedRule: f,
}
}

if len(fn.Args) != len(f.args.vals) {
return Result{
Error: fmt.Errorf("function %q expects %d arguments, got %d", f.fn, len(fn.Args), len(f.args.vals)),
EvaluatedRule: f,
}
}

argMap := make(map[string]any, len(f.args.vals))
for i, arg := range f.args.vals {
res := arg.Eval(ctx)
if !res.Ok() {
return res
}
argMap[fn.Args[i].Name] = res.Value
}
res := fn.Eval(argMap)
res.EvaluatedRule = f
return res
}

func (f *FunctionValue) String() string {
return f.fn + "(" + f.args.String() + ")"
}

func newFunctionValue(fn string, args []Rule) *FunctionValue {
argsArr := newArrayValue(args)
argsArr.raw = strings.TrimPrefix(argsArr.raw, "[")
argsArr.raw = strings.TrimSuffix(argsArr.raw, "]")
return &FunctionValue{
fn: fn,
args: argsArr,
}
}

func (f *FunctionValue) ValidateStdlibFnArgs() error {
if stdlibFn, ok := StdlibFuncs[f.fn]; ok {
if len(stdlibFn.Args) != len(f.args.vals) {
return fmt.Errorf("function %q expects %d arguments, got %d", f.fn, len(stdlibFn.Args), len(f.args.vals))
}
}
return nil
}

type Function struct {
// Args is an optional list of arguments that the function expects.
// If set, rulekit will ensure validity of the arguments and pass them as a named map to the Eval function.
Args []FunctionArg
// Eval is the function that will be called with the arguments.
// EvaluatedRule will be set by Rulekit.
Eval func(map[string]any) Result
}

type FunctionArg struct {
Name string
Type string
}

func IndexFnArg[T any](args map[string]any, idx int, name string) (T, error) {
var zeroVal T

valAny, ok := args[name]
if !ok {
return zeroVal, fmt.Errorf("unrecognized argument name %q", name)
}
val, ok := valAny.(T)
if !ok {
return zeroVal, &ErrInvalidFunctionArg{
Index: idx,
Expected: fmt.Sprintf("%T", valAny),
Got: fmt.Sprintf("%T", valAny),
}
}
return val, nil
}
29 changes: 29 additions & 0 deletions functions_stdlib.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package rulekit

import (
"fmt"
"strings"
)

var StdlibFuncs = map[string]*Function{
"starts_with": {
Args: []FunctionArg{
{Name: "value"},
{Name: "prefix"},
},
Eval: func(args map[string]any) Result {
value, err := IndexFnArg[any](args, 0, "value")
if err != nil {
return Result{Error: err}
}
prefix, err := IndexFnArg[any](args, 1, "prefix")
if err != nil {
return Result{Error: err}
}

return Result{
Value: strings.HasPrefix(fmt.Sprint(value), fmt.Sprint(prefix)),
}
},
},
}
41 changes: 41 additions & 0 deletions functions_stdlib_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package rulekit

import (
"errors"
"net"
"testing"

"github.com/stretchr/testify/require"
)

func TestFn_InvalidArgs(t *testing.T) {
_, err := Parse("some_none_stdlib_fn()")
require.NoError(t, err)

assertRulep(t, "unknown_fn()", nil).Error(errors.New(`unknown function "unknown_fn"`))
assertRulep(t, "unknown_fn(some_args)", nil).Error(errors.New(`unknown function "unknown_fn"`))
}

func TestFn_StartsWith(t *testing.T) {
r := MustParse(`starts_with(url, "https://")`)
assertRule(t, r, kv{"url": "https://example.com"}).Pass()
assertRule(t, r, kv{"url": "http://example.com"}).Fail()
assertRule(t, r, kv{"url": "invalid-url"}).Fail()

// non-string args
assertRulep(t, `starts_with(ip, "10.0")`, kv{"ip": net.ParseIP("10.0.0.1")}).Pass()
assertRulep(t, `starts_with(code, 5)`, kv{"code": 500}).Pass()
assertRulep(t, `starts_with(code, "5")`, kv{"code": 500}).Pass()
assertRulep(t, `starts_with(code, 5)`, kv{"code": 404}).Fail()
assertRulep(t, `starts_with(starts_with("https://example.com", "https://"), "true")`, nil).Pass()

// parser errors
assertParseErrorValue(t, "starts_with()", `syntax error at line 1:14:
starts_with()
^
function "starts_with" expects 2 arguments, got 0`)
assertParseErrorValue(t, "starts_with(arg1)", `syntax error at line 1:18:
starts_with(arg1)
^
function "starts_with" expects 2 arguments, got 1`)
}
Loading