diff --git a/bot/README.md b/bot/README.md index 54da469..e31ddfb 100644 --- a/bot/README.md +++ b/bot/README.md @@ -2,6 +2,8 @@ Not a lot for a Discord bot: +# THIS IS OUTDATED. TODO: UPDATE. + ``` # Cold functions, or functions that are called once in runtime: BenchmarkConstructor-8 150537 7617 ns/op diff --git a/bot/command.go b/bot/command.go new file mode 100644 index 0000000..0e70efe --- /dev/null +++ b/bot/command.go @@ -0,0 +1,236 @@ +package bot + +import ( + "reflect" +) + +type command struct { + value reflect.Value // Func + event reflect.Type + isInterface bool +} + +func newCommand(value reflect.Value, event reflect.Type) command { + return command{ + value: value, + event: event, + isInterface: event.Kind() == reflect.Interface, + } +} + +func (c *command) isEvent(t reflect.Type) bool { + return (!c.isInterface && c.event == t) || (c.isInterface && t.Implements(c.event)) +} + +func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, error) { + return callWith(c.value, arg0, argv...) +} + +func callWith(caller reflect.Value, arg0 interface{}, argv ...reflect.Value) (interface{}, error) { + var callargs = make([]reflect.Value, 0, 1+len(argv)) + + if v, ok := arg0.(reflect.Value); ok { + callargs = append(callargs, v) + } else { + callargs = append(callargs, reflect.ValueOf(arg0)) + } + + callargs = append(callargs, argv...) + return errorReturns(caller.Call(callargs)) +} + +type caller interface { + call(arg0 interface{}, argv ...reflect.Value) (interface{}, error) +} + +func errorReturns(returns []reflect.Value) (interface{}, error) { + // Handlers may return nothing. + if len(returns) == 0 { + return nil, nil + } + + // assume first return is always error, since we checked for this in + // parseCommands. + v := returns[len(returns)-1].Interface() + // If the last return (error) is nil. + if v == nil { + // If we only have 1 returns, that return must be the error. The error + // is nil, so nil is returned. + if len(returns) == 1 { + return nil, nil + } + + // Return the first argument as-is. The above returns[-1] check assumes + // 2 return values (T, error), meaning returns[0] is the T value. + return returns[0].Interface(), nil + } + + // Treat the last return as an error. + return nil, v.(error) +} + +// MethodContext is an internal struct containing fields to make this library +// work. As such, they're all unexported. Description, however, is exported for +// editing, and may be used to generate more informative help messages. +type MethodContext struct { + command + method reflect.Method // extend + middlewares []*MiddlewareContext + + Description string + + // MethodName is the name of the method. This field should NOT be changed. + MethodName string + + // Command is the Discord command used to call the method. + Command string // hidden if empty + + // Hidden is true if the method has a hidden nameflag. + // Hidden bool + + // Variadic is true if the function is a variadic one or if the last + // argument accepts multiple strings. + Variadic bool + + Arguments []Argument +} + +func parseMethod(value reflect.Value, method reflect.Method) *MethodContext { + methodT := value.Type() + numArgs := methodT.NumIn() + + if numArgs == 0 { + // Doesn't meet the requirement for an event, continue. + return nil + } + + // Check number of returns: + numOut := methodT.NumOut() + + // Returns can either be: + // Nothing - func() + // An error - func() error + // An error and something else - func() (T, error) + if numOut > 2 { + return nil + } + + // Check the last return's type if the method returns anything. + if numOut > 0 { + if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) { + // Invalid, skip. + return nil + } + } + + var command = MethodContext{ + command: newCommand(value, methodT.In(0)), + method: method, + MethodName: method.Name, + Variadic: methodT.IsVariadic(), + } + + // Only set the command name if it's a MessageCreate handler. + if command.event == typeMessageCreate { + command.Command = lowerFirstLetter(command.method.Name) + } + + if numArgs > 1 { + // Event handlers that aren't MessageCreate should not have arguments. + if command.event != typeMessageCreate { + return nil + } + + // If the event type is messageCreate: + command.Arguments = make([]Argument, 0, numArgs-1) + + // Fill up arguments. This should work with cusP and manP + for i := 1; i < numArgs; i++ { + t := methodT.In(i) + a, err := newArgument(t, command.Variadic) + if err != nil { + panic("Error parsing argument " + t.String() + ": " + err.Error()) + } + + command.Arguments = append(command.Arguments, *a) + + // We're done if the type accepts multiple arguments. + if a.custom != nil || a.manual != nil { + command.Variadic = true // treat as variadic + break + } + } + } + + return &command +} + +func (cctx *MethodContext) addMiddleware(mw *MiddlewareContext) { + cctx.middlewares = append(cctx.middlewares, mw) +} + +func (cctx *MethodContext) walkMiddlewares(ev reflect.Value) error { + for _, mw := range cctx.middlewares { + _, err := mw.call(ev) + if err != nil { + return err + } + } + return nil +} + +func (cctx *MethodContext) Usage() []string { + if len(cctx.Arguments) == 0 { + return nil + } + + var arguments = make([]string, len(cctx.Arguments)) + for i, arg := range cctx.Arguments { + arguments[i] = arg.String + } + + return arguments +} + +// SetName sets the command name. +func (cctx *MethodContext) SetName(name string) { + cctx.Command = name +} + +type MiddlewareContext struct { + command +} + +// ParseMiddleware parses a middleware function. This function panics. +func ParseMiddleware(mw interface{}) *MiddlewareContext { + value := reflect.ValueOf(mw) + methodT := value.Type() + numArgs := methodT.NumIn() + + if numArgs != 1 { + panic("Invalid argument signature for " + methodT.String()) + } + + // Check number of returns: + numOut := methodT.NumOut() + + // Returns can either be: + // Nothing - func() + // An error - func() error + if numOut > 1 { + panic("Invalid return signature for " + methodT.String()) + } + + // Check the last return's type if the method returns anything. + if numOut == 1 { + if i := methodT.Out(0); i == nil || !i.Implements(typeIError) { + panic("Unexpected return type (not error) for " + methodT.String()) + } + } + + var middleware = MiddlewareContext{ + command: newCommand(value, methodT.In(0)), + } + + return &middleware +} diff --git a/bot/ctx.go b/bot/ctx.go index 2acd2e5..bd0b4d3 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -217,19 +217,19 @@ func (ctx *Context) Subcommands() []*Subcommand { return ctx.subcommands } -// FindCommand finds a command based on the struct and method name. The queried +// FindMethod finds a method based on the struct and method name. The queried // names will have their flags stripped. // // Example // // // Find a command from the main context: -// cmd := ctx.FindCommand("", "Method") +// cmd := ctx.FindMethod("", "Method") // // Find a command from a subcommand: -// cmd = ctx.FindCommand("Starboard", "Reset") +// cmd = ctx.FindMethod("Starboard", "Reset") // -func (ctx *Context) FindCommand(structname, methodname string) *CommandContext { +func (ctx *Context) FindMethod(structname, methodname string) *MethodContext { if structname == "" { - for _, c := range ctx.Commands { + for _, c := range ctx.Methods { if c.MethodName == methodname { return c } @@ -243,7 +243,7 @@ func (ctx *Context) FindCommand(structname, methodname string) *CommandContext { continue } - for _, c := range sub.Commands { + for _, c := range sub.Methods { if c.MethodName == methodname { return c } @@ -360,52 +360,55 @@ func (ctx *Context) HelpAdmin() string { } func (ctx *Context) help(hideAdmin bool) string { - const indent = " " + // const indent = " " - var help strings.Builder + // var help strings.Builder - // Generate the headers and descriptions - help.WriteString("__Help__") + // // Generate the headers and descriptions + // help.WriteString("__Help__") - if ctx.Name != "" { - help.WriteString(": " + ctx.Name) - } + // if ctx.Name != "" { + // help.WriteString(": " + ctx.Name) + // } - if ctx.Description != "" { - help.WriteString("\n" + indent + ctx.Description) - } + // if ctx.Description != "" { + // help.WriteString("\n" + indent + ctx.Description) + // } - if ctx.Flag.Is(AdminOnly) { - // That's it. - return help.String() - } + // if ctx.Flag.Is(AdminOnly) { + // // That's it. + // return help.String() + // } - // Separators - help.WriteString("\n---\n") + // // Separators + // help.WriteString("\n---\n") - // Generate all commands - help.WriteString("__Commands__") - help.WriteString(ctx.Subcommand.Help(indent, hideAdmin)) - help.WriteByte('\n') + // // Generate all commands + // help.WriteString("__Commands__") + // help.WriteString(ctx.Subcommand.Help(indent, hideAdmin)) + // help.WriteByte('\n') - var subHelp = strings.Builder{} - var subcommands = ctx.Subcommands() + // var subHelp = strings.Builder{} + // var subcommands = ctx.Subcommands() - for _, sub := range subcommands { - if help := sub.Help(indent, hideAdmin); help != "" { - for _, line := range strings.Split(help, "\n") { - subHelp.WriteString(indent) - subHelp.WriteString(line) - subHelp.WriteByte('\n') - } - } - } + // for _, sub := range subcommands { + // if help := sub.Help(indent, hideAdmin); help != "" { + // for _, line := range strings.Split(help, "\n") { + // subHelp.WriteString(indent) + // subHelp.WriteString(line) + // subHelp.WriteByte('\n') + // } + // } + // } - if subHelp.Len() > 0 { - help.WriteString("---\n") - help.WriteString("__Subcommands__\n") - help.WriteString(subHelp.String()) - } + // if subHelp.Len() > 0 { + // help.WriteString("---\n") + // help.WriteString("__Subcommands__\n") + // help.WriteString(subHelp.String()) + // } - return help.String() + // return help.String() + + // TODO + return "" } diff --git a/bot/ctx_call.go b/bot/ctx_call.go index d50326c..8a900e9 100644 --- a/bot/ctx_call.go +++ b/bot/ctx_call.go @@ -5,136 +5,75 @@ import ( "strings" "github.com/diamondburned/arikawa/api" - "github.com/diamondburned/arikawa/bot/extras/infer" "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/gateway" "github.com/pkg/errors" ) -// NonFatal is an interface that a method can implement to ignore all errors. -// This works similarly to Break. -type NonFatal interface { - error - IgnoreError() // noop method -} - -func onlyFatal(err error) error { - if _, ok := err.(NonFatal); ok { - return nil - } - return err -} - -type _Break struct{ error } - -// implement NonFatal. -func (_Break) IgnoreError() {} - -// Break is a non-fatal error that could be returned from middlewares or -// handlers to stop the chain of execution. -// -// Middlewares are guaranteed to be executed before handlers, but the exact -// order of each are undefined. Main handlers are also guaranteed to be executed -// before all subcommands. If a main middleware cancels, no subcommand -// middlewares will be called. -// -// Break implements the NonFatal interface, which causes an error to be ignored. -var Break NonFatal = _Break{errors.New("break middleware chain, non-fatal")} - -func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext { - var callers []*CommandContext - var middles []*CommandContext - var found bool - - find := func(sub *Subcommand) { - for _, cmd := range sub.Events { - // Search only for callers, so skip middlewares. - if cmd.Flag.Is(Middleware) { - continue - } - - if cmd.event == evT { - callers = append(callers, cmd) - found = true - } - } - - // Only get middlewares if we found handlers for that same event. - if found { - // Search for middlewares with the same type: - for _, mw := range sub.mwMethods { - if mw.event == evT { - middles = append(middles, mw) - } - } - } - } +// Break is a non-fatal error that could be returned from middlewares to stop +// the chain of execution. +var Break = errors.New("break middleware chain, non-fatal") +// filterEventType filters all commands and subcommands into a 2D slice, +// structured so that a Break would only exit out the nested slice. +func (ctx *Context) filterEventType(evT reflect.Type) (callers [][]caller) { // Find the main context first. - find(ctx.Subcommand) + callers = append(callers, ctx.eventCallers(evT)) for _, sub := range ctx.subcommands { - // Reset found status - found = false // Find subcommands second. - find(sub) + callers = append(callers, sub.eventCallers(evT)) } - return append(middles, callers...) + return } -func (ctx *Context) callCmd(ev interface{}) error { - evT := reflect.TypeOf(ev) +func (ctx *Context) callCmd(ev interface{}) (bottomError error) { + evV := reflect.ValueOf(ev) + evT := evV.Type() - var isAdmin *bool // I want to die. - var isGuild *bool - var callers []*CommandContext + var callers [][]caller // Hit the cache t, ok := ctx.typeCache.Load(evT) if ok { - callers = t.([]*CommandContext) + callers = t.([][]caller) } else { callers = ctx.filterEventType(evT) ctx.typeCache.Store(evT, callers) } - // We can't do the callers[:0] trick here, as it will modify the slice - // inside the sync.Map as well. - var filtered = make([]*CommandContext, 0, len(callers)) + for _, subcallers := range callers { + for _, c := range subcallers { + _, err := c.call(evV) + if err != nil { + // Only count as an error if it's not Break. + if err = errNoBreak(err); err != nil { + bottomError = err + } - for _, cmd := range callers { - // Command flags will inherit its parent Subcommand's flags. - if true && - !(cmd.Flag.Is(AdminOnly) && !ctx.eventIsAdmin(ev, &isAdmin)) && - !(cmd.Flag.Is(GuildOnly) && !ctx.eventIsGuild(ev, &isGuild)) { - - filtered = append(filtered, cmd) - } - } - - for _, c := range filtered { - _, err := callWith(c.value, ev) - if err != nil { - if err = onlyFatal(err); err != nil { - ctx.ErrorLogger(err) + // Break the caller loop only for this subcommand. + break } - return err } } - // We call the messages later, since Hidden handlers will go into the Events - // slice, but we don't want to ignore those handlers either. + // We call the messages later, since we want MessageCreate middlewares to + // run as well. if evT == typeMessageCreate { // safe assertion always - err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent)) - return onlyFatal(err) + err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent), evV) + // There's no need for an errNoBreak here, as the method already checked + // for that. + if err != nil { + bottomError = err + } } - return nil + return } -func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { +func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent, value reflect.Value) error { // check if bot if !ctx.AllowBot && mc.Author.Bot { return nil @@ -163,102 +102,18 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { return nil // ??? } - var cmd *CommandContext - var sub *Subcommand - // var start int // arg starts from $start - - // Check if plumb: - if ctx.plumb { - cmd = ctx.Commands[0] - sub = ctx.Subcommand - // start = 0 + // Find the command and subcommand. + arguments, cmd, sub, err := ctx.findCommand(parts) + if err != nil { + return errNoBreak(err) } - // Arguments slice, which will be sliced away until only arguments are left. - var arguments = parts + // We don't run the subcommand's middlewares here, as the callCmd function + // already handles that. - // If not plumb, search for the command - if cmd == nil { - for _, c := range ctx.Commands { - if c.Command == parts[0] { - cmd = c - sub = ctx.Subcommand - arguments = arguments[1:] - // start = 1 - break - } - } - } - - // Can't find the command, look for subcommands if len(args) has a 2nd - // entry. - if cmd == nil { - for _, s := range ctx.subcommands { - if s.Command != parts[0] { - continue - } - - // Check if plumb: - if s.plumb { - cmd = s.Commands[0] - sub = s - arguments = arguments[1:] - // start = 1 - break - } - - // There's no second argument, so we can only look for Plumbed - // subcommands. - if len(parts) < 2 { - continue - } - - for _, c := range s.Commands { - if c.Command == parts[1] { - cmd = c - sub = s - arguments = arguments[2:] - break - // start = 2 - } - } - - if cmd == nil { - if s.QuietUnknownCommand { - return nil - } - - return &ErrUnknownCommand{ - Command: parts[1], - Parent: parts[0], - ctx: s.Commands, - } - } - - break - } - } - - if cmd == nil { - if ctx.QuietUnknownCommand { - return nil - } - - return &ErrUnknownCommand{ - Command: parts[0], - ctx: ctx.Commands, - } - } - - // Check for IsAdmin and IsGuild - if cmd.Flag.Is(GuildOnly) && !mc.GuildID.Valid() { - return nil - } - if cmd.Flag.Is(AdminOnly) { - p, err := ctx.State.Permissions(mc.ChannelID, mc.Author.ID) - if err != nil || !p.Has(discord.PermissionAdministrator) { - return nil - } + // Run command middlewares. + if err := cmd.walkMiddlewares(value); err != nil { + return errNoBreak(err) } // Start converting @@ -375,8 +230,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { // could contain multiple whitespaces, and the parser would not // count them. var seekTo = cmd.Command - // If plumbed, then there would only be the subcommand. - if sub.plumb { + // Implicit plumbing behavior. + if seekTo == "" { seekTo = sub.Command } @@ -406,17 +261,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { } Call: - // Try calling all middlewares first. We don't need to stack middlewares, as - // there will only be one command match. - for _, mw := range sub.mwMethods { - _, err := callWith(mw.value, mc) - if err != nil { - return err - } - } - // call the function and parse the error return value - v, err := callWith(cmd.value, mc, argv...) + v, err := cmd.call(value, argv...) if err != nil { return err } @@ -437,91 +283,59 @@ Call: return err } -func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool { - if *is != nil { - return **is +// findCommand filters. +func (ctx *Context) findCommand(parts []string) ([]string, *MethodContext, *Subcommand, error) { + // Main command entrypoint cannot have plumb. + for _, c := range ctx.Methods { + if c.Command == parts[0] { + return parts[1:], c, ctx.Subcommand, nil + } } - var channelID = infer.ChannelID(ev) - if !channelID.Valid() { - return false - } - - var userID = infer.UserID(ev) - if !userID.Valid() { - return false - } - - var res bool - - p, err := ctx.State.Permissions(channelID, userID) - if err == nil && p.Has(discord.PermissionAdministrator) { - res = true - } - - *is = &res - return res -} - -func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool { - if *is != nil { - return **is - } - - var channelID = infer.ChannelID(ev) - if !channelID.Valid() { - return false - } - - c, err := ctx.State.Channel(channelID) - if err != nil { - return false - } - - res := c.GuildID.Valid() - *is = &res - return res -} - -func callWith( - caller reflect.Value, - ev interface{}, values ...reflect.Value) (interface{}, error) { - - var callargs = make([]reflect.Value, 0, 1+len(values)) - - if v, ok := ev.(reflect.Value); ok { - callargs = append(callargs, v) - } else { - callargs = append(callargs, reflect.ValueOf(ev)) - } - - callargs = append(callargs, values...) - - return errorReturns(caller.Call(callargs)) -} - -func errorReturns(returns []reflect.Value) (interface{}, error) { - // Handlers may return nothing. - if len(returns) == 0 { - return nil, nil - } - - // assume first return is always error, since we checked for this in - // parseCommands. - v := returns[len(returns)-1].Interface() - // If the last return (error) is nil. - if v == nil { - // If we only have 1 returns, that return must be the error. The error - // is nil, so nil is returned. - if len(returns) == 1 { - return nil, nil + // Can't find the command, look for subcommands if len(args) has a 2nd + // entry. + for _, s := range ctx.subcommands { + if s.Command != parts[0] { + continue } - // Return the first argument as-is. The above returns[-1] check assumes - // 2 return values (T, error), meaning returns[0] is the T value. - return returns[0].Interface(), nil + // If there's no second argument, TODO call Help. + + if s.plumbed != nil { + return parts[1:], s.plumbed, s, nil + } + + if len(parts) >= 2 { + for _, c := range s.Methods { + if c.event == typeMessageCreate && c.Command == parts[1] { + return parts[2:], c, s, nil + } + } + } + + if s.QuietUnknownCommand || ctx.QuietUnknownCommand { + return nil, nil, nil, Break + } + + return nil, nil, nil, &ErrUnknownCommand{ + Parts: parts, + Subcmd: s, + } } - // Treat the last return as an error. - return nil, v.(error) + if ctx.QuietUnknownCommand { + return nil, nil, nil, Break + } + + return nil, nil, nil, &ErrUnknownCommand{ + Parts: parts, + Subcmd: ctx.Subcommand, + } +} + +func errNoBreak(err error) error { + if errors.Is(err, Break) { + return nil + } + return err } diff --git a/bot/ctx_plumb_test.go b/bot/ctx_plumb_test.go index a7ebb9c..192b31d 100644 --- a/bot/ctx_plumb_test.go +++ b/bot/ctx_plumb_test.go @@ -15,12 +15,16 @@ type hasPlumb struct { NotPlumbed bool } +func (h *hasPlumb) Setup(sub *Subcommand) { + sub.SetPlumb("Plumber") +} + func (h *hasPlumb) Normal(_ *gateway.MessageCreateEvent) error { h.NotPlumbed = true return nil } -func (h *hasPlumb) PーPlumber(_ *gateway.MessageCreateEvent, c RawArguments) error { +func (h *hasPlumb) Plumber(_ *gateway.MessageCreateEvent, c RawArguments) error { h.Plumbed = string(c) return nil } @@ -43,10 +47,6 @@ func TestSubcommandPlumb(t *testing.T) { t.Fatal("Failed to register hasPlumb:", err) } - if l := len(c.subcommands[0].Commands); l != 1 { - t.Fatal("Unexpected length for sub.Commands:", l) - } - // Try call exactly what's in the Plumb example: m := &gateway.MessageCreateEvent{ Message: discord.Message{ diff --git a/bot/ctx_test.go b/bot/ctx_test.go index 5a835a4..45e7d92 100644 --- a/bot/ctx_test.go +++ b/bot/ctx_test.go @@ -21,43 +21,38 @@ type testc struct { Typed bool } -func (t *testc) MーBumpCounter(interface{}) { - t.Counter++ +func (t *testc) Setup(sub *Subcommand) { + sub.AddMiddleware("*,GetCounter", func(v interface{}) { + t.Counter++ + }) + sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) { + t.Counter++ + }) } - -func (t *testc) GetCounter(_ *gateway.MessageCreateEvent) { +func (t *testc) Noop(*gateway.MessageCreateEvent) {} +func (t *testc) GetCounter(*gateway.MessageCreateEvent) { t.Return <- strconv.FormatUint(t.Counter, 10) } - func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error { t.Return <- args return errors.New("oh no") } - func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) { t.Return <- c.args } - func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) { t.Return <- c[len(c)-1] } - func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, s string, c *customManualParsed) { t.Return <- c.args } - func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) { t.Return <- c } - -func (t *testc) NoArgs(_ *gateway.MessageCreateEvent) error { +func (t *testc) NoArgs(*gateway.MessageCreateEvent) error { return errors.New("passed") } - -func (t *testc) Noop(_ *gateway.MessageCreateEvent) { -} - -func (t *testc) OnTyping(_ *gateway.TypingStartEvent) { +func (t *testc) OnTyping(*gateway.TypingStartEvent) { t.Typed = true } @@ -108,26 +103,26 @@ func TestContext(t *testing.T) { }) t.Run("find commands", func(t *testing.T) { - cmd := ctx.FindCommand("", "NoArgs") + cmd := ctx.FindMethod("", "NoArgs") if cmd == nil { t.Fatal("Failed to find NoArgs") } }) - t.Run("help", func(t *testing.T) { - if h := ctx.Help(); h == "" { - t.Fatal("Empty help?") - } - if h := ctx.HelpAdmin(); h == "" { - t.Fatal("Empty admin help?") - } - }) + // t.Run("help", func(t *testing.T) { + // if h := ctx.Help(); h == "" { + // t.Fatal("Empty help?") + // } + // if h := ctx.HelpAdmin(); h == "" { + // t.Fatal("Empty admin help?") + // } + // }) t.Run("middleware", func(t *testing.T) { ctx.HasPrefix = NewPrefix("pls do ") // This should trigger the middleware first. - if err := expect(ctx, given, "1", "pls do getCounter"); err != nil { + if err := expect(ctx, given, "3", "pls do getCounter"); err != nil { t.Fatal("Unexpected error:", err) } }) @@ -247,7 +242,7 @@ func TestContext(t *testing.T) { t.Fatal("Unexpected call error:", err) } - if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil { + if cmd := ctx.FindMethod("testc", "Noop"); cmd == nil { t.Fatal("Failed to find subcommand Noop") } }) @@ -308,6 +303,7 @@ func BenchmarkCall(b *testing.B) { Subcommand: s, State: state, HasPrefix: NewPrefix("~"), + ParseArgs: DefaultArgsParser(), } m := &gateway.MessageCreateEvent{ @@ -335,6 +331,7 @@ func BenchmarkHelp(b *testing.B) { Subcommand: s, State: state, HasPrefix: NewPrefix("~"), + ParseArgs: DefaultArgsParser(), } b.ResetTimer() diff --git a/bot/error.go b/bot/error.go index 6acb712..9249697 100644 --- a/bot/error.go +++ b/bot/error.go @@ -6,28 +6,19 @@ import ( ) type ErrUnknownCommand struct { - Prefix string - Command string - Parent string - - // TODO: list available commands? - // Here, as a reminder - ctx []*CommandContext + Parts []string // max len 2 + Subcmd *Subcommand } func (err *ErrUnknownCommand) Error() string { + if len(err.Parts) > 2 { + err.Parts = err.Parts[:2] + } return UnknownCommandString(err) } var UnknownCommandString = func(err *ErrUnknownCommand) string { - var header = "Unknown command: " + err.Prefix - if err.Parent != "" { - header += err.Parent + " " + err.Command - } else { - header += err.Command - } - - return header + return "Unknown command: " + strings.Join(err.Parts, " ") } var ( @@ -43,7 +34,7 @@ type ErrInvalidUsage struct { // TODO: usage generator? // Here, as a reminder - Ctx *CommandContext + Ctx *MethodContext } func (err *ErrInvalidUsage) Error() string { diff --git a/bot/extras/middlewares/middlewares.go b/bot/extras/middlewares/middlewares.go new file mode 100644 index 0000000..299b620 --- /dev/null +++ b/bot/extras/middlewares/middlewares.go @@ -0,0 +1,49 @@ +package middlewares + +import ( + "github.com/diamondburned/arikawa/bot" + "github.com/diamondburned/arikawa/bot/extras/infer" + "github.com/diamondburned/arikawa/discord" +) + +func AdminOnly(ctx *bot.Context) func(interface{}) error { + return func(ev interface{}) error { + var channelID = infer.ChannelID(ev) + if !channelID.Valid() { + return bot.Break + } + + var userID = infer.UserID(ev) + if !userID.Valid() { + return bot.Break + } + + p, err := ctx.State.Permissions(channelID, userID) + if err == nil && p.Has(discord.PermissionAdministrator) { + return nil + } + + return bot.Break + } +} + +func GuildOnly(ctx *bot.Context) func(interface{}) error { + return func(ev interface{}) error { + // Try and infer the GuildID. + if guildID := infer.GuildID(ev); guildID.Valid() { + return nil + } + + var channelID = infer.ChannelID(ev) + if !channelID.Valid() { + return bot.Break + } + + c, err := ctx.State.Channel(channelID) + if err != nil || !c.GuildID.Valid() { + return bot.Break + } + + return nil + } +} diff --git a/bot/extras/middlewares/test.go b/bot/extras/middlewares/test.go new file mode 100644 index 0000000..9ff2af6 --- /dev/null +++ b/bot/extras/middlewares/test.go @@ -0,0 +1,11 @@ +package main + +import "testing" + +func TestAdminOnly(t *testing.T) { + t.Fatal("Do me.") +} + +func TestGuildOnly(t *testing.T) { + t.Fatal("Do me.") +} diff --git a/bot/nameflag.go b/bot/nameflag.go deleted file mode 100644 index 9fd7320..0000000 --- a/bot/nameflag.go +++ /dev/null @@ -1,107 +0,0 @@ -package bot - -import "strings" - -type NameFlag uint64 - -var FlagSeparator = 'ー' - -const None NameFlag = 0 - -// !!! -// -// These flags are applied to all events, if possible. The defined behavior -// is to search for "ChannelID" fields or "ID" fields in structs with -// "Channel" in its name. It doesn't handle individual events, as such, will -// not be able to guarantee it will always work. Refer to package infer. - -// R - Raw, which tells the library to use the method name as-is (flags will -// still be stripped). For example, if a method is called Reset its -// command will also be Reset, without being all lower-cased. -const Raw NameFlag = 1 << 1 - -// A - AdminOnly, which tells the library to only run the Subcommand/method -// if the user is admin or not. This will automatically add GuildOnly as -// well. -const AdminOnly NameFlag = 1 << 2 - -// G - GuildOnly, which tells the library to only run the Subcommand/method -// if the user is inside a guild. -const GuildOnly NameFlag = 1 << 3 - -// M - Middleware, which tells the library that the method is a middleware. -// The method will be executed anytime a method of the same struct is -// matched. -// -// Using this flag inside the subcommand will drop all methods (this is an -// undefined behavior/UB). -const Middleware NameFlag = 1 << 4 - -// H - Hidden/Handler, which tells the router to not add this into the list -// of commands, hiding it from Help. Handlers that are hidden will not have -// any arguments parsed. It will be treated as an Event. -const Hidden NameFlag = 1 << 5 - -// P - Plumb, which tells the router to call only this handler with all the -// arguments (except the prefix string). If plumb is used, only this method -// will be called for the given struct, though all other events as well as -// methods with the H (Hidden/Handler) flag. -// -// This is different from using H (Hidden/Handler), as handlers are called -// regardless of command prefixes. Plumb methods are only called once, and -// no other methods will be called for that struct. That said, a Plumb -// method would still go into Commands, but only itself will be there. -// -// Note that if there's a Plumb method in the main commands, then none of -// the subcommands would be called. This is an unintended but expected side -// effect. -// -// Example -// -// A use for this would be subcommands that don't need a second command, or -// if the main struct manually handles command switching. This example -// demonstrates the second use-case: -// -// func (s *Sub) PーMain( -// c *gateway.MessageCreateGateway, c *Content) error { -// -// // Input: !sub this is a command -// // Output: this is a command -// -// log.Println(c.String()) -// return nil -// } -// -const Plumb NameFlag = 1 << 6 - -func ParseFlag(name string) (NameFlag, string) { - parts := strings.SplitN(name, string(FlagSeparator), 2) - if len(parts) != 2 { - return 0, name - } - - var f NameFlag - - for _, r := range parts[0] { - switch r { - case 'R': - f |= Raw - case 'A': - f |= AdminOnly | GuildOnly - case 'G': - f |= GuildOnly - case 'M': - f |= Middleware - case 'H': - f |= Hidden - case 'P': - f |= Plumb - } - } - - return f, parts[1] -} - -func (f NameFlag) Is(flag NameFlag) bool { - return f&flag != 0 -} diff --git a/bot/nameflag_test.go b/bot/nameflag_test.go deleted file mode 100644 index 915a0b9..0000000 --- a/bot/nameflag_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package bot - -import "testing" - -func TestNameFlag(t *testing.T) { - type entry struct { - Name string - Expect NameFlag - String string - } - - var entries = []entry{{ - Name: "AーEcho", - Expect: AdminOnly, - }, { - Name: "RAーGC", - Expect: Raw | AdminOnly, - }} - - for _, entry := range entries { - var f, _ = ParseFlag(entry.Name) - if !f.Is(entry.Expect) { - t.Fatalf("unexpected expectation for %s: %v", entry.Name, f) - } - } -} diff --git a/bot/subcommand.go b/bot/subcommand.go index f0ac732..02f34ad 100644 --- a/bot/subcommand.go +++ b/bot/subcommand.go @@ -70,9 +70,6 @@ type Subcommand struct { // Parsed command name: Command string - // struct flags - Flag NameFlag - // SanitizeMessage is executed on the message content if the method returns // a string content or a SendMessageData. SanitizeMessage func(content string) string @@ -85,15 +82,12 @@ type Subcommand struct { // Commands can actually return either a string, an embed, or a // SendMessageData, with error as the second argument. - // All registered command contexts: - Commands []*CommandContext - Events []*CommandContext + // All registered method contexts, including commands: + Methods []*MethodContext + plumbed *MethodContext - // Middleware command contexts: - mwMethods []*CommandContext - - // Plumb nameflag, use Commands[0] if true. - plumb bool + // Global middlewares. + globalmws []*MiddlewareContext // Directly to struct cmdValue reflect.Value @@ -103,34 +97,9 @@ type Subcommand struct { ptrValue reflect.Value ptrType reflect.Type - // command interface as reference command interface{} } -// CommandContext is an internal struct containing fields to make this library -// work. As such, they're all unexported. Description, however, is exported for -// editing, and may be used to generate more informative help messages. -type CommandContext struct { - Description string - Flag NameFlag - - MethodName string - Command string // empty if Plumb - - // Hidden is true if the method has a hidden nameflag. - Hidden bool - - // Variadic is true if the function is a variadic one or if the last - // argument accepts multiple strings. - Variadic bool - - value reflect.Value // Func - event reflect.Type // gateway.*Event - method reflect.Method - - Arguments []Argument -} - // CanSetup is used for subcommands to change variables, such as Description. // This method will be triggered when InitCommands is called, which is during // New for Context and during RegisterSubcommand for subcommands. @@ -139,19 +108,6 @@ type CanSetup interface { Setup(*Subcommand) } -func (cctx *CommandContext) Usage() []string { - if len(cctx.Arguments) == 0 { - return nil - } - - var arguments = make([]string, len(cctx.Arguments)) - for i, arg := range cctx.Arguments { - arguments[i] = arg.String - } - - return arguments -} - // NewSubcommand is used to make a new subcommand. You usually wouldn't call // this function, but instead use (*Context).RegisterSubcommand(). func NewSubcommand(cmd interface{}) (*Subcommand, error) { @@ -177,34 +133,24 @@ func NewSubcommand(cmd interface{}) (*Subcommand, error) { // shouldn't be called at all, rather you should use RegisterSubcommand. func (sub *Subcommand) NeedsName() { sub.StructName = sub.cmdType.Name() - - flag, name := ParseFlag(sub.StructName) - - if !flag.Is(Raw) { - name = lowerFirstLetter(name) - } - - sub.Command = name - sub.Flag = flag + sub.Command = lowerFirstLetter(sub.StructName) } -// FindCommand finds the command. Nil is returned if nothing is found. It's a -// better idea to not handle nil, as they would become very subtle bugs. -func (sub *Subcommand) FindCommand(methodName string) *CommandContext { - for _, c := range sub.Commands { - if c.MethodName != methodName { - continue +// FindMethod finds the MethodContext. It panics if methodName is not found. +func (sub *Subcommand) FindMethod(methodName string) *MethodContext { + for _, c := range sub.Methods { + if c.MethodName == methodName { + return c } - return c } - return nil + panic("Can't find method " + methodName) } // ChangeCommandInfo changes the matched methodName's Command and Description. -// Empty means unchanged. The returned bool is true when the method is found. +// Empty means unchanged. The returned bool is true when the command is found. func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool { - for _, c := range sub.Commands { - if c.MethodName != methodName { + for _, c := range sub.Methods { + if c.MethodName != methodName || !c.isEvent(typeMessageCreate) { continue } @@ -222,70 +168,70 @@ func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool { } func (sub *Subcommand) Help(indent string, hideAdmin bool) string { - if sub.Flag.Is(AdminOnly) && hideAdmin { - return "" - } + // // The header part: + // var header string - // The header part: - var header string + // if sub.Command != "" { + // header += "**" + sub.Command + "**" + // } - if sub.Command != "" { - header += "**" + sub.Command + "**" - } + // if sub.Description != "" { + // if header != "" { + // header += ": " + // } - if sub.Description != "" { - if header != "" { - header += ": " - } + // header += sub.Description + // } - header += sub.Description - } + // header += "\n" - header += "\n" + // // The commands part: + // var commands = "" - // The commands part: - var commands = "" + // for i, cmd := range sub.Commands { + // if cmd.Flag.Is(AdminOnly) && hideAdmin { + // continue + // } - for i, cmd := range sub.Commands { - if cmd.Flag.Is(AdminOnly) && hideAdmin { - continue - } + // switch { + // case sub.Command != "" && cmd.Command != "": + // commands += indent + sub.Command + " " + cmd.Command + // case sub.Command != "": + // commands += indent + sub.Command + // default: + // commands += indent + cmd.Command + // } - switch { - case sub.Command != "" && cmd.Command != "": - commands += indent + sub.Command + " " + cmd.Command - case sub.Command != "": - commands += indent + sub.Command - default: - commands += indent + cmd.Command - } + // // Write the usages first. + // for _, usage := range cmd.Usage() { + // commands += " " + underline(usage) + // } - // Write the usages first. - for _, usage := range cmd.Usage() { - commands += " " + underline(usage) - } + // // Is the last argument trailing? If so, append ellipsis. + // if cmd.Variadic { + // commands += "..." + // } - // Is the last argument trailing? If so, append ellipsis. - if cmd.Variadic { - commands += "..." - } + // // Write the description if there's any. + // if cmd.Description != "" { + // commands += ": " + cmd.Description + // } - // Write the description if there's any. - if cmd.Description != "" { - commands += ": " + cmd.Description - } + // // Add a new line if this isn't the last command. + // if i != len(sub.Commands)-1 { + // commands += "\n" + // } + // } - // Add a new line if this isn't the last command. - if i != len(sub.Commands)-1 { - commands += "\n" - } - } + // if commands == "" { + // return "" + // } - if commands == "" { - return "" - } + // return header + commands - return header + commands + // TODO + // TODO: Interface Helper implements Help() string + return "TODO" } func (sub *Subcommand) reflectCommands() error { @@ -327,12 +273,6 @@ func (sub *Subcommand) InitCommands(ctx *Context) error { v.Setup(sub) } - // Finalize the subcommand: - for _, cmd := range sub.Commands { - // Inherit parent's flags - cmd.Flag |= sub.Flag - } - return nil } @@ -365,126 +305,93 @@ func (sub *Subcommand) parseCommands() error { continue } - methodT := method.Type() - numArgs := methodT.NumIn() - - if numArgs == 0 { - // Doesn't meet the requirement for an event, continue. + methodT := sub.ptrType.Method(i) + if methodT.Name == "Setup" && methodT.Type == typeSetupFn { continue } - if methodT == typeSetupFn { - // Method is a setup method, continue. + cctx := parseMethod(method, methodT) + if cctx == nil { continue } - // Check number of returns: - numOut := methodT.NumOut() - - // Returns can either be: - // Nothing - func() - // An error - func() error - // An error and something else - func() (T, error) - if numOut > 2 { - continue - } - - // Check the last return's type if the method returns anything. - if numOut > 0 { - if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) { - // Invalid, skip. - continue - } - } - - var command = CommandContext{ - method: sub.ptrType.Method(i), - value: method, - event: methodT.In(0), // parse event - Variadic: methodT.IsVariadic(), - } - - // Parse the method name - flag, name := ParseFlag(command.method.Name) - - // Set the method name, command, and flag: - command.MethodName = name - command.Command = name - command.Flag = flag - - // Check if Raw is enabled for command: - if !flag.Is(Raw) { - command.Command = lowerFirstLetter(name) - } - - // Middlewares shouldn't even have arguments. - if flag.Is(Middleware) { - sub.mwMethods = append(sub.mwMethods, &command) - continue - } - - // TODO: allow more flexibility - if command.event != typeMessageCreate || flag.Is(Hidden) { - sub.Events = append(sub.Events, &command) - continue - } - - // See if we know the first return type, if error's return is the - // second: - if numOut > 1 { - switch t := methodT.Out(0); t { - case typeString, typeEmbed, typeSend: - // noop, passes - default: - continue - } - } - - // If a plumb method has been found: - if sub.plumb { - continue - } - - // If the method only takes an event: - if numArgs == 1 { - sub.Commands = append(sub.Commands, &command) - continue - } - - command.Arguments = make([]Argument, 0, numArgs) - - // Fill up arguments. This should work with cusP and manP - for i := 1; i < numArgs; i++ { - t := methodT.In(i) - a, err := newArgument(t, command.Variadic) - if err != nil { - return errors.Wrap(err, "Error parsing argument "+t.String()) - } - - command.Arguments = append(command.Arguments, *a) - - // We're done if the type accepts multiple arguments. - if a.custom != nil || a.manual != nil { - command.Variadic = true // treat as variadic - break - } - } - - // If the current event is a plumb event: - if flag.Is(Plumb) { - command.Command = "" // plumbers don't have names - sub.Commands = []*CommandContext{&command} - sub.plumb = true - continue - } - - // Append - sub.Commands = append(sub.Commands, &command) + // Append. + sub.Methods = append(sub.Methods, cctx) } return nil } +func (sub *Subcommand) AddMiddleware(methodName string, middleware interface{}) { + var mw *MiddlewareContext + // Allow *MiddlewareContext to be passed into. + if v, ok := middleware.(*MiddlewareContext); ok { + mw = v + } else { + mw = ParseMiddleware(middleware) + } + + // Parse method name: + for _, method := range strings.Split(methodName, ",") { + // Trim space. + if method = strings.TrimSpace(method); method == "*" { + // Append middleware to global middleware slice. + sub.globalmws = append(sub.globalmws, mw) + } else { + // Append middleware to that individual function. + sub.FindMethod(method).addMiddleware(mw) + } + } +} + +func (sub *Subcommand) walkMiddlewares(ev reflect.Value) error { + for _, mw := range sub.globalmws { + _, err := mw.call(ev) + if err != nil { + return err + } + } + return nil +} + +func (sub *Subcommand) eventCallers(evT reflect.Type) (callers []caller) { + // Search for global middlewares. + for _, mw := range sub.globalmws { + if mw.isEvent(evT) { + callers = append(callers, mw) + } + } + + // Search for specific handlers. + for _, cctx := range sub.Methods { + // We only take middlewares and callers if the event matches and is not + // a MessageCreate. The other function already handles that. + if cctx.event != typeMessageCreate && cctx.isEvent(evT) { + // Add the command's middlewares first. + for _, mw := range cctx.middlewares { + // Concrete struct to interface conversion done implicitly. + callers = append(callers, mw) + } + + callers = append(callers, cctx) + } + } + return +} + +// SetPlumb sets the method as the plumbed command. This means that all calls +// without the second command argument will call this method in a subcommand. It +// panics if sub.Command is empty. +func (sub *Subcommand) SetPlumb(methodName string) { + if sub.Command == "" { + panic("SetPlumb called on a main command with sub.Command empty.") + } + + method := sub.FindMethod(methodName) + method.Command = "" + sub.plumbed = method +} + func lowerFirstLetter(name string) string { return strings.ToLower(string(name[0])) + name[1:] } diff --git a/bot/subcommand_test.go b/bot/subcommand_test.go index 22eda07..998e57d 100644 --- a/bot/subcommand_test.go +++ b/bot/subcommand_test.go @@ -29,8 +29,8 @@ func TestSubcommand(t *testing.T) { } // !!! CHANGE ME - if len(sub.Commands) != 8 { - t.Fatal("invalid ctx.commands len", len(sub.Commands)) + if len(sub.Methods) < 8 { + t.Fatal("too low sub.Methods len", len(sub.Methods)) } var ( @@ -39,7 +39,7 @@ func TestSubcommand(t *testing.T) { foundNoArgs bool ) - for _, this := range sub.Commands { + for _, this := range sub.Methods { switch this.Command { case "send": foundSend = true @@ -58,13 +58,6 @@ func TestSubcommand(t *testing.T) { if len(this.Arguments) != 0 { t.Fatal("expected 0 arguments, got non-zero") } - - case "noop", "getCounter", "variadic", "trailCustom", "content": - // Found, but whatever - } - - if this.event != typeMessageCreate { - t.Fatal("invalid event type:", this.event.String()) } }