From 081275ed4eeed9aebd28d881a45a1125d043ba0c Mon Sep 17 00:00:00 2001 From: Alexis Filipozzi Date: Fri, 21 Feb 2025 09:58:43 +0100 Subject: [PATCH] feat: add injector that support contextual injection --- .github/workflows/go.yml | 63 ++++ .gitignore | 1 + .golangci.yaml | 152 +++++++++ README.md | 52 +++ binding.go | 36 ++ condition.go | 29 ++ errors.go | 68 ++++ go.mod | 11 + go.sum | 10 + injector.go | 323 ++++++++++++++++++ injector_test.go | 687 +++++++++++++++++++++++++++++++++++++++ module.go | 235 +++++++++++++ scope.go | 199 ++++++++++++ scope_test.go | 295 +++++++++++++++++ special.go | 60 ++++ 15 files changed, 2221 insertions(+) create mode 100644 .github/workflows/go.yml create mode 100644 .gitignore create mode 100644 .golangci.yaml create mode 100644 binding.go create mode 100644 condition.go create mode 100644 errors.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 injector.go create mode 100644 injector_test.go create mode 100644 module.go create mode 100644 scope.go create mode 100644 scope_test.go create mode 100644 special.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..4adf30b --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,63 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Build + run: go build -v ./... + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Test + run: go test -coverprofile=coverage.out ./... + + - name: Upload coverage report + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage.out + fail_ci_if_error: true + + lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + + - name: Install golangci-lint + run: | + go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.64.5 + + - name: Run golangci-lint + run: $(go env GOPATH)/bin/golangci-lint run \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..f637590 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,152 @@ +linters-settings: + gosec: + excludes: + - G306 + - G115 + depguard: + # new configuration + rules: + logger: + deny: + # logging is allowed only by logutils.Log, + # logrus is allowed to use only in logutils package. + - pkg: "github.com/sirupsen/logrus" + desc: logging is allowed only by logutils.Log + dupl: + threshold: 100 + funlen: + lines: -1 # the number of lines (code + empty lines) is not a right metric and leads to code without empty line or one-liner. + statements: 75 + goconst: + min-len: 2 + min-occurrences: 3 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + disabled-checks: + - dupImport # https://github.com/go-critic/go-critic/issues/845 + - ifElseChain + - octalLiteral + - whyNoLint + gocyclo: + min-complexity: 15 + gofmt: + rewrite-rules: + - pattern: 'interface{}' + replacement: 'any' + goimports: + local-prefixes: github.com/golangci/golangci-lint + + govet: + settings: + printf: + funcs: + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf + enable: + - nilness + - shadow + errorlint: + asserts: false + lll: + line-length: 140 + nolintlint: + allow-unused: false # report any unused nolint directives + require-explanation: false # don't require an explanation for nolint directives + require-specific: false # don't require nolint directives to be specific about which linter is being skipped + revive: + rules: + - name: unexported-return + disabled: true + - name: unused-parameter + +linters: + disable-all: true + enable: + - bodyclose + - depguard + - dogsled + - dupl + - errcheck + - errorlint + - funlen + - gocheckcompilerdirectives + - gochecknoinits + - goconst + - gocritic + - gocyclo + - gofmt + - goimports + - goprintffuncname + - gosec + - gosimple + - govet + - ineffassign + - lll + - nakedret + - noctx + - nolintlint + - revive + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - unused + - whitespace + + # don't enable: + # - asciicheck + # - scopelint + # - gochecknoglobals + # - gocognit + # - godot + # - godox + # - goerr113 + # - interfacer + # - maligned + # - nestif + # - prealloc + # - testpackage + # - wsl + +issues: + # Excluding configuration per-path, per-linter, per-text and per-source + exclude-rules: + - path: pkg/golinters/errcheck.go + text: "SA1019: errCfg.Exclude is deprecated: use ExcludeFunctions instead" + - path: pkg/commands/run.go + text: "SA1019: lsc.Errcheck.Exclude is deprecated: use ExcludeFunctions instead" + - path: pkg/commands/run.go + text: "SA1019: e.cfg.Run.Deadline is deprecated: Deadline exists for historical compatibility and should not be used." + + - path: pkg/golinters/gofumpt.go + text: "SA1019: settings.LangVersion is deprecated: use the global `run.go` instead." + - path: pkg/golinters/staticcheck_common.go + text: "SA1019: settings.GoVersion is deprecated: use the global `run.go` instead." + - path: pkg/lint/lintersdb/manager.go + text: "SA1019: (.+).(GoVersion|LangVersion) is deprecated: use the global `run.go` instead." + - path: pkg/golinters/unused.go + text: "rangeValCopy: each iteration copies 160 bytes \\(consider pointers or indexing\\)" + - path: test/(fix|linters)_test.go + text: "string `gocritic.go` has 3 occurrences, make it a constant" + + # Due to a change inside go-critic v0.10.0, some reports have been removed, + # but as we run analysis with the previous version of golangci-lint this leads to a paradoxical situation. + # This exclusion will be removed when the next version of golangci-lint (v1.56.0) will be released. + - path: pkg/golinters/nolintlint/nolintlint.go + text: "hugeParam: (i|b) is heavy \\(\\d+ bytes\\); consider passing it by pointer" + exclude-dirs: + - test/testdata_etc # test files + - internal/cache # extracted from Go code + - internal/renameio # extracted from Go code + - internal/robustio # extracted from Go code + +run: + timeout: 5m diff --git a/README.md b/README.md index e69de29..47f0ad7 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,52 @@ +# go-inject + +## Description + +`go-inject` is a dependency injection library that support contextual scope. + +## Example + +```go +package myapp + +import ( + "context" + + "github.com/illuin-tech/goinject" +) + +type key int +const myScopeKey key = 0 + +const MyScope = "MyScope" + +// define function to declare your own scope in context +func WithMyScopeEnabled(ctx context.Context) context.Context { + return goinject.WithContextualScopeEnabled(ctx, myScopeKey) +} + +func ShutdownMyContextScoped(ctx context.Context) { + goinject.ShutdownContextualScope(ctx, myScopeKey) +} + +// define injection modules +var Module = goinject.Module("myModule", + goinject.RegisterScope(MyScope, goinject.NewContextualScope(myScopeKey)), + goinject.Provide(func() string { + return "Hello world from scope" + }, goinject.In(MyScope)), +) + +func main() { + ctx := context.Background() + + // enable scope + ctx = WithMyScopeEnabled(ctx) + defer ShutdownMyContextScoped(ctx) + + injector, _ := goinject.NewInjector(Module) + _ = injector.Invoke(ctx, func(hello string) { + println(hello) + }) +} +``` \ No newline at end of file diff --git a/binding.go b/binding.go new file mode 100644 index 0000000..5fee497 --- /dev/null +++ b/binding.go @@ -0,0 +1,36 @@ +package goinject + +import ( + "context" + "fmt" + "reflect" +) + +// binding defines a type mapped to a more concrete type +type binding struct { + typeof reflect.Type + provider reflect.Value + providedType reflect.Type + annotatedWith string + scope string + destroyMethod func(value reflect.Value) +} + +func (b *binding) create(ctx context.Context, injector *Injector) (reflect.Value, error) { + res, err := injector.callFunctionWithArgumentInstance(ctx, b.provider) + if err != nil { + return reflect.Value{}, + fmt.Errorf("failed to call provider function for type %q: %w", b.providedType.String(), err) + } + if b.provider.Type().NumOut() == 2 { + errValue := res[1].Interface() + if errValue != nil { + err, _ = errValue.(error) + } + } + if err != nil { + return res[0], fmt.Errorf("provider for type %q returned error: %w", b.providedType.String(), err) + } else { + return res[0], nil + } +} diff --git a/condition.go b/condition.go new file mode 100644 index 0000000..33a4d77 --- /dev/null +++ b/condition.go @@ -0,0 +1,29 @@ +package goinject + +import "os" + +type Conditional interface { + evaluate() bool +} + +type environmentVariableConditional struct { + name string + havingValue string + matchIfMissing bool +} + +func (c *environmentVariableConditional) evaluate() bool { + val, ok := os.LookupEnv(c.name) + if !ok { + return c.matchIfMissing + } + return val == c.havingValue +} + +func OnEnvironmentVariable(name, havingValue string, matchIfMissing bool) Conditional { + return &environmentVariableConditional{ + name: name, + havingValue: havingValue, + matchIfMissing: matchIfMissing, + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..afb2918 --- /dev/null +++ b/errors.go @@ -0,0 +1,68 @@ +package goinject + +import ( + "fmt" + "reflect" +) + +type invalidInputError struct { + message string +} + +var _ error = &invalidInputError{} + +func newInvalidInputError(msg string) *invalidInputError { + return &invalidInputError{msg} +} + +func (e *invalidInputError) Error() string { return e.message } + +type injectionError struct { + rType reflect.Type + annotation string + cause error +} + +var _ error = &injectionError{} + +func newInjectionError(typ reflect.Type, annotation string, cause error) *injectionError { + return &injectionError{typ, annotation, cause} +} + +func (e *injectionError) Error() string { + return fmt.Sprintf("Got error while resolving type %s (with annotation %q):\n%s", e.rType.String(), e.annotation, e.cause) +} + +func (e *injectionError) Unwrap() error { return e.cause } + +type contextScopedNotActiveError struct { +} + +var _ error = &contextScopedNotActiveError{} + +func newContextScopedNotActiveError() *contextScopedNotActiveError { + return &contextScopedNotActiveError{} +} + +func (e *contextScopedNotActiveError) Error() string { return "Scope is not active" } + +type injectorConfigurationError struct { + message string + cause error +} + +var _ error = &injectorConfigurationError{} + +func newInjectorConfigurationError(message string, cause error) *injectorConfigurationError { + return &injectorConfigurationError{message, cause} +} + +func (e *injectorConfigurationError) Error() string { + if e.cause == nil { + return e.message + } else { + return fmt.Sprintf("%s:\n%s", e.message, e.cause) + } +} + +func (e *injectorConfigurationError) Unwrap() error { return e.cause } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..364c9fb --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/illuin-tech/goinject + +go 1.24.0 + +require github.com/stretchr/testify v1.10.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..713a0b4 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/injector.go b/injector.go new file mode 100644 index 0000000..d4976ff --- /dev/null +++ b/injector.go @@ -0,0 +1,323 @@ +package goinject + +import ( + "context" + "fmt" + "reflect" + "strings" +) + +var errorReflectType = reflect.TypeFor[error]() +var invocationContextReflectType = reflect.TypeFor[InvocationContext]() + +// Injector defines bindings & scopes +type Injector struct { + bindings map[reflect.Type]map[string][]*binding // list of available bindings by type and annotations + scopes map[string]Scope // Scope by names + singletonScope *singletonScope +} + +// NewInjector builds up a new Injector out of a list of Modules with singleton scope +func NewInjector(options ...Option) (*Injector, error) { + mod := &configuration{ + bindings: make(map[*binding]bool), + scopes: make(map[string]Scope), + } + + for _, o := range options { + err := o.apply(mod) + if err != nil { + return nil, err + } + } + + singletonScope := newSingletonScope() + mod.scopes[Singleton] = singletonScope + mod.scopes[PerLookUp] = newPerLookUpScope() + + injector := &Injector{ + bindings: make(map[reflect.Type]map[string][]*binding), + scopes: make(map[string]Scope), + singletonScope: singletonScope, + } + + injectorType := reflect.TypeFor[*Injector]() + injectorBinding := &binding{ + typeof: injectorType, + provider: reflect.ValueOf(func() *Injector { return injector }), + providedType: injectorType, + scope: Singleton, + } + + injector.scopes = mod.scopes + for b := range mod.bindings { + _, ok := injector.bindings[b.typeof] + if !ok { + injector.bindings[b.typeof] = make(map[string][]*binding) + } + injector.bindings[b.typeof][b.annotatedWith] = append(injector.bindings[b.typeof][b.annotatedWith], b) + } + + injector.bindings[injectorType] = make(map[string][]*binding) + injector.bindings[injectorType][""] = []*binding{injectorBinding} + + err := injector.eagerlyCreateSingletons() + if err != nil { + return nil, err + } + return injector, nil +} + +// Shutdown clear underlying singleton scope +func (injector *Injector) Shutdown() { + injector.singletonScope.Shutdown() + injector.bindings = make(map[reflect.Type]map[string][]*binding) + injector.scopes = make(map[string]Scope) +} + +// Invoke will execute the parameter function (which must be a function that optionally can return an error). +// argument of function will be resolved by the injector using configured providers & scope. +func (injector *Injector) Invoke(ctx context.Context, function any) error { + if function == nil { + return newInvalidInputError("can't invoke on nil") + } + fvalue := reflect.ValueOf(function) + ftype := fvalue.Type() + if ftype.Kind() != reflect.Func { + return newInvalidInputError( + fmt.Sprintf("can't invoke non-function %v (type %v)", function, ftype)) + } + + if ftype.NumOut() > 1 || (ftype.NumOut() == 1 && !ftype.Out(0).AssignableTo(errorReflectType)) { + return newInvalidInputError("can't invoke on function whose return type is not error or no return type") + } + + res, err := injector.callFunctionWithArgumentInstance(ctx, fvalue) + if err != nil { + return fmt.Errorf("failed to call invokation function: %w", err) + } + if ftype.NumOut() == 1 { + invokationError := res[0].Interface().(error) + if invokationError != nil { + return fmt.Errorf("invokation returned error: %w", invokationError) + } + } + return nil +} + +func (injector *Injector) eagerlyCreateSingletons() error { + for _, bindingsByAnnotation := range injector.bindings { + for _, bindingList := range bindingsByAnnotation { + for _, b := range bindingList { + if b.scope == Singleton { + _, err := injector.getScopedInstanceFromBinding(nil, b) //nolint:staticcheck + if err != nil { + return fmt.Errorf("failed to get singleton instance: %w", err) + } + } + } + } + } + return nil +} + +func (injector *Injector) callFunctionWithArgumentInstance( + ctx context.Context, + fValue reflect.Value, +) ([]reflect.Value, error) { + fType := fValue.Type() + in := make([]reflect.Value, fType.NumIn()) + var err error + for i := 0; i < fType.NumIn(); i++ { + if in[i], err = injector.getFunctionArgumentInstance(ctx, fType.In(i)); err != nil { + return []reflect.Value{}, fmt.Errorf("failed to resolve function argument #%d: %w", i, err) + } + } + + res := fValue.Call(in) + return res, nil +} + +func (injector *Injector) getFunctionArgumentInstance(ctx context.Context, argType reflect.Type) (reflect.Value, error) { + if EmbedsParams(argType) { + return injector.createEmbeddedParams(ctx, argType) + } else { + return injector.getInstanceOfAnnotatedType(ctx, argType, "", false) + } +} + +func (injector *Injector) createEmbeddedParams(ctx context.Context, embeddedType reflect.Type) (reflect.Value, error) { + if embeddedType.Kind() == reflect.Ptr { + n := reflect.New(embeddedType.Elem()) + return n, injector.setParamFields(ctx, n.Elem()) + } else { // struct + n := reflect.New(embeddedType).Elem() + return n, injector.setParamFields(ctx, n) + } +} + +func (injector *Injector) setParamFields( + ctx context.Context, + paramValue reflect.Value, +) error { + embeddedType := paramValue.Type() + for fieldIndex := 0; fieldIndex < embeddedType.NumField(); fieldIndex++ { + field := paramValue.Field(fieldIndex) + if field.Type() == _paramType { + continue + } + if tag, ok := embeddedType.Field(fieldIndex).Tag.Lookup("inject"); ok { + if !field.CanSet() { + return newInjectionError(field.Type(), tag, fmt.Errorf("use inject tag on unsettable field")) + } + + var optional bool + for _, option := range strings.Split(tag, ",") { + if strings.TrimSpace(option) == "optional" { + optional = true + } + } + tag = strings.Split(tag, ",")[0] + + instance, err := injector.getInstanceOfAnnotatedType(ctx, field.Type(), tag, optional) + if err != nil { + return newInjectionError(field.Type(), tag, err) + } + if instance.IsValid() { + field.Set(instance) + } else if optional { + continue + } else { + return newInjectionError(field.Type(), tag, fmt.Errorf("cannot get valid instance from scope")) + } + } + } + return nil +} + +// getInstanceOfAnnotatedType resolves a type request within the injector +func (injector *Injector) getInstanceOfAnnotatedType( + ctx context.Context, + t reflect.Type, + annotation string, + optional bool, +) (reflect.Value, error) { + // if is slice, return as multi bindings + if t.Kind() == reflect.Slice { + bindings := injector.findBindingsForAnnotatedType(t.Elem(), annotation) + if len(bindings) > 0 { + n := reflect.MakeSlice(t, 0, len(bindings)) + for _, binding := range bindings { + r, err := injector.getScopedInstanceFromBinding(ctx, binding) + if err != nil { + return reflect.Value{}, err + } + n = reflect.Append(n, r) + } + return n, nil + } else if optional { + return reflect.MakeSlice(t, 0, 0), nil + } else { + return reflect.MakeSlice(t, 0, 0), newInjectionError(t.Elem(), annotation, + fmt.Errorf("did not found binding, expected at least one")) + } + } + + // check if there is a binding for this type & annotation + bindings := injector.findBindingsForAnnotatedType(t, annotation) + if len(bindings) > 1 { + return reflect.Value{}, + newInjectionError(t, annotation, fmt.Errorf("found multiple bindings expected one")) + } else if len(bindings) == 1 { + return injector.getScopedInstanceFromBinding(ctx, bindings[0]) + } else if injector.isProviderType(t) { + return injector.createProviderValue(t, annotation, optional), nil + } else if t == invocationContextReflectType { + return reflect.ValueOf(ctx), nil + } else if optional { + return reflect.Value{}, nil + } else { + return reflect.Value{}, + newInjectionError(t, annotation, fmt.Errorf("did not found binding, expected one")) + } +} + +func (injector *Injector) isProviderType(t reflect.Type) bool { + return t.Kind() == reflect.Func && + t.NumIn() == 1 && t.In(0) == invocationContextReflectType && + t.NumOut() == 2 && t.Out(1) == errorReflectType +} + +func (injector *Injector) createProviderValue( + t reflect.Type, + annotation string, + optional bool, +) reflect.Value { + bindingType := t.Out(0) + return reflect.MakeFunc(t, func(args []reflect.Value) (results []reflect.Value) { + ctx := args[0].Interface().(context.Context) + instance, err := injector.getInstanceOfAnnotatedType(ctx, bindingType, annotation, optional) + var instanceVal reflect.Value + if instance.IsValid() { + instanceVal = instance + } else { + instanceVal = reflect.Zero(bindingType) + } + var errVal reflect.Value + if err != nil { + errVal = reflect.ValueOf(err) + } else { + errVal = reflect.Zero(errorReflectType) + } + return []reflect.Value{ + instanceVal, + errVal, + } + }) +} + +func (injector *Injector) findBindingsForAnnotatedType( + t reflect.Type, + annotation string, +) []*binding { + if _, ok := injector.bindings[t]; ok && len(injector.bindings[t][annotation]) > 0 { + bindings := injector.bindings[t][annotation] + res := make([]*binding, len(bindings)) + copy(res, bindings) + return res + } + + return []*binding{} +} + +func (injector *Injector) getScopedInstanceFromBinding( + ctx context.Context, + binding *binding, +) (reflect.Value, error) { + scope, err := injector.getScopeFromBinding(binding) + if err != nil { + return reflect.Value{}, err + } + val, err := scope.ResolveBinding(ctx, binding, func() (Instance, error) { + val, creationError := binding.create(ctx, injector) + destroyMethod := binding.destroyMethod + if creationError == nil && destroyMethod != nil && !val.IsZero() { + scope.RegisterDestructionCallback( + ctx, + func() { destroyMethod(val) }, + ) + } + return Instance(val), creationError + }) + return reflect.Value(val), err +} + +func (injector *Injector) getScopeFromBinding( + binding *binding, +) (Scope, error) { + if scope, ok := injector.scopes[binding.scope]; ok { + return scope, nil + } + return nil, newInjectionError( + binding.typeof, binding.annotatedWith, fmt.Errorf("unknown scope %q for binding", binding.scope)) +} diff --git a/injector_test.go b/injector_test.go new file mode 100644 index 0000000..1f73eb5 --- /dev/null +++ b/injector_test.go @@ -0,0 +1,687 @@ +package goinject + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type Parent struct { +} + +type Child struct { + parent *Parent +} + +func TestShouldReturnFromProvider(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Parent { return &Parent{} }), + Provide(func(parent *Parent) *Child { return &Child{parent: parent} }), + ) + assert.Nil(t, err) + ctx := context.Background() + var parent *Parent + err = injector.Invoke(ctx, func(p *Parent) { + parent = p + }) + assert.Nil(t, err) + var child *Child + err = injector.Invoke(ctx, func(c *Child) { + child = c + }) + assert.Nil(t, err) + assert.Same(t, parent, child.parent) + }) +} + +func TestProvideShouldAcceptErrorReturnProviders(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() (*Parent, error) { return &Parent{}, nil }, In(PerLookUp)), + Provide(func(_ *Parent) (*Child, error) { return nil, fmt.Errorf("failed to create child") }, In(PerLookUp)), + ) + assert.Nil(t, err) + ctx := context.Background() + t.Run("And return type if no error", func(t *testing.T) { + err = injector.Invoke(ctx, func(parent *Parent) { + assert.NotNil(t, parent) + }) + assert.Nil(t, err) + }) + t.Run("And return error otherwise", func(t *testing.T) { + err = injector.Invoke(ctx, func(_ *Child) { + assert.Fail(t, "should not be reached") + }) + assert.ErrorContains(t, err, "failed to create child") + }) + }) +} + +func TestUseUnknownScopeShouldReturnError(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Parent { return &Parent{} }, In("unknown")), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ *Parent) { + assert.Fail(t, "should not be reached") + }) + assert.ErrorContains(t, err, "unknown scope \"unknown\" for binding") + }) +} + +type TestInvokeParamOptional struct { + Params + ParentA *Parent `inject:", optional"` + ParentB *Parent `inject:"B"` +} + +func TestInvokeWithOptional(t *testing.T) { + assert.NotPanics(t, func() { + t.Run("using param struct argument", func(t *testing.T) { + injector, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, Named("B")), + ) + assert.Nil(t, err) + var parentA *Parent + var parentB *Parent + ctx := context.Background() + err = injector.Invoke(ctx, func(param TestInvokeParamOptional) { + parentA = param.ParentA + parentB = param.ParentB + }) + assert.Nil(t, err) + assert.Nil(t, parentA) + assert.NotNil(t, parentB) + }) + + t.Run("using param pointer argument", func(t *testing.T) { + injector, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, Named("B")), + ) + assert.Nil(t, err) + var parentA *Parent + var parentB *Parent + ctx := context.Background() + err = injector.Invoke(ctx, func(param *TestInvokeParamOptional) { + parentA = param.ParentA + parentB = param.ParentB + }) + assert.Nil(t, err) + assert.Nil(t, parentA) + assert.NotNil(t, parentB) + }) + }) +} + +type Color struct { + name string +} + +type TestInvokeParamAnnotated struct { + Params + Color *Color `inject:"red"` +} + +func TestInvokeWithAnnotation(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Color { return &Color{name: "red"} }, Named("red")), + Provide(func() *Color { return &Color{name: "blue"} }, Named("blue")), + ) + assert.Nil(t, err) + var color *Color + ctx := context.Background() + err = injector.Invoke(ctx, func(param TestInvokeParamAnnotated) { + color = param.Color + }) + assert.NotNil(t, color) + assert.Equal(t, "red", color.name) + assert.Nil(t, err) + }) +} + +func TestInvokeShouldReturnErrorIfExpectedSingleBindingButMultipleFound(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Color { return &Color{name: "blue"} }), + Provide(func() *Color { return &Color{name: "red"} }), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ *Color) { + assert.Fail(t, "should not be reached") + }) + assert.NotNil(t, err) + // verify error tree contains an injection error + var expectedErrorType *injectionError + assert.ErrorAs(t, err, &expectedErrorType) + assert.Equal(t, + "failed to call invokation function: failed to resolve"+ + " function argument #0: Got error while resolving type *goinject.Color"+ + " (with annotation \"\"):\nfound multiple bindings expected one", + err.Error()) + }) +} + +type Red *Color +type Blue *Color + +func TestInvokeUsingTypeDefinition(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Color { return &Color{name: "blue"} }, As(Type[Blue]())), + Provide(func() *Color { return &Color{name: "red"} }, As(Type[Red]())), + ) + assert.Nil(t, err) + var color Red + ctx := context.Background() + err = injector.Invoke(ctx, func(c Red) { + color = c + }) + assert.NotNil(t, color) + assert.Equal(t, "red", color.name) + assert.Nil(t, err) + }) +} + +func TestInstallModuleShouldInstallBindingsOnce(t *testing.T) { + assert.NotPanics(t, func() { + subModule := Module("sub", Provide(func() *Parent { + return &Parent{} + }, Named("parent-in-sub"))) + parentModuleA := Module("parent-a", subModule) + parentModuleB := Module("parent-b", subModule) + injector, err := NewInjector( + parentModuleA, + parentModuleB, + ) + assert.Nil(t, err) + assert.NotNil(t, injector) + assert.Equal(t, 1, len(injector.bindings[reflect.TypeFor[*Parent]()])) + assert.Equal(t, 2, len(injector.bindings)) // we add a binding for *Injector + }) +} + +type Shape interface { + Name() string +} + +type Rectangle struct { +} + +func (r *Rectangle) Name() string { + return "rectangle" +} + +type Square struct { +} + +func (s *Square) Name() string { + return "square" +} + +func TestBindToInterface(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Rectangle { + return &Rectangle{} + }, As(Type[Shape]())), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(s Shape) { + assert.IsType(t, &Rectangle{}, s) + }) + assert.Nil(t, err) + }) +} + +func TestInjectorShouldBeProvided(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(i *Injector) { + assert.Same(t, i, injector) + }) + assert.Nil(t, err) + }) +} + +type WithRefCount struct { + refCount int +} + +func TestInjectorShutdownShouldShutdownSingletonScope(t *testing.T) { + assert.NotPanics(t, func() { + refCount := 0 + injector, err := NewInjector( + Provide(func() *WithRefCount { + res := &WithRefCount{refCount: refCount} + refCount++ + return res + }, WithDestroy(func(_ *WithRefCount) { + refCount-- + }), In(Singleton)), + ) + assert.Nil(t, err) + ctx := context.Background() + + // singleton should be created eagerly + assert.Equal(t, 1, refCount) + + err = injector.Invoke(ctx, func(c *WithRefCount) { + assert.Equal(t, 1, refCount) + assert.Equal(t, 0, c.refCount) + }) + + assert.Nil(t, err) + assert.Equal(t, 1, refCount) + injector.Shutdown() + assert.Equal(t, 0, refCount) + assert.Equal(t, 0, len(injector.bindings)) + }) +} + +func TestNewInjectorShouldReturnErrorIfEagerlyCreatedSingletonReturnError(t *testing.T) { + returnedErr := fmt.Errorf("provider error") + assert.NotPanics(t, func() { + _, err := NewInjector( + Provide(func() (*WithRefCount, error) { + return nil, returnedErr + }), + ) + assert.ErrorIs(t, err, returnedErr) + assert.Equal(t, "failed to get singleton instance: provider for type \"*goinject.WithRefCount\" "+ + "returned error: provider error", err.Error()) + }) +} + +type MultiBindOptionalInvokeParams struct { + Params + Shapes []Shape `inject:",optional"` +} + +func TestMultiBind(t *testing.T) { + t.Run("Using multiple interface implementation", func(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *Rectangle { + return &Rectangle{} + }, As(Type[Shape]())), + Provide(func() *Square { + return &Square{} + }, As(Type[Shape]())), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(shapes []Shape) { + var names []string + for _, shape := range shapes { + names = append(names, shape.Name()) + } + assert.Contains(t, names, "square") + assert.Contains(t, names, "rectangle") + }) + assert.Nil(t, err) + }) + }) + + t.Run("Should not throw error if not found and optional", func(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(params MultiBindOptionalInvokeParams) { + assert.Empty(t, params.Shapes) + }) + assert.Nil(t, err) + }) + }) + + t.Run("Should throw error if not found and not optional", func(t *testing.T) { + assert.NotPanics(t, func() { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ []Shape) { + assert.Fail(t, "should not be reached") + }) + assert.NotNil(t, err) + var expectedErrorType *injectionError + assert.ErrorAs(t, err, &expectedErrorType) + assert.Equal(t, "failed to call invokation function: failed to resolve function argument #0: "+ + "Got error while resolving type goinject.Shape (with annotation \"\"):\n"+ + "did not found binding, expected at least one", err.Error()) + }) + }) +} + +type WithProvider struct { + provider Provider[*WithRefCount] +} + +type WithProviderParam struct { + Params + Provider Provider[*WithRefCount] `inject:",optional"` +} + +func TestProvider(t *testing.T) { + assert.NotPanics(t, func() { + t.Run("Get from provider should re-ask scope (with per-lookup)", func(t *testing.T) { + refCount := 0 + injector, rootError := NewInjector( + Provide(func() *WithRefCount { + res := &WithRefCount{refCount: refCount} + refCount++ + return res + }, In(PerLookUp)), + Provide(func(p Provider[*WithRefCount]) *WithProvider { + return &WithProvider{ + provider: p, + } + }), + ) + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w *WithProvider) { + ref1, err := w.provider(ctx) + assert.Nil(t, err) + ref2, err := w.provider(ctx) + assert.Nil(t, err) + assert.NotEqual(t, ref2, ref1) + assert.Equal(t, 0, ref1.refCount) + assert.Equal(t, 1, ref2.refCount) + }) + assert.Nil(t, rootError) + }) + + t.Run("Get from provider should re-ask scope (with singleton)", func(t *testing.T) { + refCount := 0 + injector, rootError := NewInjector( + Provide(func() *WithRefCount { + res := &WithRefCount{refCount: refCount} + refCount++ + return res + }, In(Singleton)), + Provide(func(p Provider[*WithRefCount]) *WithProvider { + return &WithProvider{ + provider: p, + } + }), + ) + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w *WithProvider) { + ref1, err := w.provider(ctx) + assert.Nil(t, err) + ref2, err := w.provider(ctx) + assert.Nil(t, err) + assert.Same(t, ref2, ref1) + }) + assert.Nil(t, rootError) + }) + + t.Run("Provider with optional should return zero value if not present", func(t *testing.T) { + injector, rootError := NewInjector() + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w WithProviderParam) { + ref, err := w.Provider(ctx) + assert.Nil(t, err) + assert.Nil(t, ref) + }) + assert.Nil(t, rootError) + }) + + t.Run("Provider should return error", func(t *testing.T) { + injector, rootError := NewInjector( + Provide(func() (*WithRefCount, error) { + return nil, fmt.Errorf("test error") + }, In(PerLookUp)), + Provide(func(p Provider[*WithRefCount]) *WithProvider { + return &WithProvider{ + provider: p, + } + }, In(PerLookUp)), + ) + assert.Nil(t, rootError) + ctx := context.Background() + + rootError = injector.Invoke(ctx, func(w *WithProvider) { + ref, err := w.provider(ctx) + assert.Nil(t, ref) + assert.NotNil(t, err) + assert.Equal(t, "provider for type \"*goinject.WithRefCount\" returned error: test error", err.Error()) + }) + assert.Nil(t, rootError) + }) + }) +} + +func TestConditional(t *testing.T) { + t.Run("Test conditional env var should not register binding if no match", func(t *testing.T) { + t.Setenv("TEST", "CASE-KO") + injector, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OK", false), + Provide(func() (*Parent, error) { return &Parent{}, nil }), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(_ *Parent) { + assert.Fail(t, "inaccessible") + }) + assert.NotNil(t, err) + var expectedErrorType *injectionError + assert.ErrorAs(t, err, &expectedErrorType) + assert.Equal(t, + "failed to call invokation function: failed to resolve function argument #0: "+ + "Got error while resolving type *goinject.Parent (with annotation \"\"):\ndid not found binding, "+ + "expected one", + err.Error(), + ) + }) + + t.Run("Test conditional env var should register binding if match", func(t *testing.T) { + t.Setenv("TEST", "CASE-OK") + injector, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OK", false), + Provide(func() (*Parent, error) { return &Parent{}, nil }), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(parent *Parent) { + assert.NotNil(t, parent) + }) + assert.Nil(t, err) + }) + + t.Run("Test conditional env var should register binding if no match but match missing", func(t *testing.T) { + injector, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OO", true), + Provide(func() (*Parent, error) { return &Parent{}, nil }), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func(parent *Parent) { + assert.NotNil(t, parent) + }) + assert.Nil(t, err) + }) + + t.Run("Test When should return binding configuration errors", func(t *testing.T) { + _, err := NewInjector( + When(OnEnvironmentVariable("TEST", "CASE-OK", true), + Provide(nil), + ), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "cannot accept nil provider", err.Error()) + }) +} + +func TestInvokeError(t *testing.T) { + t.Run("Invoke should not accept nil", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, nil) + assert.NotNil(t, err) + assert.IsType(t, err, &invalidInputError{}) + assert.Equal(t, "can't invoke on nil", err.Error()) + }) + + t.Run("Invoke should only accept function", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, true) + assert.NotNil(t, err) + assert.IsType(t, err, &invalidInputError{}) + assert.Equal(t, "can't invoke non-function true (type bool)", err.Error()) + }) + + t.Run("Invoke should only accept function returning error", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + err = injector.Invoke(ctx, func() *Parent { return nil }) + assert.NotNil(t, err) + assert.IsType(t, err, &invalidInputError{}) + assert.Equal(t, "can't invoke on function whose return type is not error or no return type", err.Error()) + }) + + t.Run("Invoke should return error if function return error", func(t *testing.T) { + injector, err := NewInjector() + assert.Nil(t, err) + ctx := context.Background() + invokationFnReturnedError := fmt.Errorf("returned error") + err = injector.Invoke(ctx, func() error { return invokationFnReturnedError }) + assert.NotNil(t, err) + assert.ErrorIs(t, err, invokationFnReturnedError) + }) +} + +func TestInjectorConfigurationError(t *testing.T) { + t.Run("Provide cannot accept nil", func(t *testing.T) { + _, err := NewInjector( + Provide(nil)) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "cannot accept nil provider", err.Error()) + }) + + t.Run("Provider should use function as argument", func(t *testing.T) { + _, err := NewInjector( + Provide(true)) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "provider argument should be a function", err.Error()) + }) + + t.Run("Provider function should return an instance", func(t *testing.T) { + _, err := NewInjector( + Provide(func() {})) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "expected a function that return an instance and optionally an error", err.Error()) + }) + + t.Run("Provider function cannot return multiple types (except error)", func(t *testing.T) { + _, err := NewInjector( + Provide(func() (*Parent, *Child) { + return &Parent{}, &Child{} + })) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "second return type of provider should be an error", err.Error()) + }) + + t.Run("Module should return nested errors", func(t *testing.T) { + _, err := NewInjector( + Module("test.Module", + Provide(nil)), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, "error while installing module test.Module:\ncannot accept nil provider", err.Error()) + }) + + t.Run("As provider annotation should raise error if not assignable", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, As(Type[*Child]())), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\ncannot assign "+ + "*goinject.Parent to *goinject.Child as specified in As argument", + err.Error(), + ) + }) + + t.Run("WithDestroy should raise an error if not a function", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, WithDestroy(true)), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\nargument of WithDestroy"+ + " must be a function with one argument returning void", + err.Error(), + ) + }) + + t.Run("WithDestroy should raise an error if not a function of provided type", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, WithDestroy(func(_ *Child) {})), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\nargument of WithDestroy"+ + " must be a function with one argument returning void", + err.Error(), + ) + }) + + t.Run("WithDestroy should raise an error if not a void function of provided type", func(t *testing.T) { + _, err := NewInjector( + Provide(func() *Parent { + return &Parent{} + }, WithDestroy(func(_ *Parent) error { + return nil + })), + ) + assert.NotNil(t, err) + assert.IsType(t, err, &injectorConfigurationError{}) + assert.Equal(t, + "got error while configuring provider for provided type *goinject.Parent:\nargument of WithDestroy "+ + "must be a function with one argument returning void", err.Error(), + ) + }) +} diff --git a/module.go b/module.go new file mode 100644 index 0000000..98fb697 --- /dev/null +++ b/module.go @@ -0,0 +1,235 @@ +package goinject + +import ( + "fmt" + "reflect" +) + +type configuration struct { + bindings map[*binding]bool + scopes map[string]Scope +} + +// Option enable to configure the given injector +type Option interface { + apply(*configuration) error +} + +type moduleOption struct { + name string + options []Option +} + +func (o *moduleOption) apply(mod *configuration) error { + for _, opt := range o.options { + err := opt.apply(mod) + if err != nil { + return newInjectorConfigurationError( + fmt.Sprintf("error while installing module %s", o.name), err) + } + } + return nil +} + +// Module group a list of Option in order to easily reuse them. +// the Module name is used in error when applying Option to easily find misconfigured options. +func Module(name string, opts ...Option) Option { + mo := &moduleOption{ + name: name, + options: opts, + } + return mo +} + +type provideOption struct { + constructor any + annotations []Annotation +} + +func (o *provideOption) apply(mod *configuration) error { + if o.constructor == nil { + return newInjectorConfigurationError("cannot accept nil provider", nil) + } + providerFncValue := reflect.ValueOf(o.constructor) + fncType := providerFncValue.Type() + if fncType.Kind() != reflect.Func { + return newInjectorConfigurationError("provider argument should be a function", nil) + } + if fncType.NumOut() > 2 || fncType.NumOut() == 0 { + return newInjectorConfigurationError("expected a function that return an instance and optionally an error", nil) + } + if fncType.NumOut() == 2 && !fncType.Out(1).AssignableTo(reflect.TypeOf(new(error)).Elem()) { + return newInjectorConfigurationError("second return type of provider should be an error", nil) + } + b := &binding{} + b.provider = providerFncValue + b.providedType = fncType.Out(0) + b.typeof = b.providedType + b.scope = Singleton + + for _, a := range o.annotations { + err := a.apply(b) + if err != nil { + return newInjectorConfigurationError( + fmt.Sprintf("got error while configuring provider for provided type %s", b.providedType), + err, + ) + } + } + + mod.bindings[b] = true + return nil +} + +// Provide define a binding from a function constructor that must return the provided instance (and optionally an error) +// arguments of the constructor parameter will be resolved by the injector itself. +// Provide enable to annotate the created binding using Annotation +func Provide(constructor any, annotations ...Annotation) Option { + return &provideOption{ + constructor: constructor, + annotations: annotations, + } +} + +type registerScopeOption struct { + name string + scope Scope +} + +func (o *registerScopeOption) apply(mod *configuration) error { + mod.scopes[o.name] = o.scope + return nil +} + +// RegisterScope register a new Scope with a name +func RegisterScope(name string, scope Scope) Option { + return ®isterScopeOption{ + name: name, + scope: scope, + } +} + +type whenOption struct { + condition Conditional + options []Option +} + +func (o *whenOption) apply(mod *configuration) error { + if o.condition.evaluate() { + for _, opt := range o.options { + if err := opt.apply(mod); err != nil { + return err + } + } + } + + return nil +} + +// When enable to group a list of Option that will be applied only if the given Conditional evaluate to true +func When(condition Conditional, options ...Option) Option { + return &whenOption{ + condition: condition, + options: options, + } +} + +// Annotation are used to configured bindings created by the Provide function +type Annotation interface { + apply(b *binding) error +} + +type asAnnotation struct { + target AsType +} + +func (a *asAnnotation) apply(b *binding) error { + targetType := a.target.getType() + if !b.providedType.AssignableTo(targetType) { + return newInjectorConfigurationError( + fmt.Sprintf("cannot assign %s to %s as specified in As argument", b.providedType, targetType), + nil, + ) + } + b.typeof = targetType + return nil +} + +// AsType is used in As function as an argument to register a provided type to another given assignable type +type AsType interface { + getType() reflect.Type +} + +type typeFor[T any] struct { +} + +func (t *typeFor[T]) getType() reflect.Type { + return reflect.TypeFor[T]() +} + +// Type return an AsType for a given type T +func Type[T any]() AsType { + return &typeFor[T]{} +} + +// As return an annotation that is used to override the binding registration type. +// Use it to bind a concrete type to an interface. +func As(target AsType) Annotation { + return &asAnnotation{target: target} +} + +type nameAnnotation struct { + name string +} + +func (a *nameAnnotation) apply(b *binding) error { + b.annotatedWith = a.name + return nil +} + +// Named return an annotation that is used to define the binding annotation name. +func Named(name string) Annotation { + return &nameAnnotation{name: name} +} + +type inAnnotation struct { + scope string +} + +// In return an annotation that is used to define the binding scope +func In(scope string) Annotation { + return &inAnnotation{scope: scope} +} + +func (a *inAnnotation) apply(b *binding) error { + b.scope = a.scope + return nil +} + +type withDestroyAnnotation struct { + destroyMethod any +} + +func (a *withDestroyAnnotation) apply(b *binding) error { + destroyMethodFnVal := reflect.ValueOf(a.destroyMethod) + if destroyMethodFnVal.Kind() != reflect.Func || + destroyMethodFnVal.Type().NumIn() != 1 || + destroyMethodFnVal.Type().In(0) != b.providedType || + destroyMethodFnVal.Type().NumOut() != 0 { + return newInjectorConfigurationError( + "argument of WithDestroy must be a function with one argument returning void", + nil, + ) + } + b.destroyMethod = func(val reflect.Value) { + destroyMethodFnVal.Call([]reflect.Value{val}) + } + return nil +} + +// WithDestroy return an annotation that declare a destroyMethod that will be used when closing a scope +func WithDestroy(destroyMethod any) Annotation { + return &withDestroyAnnotation{ + destroyMethod: destroyMethod, + } +} diff --git a/scope.go b/scope.go new file mode 100644 index 0000000..9c3d84a --- /dev/null +++ b/scope.go @@ -0,0 +1,199 @@ +package goinject + +import ( + "context" + "reflect" + "sync" +) + +// Instance is the return type for Scope ResolveBinding method. +// It is used to hidde the usage of reflect.Value in the public API +type Instance reflect.Value + +type instanceRegistry struct { + mu sync.Mutex // lock guarding instanceLock + instanceLock map[*binding]*sync.RWMutex // lock guarding instances + instances sync.Map + destroyMethodsLock sync.Mutex + destroyMethods []func() +} + +func (r *instanceRegistry) resolveBinding( + binding *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + r.mu.Lock() + + if l, ok := r.instanceLock[binding]; ok { + r.mu.Unlock() + l.RLock() + defer l.RUnlock() + + instance, _ := r.instances.Load(binding) + return instance.(Instance), nil + } + + r.instanceLock[binding] = new(sync.RWMutex) + l := r.instanceLock[binding] + l.Lock() + r.mu.Unlock() + + instance, err := instanceCreator() + r.instances.Store(binding, instance) + + defer l.Unlock() + + return instance, err +} + +func (r *instanceRegistry) registerDestructionCallback( + destroyCallback func(), +) { + r.destroyMethodsLock.Lock() + defer r.destroyMethodsLock.Unlock() + r.destroyMethods = append(r.destroyMethods, destroyCallback) +} + +func (r *instanceRegistry) shutdown() { + r.destroyMethodsLock.Lock() + defer r.destroyMethodsLock.Unlock() + + for i := len(r.destroyMethods) - 1; i >= 0; i-- { + r.destroyMethods[i]() + } + + r.destroyMethods = []func(){} +} + +func newInstanceRegistry() *instanceRegistry { + return &instanceRegistry{ + instanceLock: make(map[*binding]*sync.RWMutex), + destroyMethods: []func(){}, + } +} + +// Scope defines a scope's behaviour +type Scope interface { + // ResolveBinding resolve a dependency injection context for current scope + ResolveBinding( + ctx context.Context, + binding *binding, + instanceCreator func() (Instance, error), + ) (Instance, error) + + // RegisterDestructionCallback register a destruction callback. It is the responsibility of the Scope to call + // this callback when destroying the Scope + RegisterDestructionCallback( + ctx context.Context, + destroyCallback func(), + ) +} + +const PerLookUp = "inject.PerLookUp" + +// perLookUpScope is a Scope that return a new instance when requested +type perLookUpScope struct { +} + +var _ Scope = new(perLookUpScope) + +func newPerLookUpScope() Scope { + return &perLookUpScope{} +} + +func (s *perLookUpScope) ResolveBinding( + _ context.Context, + _ *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + return instanceCreator() +} + +func (s *perLookUpScope) RegisterDestructionCallback( + _ context.Context, + _ func(), +) { + // nothing to do, per lookup provided need to close destroy method themselves +} + +const Singleton = "inject.Singleton" + +// singletonScope is our Scope to handle Singletons +type singletonScope struct { + instanceRegistry *instanceRegistry +} + +var _ Scope = new(singletonScope) + +func newSingletonScope() *singletonScope { + return &singletonScope{ + instanceRegistry: newInstanceRegistry(), + } +} + +func (s *singletonScope) ResolveBinding( + _ context.Context, + binding *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + return s.instanceRegistry.resolveBinding(binding, instanceCreator) +} + +func (s *singletonScope) RegisterDestructionCallback( + _ context.Context, + destroyCallback func(), +) { + s.instanceRegistry.registerDestructionCallback(destroyCallback) +} + +func (s *singletonScope) Shutdown() { + s.instanceRegistry.shutdown() +} + +// contextualScope is an abstract scope to handle context attached scoped (request, session, ...) +type contextualScope struct { + key any +} + +var _ Scope = new(contextualScope) + +func (s *contextualScope) ResolveBinding( + ctx context.Context, + binding *binding, + instanceCreator func() (Instance, error), +) (Instance, error) { + if ctx == nil { + return Instance{}, newContextScopedNotActiveError() + } + scopeHolder, ok := ctx.Value(s.key).(*instanceRegistry) + if !ok { + return Instance{}, newContextScopedNotActiveError() + } + return scopeHolder.resolveBinding(binding, instanceCreator) +} + +func (s *contextualScope) RegisterDestructionCallback( + ctx context.Context, + destroyCallback func(), +) { + if scopeHolder, ok := ctx.Value(s.key).(*instanceRegistry); ok { + scopeHolder.registerDestructionCallback(destroyCallback) + } +} + +func NewContextualScope(key any) Scope { + return &contextualScope{ + key: key, + } +} + +func WithContextualScopeEnabled(ctx context.Context, key any) context.Context { + return context.WithValue(ctx, key, newInstanceRegistry()) +} + +func ShutdownContextualScope(ctx context.Context, key any) { + holder, ok := ctx.Value(key).(*instanceRegistry) + if ok { + holder.shutdown() + } +} diff --git a/scope_test.go b/scope_test.go new file mode 100644 index 0000000..b14673a --- /dev/null +++ b/scope_test.go @@ -0,0 +1,295 @@ +package goinject + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +type sessionScopeKey int + +const sessionScopeKeyVal sessionScopeKey = 0 + +type requestScopeKey int + +const requestScopeKeyVal requestScopeKey = 0 + +type Request struct { + ID int +} + +type Session struct { + ID int +} + +type ContextualScopesParams struct { + Params + Request *Request `inject:""` +} + +type ctxKey int + +const requestKey ctxKey = iota + +func TestContextualScopesUsingContextValue(t *testing.T) { + notAwareContextError := errors.New("not running in value aware context") + assert.NotPanics(t, func() { + injector, err := NewInjector( + Module("contextualScopeTest", + RegisterScope("request", NewContextualScope(requestScopeKeyVal)), + Provide(func(ctx InvocationContext) (*Request, error) { + if r, ok := ctx.Value(requestKey).(*Request); ok { + return r, nil + } else { + return nil, notAwareContextError + } + }, In("request")), + ), + ) + assert.Nil(t, err) + + t.Run("Provider should be able to provide from InvocationContext", func(t *testing.T) { + ctx := context.Background() + requestCtx := WithContextualScopeEnabled( + context.WithValue(ctx, requestKey, &Request{ID: 42}), + requestScopeKeyVal, + ) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + invokeErr := injector.Invoke(requestCtx, func(r *Request) { + assert.Equal(t, 42, r.ID) + }) + assert.Nil(t, invokeErr) + }) + + t.Run("Provider should be able to provide from InvocationContext with error", func(t *testing.T) { + ctx := context.Background() + requestCtx := WithContextualScopeEnabled( + ctx, + requestScopeKeyVal, + ) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + invokeErr := injector.Invoke(requestCtx, func(*Request) { + assert.Fail(t, "should not be called") + }) + assert.ErrorIs(t, invokeErr, notAwareContextError) + }) + }) +} + +func TestContextualScopes(t *testing.T) { + assert.NotPanics(t, func() { + count := 0 + injector, err := NewInjector( + Module("contextualScopeTest", + RegisterScope("request", NewContextualScope(requestScopeKeyVal)), + RegisterScope("session", NewContextualScope(sessionScopeKeyVal)), + Provide(func() *Request { + res := &Request{ID: count} + count++ + return res + }, In("request")), + Provide(func() *Session { + res := &Session{ID: count} + count++ + return res + }, In("session")), + ), + ) + assert.Nil(t, err) + + ctx := context.Background() + + t.Run("Contextual scope should return error if not active", func(t *testing.T) { + err = injector.Invoke(ctx, func(_ *Request) { + assert.Fail(t, "Should not be reached") + }) + assert.True(t, errors.Is(err, &contextScopedNotActiveError{})) + }) + + t.Run("Contextual scope should return error if not active (using Params)", func(t *testing.T) { + err = injector.Invoke(ctx, func(_ ContextualScopesParams) { + assert.Fail(t, "Should not be reached") + }) + assert.True(t, errors.Is(err, &contextScopedNotActiveError{})) + }) + + var sessionID int + var sessionID2 int + + t.Run("Test session with multiple request should keep same session scope but different request scope", + func(t *testing.T) { + sessionCtx := WithContextualScopeEnabled(ctx, sessionScopeKeyVal) + defer ShutdownContextualScope(sessionCtx, sessionScopeKeyVal) + + var request1ID int + var request2ID int + var sessionIDBis int + + t.Run("Test request 1", func(t *testing.T) { + requestCtx := WithContextualScopeEnabled(sessionCtx, requestScopeKeyVal) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + err := injector.Invoke(requestCtx, func(session *Session, request *Request) { + sessionID = session.ID + request1ID = request.ID + }) + assert.Nil(t, err) + }) + + t.Run("Test request 2", func(t *testing.T) { + requestCtx := WithContextualScopeEnabled(sessionCtx, requestScopeKeyVal) + defer ShutdownContextualScope(requestCtx, requestScopeKeyVal) + + err := injector.Invoke(requestCtx, func(session *Session, request *Request) { + sessionIDBis = session.ID + request2ID = request.ID + }) + assert.Nil(t, err) + }) + + assert.NotZero(t, request1ID) + assert.NotZero(t, request2ID) + assert.NotEqual(t, request2ID, request1ID) + + assert.Equal(t, sessionID, sessionIDBis) + }) + + t.Run("Test session 2 (without request scope)", func(t *testing.T) { + sessionCtx := WithContextualScopeEnabled(ctx, sessionScopeKeyVal) + defer ShutdownContextualScope(sessionCtx, sessionScopeKeyVal) + + err := injector.Invoke(sessionCtx, func(session *Session) { + sessionID2 = session.ID + }) + assert.Nil(t, err) + }) + + assert.NotEqual(t, sessionID, sessionID2) + }) +} + +func TestContextualScopeDestroy(t *testing.T) { + assert.NotPanics(t, func() { + count := 0 + injector, err := NewInjector( + Module("contextualScopeTest", + RegisterScope("session", NewContextualScope(sessionScopeKeyVal)), + Provide(func() *Session { + res := &Session{ID: count} + count++ + return res + }, In("session"), WithDestroy(func(_ *Session) { + count-- + })), + ), + ) + assert.Nil(t, err) + ctx := context.Background() + + t.Run("Run session", func(t *testing.T) { + sessionCtx := WithContextualScopeEnabled(ctx, sessionScopeKeyVal) + defer ShutdownContextualScope(sessionCtx, sessionScopeKeyVal) + + err := injector.Invoke(sessionCtx, func(_ *Session) { + assert.Equal(t, 1, count) + }) + assert.Nil(t, err) + }) + + assert.Equal(t, 0, count) + }) +} + +type SingletonInjectee struct { + ID int +} + +func TestSingletonScope(t *testing.T) { + count := 0 + assert.NotPanics(t, func() { + injector, err := NewInjector( + Provide(func() *SingletonInjectee { + res := &SingletonInjectee{ID: count} + count++ + return res + }, In(Singleton)), + ) + assert.Nil(t, err) + + ctx := context.Background() + var fetch1 *SingletonInjectee + var fetch2 *SingletonInjectee + err = injector.Invoke(ctx, func(s *SingletonInjectee) { + fetch1 = s + }) + assert.Nil(t, err) + err = injector.Invoke(ctx, func(s *SingletonInjectee) { + fetch2 = s + }) + assert.Nil(t, err) + assert.NotNil(t, fetch1) + assert.NotNil(t, fetch2) + assert.Same(t, fetch1, fetch2) + }) +} + +type PerLookUpInjectee struct { + ID int +} + +func TestPerLookUpScope(t *testing.T) { + assert.NotPanics(t, func() { + t.Run("Should return new instance on each request", func(t *testing.T) { + count := 0 + injector, err := NewInjector( + Provide(func() *PerLookUpInjectee { + res := &PerLookUpInjectee{ID: count} + count++ + return res + }, In(PerLookUp)), + ) + assert.Nil(t, err) + + ctx := context.Background() + var fetch1 *PerLookUpInjectee + var fetch2 *PerLookUpInjectee + err = injector.Invoke(ctx, func(s *PerLookUpInjectee) { + fetch1 = s + }) + assert.Nil(t, err) + err = injector.Invoke(ctx, func(s *PerLookUpInjectee) { + fetch2 = s + }) + assert.Nil(t, err) + assert.NotNil(t, fetch1) + assert.NotNil(t, fetch2) + assert.NotEqual(t, fetch1, fetch2) + }) + + t.Run("Should ignore destroy instance methods", func(t *testing.T) { + count := 0 + injector, err := NewInjector( + Provide(func() *PerLookUpInjectee { + res := &PerLookUpInjectee{ID: count} + count++ + return res + }, In(PerLookUp), WithDestroy(func(_ *PerLookUpInjectee) { count-- })), + ) + assert.Nil(t, err) + + ctx := context.Background() + err = injector.Invoke(ctx, func(s *PerLookUpInjectee) { + assert.Equal(t, 0, s.ID) + assert.Equal(t, 1, count) + }) + assert.Nil(t, err) + assert.Equal(t, 1, count) + injector.Shutdown() + assert.Equal(t, 1, count) + }) + }) +} diff --git a/special.go b/special.go new file mode 100644 index 0000000..35d8d6c --- /dev/null +++ b/special.go @@ -0,0 +1,60 @@ +package goinject + +import ( + "context" + "reflect" +) + +// Params may be embedded in struct to request the injector to create it +// as special struct. When a constructor accepts such a struct, instead of the +// struct becoming a dependency for that constructor, all its fields become +// dependencies instead. +// +// Fields of the struct may optionally be tagged. +// The following tags are supported, +// +// annotation Requests a value with the same name and type from the +// container. See Named Values for more information. +// optional If set to true, indicates that the dependency is optional and +// the constructor gracefully handles its absence. +type Params struct{} + +var _paramType = reflect.TypeOf(Params{}) + +// EmbedsParams checks whether the given struct is an inject.Params struct. A struct qualifies +// as an inject.Params struct if it embeds inject.Params type. +// +// A struct MUST qualify as an inject.Params struct for its fields to be treated +// specially by the injector. +func EmbedsParams(o reflect.Type) bool { + return embedsType(o, _paramType) +} + +// Returns true if t embeds e +func embedsType(t, e reflect.Type) bool { + if t.Kind() == reflect.Ptr { + return embedsType(t.Elem(), e) + } + + if t.Kind() != reflect.Struct { + // for now, only struct are supported, it might be a good idea to support pointer too + return false + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Anonymous && f.Type == e { + return true + } + } + + return false +} + +type Provider[T any] func(ctx InvocationContext) (T, error) + +// InvocationContext wrap context.Context. +// Use this interface to retrieve the context pass to the Invoke method of the injector in providers +type InvocationContext interface { + context.Context +}