From 3b903dce686aae6b4b47ab4bea2144e98f5b9cab Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sun, 3 May 2020 15:59:10 -0700 Subject: [PATCH] Bot: Added variadic arguments support --- bot/arguments.go | 29 ++++++++------ bot/arguments_test.go | 25 +++++++++++-- bot/ctx_call.go | 85 ++++++++++++++++++++++++++++++++---------- bot/ctx_test.go | 38 ++++++++++++------- bot/error.go | 28 ++++++++++---- bot/subcommand.go | 16 +++++--- bot/subcommand_test.go | 7 +--- 7 files changed, 162 insertions(+), 66 deletions(-) diff --git a/bot/arguments.go b/bot/arguments.go index 09e8202..d11b112 100644 --- a/bot/arguments.go +++ b/bot/arguments.go @@ -93,7 +93,7 @@ func (c *Content) CustomParse(content string) error { type Argument struct { String string // Rule: pointer for structs, direct for primitives - Type reflect.Type + rtype reflect.Type // indicates if the type is referenced, meaning it's a pointer but not the // original call. @@ -105,6 +105,10 @@ type Argument struct { custom *reflect.Method } +func (a *Argument) Type() reflect.Type { + return a.rtype +} + var ShellwordsEscaper = strings.NewReplacer( "\\", "\\\\", ) @@ -116,7 +120,12 @@ var ParseArgs = func(args string) ([]string, error) { // nilV, only used to return an error var nilV = reflect.Value{} -func getArgumentValueFn(t reflect.Type) (*Argument, error) { +func getArgumentValueFn(t reflect.Type, variadic bool) (*Argument, error) { + // Allow array types if varidic is true. + if variadic && t.Kind() == reflect.Slice { + t = t.Elem() + } + var typeI = t var ptr = false @@ -152,7 +161,7 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) { return &Argument{ String: fromUsager(typeI), - Type: typeI, + rtype: typeI, pointer: ptr, fn: avfn, }, nil @@ -166,17 +175,13 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) { return reflect.ValueOf(s), nil } - case reflect.Int, reflect.Int8, - reflect.Int16, reflect.Int32, reflect.Int64: - + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: fn = func(s string) (reflect.Value, error) { i, err := strconv.ParseInt(s, 10, 64) return quickRet(i, err, t) } - case reflect.Uint, reflect.Uint8, - reflect.Uint16, reflect.Uint32, reflect.Uint64: - + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: fn = func(s string) (reflect.Value, error) { u, err := strconv.ParseUint(s, 10, 64) return quickRet(u, err, t) @@ -196,7 +201,7 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) { case "False", "FALSE", "false", "F", "f", "no", "n", "N", "0": return reflect.ValueOf(false), nil default: - return nilV, errors.New("invalid bool [true/false]") + return nilV, errors.New("invalid bool [true|false]") } } } @@ -207,7 +212,7 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) { return &Argument{ String: t.String(), - Type: t, + rtype: t, fn: fn, }, nil } @@ -232,9 +237,11 @@ func fromUsager(typeI reflect.Type) string { if !ok { panic("BUG: type IUsager does not implement Usage") } + vs := mt.Func.Call([]reflect.Value{reflect.New(typeI.Elem())}) return vs[0].String() } + s := strings.Split(typeI.String(), ".") return s[len(s)-1] } diff --git a/bot/arguments_test.go b/bot/arguments_test.go index dd994a1..07132ba 100644 --- a/bot/arguments_test.go +++ b/bot/arguments_test.go @@ -27,15 +27,14 @@ func TestArguments(t *testing.T) { testArgs(t, mockParse("testString"), "testString") testArgs(t, *mockParse("testString"), "testString") - _, err := getArgumentValueFn(reflect.TypeOf(struct{}{})) + _, err := getArgumentValueFn(reflect.TypeOf(struct{}{}), false) if !strings.HasPrefix(err.Error(), "invalid type: ") { t.Fatal("Unexpected error:", err) } - } func testArgs(t *testing.T, expect interface{}, input string) { - f, err := getArgumentValueFn(reflect.TypeOf(expect)) + f, err := getArgumentValueFn(reflect.TypeOf(expect), false) if err != nil { t.Fatal("Failed to get argument value function:", err) } @@ -49,3 +48,23 @@ func testArgs(t *testing.T, expect interface{}, input string) { t.Fatal("Value :", v, "\nExpects:", expect) } } + +// used for ctx_test.go + +type customManualParsed struct { + args []string +} + +func (c *customManualParsed) ParseContent(args []string) error { + c.args = args + return nil +} + +type customParsed struct { + parsed bool +} + +func (c *customParsed) Parse(string) error { + c.parsed = true + return nil +} diff --git a/bot/ctx_call.go b/bot/ctx_call.go index 7ccebd5..7f273db 100644 --- a/bot/ctx_call.go +++ b/bot/ctx_call.go @@ -265,7 +265,7 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { // Check manual or parser if cmd.Arguments[0].fn == nil { // Create a zero value instance of this: - v := reflect.New(cmd.Arguments[0].Type) + v := reflect.New(cmd.Arguments[0].rtype) ret := []reflect.Value{} switch { @@ -313,35 +313,78 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { goto Call } - // Not enough arguments given - if delta := len(args[start:]) - len(cmd.Arguments); delta != 0 { - var err = "Not enough arguments given" - if delta > 0 { - err = "Too many arguments given" + // Argument count check. + if argdelta := len(args[start:]) - len(cmd.Arguments); argdelta != 0 { + var err error // no err if nil + + switch { + // If there aren't enough arguments given. + case argdelta < 0: + err = ErrNotEnoughArgs + + // If there are too many arguments, then check if the command supports + // variadic arguments. We already did a length check above. + case argdelta > 0 && !cmd.Variadic: + // If it's not variadic, then we can't accept it. + err = ErrTooManyArgs } - return &ErrInvalidUsage{ - Args: args, - Index: len(args) - 1, - Err: err, - Ctx: cmd, + if err != nil { + return &ErrInvalidUsage{ + Prefix: pf, + Args: args, + Index: len(args) - 1, + Wrap: err, + Ctx: cmd, + } } } + // Allocate a new slice the length of function arguments. argv = make([]reflect.Value, len(cmd.Arguments)) - for i := start; i < len(args); i++ { - v, err := cmd.Arguments[i-start].fn(args[i]) + for i := 0; i < len(argv); i++ { + v, err := cmd.Arguments[i].fn(args[start+i]) if err != nil { return &ErrInvalidUsage{ - Args: args, - Index: i, - Err: err.Error(), - Ctx: cmd, + Prefix: pf, + Args: args, + Index: i, + Wrap: err, + Ctx: cmd, } } - argv[i-start] = v + argv[i] = v + } + + // Parse the rest with variadic arguments. Go's reflect states that varidic + // parameters will automatically be copied, which is good. + if len(args) > len(argv) { + // The location to continue parsing from args. + argc := len(argv) + // Allocate a new slice to append into. We start 1-off from the start, + // as the first argument of the variadic slice is already parsed. + vars := make([]reflect.Value, len(args)-len(argv)-1) + last := cmd.Arguments[len(cmd.Arguments)-1] + + // Continue the above loop, where i stops before len(argv). + for i := 0; i < len(vars); i++ { + v, err := last.fn(args[argc+i+1]) + if err != nil { + return &ErrInvalidUsage{ + Prefix: pf, + Args: args, + Index: i, + Wrap: err, + Ctx: cmd, + } + } + + vars[i] = v + } + + argv = append(argv, vars...) } Call: @@ -426,10 +469,12 @@ func callWith( caller reflect.Value, ev interface{}, values ...reflect.Value) (interface{}, error) { - return errorReturns(caller.Call(append( + values = append( []reflect.Value{reflect.ValueOf(ev)}, values..., - ))) + ) + + return errorReturns(caller.Call(values)) } func errorReturns(returns []reflect.Value) (interface{}, error) { diff --git a/bot/ctx_test.go b/bot/ctx_test.go index 641923c..60bffde 100644 --- a/bot/ctx_test.go +++ b/bot/ctx_test.go @@ -29,16 +29,21 @@ func (t *testCommands) GetCounter(_ *gateway.MessageCreateEvent) error { return nil } -func (t *testCommands) Send(_ *gateway.MessageCreateEvent, arg string) error { - t.Return <- arg +func (t *testCommands) Send(_ *gateway.MessageCreateEvent, args ...string) error { + t.Return <- args return errors.New("oh no") } -func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *customParseable) error { +func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) error { t.Return <- c.args return nil } +func (t *testCommands) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) error { + t.Return <- c[len(c)-1] + return nil +} + func (t *testCommands) NoArgs(_ *gateway.MessageCreateEvent) error { return errors.New("passed") } @@ -52,15 +57,6 @@ func (t *testCommands) OnTyping(_ *gateway.TypingStartEvent) error { return nil } -type customParseable struct { - args []string -} - -func (c *customParseable) ParseContent(args []string) error { - c.args = args - return nil -} - func TestNewContext(t *testing.T) { var state = &state.State{ Store: state.NewDefaultStore(nil), @@ -181,12 +177,17 @@ func TestContext(t *testing.T) { // Set a custom prefix ctx.HasPrefix = NewPrefix("~") - if err := testReturn("test", "~send test"); err.Error() != "oh no" { + var ( + strings = "hacka doll no. 3" + expects = []string{"hacka", "doll", "no.", "3"} + ) + + if err := testReturn(expects, "~send "+strings); err.Error() != "oh no" { t.Fatal("Unexpected error:", err) } }) - t.Run("call command custom parser", func(t *testing.T) { + t.Run("call command custom manual parser", func(t *testing.T) { ctx.HasPrefix = NewPrefix("!") expects := []string{"custom", "arg1", ":)"} @@ -195,6 +196,15 @@ func TestContext(t *testing.T) { } }) + t.Run("call command custom variadic parser", func(t *testing.T) { + ctx.HasPrefix = NewPrefix("!") + expects := &customParsed{true} + + if err := testReturn(expects, "!variadic bruh moment"); err != nil { + t.Fatal("Unexpected call error:", err) + } + }) + testMessage := func(content string) error { // Mock a messageCreate event m := &gateway.MessageCreateEvent{ diff --git a/bot/error.go b/bot/error.go index 931a09b..6acb712 100644 --- a/bot/error.go +++ b/bot/error.go @@ -1,10 +1,12 @@ package bot import ( + "errors" "strings" ) type ErrUnknownCommand struct { + Prefix string Command string Parent string @@ -18,7 +20,7 @@ func (err *ErrUnknownCommand) Error() string { } var UnknownCommandString = func(err *ErrUnknownCommand) string { - var header = "Unknown command: " + var header = "Unknown command: " + err.Prefix if err.Parent != "" { header += err.Parent + " " + err.Command } else { @@ -28,10 +30,16 @@ var UnknownCommandString = func(err *ErrUnknownCommand) string { return header } +var ( + ErrTooManyArgs = errors.New("Too many arguments given") + ErrNotEnoughArgs = errors.New("Not enough arguments given") +) + type ErrInvalidUsage struct { - Args []string - Index int - Err string + Prefix string + Args []string + Index int + Wrap error // TODO: usage generator? // Here, as a reminder @@ -42,9 +50,13 @@ func (err *ErrInvalidUsage) Error() string { return InvalidUsageString(err) } +func (err *ErrInvalidUsage) Unwrap() error { + return err.Wrap +} + var InvalidUsageString = func(err *ErrInvalidUsage) string { if err.Index == 0 { - return "Invalid usage, error: " + err.Err + return "Invalid usage, error: " + err.Wrap.Error() + "." } if len(err.Args) == 0 { @@ -52,6 +64,8 @@ var InvalidUsageString = func(err *ErrInvalidUsage) string { } body := "Invalid usage at " + + // Write the prefix. + err.Prefix + // Write the first part strings.Join(err.Args[:err.Index], " ") + // Write the wrong part @@ -59,8 +73,8 @@ var InvalidUsageString = func(err *ErrInvalidUsage) string { // Write the last part strings.Join(err.Args[err.Index+1:], " ") - if err.Err != "" { - body += "\nError: " + err.Err + if err.Wrap != nil { + body += "\nError: " + err.Wrap.Error() + "." } return body diff --git a/bot/subcommand.go b/bot/subcommand.go index 7f81dd7..3242df1 100644 --- a/bot/subcommand.go +++ b/bot/subcommand.go @@ -120,6 +120,9 @@ type CommandContext struct { // Hidden is true if the method has a hidden nameflag. Hidden bool + // Variadic is true if the function is a variadic one. + Variadic bool + value reflect.Value // Func event reflect.Type // gateway.*Event method reflect.Method @@ -389,9 +392,10 @@ func (sub *Subcommand) parseCommands() error { } var command = CommandContext{ - method: sub.ptrType.Method(i), - value: method, - event: methodT.In(0), // parse event + method: sub.ptrType.Method(i), + value: method, + event: methodT.In(0), // parse event + Variadic: methodT.IsVariadic(), } // Parse the method name @@ -460,7 +464,7 @@ func (sub *Subcommand) parseCommands() error { command.Arguments = []Argument{{ String: t.String(), - Type: t, + rtype: t, pointer: ptr, custom: &mt, }} @@ -478,7 +482,7 @@ func (sub *Subcommand) parseCommands() error { command.Arguments = []Argument{{ String: t.String(), - Type: t, + rtype: t, pointer: ptr, manual: &mt, }} @@ -491,7 +495,7 @@ func (sub *Subcommand) parseCommands() error { // Fill up arguments for i := 1; i < numArgs; i++ { t := methodT.In(i) - a, err := getArgumentValueFn(t) + a, err := getArgumentValueFn(t, command.Variadic) if err != nil { return errors.Wrap(err, "Error parsing argument "+t.String()) } diff --git a/bot/subcommand_test.go b/bot/subcommand_test.go index ad8736c..c33edcc 100644 --- a/bot/subcommand_test.go +++ b/bot/subcommand_test.go @@ -27,7 +27,7 @@ func TestSubcommand(t *testing.T) { } // !!! CHANGE ME - if len(sub.Commands) != 5 { + if len(sub.Commands) != 6 { t.Fatal("invalid ctx.commands len", len(sub.Commands)) } @@ -57,11 +57,8 @@ func TestSubcommand(t *testing.T) { t.Fatal("expected 0 arguments, got non-zero") } - case "noop", "getCounter": + case "noop", "getCounter", "variadic": // Found, but whatever - - default: - t.Fatal("Unexpected command:", this.Command) } if this.event != typeMessageCreate {