diff --git a/bot/arguments.go b/bot/arguments.go index d11b112..c382c63 100644 --- a/bot/arguments.go +++ b/bot/arguments.go @@ -74,14 +74,15 @@ func (r RawArguments) Length() int { } // CustomParser has a CustomParse method, which would be passed in the full -// message content with the prefix trimmed (but not the command). This is used -// for commands that require more advanced parsing than the default CSV reader. +// message content with the prefix trimmed, but not the command. This is used +// for commands that require more advanced parsing than the default parser. type CustomParser interface { CustomParse(content string) error } // CustomArguments implements the CustomParser interface, which sets the string -// exactly. +// exactly. This string contains the command, subcommand, and all its arguments. +// It does not contain the prefix. type Content string func (c *Content) CustomParse(content string) error { @@ -120,7 +121,7 @@ var ParseArgs = func(args string) ([]string, error) { // nilV, only used to return an error var nilV = reflect.Value{} -func getArgumentValueFn(t reflect.Type, variadic bool) (*Argument, error) { +func newArgument(t reflect.Type, variadic bool) (*Argument, error) { // Allow array types if varidic is true. if variadic && t.Kind() == reflect.Slice { t = t.Elem() @@ -134,6 +135,39 @@ func getArgumentValueFn(t reflect.Type, variadic bool) (*Argument, error) { ptr = true } + // This shouldn't be varidic. + if !variadic && typeI.Implements(typeICusP) { + mt, _ := typeI.MethodByName("CustomParse") + + // TODO: maybe ish? + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return &Argument{ + String: t.String(), + rtype: t, + pointer: ptr, + custom: &mt, + }, nil + } + + // This shouldn't be variadic either. + if !variadic && typeI.Implements(typeIManP) { + mt, _ := typeI.MethodByName("ParseContent") + + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return &Argument{ + String: t.String(), + rtype: t, + pointer: ptr, + manual: &mt, + }, nil + } + if typeI.Implements(typeIParser) { mt, ok := typeI.MethodByName("Parse") if !ok { diff --git a/bot/arguments_test.go b/bot/arguments_test.go index 07132ba..de43ba2 100644 --- a/bot/arguments_test.go +++ b/bot/arguments_test.go @@ -27,14 +27,14 @@ func TestArguments(t *testing.T) { testArgs(t, mockParse("testString"), "testString") testArgs(t, *mockParse("testString"), "testString") - _, err := getArgumentValueFn(reflect.TypeOf(struct{}{}), false) + _, err := newArgument(reflect.TypeOf(struct{}{}), false) if !strings.HasPrefix(err.Error(), "invalid type: ") { t.Fatal("Unexpected error:", err) } } func testArgs(t *testing.T, expect interface{}, input string) { - f, err := getArgumentValueFn(reflect.TypeOf(expect), false) + f, err := newArgument(reflect.TypeOf(expect), false) if err != nil { t.Fatal("Failed to get argument value function:", err) } diff --git a/bot/ctx_call.go b/bot/ctx_call.go index 7f273db..0e8ba9f 100644 --- a/bot/ctx_call.go +++ b/bot/ctx_call.go @@ -153,33 +153,37 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { } // parse arguments - args, err := ParseArgs(content) + parts, err := ParseArgs(content) if err != nil { return errors.Wrap(err, "Failed to parse command") } - if len(args) == 0 { + if len(parts) == 0 { return nil // ??? } var cmd *CommandContext var sub *Subcommand - var start int // arg starts from $start + // var start int // arg starts from $start // Check if plumb: if ctx.plumb { cmd = ctx.Commands[0] sub = ctx.Subcommand - start = 0 + // start = 0 } + // Arguments slice, which will be sliced away until only arguments are left. + var arguments = parts + // If not plumb, search for the command if cmd == nil { for _, c := range ctx.Commands { - if c.Command == args[0] { + if c.Command == parts[0] { cmd = c sub = ctx.Subcommand - start = 1 + arguments = arguments[1:] + // start = 1 break } } @@ -189,7 +193,7 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { // entry. if cmd == nil { for _, s := range ctx.subcommands { - if s.Command != args[0] { + if s.Command != parts[0] { continue } @@ -197,21 +201,24 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { if s.plumb { cmd = s.Commands[0] sub = s - start = 1 + arguments = arguments[1:] + // start = 1 break } // There's no second argument, so we can only look for Plumbed // subcommands. - if len(args) < 2 { + if len(parts) < 2 { continue } for _, c := range s.Commands { - if c.Command == args[1] { + if c.Command == parts[1] { cmd = c sub = s - start = 2 + arguments = arguments[2:] + break + // start = 2 } } @@ -221,8 +228,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { } return &ErrUnknownCommand{ - Command: args[1], - Parent: args[0], + Command: parts[1], + Parent: parts[0], ctx: s.Commands, } } @@ -237,7 +244,7 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { } return &ErrUnknownCommand{ - Command: args[0], + Command: parts[0], ctx: ctx.Commands, } } @@ -255,6 +262,10 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { // Start converting var argv []reflect.Value + var argc int + + // the last argument in the list, not used until set + var last Argument // Here's an edge case: when the handler takes no arguments, we allow that // anyway, as they might've used the raw content. @@ -262,59 +273,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { goto Call } - // Check manual or parser - if cmd.Arguments[0].fn == nil { - // Create a zero value instance of this: - v := reflect.New(cmd.Arguments[0].rtype) - ret := []reflect.Value{} - - switch { - case cmd.Arguments[0].manual != nil: - // Pop out the subcommand name, if there's one: - if sub.Command != "" { - args = args[1:] - } - - // Call the manual parse method: - ret = cmd.Arguments[0].manual.Func.Call([]reflect.Value{ - v, reflect.ValueOf(args), - }) - - case cmd.Arguments[0].custom != nil: - var pad = len(cmd.Command) - if len(sub.Command) > 0 { // if this is also a subcommand: - pad += len(sub.Command) + 1 - } - - // For consistent behavior, clear the subcommand (and command) name off: - content = content[pad:] - // Trim space if there are any: - content = strings.TrimSpace(content) - - // Call the method with the raw unparsed command: - ret = cmd.Arguments[0].custom.Func.Call([]reflect.Value{ - v, reflect.ValueOf(content), - }) - } - - // Check the returned error: - _, err := errorReturns(ret) - if err != nil { - return err - } - - // Check if the argument wants a non-pointer: - if cmd.Arguments[0].pointer { - v = v.Elem() - } - - // Add the argument to the list of arguments: - argv = append(argv, v) - goto Call - } - // Argument count check. - if argdelta := len(args[start:]) - len(cmd.Arguments); argdelta != 0 { + if argdelta := len(arguments) - len(cmd.Arguments); argdelta != 0 { var err error // no err if nil switch { @@ -332,59 +292,107 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { if err != nil { return &ErrInvalidUsage{ Prefix: pf, - Args: args, - Index: len(args) - 1, + Args: parts, + Index: len(parts) - 1, Wrap: err, Ctx: cmd, } } } - // Allocate a new slice the length of function arguments. - argv = make([]reflect.Value, len(cmd.Arguments)) + // The last argument in the arguments slice. + last = cmd.Arguments[len(cmd.Arguments)-1] - for i := 0; i < len(argv); i++ { - v, err := cmd.Arguments[i].fn(args[start+i]) + // Allocate a new slice the length of function arguments. + argc = len(cmd.Arguments) - 1 // arg len without last + argv = make([]reflect.Value, 0, argc) // could be 0 + + // Parse all arguments except for the last one. + for i := 0; i < argc; i++ { + v, err := cmd.Arguments[i].fn(arguments[0]) if err != nil { return &ErrInvalidUsage{ Prefix: pf, - Args: args, - Index: i, + Args: parts, + Index: len(parts) - len(arguments) + i, Wrap: err, Ctx: cmd, } } - argv[i] = v + // Pop arguments. + arguments = arguments[1:] + argv = append(argv, v) } - // Parse the rest with variadic arguments. Go's reflect states that varidic - // parameters will automatically be copied, which is good. - if len(args) > len(argv) { - // The location to continue parsing from args. - argc := len(argv) - // Allocate a new slice to append into. We start 1-off from the start, - // as the first argument of the variadic slice is already parsed. - vars := make([]reflect.Value, len(args)-len(argv)-1) - last := cmd.Arguments[len(cmd.Arguments)-1] + // Is this last argument actually a variadic slice? If yes, then it + // should still have fn normally. + if last.fn != nil { + // Allocate a new slice to append into. + vars := make([]reflect.Value, 0, len(arguments)) - // Continue the above loop, where i stops before len(argv). - for i := 0; i < len(vars); i++ { - v, err := last.fn(args[argc+i+1]) + // Parse the rest with variadic arguments. Go's reflect states that + // varidic parameters will automatically be copied, which is good. + for i := 0; len(arguments) > 0; i++ { + v, err := last.fn(arguments[0]) if err != nil { return &ErrInvalidUsage{ Prefix: pf, - Args: args, - Index: i, + Args: parts, + Index: len(parts) - len(arguments) + i, Wrap: err, Ctx: cmd, } } - vars[i] = v + arguments = arguments[1:] + vars = append(vars, v) } argv = append(argv, vars...) + + } else { + // Create a zero value instance of this: + v := reflect.New(last.rtype) + var err error // return error + + switch { + // If the argument wants all arguments: + case last.manual != nil: + // Call the manual parse method: + _, err = callWith(last.manual.Func, v, reflect.ValueOf(arguments)) + + // If the argument wants all arguments in string: + case last.custom != nil: + // Manual string seeking is a must here. This is because the string + // could contain multiple whitespaces, and the parser would not + // count them. + var seekTo = cmd.Command + if sub.Command != "" { + seekTo = sub.Command + } + + // Seek to the string. + if i := strings.Index(content, seekTo); i > -1 { + content = strings.TrimSpace(content[i:]) + } + + // Call the method with the raw unparsed command: + _, err = callWith(last.custom.Func, v, reflect.ValueOf(content)) + } + + // Check the returned error: + if err != nil { + return err + } + + // Check if the argument wants a non-pointer: + if last.pointer { + v = v.Elem() + } + + // Add the argument into argv. + argv = append(argv, v) } Call: @@ -469,12 +477,17 @@ func callWith( caller reflect.Value, ev interface{}, values ...reflect.Value) (interface{}, error) { - values = append( - []reflect.Value{reflect.ValueOf(ev)}, - values..., - ) + var callargs = make([]reflect.Value, 0, 1+len(values)) - return errorReturns(caller.Call(values)) + if v, ok := ev.(reflect.Value); ok { + callargs = append(callargs, v) + } else { + callargs = append(callargs, reflect.ValueOf(ev)) + } + + callargs = append(callargs, values...) + + return errorReturns(caller.Call(callargs)) } func errorReturns(returns []reflect.Value) (interface{}, error) { diff --git a/bot/ctx_plumb_test.go b/bot/ctx_plumb_test.go index 79a7476..ddb71a1 100644 --- a/bot/ctx_plumb_test.go +++ b/bot/ctx_plumb_test.go @@ -32,7 +32,7 @@ func TestSubcommandPlumb(t *testing.T) { Store: state.NewDefaultStore(nil), } - c, err := New(state, &testCommands{}) + c, err := New(state, &testc{}) if err != nil { t.Fatal("Failed to create new context:", err) } @@ -64,7 +64,7 @@ func TestSubcommandPlumb(t *testing.T) { t.Fatal("Normal method called for hasPlumb") } - if p.Plumbed != "test command" { + if p.Plumbed != "hasPlumb test command" { t.Fatal("Unexpected custom argument for plumbed:", p.Plumbed) } } diff --git a/bot/ctx_test.go b/bot/ctx_test.go index 60bffde..fef2652 100644 --- a/bot/ctx_test.go +++ b/bot/ctx_test.go @@ -12,47 +12,52 @@ import ( "github.com/diamondburned/arikawa/state" ) -type testCommands struct { +type testc struct { Ctx *Context Return chan interface{} Counter uint64 Typed bool } -func (t *testCommands) MーBumpCounter(interface{}) error { +func (t *testc) MーBumpCounter(interface{}) error { t.Counter++ return nil } -func (t *testCommands) GetCounter(_ *gateway.MessageCreateEvent) error { +func (t *testc) GetCounter(_ *gateway.MessageCreateEvent) error { t.Return <- strconv.FormatUint(t.Counter, 10) return nil } -func (t *testCommands) Send(_ *gateway.MessageCreateEvent, args ...string) error { +func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error { t.Return <- args return errors.New("oh no") } -func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) error { +func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) error { t.Return <- c.args return nil } -func (t *testCommands) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) error { +func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) error { t.Return <- c[len(c)-1] return nil } -func (t *testCommands) NoArgs(_ *gateway.MessageCreateEvent) error { - return errors.New("passed") -} - -func (t *testCommands) Noop(_ *gateway.MessageCreateEvent) error { +func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, s string, c *customManualParsed) error { + t.Return <- c.args return nil } -func (t *testCommands) OnTyping(_ *gateway.TypingStartEvent) error { +func (t *testc) NoArgs(_ *gateway.MessageCreateEvent) error { + return errors.New("passed") +} + +func (t *testc) Noop(_ *gateway.MessageCreateEvent) error { + return nil +} + +func (t *testc) OnTyping(_ *gateway.TypingStartEvent) error { t.Typed = true return nil } @@ -62,7 +67,7 @@ func TestNewContext(t *testing.T) { Store: state.NewDefaultStore(nil), } - c, err := New(state, &testCommands{}) + c, err := New(state, &testc{}) if err != nil { t.Fatal("Failed to create new context:", err) } @@ -73,7 +78,7 @@ func TestNewContext(t *testing.T) { } func TestContext(t *testing.T) { - var given = &testCommands{} + var given = &testc{} var state = &state.State{ Store: state.NewDefaultStore(nil), } @@ -119,6 +124,8 @@ func TestContext(t *testing.T) { }) testReturn := func(expects interface{}, content string) (call error) { + t.Helper() + // Return channel for testing ret := make(chan interface{}) given.Return = ret @@ -189,7 +196,7 @@ func TestContext(t *testing.T) { t.Run("call command custom manual parser", func(t *testing.T) { ctx.HasPrefix = NewPrefix("!") - expects := []string{"custom", "arg1", ":)"} + expects := []string{"arg1", ":)"} if err := testReturn(expects, "!custom arg1 :)"); err != nil { t.Fatal("Unexpected call error:", err) @@ -205,6 +212,15 @@ func TestContext(t *testing.T) { } }) + t.Run("call command custom trailing manual parser", func(t *testing.T) { + ctx.HasPrefix = NewPrefix("!") + expects := []string{"arikawa"} + + if err := testReturn(expects, "!trailCustom hime arikawa"); err != nil { + t.Fatal("Unexpected call error:", err) + } + }) + testMessage := func(content string) error { // Mock a messageCreate event m := &gateway.MessageCreateEvent{ @@ -241,16 +257,16 @@ func TestContext(t *testing.T) { t.Run("register subcommand", func(t *testing.T) { ctx.HasPrefix = NewPrefix("run ") - _, err := ctx.RegisterSubcommand(&testCommands{}) + _, err := ctx.RegisterSubcommand(&testc{}) if err != nil { t.Fatal("Failed to register subcommand:", err) } - if err := testMessage("run testCommands noop"); err != nil { + if err := testMessage("run testc noop"); err != nil { t.Fatal("Unexpected error:", err) } - cmd := ctx.FindCommand("testCommands", "Noop") + cmd := ctx.FindCommand("testc", "Noop") if cmd == nil { t.Fatal("Failed to find subcommand Noop") } @@ -263,12 +279,12 @@ func BenchmarkConstructor(b *testing.B) { } for i := 0; i < b.N; i++ { - _, _ = New(state, &testCommands{}) + _, _ = New(state, &testc{}) } } func BenchmarkCall(b *testing.B) { - var given = &testCommands{} + var given = &testc{} var state = &state.State{ Store: state.NewDefaultStore(nil), } @@ -295,7 +311,7 @@ func BenchmarkCall(b *testing.B) { } func BenchmarkHelp(b *testing.B) { - var given = &testCommands{} + var given = &testc{} var state = &state.State{ Store: state.NewDefaultStore(nil), } diff --git a/bot/subcommand.go b/bot/subcommand.go index 3242df1..eefc30c 100644 --- a/bot/subcommand.go +++ b/bot/subcommand.go @@ -120,7 +120,8 @@ type CommandContext struct { // Hidden is true if the method has a hidden nameflag. Hidden bool - // Variadic is true if the function is a variadic one. + // Variadic is true if the function is a variadic one or if the last + // argument accepts multiple strings. Variadic bool value reflect.Value // Func @@ -445,65 +446,25 @@ func (sub *Subcommand) parseCommands() error { continue } - // The argument's second argument (the first is the event). - var inT = methodT.In(1) - var ptr bool - - if inT.Kind() != reflect.Ptr { - inT = reflect.PtrTo(inT) - ptr = true - } - - // If the second argument implements CustomParse() - if t := inT; t.Implements(typeICusP) { - mt, _ := inT.MethodByName("CustomParse") - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - command.Arguments = []Argument{{ - String: t.String(), - rtype: t, - pointer: ptr, - custom: &mt, - }} - - goto Done - } - - // If the second argument implements ParseContent() - if t := inT; t.Implements(typeIManP) { - mt, _ := inT.MethodByName("ParseContent") - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - command.Arguments = []Argument{{ - String: t.String(), - rtype: t, - pointer: ptr, - manual: &mt, - }} - - goto Done - } - command.Arguments = make([]Argument, 0, numArgs) - // Fill up arguments + // Fill up arguments. This should work with cusP and manP for i := 1; i < numArgs; i++ { t := methodT.In(i) - a, err := getArgumentValueFn(t, command.Variadic) + a, err := newArgument(t, command.Variadic) if err != nil { return errors.Wrap(err, "Error parsing argument "+t.String()) } command.Arguments = append(command.Arguments, *a) + + // We're done if the type accepts multiple arguments. + if a.custom != nil || a.manual != nil { + command.Variadic = true // treat as variadic + break + } } - Done: // If the current event is a plumb event: if flag.Is(Plumb) { command.Command = "" // plumbers don't have names diff --git a/bot/subcommand_test.go b/bot/subcommand_test.go index c33edcc..6f77905 100644 --- a/bot/subcommand_test.go +++ b/bot/subcommand_test.go @@ -3,14 +3,14 @@ package bot import "testing" func TestNewSubcommand(t *testing.T) { - _, err := NewSubcommand(&testCommands{}) + _, err := NewSubcommand(&testc{}) if err != nil { t.Fatal("Failed to create new subcommand:", err) } } func TestSubcommand(t *testing.T) { - var given = &testCommands{} + var given = &testc{} var sub = &Subcommand{ command: given, } @@ -27,7 +27,7 @@ func TestSubcommand(t *testing.T) { } // !!! CHANGE ME - if len(sub.Commands) != 6 { + if len(sub.Commands) != 7 { t.Fatal("invalid ctx.commands len", len(sub.Commands)) } @@ -57,7 +57,7 @@ func TestSubcommand(t *testing.T) { t.Fatal("expected 0 arguments, got non-zero") } - case "noop", "getCounter", "variadic": + case "noop", "getCounter", "variadic", "trailCustom": // Found, but whatever } @@ -88,6 +88,6 @@ func TestSubcommand(t *testing.T) { func BenchmarkSubcommandConstructor(b *testing.B) { for i := 0; i < b.N; i++ { - NewSubcommand(&testCommands{}) + NewSubcommand(&testc{}) } }