arikawa/bot/ctx_call.go

570 lines
12 KiB
Go
Raw Normal View History

2020-01-19 06:06:00 +00:00
package bot
import (
"reflect"
"strings"
2020-01-26 09:06:54 +00:00
"github.com/diamondburned/arikawa/api"
2020-01-19 06:06:00 +00:00
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
2020-01-26 07:17:18 +00:00
"github.com/pkg/errors"
2020-01-19 06:06:00 +00:00
)
// NonFatal is an interface that a method can implement to ignore all errors.
// This works similarly to Break.
type NonFatal interface {
error
IgnoreError() // noop method
}
func onlyFatal(err error) error {
if _, ok := err.(NonFatal); ok {
return nil
2020-01-23 07:20:24 +00:00
}
return err
}
2020-01-19 06:06:00 +00:00
type _Break struct{ error }
// implement NonFatal.
func (_Break) IgnoreError() {}
// Break is a non-fatal error that could be returned from middlewares or
// handlers to stop the chain of execution.
//
// Middlewares are guaranteed to be executed before handlers, but the exact
// order of each are undefined. Main handlers are also guaranteed to be executed
// before all subcommands. If a main middleware cancels, no subcommand
// middlewares will be called.
//
// Break implements the NonFatal interface, which causes an error to be ignored.
var Break NonFatal = _Break{errors.New("break middleware chain, non-fatal")}
func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
var callers []*CommandContext
var middles []*CommandContext
var found bool
find := func(sub *Subcommand) {
for _, cmd := range sub.Events {
// Search only for callers, so skip middlewares.
if cmd.Flag.Is(Middleware) {
continue
}
if cmd.event == evT {
callers = append(callers, cmd)
found = true
2020-01-19 06:06:00 +00:00
}
}
// Only get middlewares if we found handlers for that same event.
if found {
// Search for middlewares with the same type:
for _, mw := range sub.mwMethods {
if mw.event == evT {
middles = append(middles, mw)
}
}
}
2020-01-23 07:20:24 +00:00
}
2020-01-19 06:06:00 +00:00
// Find the main context first.
find(ctx.Subcommand)
for _, sub := range ctx.subcommands {
// Reset found status
found = false
// Find subcommands second.
find(sub)
}
return append(middles, callers...)
2020-01-23 07:20:24 +00:00
}
2020-01-19 06:06:00 +00:00
2020-01-23 07:20:24 +00:00
func (ctx *Context) callCmd(ev interface{}) error {
evT := reflect.TypeOf(ev)
2020-01-19 06:06:00 +00:00
var isAdmin *bool // I want to die.
var isGuild *bool
var callers []*CommandContext
// Hit the cache
t, ok := ctx.typeCache.Load(evT)
if ok {
callers = t.([]*CommandContext)
} else {
callers = ctx.filterEventType(evT)
ctx.typeCache.Store(evT, callers)
}
2020-01-23 07:20:24 +00:00
// We can't do the callers[:0] trick here, as it will modify the slice
// inside the sync.Map as well.
var filtered = make([]*CommandContext, 0, len(callers))
2020-01-19 06:06:00 +00:00
for _, cmd := range callers {
// Command flags will inherit its parent Subcommand's flags.
if true &&
!(cmd.Flag.Is(AdminOnly) && !ctx.eventIsAdmin(ev, &isAdmin)) &&
!(cmd.Flag.Is(GuildOnly) && !ctx.eventIsGuild(ev, &isGuild)) {
filtered = append(filtered, cmd)
2020-01-19 06:06:00 +00:00
}
}
2020-01-19 06:06:00 +00:00
for _, c := range filtered {
2020-01-26 09:06:54 +00:00
_, err := callWith(c.value, ev)
if err != nil {
if err = onlyFatal(err); err != nil {
ctx.ErrorLogger(err)
}
return err
}
2020-01-19 06:06:00 +00:00
}
2020-01-25 22:30:15 +00:00
// We call the messages later, since Hidden handlers will go into the Events
// slice, but we don't want to ignore those handlers either.
if evT == typeMessageCreate {
// safe assertion always
err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent))
return onlyFatal(err)
2020-01-25 22:30:15 +00:00
}
return nil
}
2020-01-19 06:06:00 +00:00
func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
2020-04-10 04:46:21 +00:00
// check if bot
if !ctx.AllowBot && mc.Author.Bot {
return nil
}
2020-01-19 06:06:00 +00:00
// check if prefix
pf, ok := ctx.HasPrefix(mc)
if !ok {
2020-01-19 06:06:00 +00:00
return nil
}
// trim the prefix before splitting, this way multi-words prefices work
content := mc.Content[len(pf):]
2020-01-19 06:06:00 +00:00
if content == "" {
return nil // just the prefix only
}
// parse arguments
args, err := ParseArgs(content)
if err != nil {
2020-01-26 07:17:18 +00:00
return errors.Wrap(err, "Failed to parse command")
2020-01-19 06:06:00 +00:00
}
if len(args) == 0 {
2020-01-19 06:06:00 +00:00
return nil // ???
}
var cmd *CommandContext
var sub *Subcommand
2020-01-19 06:06:00 +00:00
var start int // arg starts from $start
// Check if plumb:
if ctx.plumb {
cmd = ctx.Commands[0]
sub = ctx.Subcommand
start = 0
}
// If not plumb, search for the command
if cmd == nil {
for _, c := range ctx.Commands {
if c.Command == args[0] {
cmd = c
sub = ctx.Subcommand
start = 1
break
}
2020-01-19 06:06:00 +00:00
}
}
// Can't find the command, look for subcommands if len(args) has a 2nd
2020-01-19 06:06:00 +00:00
// entry.
if cmd == nil {
for _, s := range ctx.subcommands {
if s.Command != args[0] {
2020-01-19 06:06:00 +00:00
continue
}
// Check if plumb:
if s.plumb {
cmd = s.Commands[0]
sub = s
start = 1
break
}
// There's no second argument, so we can only look for Plumbed
// subcommands.
if len(args) < 2 {
continue
}
2020-01-19 06:06:00 +00:00
for _, c := range s.Commands {
if c.Command == args[1] {
2020-01-19 06:06:00 +00:00
cmd = c
sub = s
2020-01-19 06:06:00 +00:00
start = 2
}
}
if cmd == nil {
2020-01-29 00:01:39 +00:00
if s.QuietUnknownCommand {
return nil
}
2020-01-19 06:06:00 +00:00
return &ErrUnknownCommand{
Command: args[1],
Parent: args[0],
ctx: s.Commands,
}
}
break
2020-01-19 06:06:00 +00:00
}
}
2020-01-29 00:01:39 +00:00
if cmd == nil {
if ctx.QuietUnknownCommand {
return nil
}
2020-01-19 06:06:00 +00:00
return &ErrUnknownCommand{
Command: args[0],
ctx: ctx.Commands,
}
}
2020-01-23 07:20:24 +00:00
// Check for IsAdmin and IsGuild
if cmd.Flag.Is(GuildOnly) && !mc.GuildID.Valid() {
return nil
}
if cmd.Flag.Is(AdminOnly) {
p, err := ctx.State.Permissions(mc.ChannelID, mc.Author.ID)
if err != nil || !p.Has(discord.PermissionAdministrator) {
return nil
}
}
2020-01-19 06:06:00 +00:00
// Start converting
var argv []reflect.Value
// Here's an edge case: when the handler takes no arguments, we allow that
// anyway, as they might've used the raw content.
if len(cmd.Arguments) < 1 {
goto Call
}
// Check manual or parser
if cmd.Arguments[0].fn == nil {
// Create a zero value instance of this:
2020-05-03 22:59:10 +00:00
v := reflect.New(cmd.Arguments[0].rtype)
ret := []reflect.Value{}
2020-01-19 06:06:00 +00:00
switch {
case cmd.Arguments[0].manual != nil:
// Pop out the subcommand name, if there's one:
if sub.Command != "" {
args = args[1:]
}
// Call the manual parse method:
ret = cmd.Arguments[0].manual.Func.Call([]reflect.Value{
v, reflect.ValueOf(args),
})
case cmd.Arguments[0].custom != nil:
2020-04-10 06:10:21 +00:00
var pad = len(cmd.Command)
if len(sub.Command) > 0 { // if this is also a subcommand:
pad += len(sub.Command) + 1
}
// For consistent behavior, clear the subcommand (and command) name off:
content = content[pad:]
// Trim space if there are any:
content = strings.TrimSpace(content)
// Call the method with the raw unparsed command:
ret = cmd.Arguments[0].custom.Func.Call([]reflect.Value{
v, reflect.ValueOf(content),
})
}
2020-01-19 06:06:00 +00:00
// Check the returned error:
2020-01-26 09:06:54 +00:00
_, err := errorReturns(ret)
if err != nil {
2020-01-19 06:06:00 +00:00
return err
}
// Check if the argument wants a non-pointer:
if cmd.Arguments[0].pointer {
v = v.Elem()
}
// Add the argument to the list of arguments:
2020-01-19 06:06:00 +00:00
argv = append(argv, v)
goto Call
}
2020-05-03 22:59:10 +00:00
// Argument count check.
if argdelta := len(args[start:]) - len(cmd.Arguments); argdelta != 0 {
var err error // no err if nil
switch {
// If there aren't enough arguments given.
case argdelta < 0:
err = ErrNotEnoughArgs
// If there are too many arguments, then check if the command supports
// variadic arguments. We already did a length check above.
case argdelta > 0 && !cmd.Variadic:
// If it's not variadic, then we can't accept it.
err = ErrTooManyArgs
}
2020-05-03 22:59:10 +00:00
if err != nil {
return &ErrInvalidUsage{
Prefix: pf,
Args: args,
Index: len(args) - 1,
Wrap: err,
Ctx: cmd,
}
2020-01-19 06:06:00 +00:00
}
}
2020-05-03 22:59:10 +00:00
// Allocate a new slice the length of function arguments.
argv = make([]reflect.Value, len(cmd.Arguments))
2020-01-19 06:06:00 +00:00
2020-05-03 22:59:10 +00:00
for i := 0; i < len(argv); i++ {
v, err := cmd.Arguments[i].fn(args[start+i])
2020-01-19 06:06:00 +00:00
if err != nil {
return &ErrInvalidUsage{
2020-05-03 22:59:10 +00:00
Prefix: pf,
Args: args,
Index: i,
Wrap: err,
Ctx: cmd,
2020-01-19 06:06:00 +00:00
}
}
2020-05-03 22:59:10 +00:00
argv[i] = v
}
// Parse the rest with variadic arguments. Go's reflect states that varidic
// parameters will automatically be copied, which is good.
if len(args) > len(argv) {
// The location to continue parsing from args.
argc := len(argv)
// Allocate a new slice to append into. We start 1-off from the start,
// as the first argument of the variadic slice is already parsed.
vars := make([]reflect.Value, len(args)-len(argv)-1)
last := cmd.Arguments[len(cmd.Arguments)-1]
// Continue the above loop, where i stops before len(argv).
for i := 0; i < len(vars); i++ {
v, err := last.fn(args[argc+i+1])
if err != nil {
return &ErrInvalidUsage{
Prefix: pf,
Args: args,
Index: i,
Wrap: err,
Ctx: cmd,
}
}
vars[i] = v
}
argv = append(argv, vars...)
2020-01-19 06:06:00 +00:00
}
Call:
// Try calling all middlewares first. We don't need to stack middlewares, as
// there will only be one command match.
for _, mw := range sub.mwMethods {
2020-01-26 09:06:54 +00:00
_, err := callWith(mw.value, mc)
if err != nil {
return err
}
}
2020-01-19 06:06:00 +00:00
// call the function and parse the error return value
2020-01-26 09:06:54 +00:00
v, err := callWith(cmd.value, mc, argv...)
if err != nil {
return err
}
switch v := v.(type) {
case string:
v = sub.SanitizeMessage(v)
_, err = ctx.SendMessage(mc.ChannelID, v, nil)
case *discord.Embed:
_, err = ctx.SendMessage(mc.ChannelID, "", v)
case *api.SendMessageData:
if v.Content != "" {
v.Content = sub.SanitizeMessage(v.Content)
}
_, err = ctx.SendMessageComplex(mc.ChannelID, *v)
}
return err
2020-01-19 06:06:00 +00:00
}
func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
if *is != nil {
return **is
}
var channelID = reflectChannelID(ev)
if !channelID.Valid() {
return false
}
var userID = reflectUserID(ev)
if !userID.Valid() {
return false
}
var res bool
p, err := ctx.State.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
res = true
}
*is = &res
return res
}
2020-01-23 07:20:24 +00:00
func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool {
if *is != nil {
return **is
}
var channelID = reflectChannelID(ev)
if !channelID.Valid() {
return false
}
c, err := ctx.State.Channel(channelID)
if err != nil {
return false
}
res := c.GuildID.Valid()
*is = &res
return res
}
2020-01-26 09:06:54 +00:00
func callWith(
caller reflect.Value,
ev interface{}, values ...reflect.Value) (interface{}, error) {
2020-05-03 22:59:10 +00:00
values = append(
2020-01-19 06:06:00 +00:00
[]reflect.Value{reflect.ValueOf(ev)},
values...,
2020-05-03 22:59:10 +00:00
)
return errorReturns(caller.Call(values))
2020-01-19 06:06:00 +00:00
}
2020-01-26 09:06:54 +00:00
func errorReturns(returns []reflect.Value) (interface{}, error) {
// Handlers may return nothing.
if len(returns) == 0 {
return nil, nil
}
// assume first return is always error, since we checked for this in
// parseCommands.
2020-01-26 09:06:54 +00:00
v := returns[len(returns)-1].Interface()
// If the last return (error) is nil.
2020-01-19 06:06:00 +00:00
if v == nil {
// If we only have 1 returns, that return must be the error. The error
// is nil, so nil is returned.
2020-01-26 09:06:54 +00:00
if len(returns) == 1 {
return nil, nil
}
// Return the first argument as-is. The above returns[-1] check assumes
// 2 return values (T, error), meaning returns[0] is the T value.
2020-01-26 09:06:54 +00:00
return returns[0].Interface(), nil
2020-01-19 06:06:00 +00:00
}
// Treat the last return as an error.
2020-01-26 09:06:54 +00:00
return nil, v.(error)
2020-01-19 06:06:00 +00:00
}
func reflectChannelID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "Channel")
}
func reflectGuildID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "Guild")
}
func reflectUserID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "User")
}
func _reflectID(v reflect.Value, thing string) discord.Snowflake {
if !v.IsValid() {
return 0
}
t := v.Type()
if t.Kind() == reflect.Ptr {
v = v.Elem()
// Recheck after dereferring
if !v.IsValid() {
return 0
}
t = v.Type()
}
if t.Kind() != reflect.Struct {
return 0
}
numFields := t.NumField()
for i := 0; i < numFields; i++ {
field := t.Field(i)
fType := field.Type
if fType.Kind() == reflect.Ptr {
fType = fType.Elem()
}
switch fType.Kind() {
case reflect.Struct:
if chID := _reflectID(v.Field(i), thing); chID.Valid() {
return chID
}
case reflect.Int64:
if field.Name == thing+"ID" {
// grab value real quick
return discord.Snowflake(v.Field(i).Int())
}
// Special case where the struct name has Channel in it
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
return discord.Snowflake(v.Field(i).Int())
}
}
}
return 0
}