Bot: Partially implemented middlewares

This commit is contained in:
diamondburned (Forefront) 2020-05-10 01:45:00 -07:00
parent 1ca7d1c62c
commit 964e8cdf13
13 changed files with 619 additions and 749 deletions

View File

@ -2,6 +2,8 @@
Not a lot for a Discord bot: Not a lot for a Discord bot:
# THIS IS OUTDATED. TODO: UPDATE.
``` ```
# Cold functions, or functions that are called once in runtime: # Cold functions, or functions that are called once in runtime:
BenchmarkConstructor-8 150537 7617 ns/op BenchmarkConstructor-8 150537 7617 ns/op

236
bot/command.go Normal file
View File

@ -0,0 +1,236 @@
package bot
import (
"reflect"
)
type command struct {
value reflect.Value // Func
event reflect.Type
isInterface bool
}
func newCommand(value reflect.Value, event reflect.Type) command {
return command{
value: value,
event: event,
isInterface: event.Kind() == reflect.Interface,
}
}
func (c *command) isEvent(t reflect.Type) bool {
return (!c.isInterface && c.event == t) || (c.isInterface && t.Implements(c.event))
}
func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
return callWith(c.value, arg0, argv...)
}
func callWith(caller reflect.Value, arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
var callargs = make([]reflect.Value, 0, 1+len(argv))
if v, ok := arg0.(reflect.Value); ok {
callargs = append(callargs, v)
} else {
callargs = append(callargs, reflect.ValueOf(arg0))
}
callargs = append(callargs, argv...)
return errorReturns(caller.Call(callargs))
}
type caller interface {
call(arg0 interface{}, argv ...reflect.Value) (interface{}, error)
}
func errorReturns(returns []reflect.Value) (interface{}, error) {
// Handlers may return nothing.
if len(returns) == 0 {
return nil, nil
}
// assume first return is always error, since we checked for this in
// parseCommands.
v := returns[len(returns)-1].Interface()
// If the last return (error) is nil.
if v == nil {
// If we only have 1 returns, that return must be the error. The error
// is nil, so nil is returned.
if len(returns) == 1 {
return nil, nil
}
// Return the first argument as-is. The above returns[-1] check assumes
// 2 return values (T, error), meaning returns[0] is the T value.
return returns[0].Interface(), nil
}
// Treat the last return as an error.
return nil, v.(error)
}
// MethodContext is an internal struct containing fields to make this library
// work. As such, they're all unexported. Description, however, is exported for
// editing, and may be used to generate more informative help messages.
type MethodContext struct {
command
method reflect.Method // extend
middlewares []*MiddlewareContext
Description string
// MethodName is the name of the method. This field should NOT be changed.
MethodName string
// Command is the Discord command used to call the method.
Command string // hidden if empty
// Hidden is true if the method has a hidden nameflag.
// Hidden bool
// Variadic is true if the function is a variadic one or if the last
// argument accepts multiple strings.
Variadic bool
Arguments []Argument
}
func parseMethod(value reflect.Value, method reflect.Method) *MethodContext {
methodT := value.Type()
numArgs := methodT.NumIn()
if numArgs == 0 {
// Doesn't meet the requirement for an event, continue.
return nil
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
// An error and something else - func() (T, error)
if numOut > 2 {
return nil
}
// Check the last return's type if the method returns anything.
if numOut > 0 {
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
// Invalid, skip.
return nil
}
}
var command = MethodContext{
command: newCommand(value, methodT.In(0)),
method: method,
MethodName: method.Name,
Variadic: methodT.IsVariadic(),
}
// Only set the command name if it's a MessageCreate handler.
if command.event == typeMessageCreate {
command.Command = lowerFirstLetter(command.method.Name)
}
if numArgs > 1 {
// Event handlers that aren't MessageCreate should not have arguments.
if command.event != typeMessageCreate {
return nil
}
// If the event type is messageCreate:
command.Arguments = make([]Argument, 0, numArgs-1)
// Fill up arguments. This should work with cusP and manP
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
a, err := newArgument(t, command.Variadic)
if err != nil {
panic("Error parsing argument " + t.String() + ": " + err.Error())
}
command.Arguments = append(command.Arguments, *a)
// We're done if the type accepts multiple arguments.
if a.custom != nil || a.manual != nil {
command.Variadic = true // treat as variadic
break
}
}
}
return &command
}
func (cctx *MethodContext) addMiddleware(mw *MiddlewareContext) {
cctx.middlewares = append(cctx.middlewares, mw)
}
func (cctx *MethodContext) walkMiddlewares(ev reflect.Value) error {
for _, mw := range cctx.middlewares {
_, err := mw.call(ev)
if err != nil {
return err
}
}
return nil
}
func (cctx *MethodContext) Usage() []string {
if len(cctx.Arguments) == 0 {
return nil
}
var arguments = make([]string, len(cctx.Arguments))
for i, arg := range cctx.Arguments {
arguments[i] = arg.String
}
return arguments
}
// SetName sets the command name.
func (cctx *MethodContext) SetName(name string) {
cctx.Command = name
}
type MiddlewareContext struct {
command
}
// ParseMiddleware parses a middleware function. This function panics.
func ParseMiddleware(mw interface{}) *MiddlewareContext {
value := reflect.ValueOf(mw)
methodT := value.Type()
numArgs := methodT.NumIn()
if numArgs != 1 {
panic("Invalid argument signature for " + methodT.String())
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
if numOut > 1 {
panic("Invalid return signature for " + methodT.String())
}
// Check the last return's type if the method returns anything.
if numOut == 1 {
if i := methodT.Out(0); i == nil || !i.Implements(typeIError) {
panic("Unexpected return type (not error) for " + methodT.String())
}
}
var middleware = MiddlewareContext{
command: newCommand(value, methodT.In(0)),
}
return &middleware
}

View File

@ -217,19 +217,19 @@ func (ctx *Context) Subcommands() []*Subcommand {
return ctx.subcommands return ctx.subcommands
} }
// FindCommand finds a command based on the struct and method name. The queried // FindMethod finds a method based on the struct and method name. The queried
// names will have their flags stripped. // names will have their flags stripped.
// //
// Example // Example
// //
// // Find a command from the main context: // // Find a command from the main context:
// cmd := ctx.FindCommand("", "Method") // cmd := ctx.FindMethod("", "Method")
// // Find a command from a subcommand: // // Find a command from a subcommand:
// cmd = ctx.FindCommand("Starboard", "Reset") // cmd = ctx.FindMethod("Starboard", "Reset")
// //
func (ctx *Context) FindCommand(structname, methodname string) *CommandContext { func (ctx *Context) FindMethod(structname, methodname string) *MethodContext {
if structname == "" { if structname == "" {
for _, c := range ctx.Commands { for _, c := range ctx.Methods {
if c.MethodName == methodname { if c.MethodName == methodname {
return c return c
} }
@ -243,7 +243,7 @@ func (ctx *Context) FindCommand(structname, methodname string) *CommandContext {
continue continue
} }
for _, c := range sub.Commands { for _, c := range sub.Methods {
if c.MethodName == methodname { if c.MethodName == methodname {
return c return c
} }
@ -360,52 +360,55 @@ func (ctx *Context) HelpAdmin() string {
} }
func (ctx *Context) help(hideAdmin bool) string { func (ctx *Context) help(hideAdmin bool) string {
const indent = " " // const indent = " "
var help strings.Builder // var help strings.Builder
// Generate the headers and descriptions // // Generate the headers and descriptions
help.WriteString("__Help__") // help.WriteString("__Help__")
if ctx.Name != "" { // if ctx.Name != "" {
help.WriteString(": " + ctx.Name) // help.WriteString(": " + ctx.Name)
} // }
if ctx.Description != "" { // if ctx.Description != "" {
help.WriteString("\n" + indent + ctx.Description) // help.WriteString("\n" + indent + ctx.Description)
} // }
if ctx.Flag.Is(AdminOnly) { // if ctx.Flag.Is(AdminOnly) {
// That's it. // // That's it.
return help.String() // return help.String()
} // }
// Separators // // Separators
help.WriteString("\n---\n") // help.WriteString("\n---\n")
// Generate all commands // // Generate all commands
help.WriteString("__Commands__") // help.WriteString("__Commands__")
help.WriteString(ctx.Subcommand.Help(indent, hideAdmin)) // help.WriteString(ctx.Subcommand.Help(indent, hideAdmin))
help.WriteByte('\n') // help.WriteByte('\n')
var subHelp = strings.Builder{} // var subHelp = strings.Builder{}
var subcommands = ctx.Subcommands() // var subcommands = ctx.Subcommands()
for _, sub := range subcommands { // for _, sub := range subcommands {
if help := sub.Help(indent, hideAdmin); help != "" { // if help := sub.Help(indent, hideAdmin); help != "" {
for _, line := range strings.Split(help, "\n") { // for _, line := range strings.Split(help, "\n") {
subHelp.WriteString(indent) // subHelp.WriteString(indent)
subHelp.WriteString(line) // subHelp.WriteString(line)
subHelp.WriteByte('\n') // subHelp.WriteByte('\n')
} // }
} // }
} // }
if subHelp.Len() > 0 { // if subHelp.Len() > 0 {
help.WriteString("---\n") // help.WriteString("---\n")
help.WriteString("__Subcommands__\n") // help.WriteString("__Subcommands__\n")
help.WriteString(subHelp.String()) // help.WriteString(subHelp.String())
} // }
return help.String() // return help.String()
// TODO
return ""
} }

View File

@ -5,136 +5,75 @@ import (
"strings" "strings"
"github.com/diamondburned/arikawa/api" "github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/bot/extras/infer"
"github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway" "github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// NonFatal is an interface that a method can implement to ignore all errors. // Break is a non-fatal error that could be returned from middlewares to stop
// This works similarly to Break. // the chain of execution.
type NonFatal interface { var Break = errors.New("break middleware chain, non-fatal")
error
IgnoreError() // noop method
}
func onlyFatal(err error) error {
if _, ok := err.(NonFatal); ok {
return nil
}
return err
}
type _Break struct{ error }
// implement NonFatal.
func (_Break) IgnoreError() {}
// Break is a non-fatal error that could be returned from middlewares or
// handlers to stop the chain of execution.
//
// Middlewares are guaranteed to be executed before handlers, but the exact
// order of each are undefined. Main handlers are also guaranteed to be executed
// before all subcommands. If a main middleware cancels, no subcommand
// middlewares will be called.
//
// Break implements the NonFatal interface, which causes an error to be ignored.
var Break NonFatal = _Break{errors.New("break middleware chain, non-fatal")}
func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
var callers []*CommandContext
var middles []*CommandContext
var found bool
find := func(sub *Subcommand) {
for _, cmd := range sub.Events {
// Search only for callers, so skip middlewares.
if cmd.Flag.Is(Middleware) {
continue
}
if cmd.event == evT {
callers = append(callers, cmd)
found = true
}
}
// Only get middlewares if we found handlers for that same event.
if found {
// Search for middlewares with the same type:
for _, mw := range sub.mwMethods {
if mw.event == evT {
middles = append(middles, mw)
}
}
}
}
// filterEventType filters all commands and subcommands into a 2D slice,
// structured so that a Break would only exit out the nested slice.
func (ctx *Context) filterEventType(evT reflect.Type) (callers [][]caller) {
// Find the main context first. // Find the main context first.
find(ctx.Subcommand) callers = append(callers, ctx.eventCallers(evT))
for _, sub := range ctx.subcommands { for _, sub := range ctx.subcommands {
// Reset found status
found = false
// Find subcommands second. // Find subcommands second.
find(sub) callers = append(callers, sub.eventCallers(evT))
} }
return append(middles, callers...) return
} }
func (ctx *Context) callCmd(ev interface{}) error { func (ctx *Context) callCmd(ev interface{}) (bottomError error) {
evT := reflect.TypeOf(ev) evV := reflect.ValueOf(ev)
evT := evV.Type()
var isAdmin *bool // I want to die. var callers [][]caller
var isGuild *bool
var callers []*CommandContext
// Hit the cache // Hit the cache
t, ok := ctx.typeCache.Load(evT) t, ok := ctx.typeCache.Load(evT)
if ok { if ok {
callers = t.([]*CommandContext) callers = t.([][]caller)
} else { } else {
callers = ctx.filterEventType(evT) callers = ctx.filterEventType(evT)
ctx.typeCache.Store(evT, callers) ctx.typeCache.Store(evT, callers)
} }
// We can't do the callers[:0] trick here, as it will modify the slice for _, subcallers := range callers {
// inside the sync.Map as well. for _, c := range subcallers {
var filtered = make([]*CommandContext, 0, len(callers)) _, err := c.call(evV)
if err != nil {
// Only count as an error if it's not Break.
if err = errNoBreak(err); err != nil {
bottomError = err
}
for _, cmd := range callers { // Break the caller loop only for this subcommand.
// Command flags will inherit its parent Subcommand's flags. break
if true &&
!(cmd.Flag.Is(AdminOnly) && !ctx.eventIsAdmin(ev, &isAdmin)) &&
!(cmd.Flag.Is(GuildOnly) && !ctx.eventIsGuild(ev, &isGuild)) {
filtered = append(filtered, cmd)
}
}
for _, c := range filtered {
_, err := callWith(c.value, ev)
if err != nil {
if err = onlyFatal(err); err != nil {
ctx.ErrorLogger(err)
} }
return err
} }
} }
// We call the messages later, since Hidden handlers will go into the Events // We call the messages later, since we want MessageCreate middlewares to
// slice, but we don't want to ignore those handlers either. // run as well.
if evT == typeMessageCreate { if evT == typeMessageCreate {
// safe assertion always // safe assertion always
err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent)) err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent), evV)
return onlyFatal(err) // There's no need for an errNoBreak here, as the method already checked
// for that.
if err != nil {
bottomError = err
}
} }
return nil return
} }
func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent, value reflect.Value) error {
// check if bot // check if bot
if !ctx.AllowBot && mc.Author.Bot { if !ctx.AllowBot && mc.Author.Bot {
return nil return nil
@ -163,102 +102,18 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
return nil // ??? return nil // ???
} }
var cmd *CommandContext // Find the command and subcommand.
var sub *Subcommand arguments, cmd, sub, err := ctx.findCommand(parts)
// var start int // arg starts from $start if err != nil {
return errNoBreak(err)
// Check if plumb:
if ctx.plumb {
cmd = ctx.Commands[0]
sub = ctx.Subcommand
// start = 0
} }
// Arguments slice, which will be sliced away until only arguments are left. // We don't run the subcommand's middlewares here, as the callCmd function
var arguments = parts // already handles that.
// If not plumb, search for the command // Run command middlewares.
if cmd == nil { if err := cmd.walkMiddlewares(value); err != nil {
for _, c := range ctx.Commands { return errNoBreak(err)
if c.Command == parts[0] {
cmd = c
sub = ctx.Subcommand
arguments = arguments[1:]
// start = 1
break
}
}
}
// Can't find the command, look for subcommands if len(args) has a 2nd
// entry.
if cmd == nil {
for _, s := range ctx.subcommands {
if s.Command != parts[0] {
continue
}
// Check if plumb:
if s.plumb {
cmd = s.Commands[0]
sub = s
arguments = arguments[1:]
// start = 1
break
}
// There's no second argument, so we can only look for Plumbed
// subcommands.
if len(parts) < 2 {
continue
}
for _, c := range s.Commands {
if c.Command == parts[1] {
cmd = c
sub = s
arguments = arguments[2:]
break
// start = 2
}
}
if cmd == nil {
if s.QuietUnknownCommand {
return nil
}
return &ErrUnknownCommand{
Command: parts[1],
Parent: parts[0],
ctx: s.Commands,
}
}
break
}
}
if cmd == nil {
if ctx.QuietUnknownCommand {
return nil
}
return &ErrUnknownCommand{
Command: parts[0],
ctx: ctx.Commands,
}
}
// Check for IsAdmin and IsGuild
if cmd.Flag.Is(GuildOnly) && !mc.GuildID.Valid() {
return nil
}
if cmd.Flag.Is(AdminOnly) {
p, err := ctx.State.Permissions(mc.ChannelID, mc.Author.ID)
if err != nil || !p.Has(discord.PermissionAdministrator) {
return nil
}
} }
// Start converting // Start converting
@ -375,8 +230,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
// could contain multiple whitespaces, and the parser would not // could contain multiple whitespaces, and the parser would not
// count them. // count them.
var seekTo = cmd.Command var seekTo = cmd.Command
// If plumbed, then there would only be the subcommand. // Implicit plumbing behavior.
if sub.plumb { if seekTo == "" {
seekTo = sub.Command seekTo = sub.Command
} }
@ -406,17 +261,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
} }
Call: Call:
// Try calling all middlewares first. We don't need to stack middlewares, as
// there will only be one command match.
for _, mw := range sub.mwMethods {
_, err := callWith(mw.value, mc)
if err != nil {
return err
}
}
// call the function and parse the error return value // call the function and parse the error return value
v, err := callWith(cmd.value, mc, argv...) v, err := cmd.call(value, argv...)
if err != nil { if err != nil {
return err return err
} }
@ -437,91 +283,59 @@ Call:
return err return err
} }
func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool { // findCommand filters.
if *is != nil { func (ctx *Context) findCommand(parts []string) ([]string, *MethodContext, *Subcommand, error) {
return **is // Main command entrypoint cannot have plumb.
for _, c := range ctx.Methods {
if c.Command == parts[0] {
return parts[1:], c, ctx.Subcommand, nil
}
} }
var channelID = infer.ChannelID(ev) // Can't find the command, look for subcommands if len(args) has a 2nd
if !channelID.Valid() { // entry.
return false for _, s := range ctx.subcommands {
} if s.Command != parts[0] {
continue
var userID = infer.UserID(ev)
if !userID.Valid() {
return false
}
var res bool
p, err := ctx.State.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
res = true
}
*is = &res
return res
}
func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool {
if *is != nil {
return **is
}
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return false
}
c, err := ctx.State.Channel(channelID)
if err != nil {
return false
}
res := c.GuildID.Valid()
*is = &res
return res
}
func callWith(
caller reflect.Value,
ev interface{}, values ...reflect.Value) (interface{}, error) {
var callargs = make([]reflect.Value, 0, 1+len(values))
if v, ok := ev.(reflect.Value); ok {
callargs = append(callargs, v)
} else {
callargs = append(callargs, reflect.ValueOf(ev))
}
callargs = append(callargs, values...)
return errorReturns(caller.Call(callargs))
}
func errorReturns(returns []reflect.Value) (interface{}, error) {
// Handlers may return nothing.
if len(returns) == 0 {
return nil, nil
}
// assume first return is always error, since we checked for this in
// parseCommands.
v := returns[len(returns)-1].Interface()
// If the last return (error) is nil.
if v == nil {
// If we only have 1 returns, that return must be the error. The error
// is nil, so nil is returned.
if len(returns) == 1 {
return nil, nil
} }
// Return the first argument as-is. The above returns[-1] check assumes // If there's no second argument, TODO call Help.
// 2 return values (T, error), meaning returns[0] is the T value.
return returns[0].Interface(), nil if s.plumbed != nil {
return parts[1:], s.plumbed, s, nil
}
if len(parts) >= 2 {
for _, c := range s.Methods {
if c.event == typeMessageCreate && c.Command == parts[1] {
return parts[2:], c, s, nil
}
}
}
if s.QuietUnknownCommand || ctx.QuietUnknownCommand {
return nil, nil, nil, Break
}
return nil, nil, nil, &ErrUnknownCommand{
Parts: parts,
Subcmd: s,
}
} }
// Treat the last return as an error. if ctx.QuietUnknownCommand {
return nil, v.(error) return nil, nil, nil, Break
}
return nil, nil, nil, &ErrUnknownCommand{
Parts: parts,
Subcmd: ctx.Subcommand,
}
}
func errNoBreak(err error) error {
if errors.Is(err, Break) {
return nil
}
return err
} }

View File

@ -15,12 +15,16 @@ type hasPlumb struct {
NotPlumbed bool NotPlumbed bool
} }
func (h *hasPlumb) Setup(sub *Subcommand) {
sub.SetPlumb("Plumber")
}
func (h *hasPlumb) Normal(_ *gateway.MessageCreateEvent) error { func (h *hasPlumb) Normal(_ *gateway.MessageCreateEvent) error {
h.NotPlumbed = true h.NotPlumbed = true
return nil return nil
} }
func (h *hasPlumb) PーPlumber(_ *gateway.MessageCreateEvent, c RawArguments) error { func (h *hasPlumb) Plumber(_ *gateway.MessageCreateEvent, c RawArguments) error {
h.Plumbed = string(c) h.Plumbed = string(c)
return nil return nil
} }
@ -43,10 +47,6 @@ func TestSubcommandPlumb(t *testing.T) {
t.Fatal("Failed to register hasPlumb:", err) t.Fatal("Failed to register hasPlumb:", err)
} }
if l := len(c.subcommands[0].Commands); l != 1 {
t.Fatal("Unexpected length for sub.Commands:", l)
}
// Try call exactly what's in the Plumb example: // Try call exactly what's in the Plumb example:
m := &gateway.MessageCreateEvent{ m := &gateway.MessageCreateEvent{
Message: discord.Message{ Message: discord.Message{

View File

@ -21,43 +21,38 @@ type testc struct {
Typed bool Typed bool
} }
func (t *testc) MーBumpCounter(interface{}) { func (t *testc) Setup(sub *Subcommand) {
t.Counter++ sub.AddMiddleware("*,GetCounter", func(v interface{}) {
t.Counter++
})
sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) {
t.Counter++
})
} }
func (t *testc) Noop(*gateway.MessageCreateEvent) {}
func (t *testc) GetCounter(_ *gateway.MessageCreateEvent) { func (t *testc) GetCounter(*gateway.MessageCreateEvent) {
t.Return <- strconv.FormatUint(t.Counter, 10) t.Return <- strconv.FormatUint(t.Counter, 10)
} }
func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error { func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error {
t.Return <- args t.Return <- args
return errors.New("oh no") return errors.New("oh no")
} }
func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) { func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) {
t.Return <- c.args t.Return <- c.args
} }
func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) { func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) {
t.Return <- c[len(c)-1] t.Return <- c[len(c)-1]
} }
func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, s string, c *customManualParsed) { func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, s string, c *customManualParsed) {
t.Return <- c.args t.Return <- c.args
} }
func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) { func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) {
t.Return <- c t.Return <- c
} }
func (t *testc) NoArgs(*gateway.MessageCreateEvent) error {
func (t *testc) NoArgs(_ *gateway.MessageCreateEvent) error {
return errors.New("passed") return errors.New("passed")
} }
func (t *testc) OnTyping(*gateway.TypingStartEvent) {
func (t *testc) Noop(_ *gateway.MessageCreateEvent) {
}
func (t *testc) OnTyping(_ *gateway.TypingStartEvent) {
t.Typed = true t.Typed = true
} }
@ -108,26 +103,26 @@ func TestContext(t *testing.T) {
}) })
t.Run("find commands", func(t *testing.T) { t.Run("find commands", func(t *testing.T) {
cmd := ctx.FindCommand("", "NoArgs") cmd := ctx.FindMethod("", "NoArgs")
if cmd == nil { if cmd == nil {
t.Fatal("Failed to find NoArgs") t.Fatal("Failed to find NoArgs")
} }
}) })
t.Run("help", func(t *testing.T) { // t.Run("help", func(t *testing.T) {
if h := ctx.Help(); h == "" { // if h := ctx.Help(); h == "" {
t.Fatal("Empty help?") // t.Fatal("Empty help?")
} // }
if h := ctx.HelpAdmin(); h == "" { // if h := ctx.HelpAdmin(); h == "" {
t.Fatal("Empty admin help?") // t.Fatal("Empty admin help?")
} // }
}) // })
t.Run("middleware", func(t *testing.T) { t.Run("middleware", func(t *testing.T) {
ctx.HasPrefix = NewPrefix("pls do ") ctx.HasPrefix = NewPrefix("pls do ")
// This should trigger the middleware first. // This should trigger the middleware first.
if err := expect(ctx, given, "1", "pls do getCounter"); err != nil { if err := expect(ctx, given, "3", "pls do getCounter"); err != nil {
t.Fatal("Unexpected error:", err) t.Fatal("Unexpected error:", err)
} }
}) })
@ -247,7 +242,7 @@ func TestContext(t *testing.T) {
t.Fatal("Unexpected call error:", err) t.Fatal("Unexpected call error:", err)
} }
if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil { if cmd := ctx.FindMethod("testc", "Noop"); cmd == nil {
t.Fatal("Failed to find subcommand Noop") t.Fatal("Failed to find subcommand Noop")
} }
}) })
@ -308,6 +303,7 @@ func BenchmarkCall(b *testing.B) {
Subcommand: s, Subcommand: s,
State: state, State: state,
HasPrefix: NewPrefix("~"), HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
} }
m := &gateway.MessageCreateEvent{ m := &gateway.MessageCreateEvent{
@ -335,6 +331,7 @@ func BenchmarkHelp(b *testing.B) {
Subcommand: s, Subcommand: s,
State: state, State: state,
HasPrefix: NewPrefix("~"), HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
} }
b.ResetTimer() b.ResetTimer()

View File

@ -6,28 +6,19 @@ import (
) )
type ErrUnknownCommand struct { type ErrUnknownCommand struct {
Prefix string Parts []string // max len 2
Command string Subcmd *Subcommand
Parent string
// TODO: list available commands?
// Here, as a reminder
ctx []*CommandContext
} }
func (err *ErrUnknownCommand) Error() string { func (err *ErrUnknownCommand) Error() string {
if len(err.Parts) > 2 {
err.Parts = err.Parts[:2]
}
return UnknownCommandString(err) return UnknownCommandString(err)
} }
var UnknownCommandString = func(err *ErrUnknownCommand) string { var UnknownCommandString = func(err *ErrUnknownCommand) string {
var header = "Unknown command: " + err.Prefix return "Unknown command: " + strings.Join(err.Parts, " ")
if err.Parent != "" {
header += err.Parent + " " + err.Command
} else {
header += err.Command
}
return header
} }
var ( var (
@ -43,7 +34,7 @@ type ErrInvalidUsage struct {
// TODO: usage generator? // TODO: usage generator?
// Here, as a reminder // Here, as a reminder
Ctx *CommandContext Ctx *MethodContext
} }
func (err *ErrInvalidUsage) Error() string { func (err *ErrInvalidUsage) Error() string {

View File

@ -0,0 +1,49 @@
package middlewares
import (
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/bot/extras/infer"
"github.com/diamondburned/arikawa/discord"
)
func AdminOnly(ctx *bot.Context) func(interface{}) error {
return func(ev interface{}) error {
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return bot.Break
}
var userID = infer.UserID(ev)
if !userID.Valid() {
return bot.Break
}
p, err := ctx.State.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
return nil
}
return bot.Break
}
}
func GuildOnly(ctx *bot.Context) func(interface{}) error {
return func(ev interface{}) error {
// Try and infer the GuildID.
if guildID := infer.GuildID(ev); guildID.Valid() {
return nil
}
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return bot.Break
}
c, err := ctx.State.Channel(channelID)
if err != nil || !c.GuildID.Valid() {
return bot.Break
}
return nil
}
}

View File

@ -0,0 +1,11 @@
package main
import "testing"
func TestAdminOnly(t *testing.T) {
t.Fatal("Do me.")
}
func TestGuildOnly(t *testing.T) {
t.Fatal("Do me.")
}

View File

@ -1,107 +0,0 @@
package bot
import "strings"
type NameFlag uint64
var FlagSeparator = 'ー'
const None NameFlag = 0
// !!!
//
// These flags are applied to all events, if possible. The defined behavior
// is to search for "ChannelID" fields or "ID" fields in structs with
// "Channel" in its name. It doesn't handle individual events, as such, will
// not be able to guarantee it will always work. Refer to package infer.
// R - Raw, which tells the library to use the method name as-is (flags will
// still be stripped). For example, if a method is called Reset its
// command will also be Reset, without being all lower-cased.
const Raw NameFlag = 1 << 1
// A - AdminOnly, which tells the library to only run the Subcommand/method
// if the user is admin or not. This will automatically add GuildOnly as
// well.
const AdminOnly NameFlag = 1 << 2
// G - GuildOnly, which tells the library to only run the Subcommand/method
// if the user is inside a guild.
const GuildOnly NameFlag = 1 << 3
// M - Middleware, which tells the library that the method is a middleware.
// The method will be executed anytime a method of the same struct is
// matched.
//
// Using this flag inside the subcommand will drop all methods (this is an
// undefined behavior/UB).
const Middleware NameFlag = 1 << 4
// H - Hidden/Handler, which tells the router to not add this into the list
// of commands, hiding it from Help. Handlers that are hidden will not have
// any arguments parsed. It will be treated as an Event.
const Hidden NameFlag = 1 << 5
// P - Plumb, which tells the router to call only this handler with all the
// arguments (except the prefix string). If plumb is used, only this method
// will be called for the given struct, though all other events as well as
// methods with the H (Hidden/Handler) flag.
//
// This is different from using H (Hidden/Handler), as handlers are called
// regardless of command prefixes. Plumb methods are only called once, and
// no other methods will be called for that struct. That said, a Plumb
// method would still go into Commands, but only itself will be there.
//
// Note that if there's a Plumb method in the main commands, then none of
// the subcommands would be called. This is an unintended but expected side
// effect.
//
// Example
//
// A use for this would be subcommands that don't need a second command, or
// if the main struct manually handles command switching. This example
// demonstrates the second use-case:
//
// func (s *Sub) PーMain(
// c *gateway.MessageCreateGateway, c *Content) error {
//
// // Input: !sub this is a command
// // Output: this is a command
//
// log.Println(c.String())
// return nil
// }
//
const Plumb NameFlag = 1 << 6
func ParseFlag(name string) (NameFlag, string) {
parts := strings.SplitN(name, string(FlagSeparator), 2)
if len(parts) != 2 {
return 0, name
}
var f NameFlag
for _, r := range parts[0] {
switch r {
case 'R':
f |= Raw
case 'A':
f |= AdminOnly | GuildOnly
case 'G':
f |= GuildOnly
case 'M':
f |= Middleware
case 'H':
f |= Hidden
case 'P':
f |= Plumb
}
}
return f, parts[1]
}
func (f NameFlag) Is(flag NameFlag) bool {
return f&flag != 0
}

View File

@ -1,26 +0,0 @@
package bot
import "testing"
func TestNameFlag(t *testing.T) {
type entry struct {
Name string
Expect NameFlag
String string
}
var entries = []entry{{
Name: "AーEcho",
Expect: AdminOnly,
}, {
Name: "RAーGC",
Expect: Raw | AdminOnly,
}}
for _, entry := range entries {
var f, _ = ParseFlag(entry.Name)
if !f.Is(entry.Expect) {
t.Fatalf("unexpected expectation for %s: %v", entry.Name, f)
}
}
}

View File

@ -70,9 +70,6 @@ type Subcommand struct {
// Parsed command name: // Parsed command name:
Command string Command string
// struct flags
Flag NameFlag
// SanitizeMessage is executed on the message content if the method returns // SanitizeMessage is executed on the message content if the method returns
// a string content or a SendMessageData. // a string content or a SendMessageData.
SanitizeMessage func(content string) string SanitizeMessage func(content string) string
@ -85,15 +82,12 @@ type Subcommand struct {
// Commands can actually return either a string, an embed, or a // Commands can actually return either a string, an embed, or a
// SendMessageData, with error as the second argument. // SendMessageData, with error as the second argument.
// All registered command contexts: // All registered method contexts, including commands:
Commands []*CommandContext Methods []*MethodContext
Events []*CommandContext plumbed *MethodContext
// Middleware command contexts: // Global middlewares.
mwMethods []*CommandContext globalmws []*MiddlewareContext
// Plumb nameflag, use Commands[0] if true.
plumb bool
// Directly to struct // Directly to struct
cmdValue reflect.Value cmdValue reflect.Value
@ -103,34 +97,9 @@ type Subcommand struct {
ptrValue reflect.Value ptrValue reflect.Value
ptrType reflect.Type ptrType reflect.Type
// command interface as reference
command interface{} command interface{}
} }
// CommandContext is an internal struct containing fields to make this library
// work. As such, they're all unexported. Description, however, is exported for
// editing, and may be used to generate more informative help messages.
type CommandContext struct {
Description string
Flag NameFlag
MethodName string
Command string // empty if Plumb
// Hidden is true if the method has a hidden nameflag.
Hidden bool
// Variadic is true if the function is a variadic one or if the last
// argument accepts multiple strings.
Variadic bool
value reflect.Value // Func
event reflect.Type // gateway.*Event
method reflect.Method
Arguments []Argument
}
// CanSetup is used for subcommands to change variables, such as Description. // CanSetup is used for subcommands to change variables, such as Description.
// This method will be triggered when InitCommands is called, which is during // This method will be triggered when InitCommands is called, which is during
// New for Context and during RegisterSubcommand for subcommands. // New for Context and during RegisterSubcommand for subcommands.
@ -139,19 +108,6 @@ type CanSetup interface {
Setup(*Subcommand) Setup(*Subcommand)
} }
func (cctx *CommandContext) Usage() []string {
if len(cctx.Arguments) == 0 {
return nil
}
var arguments = make([]string, len(cctx.Arguments))
for i, arg := range cctx.Arguments {
arguments[i] = arg.String
}
return arguments
}
// NewSubcommand is used to make a new subcommand. You usually wouldn't call // NewSubcommand is used to make a new subcommand. You usually wouldn't call
// this function, but instead use (*Context).RegisterSubcommand(). // this function, but instead use (*Context).RegisterSubcommand().
func NewSubcommand(cmd interface{}) (*Subcommand, error) { func NewSubcommand(cmd interface{}) (*Subcommand, error) {
@ -177,34 +133,24 @@ func NewSubcommand(cmd interface{}) (*Subcommand, error) {
// shouldn't be called at all, rather you should use RegisterSubcommand. // shouldn't be called at all, rather you should use RegisterSubcommand.
func (sub *Subcommand) NeedsName() { func (sub *Subcommand) NeedsName() {
sub.StructName = sub.cmdType.Name() sub.StructName = sub.cmdType.Name()
sub.Command = lowerFirstLetter(sub.StructName)
flag, name := ParseFlag(sub.StructName)
if !flag.Is(Raw) {
name = lowerFirstLetter(name)
}
sub.Command = name
sub.Flag = flag
} }
// FindCommand finds the command. Nil is returned if nothing is found. It's a // FindMethod finds the MethodContext. It panics if methodName is not found.
// better idea to not handle nil, as they would become very subtle bugs. func (sub *Subcommand) FindMethod(methodName string) *MethodContext {
func (sub *Subcommand) FindCommand(methodName string) *CommandContext { for _, c := range sub.Methods {
for _, c := range sub.Commands { if c.MethodName == methodName {
if c.MethodName != methodName { return c
continue
} }
return c
} }
return nil panic("Can't find method " + methodName)
} }
// ChangeCommandInfo changes the matched methodName's Command and Description. // ChangeCommandInfo changes the matched methodName's Command and Description.
// Empty means unchanged. The returned bool is true when the method is found. // Empty means unchanged. The returned bool is true when the command is found.
func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool { func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool {
for _, c := range sub.Commands { for _, c := range sub.Methods {
if c.MethodName != methodName { if c.MethodName != methodName || !c.isEvent(typeMessageCreate) {
continue continue
} }
@ -222,70 +168,70 @@ func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool {
} }
func (sub *Subcommand) Help(indent string, hideAdmin bool) string { func (sub *Subcommand) Help(indent string, hideAdmin bool) string {
if sub.Flag.Is(AdminOnly) && hideAdmin { // // The header part:
return "" // var header string
}
// The header part: // if sub.Command != "" {
var header string // header += "**" + sub.Command + "**"
// }
if sub.Command != "" { // if sub.Description != "" {
header += "**" + sub.Command + "**" // if header != "" {
} // header += ": "
// }
if sub.Description != "" { // header += sub.Description
if header != "" { // }
header += ": "
}
header += sub.Description // header += "\n"
}
header += "\n" // // The commands part:
// var commands = ""
// The commands part: // for i, cmd := range sub.Commands {
var commands = "" // if cmd.Flag.Is(AdminOnly) && hideAdmin {
// continue
// }
for i, cmd := range sub.Commands { // switch {
if cmd.Flag.Is(AdminOnly) && hideAdmin { // case sub.Command != "" && cmd.Command != "":
continue // commands += indent + sub.Command + " " + cmd.Command
} // case sub.Command != "":
// commands += indent + sub.Command
// default:
// commands += indent + cmd.Command
// }
switch { // // Write the usages first.
case sub.Command != "" && cmd.Command != "": // for _, usage := range cmd.Usage() {
commands += indent + sub.Command + " " + cmd.Command // commands += " " + underline(usage)
case sub.Command != "": // }
commands += indent + sub.Command
default:
commands += indent + cmd.Command
}
// Write the usages first. // // Is the last argument trailing? If so, append ellipsis.
for _, usage := range cmd.Usage() { // if cmd.Variadic {
commands += " " + underline(usage) // commands += "..."
} // }
// Is the last argument trailing? If so, append ellipsis. // // Write the description if there's any.
if cmd.Variadic { // if cmd.Description != "" {
commands += "..." // commands += ": " + cmd.Description
} // }
// Write the description if there's any. // // Add a new line if this isn't the last command.
if cmd.Description != "" { // if i != len(sub.Commands)-1 {
commands += ": " + cmd.Description // commands += "\n"
} // }
// }
// Add a new line if this isn't the last command. // if commands == "" {
if i != len(sub.Commands)-1 { // return ""
commands += "\n" // }
}
}
if commands == "" { // return header + commands
return ""
}
return header + commands // TODO
// TODO: Interface Helper implements Help() string
return "TODO"
} }
func (sub *Subcommand) reflectCommands() error { func (sub *Subcommand) reflectCommands() error {
@ -327,12 +273,6 @@ func (sub *Subcommand) InitCommands(ctx *Context) error {
v.Setup(sub) v.Setup(sub)
} }
// Finalize the subcommand:
for _, cmd := range sub.Commands {
// Inherit parent's flags
cmd.Flag |= sub.Flag
}
return nil return nil
} }
@ -365,126 +305,93 @@ func (sub *Subcommand) parseCommands() error {
continue continue
} }
methodT := method.Type() methodT := sub.ptrType.Method(i)
numArgs := methodT.NumIn() if methodT.Name == "Setup" && methodT.Type == typeSetupFn {
if numArgs == 0 {
// Doesn't meet the requirement for an event, continue.
continue continue
} }
if methodT == typeSetupFn { cctx := parseMethod(method, methodT)
// Method is a setup method, continue. if cctx == nil {
continue continue
} }
// Check number of returns: // Append.
numOut := methodT.NumOut() sub.Methods = append(sub.Methods, cctx)
// Returns can either be:
// Nothing - func()
// An error - func() error
// An error and something else - func() (T, error)
if numOut > 2 {
continue
}
// Check the last return's type if the method returns anything.
if numOut > 0 {
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
// Invalid, skip.
continue
}
}
var command = CommandContext{
method: sub.ptrType.Method(i),
value: method,
event: methodT.In(0), // parse event
Variadic: methodT.IsVariadic(),
}
// Parse the method name
flag, name := ParseFlag(command.method.Name)
// Set the method name, command, and flag:
command.MethodName = name
command.Command = name
command.Flag = flag
// Check if Raw is enabled for command:
if !flag.Is(Raw) {
command.Command = lowerFirstLetter(name)
}
// Middlewares shouldn't even have arguments.
if flag.Is(Middleware) {
sub.mwMethods = append(sub.mwMethods, &command)
continue
}
// TODO: allow more flexibility
if command.event != typeMessageCreate || flag.Is(Hidden) {
sub.Events = append(sub.Events, &command)
continue
}
// See if we know the first return type, if error's return is the
// second:
if numOut > 1 {
switch t := methodT.Out(0); t {
case typeString, typeEmbed, typeSend:
// noop, passes
default:
continue
}
}
// If a plumb method has been found:
if sub.plumb {
continue
}
// If the method only takes an event:
if numArgs == 1 {
sub.Commands = append(sub.Commands, &command)
continue
}
command.Arguments = make([]Argument, 0, numArgs)
// Fill up arguments. This should work with cusP and manP
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
a, err := newArgument(t, command.Variadic)
if err != nil {
return errors.Wrap(err, "Error parsing argument "+t.String())
}
command.Arguments = append(command.Arguments, *a)
// We're done if the type accepts multiple arguments.
if a.custom != nil || a.manual != nil {
command.Variadic = true // treat as variadic
break
}
}
// If the current event is a plumb event:
if flag.Is(Plumb) {
command.Command = "" // plumbers don't have names
sub.Commands = []*CommandContext{&command}
sub.plumb = true
continue
}
// Append
sub.Commands = append(sub.Commands, &command)
} }
return nil return nil
} }
func (sub *Subcommand) AddMiddleware(methodName string, middleware interface{}) {
var mw *MiddlewareContext
// Allow *MiddlewareContext to be passed into.
if v, ok := middleware.(*MiddlewareContext); ok {
mw = v
} else {
mw = ParseMiddleware(middleware)
}
// Parse method name:
for _, method := range strings.Split(methodName, ",") {
// Trim space.
if method = strings.TrimSpace(method); method == "*" {
// Append middleware to global middleware slice.
sub.globalmws = append(sub.globalmws, mw)
} else {
// Append middleware to that individual function.
sub.FindMethod(method).addMiddleware(mw)
}
}
}
func (sub *Subcommand) walkMiddlewares(ev reflect.Value) error {
for _, mw := range sub.globalmws {
_, err := mw.call(ev)
if err != nil {
return err
}
}
return nil
}
func (sub *Subcommand) eventCallers(evT reflect.Type) (callers []caller) {
// Search for global middlewares.
for _, mw := range sub.globalmws {
if mw.isEvent(evT) {
callers = append(callers, mw)
}
}
// Search for specific handlers.
for _, cctx := range sub.Methods {
// We only take middlewares and callers if the event matches and is not
// a MessageCreate. The other function already handles that.
if cctx.event != typeMessageCreate && cctx.isEvent(evT) {
// Add the command's middlewares first.
for _, mw := range cctx.middlewares {
// Concrete struct to interface conversion done implicitly.
callers = append(callers, mw)
}
callers = append(callers, cctx)
}
}
return
}
// SetPlumb sets the method as the plumbed command. This means that all calls
// without the second command argument will call this method in a subcommand. It
// panics if sub.Command is empty.
func (sub *Subcommand) SetPlumb(methodName string) {
if sub.Command == "" {
panic("SetPlumb called on a main command with sub.Command empty.")
}
method := sub.FindMethod(methodName)
method.Command = ""
sub.plumbed = method
}
func lowerFirstLetter(name string) string { func lowerFirstLetter(name string) string {
return strings.ToLower(string(name[0])) + name[1:] return strings.ToLower(string(name[0])) + name[1:]
} }

View File

@ -29,8 +29,8 @@ func TestSubcommand(t *testing.T) {
} }
// !!! CHANGE ME // !!! CHANGE ME
if len(sub.Commands) != 8 { if len(sub.Methods) < 8 {
t.Fatal("invalid ctx.commands len", len(sub.Commands)) t.Fatal("too low sub.Methods len", len(sub.Methods))
} }
var ( var (
@ -39,7 +39,7 @@ func TestSubcommand(t *testing.T) {
foundNoArgs bool foundNoArgs bool
) )
for _, this := range sub.Commands { for _, this := range sub.Methods {
switch this.Command { switch this.Command {
case "send": case "send":
foundSend = true foundSend = true
@ -58,13 +58,6 @@ func TestSubcommand(t *testing.T) {
if len(this.Arguments) != 0 { if len(this.Arguments) != 0 {
t.Fatal("expected 0 arguments, got non-zero") t.Fatal("expected 0 arguments, got non-zero")
} }
case "noop", "getCounter", "variadic", "trailCustom", "content":
// Found, but whatever
}
if this.event != typeMessageCreate {
t.Fatal("invalid event type:", this.event.String())
} }
} }