Bot: Partially implemented middlewares

This commit is contained in:
diamondburned (Forefront) 2020-05-10 01:45:00 -07:00
parent 1ca7d1c62c
commit 964e8cdf13
13 changed files with 619 additions and 749 deletions

View File

@ -2,6 +2,8 @@
Not a lot for a Discord bot:
# THIS IS OUTDATED. TODO: UPDATE.
```
# Cold functions, or functions that are called once in runtime:
BenchmarkConstructor-8 150537 7617 ns/op

236
bot/command.go Normal file
View File

@ -0,0 +1,236 @@
package bot
import (
"reflect"
)
type command struct {
value reflect.Value // Func
event reflect.Type
isInterface bool
}
func newCommand(value reflect.Value, event reflect.Type) command {
return command{
value: value,
event: event,
isInterface: event.Kind() == reflect.Interface,
}
}
func (c *command) isEvent(t reflect.Type) bool {
return (!c.isInterface && c.event == t) || (c.isInterface && t.Implements(c.event))
}
func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
return callWith(c.value, arg0, argv...)
}
func callWith(caller reflect.Value, arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
var callargs = make([]reflect.Value, 0, 1+len(argv))
if v, ok := arg0.(reflect.Value); ok {
callargs = append(callargs, v)
} else {
callargs = append(callargs, reflect.ValueOf(arg0))
}
callargs = append(callargs, argv...)
return errorReturns(caller.Call(callargs))
}
type caller interface {
call(arg0 interface{}, argv ...reflect.Value) (interface{}, error)
}
func errorReturns(returns []reflect.Value) (interface{}, error) {
// Handlers may return nothing.
if len(returns) == 0 {
return nil, nil
}
// assume first return is always error, since we checked for this in
// parseCommands.
v := returns[len(returns)-1].Interface()
// If the last return (error) is nil.
if v == nil {
// If we only have 1 returns, that return must be the error. The error
// is nil, so nil is returned.
if len(returns) == 1 {
return nil, nil
}
// Return the first argument as-is. The above returns[-1] check assumes
// 2 return values (T, error), meaning returns[0] is the T value.
return returns[0].Interface(), nil
}
// Treat the last return as an error.
return nil, v.(error)
}
// MethodContext is an internal struct containing fields to make this library
// work. As such, they're all unexported. Description, however, is exported for
// editing, and may be used to generate more informative help messages.
type MethodContext struct {
command
method reflect.Method // extend
middlewares []*MiddlewareContext
Description string
// MethodName is the name of the method. This field should NOT be changed.
MethodName string
// Command is the Discord command used to call the method.
Command string // hidden if empty
// Hidden is true if the method has a hidden nameflag.
// Hidden bool
// Variadic is true if the function is a variadic one or if the last
// argument accepts multiple strings.
Variadic bool
Arguments []Argument
}
func parseMethod(value reflect.Value, method reflect.Method) *MethodContext {
methodT := value.Type()
numArgs := methodT.NumIn()
if numArgs == 0 {
// Doesn't meet the requirement for an event, continue.
return nil
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
// An error and something else - func() (T, error)
if numOut > 2 {
return nil
}
// Check the last return's type if the method returns anything.
if numOut > 0 {
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
// Invalid, skip.
return nil
}
}
var command = MethodContext{
command: newCommand(value, methodT.In(0)),
method: method,
MethodName: method.Name,
Variadic: methodT.IsVariadic(),
}
// Only set the command name if it's a MessageCreate handler.
if command.event == typeMessageCreate {
command.Command = lowerFirstLetter(command.method.Name)
}
if numArgs > 1 {
// Event handlers that aren't MessageCreate should not have arguments.
if command.event != typeMessageCreate {
return nil
}
// If the event type is messageCreate:
command.Arguments = make([]Argument, 0, numArgs-1)
// Fill up arguments. This should work with cusP and manP
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
a, err := newArgument(t, command.Variadic)
if err != nil {
panic("Error parsing argument " + t.String() + ": " + err.Error())
}
command.Arguments = append(command.Arguments, *a)
// We're done if the type accepts multiple arguments.
if a.custom != nil || a.manual != nil {
command.Variadic = true // treat as variadic
break
}
}
}
return &command
}
func (cctx *MethodContext) addMiddleware(mw *MiddlewareContext) {
cctx.middlewares = append(cctx.middlewares, mw)
}
func (cctx *MethodContext) walkMiddlewares(ev reflect.Value) error {
for _, mw := range cctx.middlewares {
_, err := mw.call(ev)
if err != nil {
return err
}
}
return nil
}
func (cctx *MethodContext) Usage() []string {
if len(cctx.Arguments) == 0 {
return nil
}
var arguments = make([]string, len(cctx.Arguments))
for i, arg := range cctx.Arguments {
arguments[i] = arg.String
}
return arguments
}
// SetName sets the command name.
func (cctx *MethodContext) SetName(name string) {
cctx.Command = name
}
type MiddlewareContext struct {
command
}
// ParseMiddleware parses a middleware function. This function panics.
func ParseMiddleware(mw interface{}) *MiddlewareContext {
value := reflect.ValueOf(mw)
methodT := value.Type()
numArgs := methodT.NumIn()
if numArgs != 1 {
panic("Invalid argument signature for " + methodT.String())
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
if numOut > 1 {
panic("Invalid return signature for " + methodT.String())
}
// Check the last return's type if the method returns anything.
if numOut == 1 {
if i := methodT.Out(0); i == nil || !i.Implements(typeIError) {
panic("Unexpected return type (not error) for " + methodT.String())
}
}
var middleware = MiddlewareContext{
command: newCommand(value, methodT.In(0)),
}
return &middleware
}

View File

@ -217,19 +217,19 @@ func (ctx *Context) Subcommands() []*Subcommand {
return ctx.subcommands
}
// FindCommand finds a command based on the struct and method name. The queried
// FindMethod finds a method based on the struct and method name. The queried
// names will have their flags stripped.
//
// Example
//
// // Find a command from the main context:
// cmd := ctx.FindCommand("", "Method")
// cmd := ctx.FindMethod("", "Method")
// // Find a command from a subcommand:
// cmd = ctx.FindCommand("Starboard", "Reset")
// cmd = ctx.FindMethod("Starboard", "Reset")
//
func (ctx *Context) FindCommand(structname, methodname string) *CommandContext {
func (ctx *Context) FindMethod(structname, methodname string) *MethodContext {
if structname == "" {
for _, c := range ctx.Commands {
for _, c := range ctx.Methods {
if c.MethodName == methodname {
return c
}
@ -243,7 +243,7 @@ func (ctx *Context) FindCommand(structname, methodname string) *CommandContext {
continue
}
for _, c := range sub.Commands {
for _, c := range sub.Methods {
if c.MethodName == methodname {
return c
}
@ -360,52 +360,55 @@ func (ctx *Context) HelpAdmin() string {
}
func (ctx *Context) help(hideAdmin bool) string {
const indent = " "
// const indent = " "
var help strings.Builder
// var help strings.Builder
// Generate the headers and descriptions
help.WriteString("__Help__")
// // Generate the headers and descriptions
// help.WriteString("__Help__")
if ctx.Name != "" {
help.WriteString(": " + ctx.Name)
}
// if ctx.Name != "" {
// help.WriteString(": " + ctx.Name)
// }
if ctx.Description != "" {
help.WriteString("\n" + indent + ctx.Description)
}
// if ctx.Description != "" {
// help.WriteString("\n" + indent + ctx.Description)
// }
if ctx.Flag.Is(AdminOnly) {
// That's it.
return help.String()
}
// if ctx.Flag.Is(AdminOnly) {
// // That's it.
// return help.String()
// }
// Separators
help.WriteString("\n---\n")
// // Separators
// help.WriteString("\n---\n")
// Generate all commands
help.WriteString("__Commands__")
help.WriteString(ctx.Subcommand.Help(indent, hideAdmin))
help.WriteByte('\n')
// // Generate all commands
// help.WriteString("__Commands__")
// help.WriteString(ctx.Subcommand.Help(indent, hideAdmin))
// help.WriteByte('\n')
var subHelp = strings.Builder{}
var subcommands = ctx.Subcommands()
// var subHelp = strings.Builder{}
// var subcommands = ctx.Subcommands()
for _, sub := range subcommands {
if help := sub.Help(indent, hideAdmin); help != "" {
for _, line := range strings.Split(help, "\n") {
subHelp.WriteString(indent)
subHelp.WriteString(line)
subHelp.WriteByte('\n')
}
}
}
// for _, sub := range subcommands {
// if help := sub.Help(indent, hideAdmin); help != "" {
// for _, line := range strings.Split(help, "\n") {
// subHelp.WriteString(indent)
// subHelp.WriteString(line)
// subHelp.WriteByte('\n')
// }
// }
// }
if subHelp.Len() > 0 {
help.WriteString("---\n")
help.WriteString("__Subcommands__\n")
help.WriteString(subHelp.String())
}
// if subHelp.Len() > 0 {
// help.WriteString("---\n")
// help.WriteString("__Subcommands__\n")
// help.WriteString(subHelp.String())
// }
return help.String()
// return help.String()
// TODO
return ""
}

View File

@ -5,136 +5,75 @@ import (
"strings"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/bot/extras/infer"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
)
// NonFatal is an interface that a method can implement to ignore all errors.
// This works similarly to Break.
type NonFatal interface {
error
IgnoreError() // noop method
}
func onlyFatal(err error) error {
if _, ok := err.(NonFatal); ok {
return nil
}
return err
}
type _Break struct{ error }
// implement NonFatal.
func (_Break) IgnoreError() {}
// Break is a non-fatal error that could be returned from middlewares or
// handlers to stop the chain of execution.
//
// Middlewares are guaranteed to be executed before handlers, but the exact
// order of each are undefined. Main handlers are also guaranteed to be executed
// before all subcommands. If a main middleware cancels, no subcommand
// middlewares will be called.
//
// Break implements the NonFatal interface, which causes an error to be ignored.
var Break NonFatal = _Break{errors.New("break middleware chain, non-fatal")}
func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
var callers []*CommandContext
var middles []*CommandContext
var found bool
find := func(sub *Subcommand) {
for _, cmd := range sub.Events {
// Search only for callers, so skip middlewares.
if cmd.Flag.Is(Middleware) {
continue
}
if cmd.event == evT {
callers = append(callers, cmd)
found = true
}
}
// Only get middlewares if we found handlers for that same event.
if found {
// Search for middlewares with the same type:
for _, mw := range sub.mwMethods {
if mw.event == evT {
middles = append(middles, mw)
}
}
}
}
// Break is a non-fatal error that could be returned from middlewares to stop
// the chain of execution.
var Break = errors.New("break middleware chain, non-fatal")
// filterEventType filters all commands and subcommands into a 2D slice,
// structured so that a Break would only exit out the nested slice.
func (ctx *Context) filterEventType(evT reflect.Type) (callers [][]caller) {
// Find the main context first.
find(ctx.Subcommand)
callers = append(callers, ctx.eventCallers(evT))
for _, sub := range ctx.subcommands {
// Reset found status
found = false
// Find subcommands second.
find(sub)
callers = append(callers, sub.eventCallers(evT))
}
return append(middles, callers...)
return
}
func (ctx *Context) callCmd(ev interface{}) error {
evT := reflect.TypeOf(ev)
func (ctx *Context) callCmd(ev interface{}) (bottomError error) {
evV := reflect.ValueOf(ev)
evT := evV.Type()
var isAdmin *bool // I want to die.
var isGuild *bool
var callers []*CommandContext
var callers [][]caller
// Hit the cache
t, ok := ctx.typeCache.Load(evT)
if ok {
callers = t.([]*CommandContext)
callers = t.([][]caller)
} else {
callers = ctx.filterEventType(evT)
ctx.typeCache.Store(evT, callers)
}
// We can't do the callers[:0] trick here, as it will modify the slice
// inside the sync.Map as well.
var filtered = make([]*CommandContext, 0, len(callers))
for _, subcallers := range callers {
for _, c := range subcallers {
_, err := c.call(evV)
if err != nil {
// Only count as an error if it's not Break.
if err = errNoBreak(err); err != nil {
bottomError = err
}
for _, cmd := range callers {
// Command flags will inherit its parent Subcommand's flags.
if true &&
!(cmd.Flag.Is(AdminOnly) && !ctx.eventIsAdmin(ev, &isAdmin)) &&
!(cmd.Flag.Is(GuildOnly) && !ctx.eventIsGuild(ev, &isGuild)) {
filtered = append(filtered, cmd)
}
}
for _, c := range filtered {
_, err := callWith(c.value, ev)
if err != nil {
if err = onlyFatal(err); err != nil {
ctx.ErrorLogger(err)
// Break the caller loop only for this subcommand.
break
}
return err
}
}
// We call the messages later, since Hidden handlers will go into the Events
// slice, but we don't want to ignore those handlers either.
// We call the messages later, since we want MessageCreate middlewares to
// run as well.
if evT == typeMessageCreate {
// safe assertion always
err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent))
return onlyFatal(err)
err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent), evV)
// There's no need for an errNoBreak here, as the method already checked
// for that.
if err != nil {
bottomError = err
}
}
return nil
return
}
func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent, value reflect.Value) error {
// check if bot
if !ctx.AllowBot && mc.Author.Bot {
return nil
@ -163,102 +102,18 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
return nil // ???
}
var cmd *CommandContext
var sub *Subcommand
// var start int // arg starts from $start
// Check if plumb:
if ctx.plumb {
cmd = ctx.Commands[0]
sub = ctx.Subcommand
// start = 0
// Find the command and subcommand.
arguments, cmd, sub, err := ctx.findCommand(parts)
if err != nil {
return errNoBreak(err)
}
// Arguments slice, which will be sliced away until only arguments are left.
var arguments = parts
// We don't run the subcommand's middlewares here, as the callCmd function
// already handles that.
// If not plumb, search for the command
if cmd == nil {
for _, c := range ctx.Commands {
if c.Command == parts[0] {
cmd = c
sub = ctx.Subcommand
arguments = arguments[1:]
// start = 1
break
}
}
}
// Can't find the command, look for subcommands if len(args) has a 2nd
// entry.
if cmd == nil {
for _, s := range ctx.subcommands {
if s.Command != parts[0] {
continue
}
// Check if plumb:
if s.plumb {
cmd = s.Commands[0]
sub = s
arguments = arguments[1:]
// start = 1
break
}
// There's no second argument, so we can only look for Plumbed
// subcommands.
if len(parts) < 2 {
continue
}
for _, c := range s.Commands {
if c.Command == parts[1] {
cmd = c
sub = s
arguments = arguments[2:]
break
// start = 2
}
}
if cmd == nil {
if s.QuietUnknownCommand {
return nil
}
return &ErrUnknownCommand{
Command: parts[1],
Parent: parts[0],
ctx: s.Commands,
}
}
break
}
}
if cmd == nil {
if ctx.QuietUnknownCommand {
return nil
}
return &ErrUnknownCommand{
Command: parts[0],
ctx: ctx.Commands,
}
}
// Check for IsAdmin and IsGuild
if cmd.Flag.Is(GuildOnly) && !mc.GuildID.Valid() {
return nil
}
if cmd.Flag.Is(AdminOnly) {
p, err := ctx.State.Permissions(mc.ChannelID, mc.Author.ID)
if err != nil || !p.Has(discord.PermissionAdministrator) {
return nil
}
// Run command middlewares.
if err := cmd.walkMiddlewares(value); err != nil {
return errNoBreak(err)
}
// Start converting
@ -375,8 +230,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
// could contain multiple whitespaces, and the parser would not
// count them.
var seekTo = cmd.Command
// If plumbed, then there would only be the subcommand.
if sub.plumb {
// Implicit plumbing behavior.
if seekTo == "" {
seekTo = sub.Command
}
@ -406,17 +261,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
}
Call:
// Try calling all middlewares first. We don't need to stack middlewares, as
// there will only be one command match.
for _, mw := range sub.mwMethods {
_, err := callWith(mw.value, mc)
if err != nil {
return err
}
}
// call the function and parse the error return value
v, err := callWith(cmd.value, mc, argv...)
v, err := cmd.call(value, argv...)
if err != nil {
return err
}
@ -437,91 +283,59 @@ Call:
return err
}
func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
if *is != nil {
return **is
// findCommand filters.
func (ctx *Context) findCommand(parts []string) ([]string, *MethodContext, *Subcommand, error) {
// Main command entrypoint cannot have plumb.
for _, c := range ctx.Methods {
if c.Command == parts[0] {
return parts[1:], c, ctx.Subcommand, nil
}
}
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return false
}
var userID = infer.UserID(ev)
if !userID.Valid() {
return false
}
var res bool
p, err := ctx.State.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
res = true
}
*is = &res
return res
}
func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool {
if *is != nil {
return **is
}
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return false
}
c, err := ctx.State.Channel(channelID)
if err != nil {
return false
}
res := c.GuildID.Valid()
*is = &res
return res
}
func callWith(
caller reflect.Value,
ev interface{}, values ...reflect.Value) (interface{}, error) {
var callargs = make([]reflect.Value, 0, 1+len(values))
if v, ok := ev.(reflect.Value); ok {
callargs = append(callargs, v)
} else {
callargs = append(callargs, reflect.ValueOf(ev))
}
callargs = append(callargs, values...)
return errorReturns(caller.Call(callargs))
}
func errorReturns(returns []reflect.Value) (interface{}, error) {
// Handlers may return nothing.
if len(returns) == 0 {
return nil, nil
}
// assume first return is always error, since we checked for this in
// parseCommands.
v := returns[len(returns)-1].Interface()
// If the last return (error) is nil.
if v == nil {
// If we only have 1 returns, that return must be the error. The error
// is nil, so nil is returned.
if len(returns) == 1 {
return nil, nil
// Can't find the command, look for subcommands if len(args) has a 2nd
// entry.
for _, s := range ctx.subcommands {
if s.Command != parts[0] {
continue
}
// Return the first argument as-is. The above returns[-1] check assumes
// 2 return values (T, error), meaning returns[0] is the T value.
return returns[0].Interface(), nil
// If there's no second argument, TODO call Help.
if s.plumbed != nil {
return parts[1:], s.plumbed, s, nil
}
if len(parts) >= 2 {
for _, c := range s.Methods {
if c.event == typeMessageCreate && c.Command == parts[1] {
return parts[2:], c, s, nil
}
}
}
if s.QuietUnknownCommand || ctx.QuietUnknownCommand {
return nil, nil, nil, Break
}
return nil, nil, nil, &ErrUnknownCommand{
Parts: parts,
Subcmd: s,
}
}
// Treat the last return as an error.
return nil, v.(error)
if ctx.QuietUnknownCommand {
return nil, nil, nil, Break
}
return nil, nil, nil, &ErrUnknownCommand{
Parts: parts,
Subcmd: ctx.Subcommand,
}
}
func errNoBreak(err error) error {
if errors.Is(err, Break) {
return nil
}
return err
}

View File

@ -15,12 +15,16 @@ type hasPlumb struct {
NotPlumbed bool
}
func (h *hasPlumb) Setup(sub *Subcommand) {
sub.SetPlumb("Plumber")
}
func (h *hasPlumb) Normal(_ *gateway.MessageCreateEvent) error {
h.NotPlumbed = true
return nil
}
func (h *hasPlumb) PーPlumber(_ *gateway.MessageCreateEvent, c RawArguments) error {
func (h *hasPlumb) Plumber(_ *gateway.MessageCreateEvent, c RawArguments) error {
h.Plumbed = string(c)
return nil
}
@ -43,10 +47,6 @@ func TestSubcommandPlumb(t *testing.T) {
t.Fatal("Failed to register hasPlumb:", err)
}
if l := len(c.subcommands[0].Commands); l != 1 {
t.Fatal("Unexpected length for sub.Commands:", l)
}
// Try call exactly what's in the Plumb example:
m := &gateway.MessageCreateEvent{
Message: discord.Message{

View File

@ -21,43 +21,38 @@ type testc struct {
Typed bool
}
func (t *testc) MーBumpCounter(interface{}) {
t.Counter++
func (t *testc) Setup(sub *Subcommand) {
sub.AddMiddleware("*,GetCounter", func(v interface{}) {
t.Counter++
})
sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) {
t.Counter++
})
}
func (t *testc) GetCounter(_ *gateway.MessageCreateEvent) {
func (t *testc) Noop(*gateway.MessageCreateEvent) {}
func (t *testc) GetCounter(*gateway.MessageCreateEvent) {
t.Return <- strconv.FormatUint(t.Counter, 10)
}
func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error {
t.Return <- args
return errors.New("oh no")
}
func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) {
t.Return <- c.args
}
func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) {
t.Return <- c[len(c)-1]
}
func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, s string, c *customManualParsed) {
t.Return <- c.args
}
func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) {
t.Return <- c
}
func (t *testc) NoArgs(_ *gateway.MessageCreateEvent) error {
func (t *testc) NoArgs(*gateway.MessageCreateEvent) error {
return errors.New("passed")
}
func (t *testc) Noop(_ *gateway.MessageCreateEvent) {
}
func (t *testc) OnTyping(_ *gateway.TypingStartEvent) {
func (t *testc) OnTyping(*gateway.TypingStartEvent) {
t.Typed = true
}
@ -108,26 +103,26 @@ func TestContext(t *testing.T) {
})
t.Run("find commands", func(t *testing.T) {
cmd := ctx.FindCommand("", "NoArgs")
cmd := ctx.FindMethod("", "NoArgs")
if cmd == nil {
t.Fatal("Failed to find NoArgs")
}
})
t.Run("help", func(t *testing.T) {
if h := ctx.Help(); h == "" {
t.Fatal("Empty help?")
}
if h := ctx.HelpAdmin(); h == "" {
t.Fatal("Empty admin help?")
}
})
// t.Run("help", func(t *testing.T) {
// if h := ctx.Help(); h == "" {
// t.Fatal("Empty help?")
// }
// if h := ctx.HelpAdmin(); h == "" {
// t.Fatal("Empty admin help?")
// }
// })
t.Run("middleware", func(t *testing.T) {
ctx.HasPrefix = NewPrefix("pls do ")
// This should trigger the middleware first.
if err := expect(ctx, given, "1", "pls do getCounter"); err != nil {
if err := expect(ctx, given, "3", "pls do getCounter"); err != nil {
t.Fatal("Unexpected error:", err)
}
})
@ -247,7 +242,7 @@ func TestContext(t *testing.T) {
t.Fatal("Unexpected call error:", err)
}
if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil {
if cmd := ctx.FindMethod("testc", "Noop"); cmd == nil {
t.Fatal("Failed to find subcommand Noop")
}
})
@ -308,6 +303,7 @@ func BenchmarkCall(b *testing.B) {
Subcommand: s,
State: state,
HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
}
m := &gateway.MessageCreateEvent{
@ -335,6 +331,7 @@ func BenchmarkHelp(b *testing.B) {
Subcommand: s,
State: state,
HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
}
b.ResetTimer()

View File

@ -6,28 +6,19 @@ import (
)
type ErrUnknownCommand struct {
Prefix string
Command string
Parent string
// TODO: list available commands?
// Here, as a reminder
ctx []*CommandContext
Parts []string // max len 2
Subcmd *Subcommand
}
func (err *ErrUnknownCommand) Error() string {
if len(err.Parts) > 2 {
err.Parts = err.Parts[:2]
}
return UnknownCommandString(err)
}
var UnknownCommandString = func(err *ErrUnknownCommand) string {
var header = "Unknown command: " + err.Prefix
if err.Parent != "" {
header += err.Parent + " " + err.Command
} else {
header += err.Command
}
return header
return "Unknown command: " + strings.Join(err.Parts, " ")
}
var (
@ -43,7 +34,7 @@ type ErrInvalidUsage struct {
// TODO: usage generator?
// Here, as a reminder
Ctx *CommandContext
Ctx *MethodContext
}
func (err *ErrInvalidUsage) Error() string {

View File

@ -0,0 +1,49 @@
package middlewares
import (
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/bot/extras/infer"
"github.com/diamondburned/arikawa/discord"
)
func AdminOnly(ctx *bot.Context) func(interface{}) error {
return func(ev interface{}) error {
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return bot.Break
}
var userID = infer.UserID(ev)
if !userID.Valid() {
return bot.Break
}
p, err := ctx.State.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
return nil
}
return bot.Break
}
}
func GuildOnly(ctx *bot.Context) func(interface{}) error {
return func(ev interface{}) error {
// Try and infer the GuildID.
if guildID := infer.GuildID(ev); guildID.Valid() {
return nil
}
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return bot.Break
}
c, err := ctx.State.Channel(channelID)
if err != nil || !c.GuildID.Valid() {
return bot.Break
}
return nil
}
}

View File

@ -0,0 +1,11 @@
package main
import "testing"
func TestAdminOnly(t *testing.T) {
t.Fatal("Do me.")
}
func TestGuildOnly(t *testing.T) {
t.Fatal("Do me.")
}

View File

@ -1,107 +0,0 @@
package bot
import "strings"
type NameFlag uint64
var FlagSeparator = 'ー'
const None NameFlag = 0
// !!!
//
// These flags are applied to all events, if possible. The defined behavior
// is to search for "ChannelID" fields or "ID" fields in structs with
// "Channel" in its name. It doesn't handle individual events, as such, will
// not be able to guarantee it will always work. Refer to package infer.
// R - Raw, which tells the library to use the method name as-is (flags will
// still be stripped). For example, if a method is called Reset its
// command will also be Reset, without being all lower-cased.
const Raw NameFlag = 1 << 1
// A - AdminOnly, which tells the library to only run the Subcommand/method
// if the user is admin or not. This will automatically add GuildOnly as
// well.
const AdminOnly NameFlag = 1 << 2
// G - GuildOnly, which tells the library to only run the Subcommand/method
// if the user is inside a guild.
const GuildOnly NameFlag = 1 << 3
// M - Middleware, which tells the library that the method is a middleware.
// The method will be executed anytime a method of the same struct is
// matched.
//
// Using this flag inside the subcommand will drop all methods (this is an
// undefined behavior/UB).
const Middleware NameFlag = 1 << 4
// H - Hidden/Handler, which tells the router to not add this into the list
// of commands, hiding it from Help. Handlers that are hidden will not have
// any arguments parsed. It will be treated as an Event.
const Hidden NameFlag = 1 << 5
// P - Plumb, which tells the router to call only this handler with all the
// arguments (except the prefix string). If plumb is used, only this method
// will be called for the given struct, though all other events as well as
// methods with the H (Hidden/Handler) flag.
//
// This is different from using H (Hidden/Handler), as handlers are called
// regardless of command prefixes. Plumb methods are only called once, and
// no other methods will be called for that struct. That said, a Plumb
// method would still go into Commands, but only itself will be there.
//
// Note that if there's a Plumb method in the main commands, then none of
// the subcommands would be called. This is an unintended but expected side
// effect.
//
// Example
//
// A use for this would be subcommands that don't need a second command, or
// if the main struct manually handles command switching. This example
// demonstrates the second use-case:
//
// func (s *Sub) PーMain(
// c *gateway.MessageCreateGateway, c *Content) error {
//
// // Input: !sub this is a command
// // Output: this is a command
//
// log.Println(c.String())
// return nil
// }
//
const Plumb NameFlag = 1 << 6
func ParseFlag(name string) (NameFlag, string) {
parts := strings.SplitN(name, string(FlagSeparator), 2)
if len(parts) != 2 {
return 0, name
}
var f NameFlag
for _, r := range parts[0] {
switch r {
case 'R':
f |= Raw
case 'A':
f |= AdminOnly | GuildOnly
case 'G':
f |= GuildOnly
case 'M':
f |= Middleware
case 'H':
f |= Hidden
case 'P':
f |= Plumb
}
}
return f, parts[1]
}
func (f NameFlag) Is(flag NameFlag) bool {
return f&flag != 0
}

View File

@ -1,26 +0,0 @@
package bot
import "testing"
func TestNameFlag(t *testing.T) {
type entry struct {
Name string
Expect NameFlag
String string
}
var entries = []entry{{
Name: "AーEcho",
Expect: AdminOnly,
}, {
Name: "RAーGC",
Expect: Raw | AdminOnly,
}}
for _, entry := range entries {
var f, _ = ParseFlag(entry.Name)
if !f.Is(entry.Expect) {
t.Fatalf("unexpected expectation for %s: %v", entry.Name, f)
}
}
}

View File

@ -70,9 +70,6 @@ type Subcommand struct {
// Parsed command name:
Command string
// struct flags
Flag NameFlag
// SanitizeMessage is executed on the message content if the method returns
// a string content or a SendMessageData.
SanitizeMessage func(content string) string
@ -85,15 +82,12 @@ type Subcommand struct {
// Commands can actually return either a string, an embed, or a
// SendMessageData, with error as the second argument.
// All registered command contexts:
Commands []*CommandContext
Events []*CommandContext
// All registered method contexts, including commands:
Methods []*MethodContext
plumbed *MethodContext
// Middleware command contexts:
mwMethods []*CommandContext
// Plumb nameflag, use Commands[0] if true.
plumb bool
// Global middlewares.
globalmws []*MiddlewareContext
// Directly to struct
cmdValue reflect.Value
@ -103,34 +97,9 @@ type Subcommand struct {
ptrValue reflect.Value
ptrType reflect.Type
// command interface as reference
command interface{}
}
// CommandContext is an internal struct containing fields to make this library
// work. As such, they're all unexported. Description, however, is exported for
// editing, and may be used to generate more informative help messages.
type CommandContext struct {
Description string
Flag NameFlag
MethodName string
Command string // empty if Plumb
// Hidden is true if the method has a hidden nameflag.
Hidden bool
// Variadic is true if the function is a variadic one or if the last
// argument accepts multiple strings.
Variadic bool
value reflect.Value // Func
event reflect.Type // gateway.*Event
method reflect.Method
Arguments []Argument
}
// CanSetup is used for subcommands to change variables, such as Description.
// This method will be triggered when InitCommands is called, which is during
// New for Context and during RegisterSubcommand for subcommands.
@ -139,19 +108,6 @@ type CanSetup interface {
Setup(*Subcommand)
}
func (cctx *CommandContext) Usage() []string {
if len(cctx.Arguments) == 0 {
return nil
}
var arguments = make([]string, len(cctx.Arguments))
for i, arg := range cctx.Arguments {
arguments[i] = arg.String
}
return arguments
}
// NewSubcommand is used to make a new subcommand. You usually wouldn't call
// this function, but instead use (*Context).RegisterSubcommand().
func NewSubcommand(cmd interface{}) (*Subcommand, error) {
@ -177,34 +133,24 @@ func NewSubcommand(cmd interface{}) (*Subcommand, error) {
// shouldn't be called at all, rather you should use RegisterSubcommand.
func (sub *Subcommand) NeedsName() {
sub.StructName = sub.cmdType.Name()
flag, name := ParseFlag(sub.StructName)
if !flag.Is(Raw) {
name = lowerFirstLetter(name)
}
sub.Command = name
sub.Flag = flag
sub.Command = lowerFirstLetter(sub.StructName)
}
// FindCommand finds the command. Nil is returned if nothing is found. It's a
// better idea to not handle nil, as they would become very subtle bugs.
func (sub *Subcommand) FindCommand(methodName string) *CommandContext {
for _, c := range sub.Commands {
if c.MethodName != methodName {
continue
// FindMethod finds the MethodContext. It panics if methodName is not found.
func (sub *Subcommand) FindMethod(methodName string) *MethodContext {
for _, c := range sub.Methods {
if c.MethodName == methodName {
return c
}
return c
}
return nil
panic("Can't find method " + methodName)
}
// ChangeCommandInfo changes the matched methodName's Command and Description.
// Empty means unchanged. The returned bool is true when the method is found.
// Empty means unchanged. The returned bool is true when the command is found.
func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool {
for _, c := range sub.Commands {
if c.MethodName != methodName {
for _, c := range sub.Methods {
if c.MethodName != methodName || !c.isEvent(typeMessageCreate) {
continue
}
@ -222,70 +168,70 @@ func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool {
}
func (sub *Subcommand) Help(indent string, hideAdmin bool) string {
if sub.Flag.Is(AdminOnly) && hideAdmin {
return ""
}
// // The header part:
// var header string
// The header part:
var header string
// if sub.Command != "" {
// header += "**" + sub.Command + "**"
// }
if sub.Command != "" {
header += "**" + sub.Command + "**"
}
// if sub.Description != "" {
// if header != "" {
// header += ": "
// }
if sub.Description != "" {
if header != "" {
header += ": "
}
// header += sub.Description
// }
header += sub.Description
}
// header += "\n"
header += "\n"
// // The commands part:
// var commands = ""
// The commands part:
var commands = ""
// for i, cmd := range sub.Commands {
// if cmd.Flag.Is(AdminOnly) && hideAdmin {
// continue
// }
for i, cmd := range sub.Commands {
if cmd.Flag.Is(AdminOnly) && hideAdmin {
continue
}
// switch {
// case sub.Command != "" && cmd.Command != "":
// commands += indent + sub.Command + " " + cmd.Command
// case sub.Command != "":
// commands += indent + sub.Command
// default:
// commands += indent + cmd.Command
// }
switch {
case sub.Command != "" && cmd.Command != "":
commands += indent + sub.Command + " " + cmd.Command
case sub.Command != "":
commands += indent + sub.Command
default:
commands += indent + cmd.Command
}
// // Write the usages first.
// for _, usage := range cmd.Usage() {
// commands += " " + underline(usage)
// }
// Write the usages first.
for _, usage := range cmd.Usage() {
commands += " " + underline(usage)
}
// // Is the last argument trailing? If so, append ellipsis.
// if cmd.Variadic {
// commands += "..."
// }
// Is the last argument trailing? If so, append ellipsis.
if cmd.Variadic {
commands += "..."
}
// // Write the description if there's any.
// if cmd.Description != "" {
// commands += ": " + cmd.Description
// }
// Write the description if there's any.
if cmd.Description != "" {
commands += ": " + cmd.Description
}
// // Add a new line if this isn't the last command.
// if i != len(sub.Commands)-1 {
// commands += "\n"
// }
// }
// Add a new line if this isn't the last command.
if i != len(sub.Commands)-1 {
commands += "\n"
}
}
// if commands == "" {
// return ""
// }
if commands == "" {
return ""
}
// return header + commands
return header + commands
// TODO
// TODO: Interface Helper implements Help() string
return "TODO"
}
func (sub *Subcommand) reflectCommands() error {
@ -327,12 +273,6 @@ func (sub *Subcommand) InitCommands(ctx *Context) error {
v.Setup(sub)
}
// Finalize the subcommand:
for _, cmd := range sub.Commands {
// Inherit parent's flags
cmd.Flag |= sub.Flag
}
return nil
}
@ -365,126 +305,93 @@ func (sub *Subcommand) parseCommands() error {
continue
}
methodT := method.Type()
numArgs := methodT.NumIn()
if numArgs == 0 {
// Doesn't meet the requirement for an event, continue.
methodT := sub.ptrType.Method(i)
if methodT.Name == "Setup" && methodT.Type == typeSetupFn {
continue
}
if methodT == typeSetupFn {
// Method is a setup method, continue.
cctx := parseMethod(method, methodT)
if cctx == nil {
continue
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
// An error and something else - func() (T, error)
if numOut > 2 {
continue
}
// Check the last return's type if the method returns anything.
if numOut > 0 {
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
// Invalid, skip.
continue
}
}
var command = CommandContext{
method: sub.ptrType.Method(i),
value: method,
event: methodT.In(0), // parse event
Variadic: methodT.IsVariadic(),
}
// Parse the method name
flag, name := ParseFlag(command.method.Name)
// Set the method name, command, and flag:
command.MethodName = name
command.Command = name
command.Flag = flag
// Check if Raw is enabled for command:
if !flag.Is(Raw) {
command.Command = lowerFirstLetter(name)
}
// Middlewares shouldn't even have arguments.
if flag.Is(Middleware) {
sub.mwMethods = append(sub.mwMethods, &command)
continue
}
// TODO: allow more flexibility
if command.event != typeMessageCreate || flag.Is(Hidden) {
sub.Events = append(sub.Events, &command)
continue
}
// See if we know the first return type, if error's return is the
// second:
if numOut > 1 {
switch t := methodT.Out(0); t {
case typeString, typeEmbed, typeSend:
// noop, passes
default:
continue
}
}
// If a plumb method has been found:
if sub.plumb {
continue
}
// If the method only takes an event:
if numArgs == 1 {
sub.Commands = append(sub.Commands, &command)
continue
}
command.Arguments = make([]Argument, 0, numArgs)
// Fill up arguments. This should work with cusP and manP
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
a, err := newArgument(t, command.Variadic)
if err != nil {
return errors.Wrap(err, "Error parsing argument "+t.String())
}
command.Arguments = append(command.Arguments, *a)
// We're done if the type accepts multiple arguments.
if a.custom != nil || a.manual != nil {
command.Variadic = true // treat as variadic
break
}
}
// If the current event is a plumb event:
if flag.Is(Plumb) {
command.Command = "" // plumbers don't have names
sub.Commands = []*CommandContext{&command}
sub.plumb = true
continue
}
// Append
sub.Commands = append(sub.Commands, &command)
// Append.
sub.Methods = append(sub.Methods, cctx)
}
return nil
}
func (sub *Subcommand) AddMiddleware(methodName string, middleware interface{}) {
var mw *MiddlewareContext
// Allow *MiddlewareContext to be passed into.
if v, ok := middleware.(*MiddlewareContext); ok {
mw = v
} else {
mw = ParseMiddleware(middleware)
}
// Parse method name:
for _, method := range strings.Split(methodName, ",") {
// Trim space.
if method = strings.TrimSpace(method); method == "*" {
// Append middleware to global middleware slice.
sub.globalmws = append(sub.globalmws, mw)
} else {
// Append middleware to that individual function.
sub.FindMethod(method).addMiddleware(mw)
}
}
}
func (sub *Subcommand) walkMiddlewares(ev reflect.Value) error {
for _, mw := range sub.globalmws {
_, err := mw.call(ev)
if err != nil {
return err
}
}
return nil
}
func (sub *Subcommand) eventCallers(evT reflect.Type) (callers []caller) {
// Search for global middlewares.
for _, mw := range sub.globalmws {
if mw.isEvent(evT) {
callers = append(callers, mw)
}
}
// Search for specific handlers.
for _, cctx := range sub.Methods {
// We only take middlewares and callers if the event matches and is not
// a MessageCreate. The other function already handles that.
if cctx.event != typeMessageCreate && cctx.isEvent(evT) {
// Add the command's middlewares first.
for _, mw := range cctx.middlewares {
// Concrete struct to interface conversion done implicitly.
callers = append(callers, mw)
}
callers = append(callers, cctx)
}
}
return
}
// SetPlumb sets the method as the plumbed command. This means that all calls
// without the second command argument will call this method in a subcommand. It
// panics if sub.Command is empty.
func (sub *Subcommand) SetPlumb(methodName string) {
if sub.Command == "" {
panic("SetPlumb called on a main command with sub.Command empty.")
}
method := sub.FindMethod(methodName)
method.Command = ""
sub.plumbed = method
}
func lowerFirstLetter(name string) string {
return strings.ToLower(string(name[0])) + name[1:]
}

View File

@ -29,8 +29,8 @@ func TestSubcommand(t *testing.T) {
}
// !!! CHANGE ME
if len(sub.Commands) != 8 {
t.Fatal("invalid ctx.commands len", len(sub.Commands))
if len(sub.Methods) < 8 {
t.Fatal("too low sub.Methods len", len(sub.Methods))
}
var (
@ -39,7 +39,7 @@ func TestSubcommand(t *testing.T) {
foundNoArgs bool
)
for _, this := range sub.Commands {
for _, this := range sub.Methods {
switch this.Command {
case "send":
foundSend = true
@ -58,13 +58,6 @@ func TestSubcommand(t *testing.T) {
if len(this.Arguments) != 0 {
t.Fatal("expected 0 arguments, got non-zero")
}
case "noop", "getCounter", "variadic", "trailCustom", "content":
// Found, but whatever
}
if this.event != typeMessageCreate {
t.Fatal("invalid event type:", this.event.String())
}
}