1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-19 02:58:01 +00:00
arikawa/bot/ctx_call.go
diamondburned 29582d6131 Bot: Allow both plumbed and normal commands
This commit changes the existing Plumb behavior to allow normal commands
to coexist along a plumbed command. This change allows certain behaviors
that would otherwise require manually switching on arguments.

An example use case of this change would be having a default behavior
when a subcommand call doesn't have a command name. For example, given
this code:

    func (b *Banana) Setup(sub *bot.Subcommand) { sub.SetPlumb(b.Help) }

    func (b *Banana) Green(*gateway.MessageCreateEvent) {}
    func (b *Banana) Help(*gateway.MessageCreateEvent)  {}

The subcommand "banana" could have its own help when it's called as
"!banana", while "!banana green" would trigger another handler.
2020-11-30 14:26:53 -08:00

409 lines
9.9 KiB
Go

package bot
import (
"reflect"
"strings"
"github.com/diamondburned/arikawa/v2/api"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/gateway"
"github.com/pkg/errors"
)
// Break is a non-fatal error that could be returned from middlewares to stop
// the chain of execution.
var Break = errors.New("break middleware chain, non-fatal")
// filterEventType filters all commands and subcommands into a 2D slice,
// structured so that a Break would only exit out the nested slice.
func (ctx *Context) filterEventType(evT reflect.Type) (callers [][]caller) {
// Find the main context first.
callers = append(callers, ctx.eventCallers(evT))
for _, sub := range ctx.subcommands {
// Find subcommands second.
callers = append(callers, sub.eventCallers(evT))
}
return
}
func (ctx *Context) callCmd(ev interface{}) (bottomError error) {
evV := reflect.ValueOf(ev)
evT := evV.Type()
var callers [][]caller
// Hit the cache
t, ok := ctx.typeCache.Load(evT)
if ok {
callers = t.([][]caller)
} else {
callers = ctx.filterEventType(evT)
ctx.typeCache.Store(evT, callers)
}
for _, subcallers := range callers {
for _, c := range subcallers {
_, err := c.call(evV)
if err != nil {
// Only count as an error if it's not Break.
if err = errNoBreak(err); err != nil {
bottomError = err
}
// Break the caller loop only for this subcommand.
break
}
}
}
var msc *gateway.MessageCreateEvent
// We call the messages later, since we want MessageCreate middlewares to
// run as well.
switch {
case evT == typeMessageCreate:
msc = ev.(*gateway.MessageCreateEvent)
case evT == typeMessageUpdate && ctx.EditableCommands:
up := ev.(*gateway.MessageUpdateEvent)
// Message updates could have empty contents when only their embeds are
// filled. We don't need that here.
if up.Content == "" {
return nil
}
// Query the updated message.
m, err := ctx.Cabinet.Message(up.ChannelID, up.ID)
if err != nil {
// It's probably safe to ignore this.
return nil
}
// Treat the message update as a message create event to avoid breaking
// changes.
msc = &gateway.MessageCreateEvent{Message: *m, Member: up.Member}
// Fill up member, if available.
if m.GuildID.IsValid() && up.Member == nil {
if mem, err := ctx.Cabinet.Member(m.GuildID, m.Author.ID); err == nil {
msc.Member = mem
}
}
// Update the reflect value as well.
evV = reflect.ValueOf(msc)
default:
// Unknown event, return.
return nil
}
// There's no need for an errNoBreak here, as the method already checked
// for that.
return ctx.callMessageCreate(msc, evV)
}
func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent, value reflect.Value) error {
// check if bot
if !ctx.AllowBot && mc.Author.Bot {
return nil
}
// check if prefix
pf, ok := ctx.HasPrefix(mc)
if !ok {
return nil
}
// trim the prefix before splitting, this way multi-words prefixes work
content := mc.Content[len(pf):]
if content == "" {
return nil // just the prefix only
}
// parse arguments
parts, parseErr := ctx.ParseArgs(content)
// We're not checking parse errors yet, as raw arguments may be able to
// ignore it.
if len(parts) == 0 {
return parseErr
}
// Find the command and subcommand.
arguments, cmd, sub, err := ctx.findCommand(parts)
if err != nil {
return errNoBreak(err)
}
// We don't run the subcommand's middlewares here, as the callCmd function
// already handles that.
// Run command middlewares.
if err := cmd.walkMiddlewares(value); err != nil {
return errNoBreak(err)
}
// Start converting
var argv []reflect.Value
var argc int
// the last argument in the list, not used until set
var last Argument
// Here's an edge case: when the handler takes no arguments, we allow that
// anyway, as they might've used the raw content.
if len(cmd.Arguments) == 0 {
goto Call
}
// Argument count check.
if argdelta := len(arguments) - len(cmd.Arguments); argdelta != 0 {
var err error // no err if nil
// If the function is variadic, then we can allow the last argument to
// be empty.
if cmd.Variadic {
argdelta++
}
switch {
// If there aren't enough arguments given.
case argdelta < 0:
err = ErrNotEnoughArgs
// If there are too many arguments, then check if the command supports
// variadic arguments. We already did a length check above.
case argdelta > 0 && !cmd.Variadic:
// If it's not variadic, then we can't accept it.
err = ErrTooManyArgs
}
if err != nil {
return &ErrInvalidUsage{
Prefix: pf,
Args: parts,
Index: len(parts) - 1,
Wrap: err,
Ctx: cmd,
}
}
}
// The last argument in the arguments slice.
last = cmd.Arguments[len(cmd.Arguments)-1]
// Allocate a new slice the length of function arguments.
argc = len(cmd.Arguments) - 1 // arg len without last
argv = make([]reflect.Value, 0, argc) // could be 0
// Parse all arguments except for the last one.
for i := 0; i < argc; i++ {
v, err := cmd.Arguments[i].fn(arguments[0])
if err != nil {
return &ErrInvalidUsage{
Prefix: pf,
Args: parts,
Index: len(parts) - len(arguments) + i,
Wrap: err,
Ctx: cmd,
}
}
// Pop arguments.
arguments = arguments[1:]
argv = append(argv, v)
}
// Is this last argument actually a variadic slice? If yes, then it
// should still have fn normally.
if last.fn != nil {
// Allocate a new slice to append into.
vars := make([]reflect.Value, 0, len(arguments))
// Parse the rest with variadic arguments. Go's reflect states that
// variadic parameters will automatically be copied, which is good.
for i := 0; len(arguments) > 0; i++ {
v, err := last.fn(arguments[0])
if err != nil {
return &ErrInvalidUsage{
Prefix: pf,
Args: parts,
Index: len(parts) - len(arguments) + i,
Wrap: err,
Ctx: cmd,
}
}
arguments = arguments[1:]
vars = append(vars, v)
}
argv = append(argv, vars...)
} else {
// Create a zero value instance of this:
v := reflect.New(last.rtype)
var err error // return error
switch {
// If the argument wants all arguments:
case last.manual != nil:
// Call the manual parse method:
err = last.manual(v.Interface().(ManualParser), arguments)
// If the argument wants all arguments in string:
case last.custom != nil:
// Ignore parser errors. This allows custom commands sliced away to
// have erroneous hanging quotes.
parseErr = nil
content = trimPrefixStringAndSlice(content, sub.Command, sub.Aliases)
// If the current command is not the plumbed command, then we can
// keep trimming. We have to check for this, as a plumbed subcommand
// may return other non-plumbed commands.
if cmd != sub.plumbed {
content = trimPrefixStringAndSlice(content, cmd.Command, cmd.Aliases)
}
// Call the method with the raw unparsed command:
err = last.custom(v.Interface().(CustomParser), content)
}
// Check the returned error:
if err != nil {
return err
}
// Check if the argument wants a non-pointer:
if last.pointer {
v = v.Elem()
}
// Add the argument into argv.
argv = append(argv, v)
}
// Check for parsing errors after parsing arguments.
if parseErr != nil {
return parseErr
}
Call:
// call the function and parse the error return value
v, err := cmd.call(value, argv...)
if err != nil {
return err
}
switch v := v.(type) {
case string:
v = sub.SanitizeMessage(v)
_, err = ctx.SendMessage(mc.ChannelID, v, nil)
case *discord.Embed:
_, err = ctx.SendMessage(mc.ChannelID, "", v)
case *api.SendMessageData:
if v.Content != "" {
v.Content = sub.SanitizeMessage(v.Content)
}
_, err = ctx.SendMessageComplex(mc.ChannelID, *v)
}
return err
}
// findCommand filters.
func (ctx *Context) findCommand(parts []string) ([]string, *MethodContext, *Subcommand, error) {
// Main command entrypoint cannot have plumb.
for _, c := range ctx.Commands {
if searchStringAndSlice(parts[0], c.Command, c.Aliases) {
return parts[1:], c, ctx.Subcommand, nil
}
}
// Can't find the command, look for subcommands if len(args) has a 2nd
// entry.
for _, s := range ctx.subcommands {
if !searchStringAndSlice(parts[0], s.Command, s.Aliases) {
continue
}
// The new plumbing behavior allows other commands to co-exist with a
// plumbed command. Those commands will override the second argument,
// similarly to a non-plumbed command.
if len(parts) >= 2 {
for _, c := range s.Commands {
// Skip plumbed commands as those are considered to have
// an empty Command.
if c == s.plumbed {
continue
}
if searchStringAndSlice(parts[1], c.Command, c.Aliases) {
return parts[2:], c, s, nil
}
}
}
if s.IsPlumbed() {
return parts[1:], s.plumbed, s, nil
}
// If unknown command is disabled or the subcommand is hidden:
if ctx.SilentUnknown.Subcommand || s.Hidden {
return nil, nil, nil, Break
}
return nil, nil, nil, newErrUnknownCommand(s, parts)
}
if ctx.SilentUnknown.Command {
return nil, nil, nil, Break
}
return nil, nil, nil, newErrUnknownCommand(ctx.Subcommand, parts)
}
// searchStringAndSlice searches if str is equal to isString or any of the given
// otherStrings. It is used for alias matching.
func searchStringAndSlice(str string, isString string, otherStrings []string) bool {
if str == isString {
return true
}
for _, other := range otherStrings {
if other == str {
return true
}
}
return false
}
// trimPrefixStringAndSlice behaves similarly to searchStringAndSlice, but it
// trims the prefix and the surrounding spaces after a match.
func trimPrefixStringAndSlice(str string, prefix string, prefixes []string) string {
if strings.HasPrefix(str, prefix) {
return strings.TrimSpace(str[len(prefix):])
}
for _, prefix := range prefixes {
if strings.HasPrefix(str, prefix) {
return strings.TrimSpace(str[len(prefix):])
}
}
return str
}
func errNoBreak(err error) error {
if errors.Is(err, Break) {
return nil
}
return err
}