Skip to content

Commit

Permalink
feat: support recursive injection of provider parameters
Browse files Browse the repository at this point in the history
This allows provider functions to accept parameters that are injected by other
bindings or binding providers, eg. call the provider function with the root CLI
struct (which is automatically bound by Kong):

  kong.BindToProvider(func(cli *CLI) (*Injected, error) { ... })
  • Loading branch information
alecthomas committed Nov 1, 2024
1 parent 373692a commit 7bbb0b7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
42 changes: 20 additions & 22 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (
"strings"
)

type bindings map[reflect.Type]func() (reflect.Value, error)
// A map of type to function that returns a value of that type.
//
// The function should have the signature func(...) (T, error). Arguments are recursively resolved.
type bindings map[reflect.Type]any

func (b bindings) String() string {
out := []string{}
Expand All @@ -19,32 +22,23 @@ func (b bindings) String() string {
func (b bindings) add(values ...interface{}) bindings {
for _, v := range values {
v := v
b[reflect.TypeOf(v)] = func() (reflect.Value, error) { return reflect.ValueOf(v), nil }
b[reflect.TypeOf(v)] = func() (any, error) { return v, nil }
}
return b
}

func (b bindings) addTo(impl, iface interface{}) {
valueOf := reflect.ValueOf(impl)
b[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil }
b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil }
}

func (b bindings) addProvider(provider interface{}) error {
pv := reflect.ValueOf(provider)
t := pv.Type()
if t.Kind() != reflect.Func || t.NumIn() != 0 || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("%T must be a function with the signature func()(T, error)", provider)
if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider)
}
rt := pv.Type().Out(0)
b[rt] = func() (reflect.Value, error) {
out := pv.Call(nil)
errv := out[1]
var err error
if !errv.IsNil() {
err = errv.Interface().(error) //nolint
}
return out[0], err
}
b[rt] = provider
return nil
}

Expand Down Expand Up @@ -101,15 +95,19 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
t := f.Type()
for i := 0; i < t.NumIn(); i++ {
pt := t.In(i)
if argf, ok := bindings[pt]; ok {
argv, err := argf()
if err != nil {
return nil, err
}
in = append(in, argv)
} else {
argf, ok := bindings[pt]
if !ok {
return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
}
// Recursively resolve binding functions.
argv, err := callAnyFunction(reflect.ValueOf(argf), bindings)
if err != nil {
return nil, fmt.Errorf("%s: %w", pt, err)
}
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() {
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
}
in = append(in, reflect.ValueOf(argv[0]))
}
outv := f.Call(in)
out = make([]any, len(outv))
Expand Down
6 changes: 5 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ func BindTo(impl, iface interface{}) Option {
})
}

// BindToProvider allows binding of provider functions.
// BindToProvider binds an injected value to a provider function.
//
// The provider function must have the signature:
//
// func() (interface{}, error)
//
// This is useful when the Run() function of different commands require different values that may
// not all be initialisable from the main() function.
Expand Down
7 changes: 6 additions & 1 deletion options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ func TestCallbackCustomError(t *testing.T) {
}

type bindToProviderCLI struct {
Filled bool `default:"true"`
Called bool
Cmd bindToProviderCmd `cmd:""`
}

type boundThing struct {
Filled bool
}

type bindToProviderCmd struct{}
Expand All @@ -105,7 +107,10 @@ func (*bindToProviderCmd) Run(cli *bindToProviderCLI, b *boundThing) error {

func TestBindToProvider(t *testing.T) {
var cli bindToProviderCLI
app, err := New(&cli, BindToProvider(func() (*boundThing, error) { return &boundThing{}, nil }))
app, err := New(&cli, BindToProvider(func(cli *bindToProviderCLI) (*boundThing, error) {
assert.True(t, cli.Filled, "CLI struct should have already been populated by Kong")
return &boundThing{Filled: cli.Filled}, nil
}))
assert.NoError(t, err)
ctx, err := app.Parse([]string{"cmd"})
assert.NoError(t, err)
Expand Down

0 comments on commit 7bbb0b7

Please sign in to comment.