From 1c25ccbf8e315010ee9b1a9af9cbf70f0253464a Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sun, 26 Jan 2020 01:06:54 -0800 Subject: [PATCH] Bot: added data return types --- _example/advanced_bot/bot.go | 47 ++++++++++++++++------------------- bot/arguments.go | 3 ++- bot/ctx.go | 23 ++++++++++++++--- bot/ctx_call.go | 48 ++++++++++++++++++++++++++++-------- bot/subcommand.go | 39 ++++++++++++++++++++++++++--- discord/time.go | 8 ++++++ 6 files changed, 124 insertions(+), 44 deletions(-) diff --git a/_example/advanced_bot/bot.go b/_example/advanced_bot/bot.go index 3805380..16e1ebe 100644 --- a/_example/advanced_bot/bot.go +++ b/_example/advanced_bot/bot.go @@ -20,9 +20,8 @@ type Bot struct { } // Help prints the default help message. -func (bot *Bot) Help(m *gateway.MessageCreateEvent) error { - _, err := bot.Ctx.SendMessage(m.ChannelID, bot.Ctx.Help(), nil) - return err +func (bot *Bot) Help(m *gateway.MessageCreateEvent) (string, error) { + return bot.Ctx.Help(), nil } // Add demonstrates the usage of typed arguments. Run it with "~add 1 2". @@ -40,40 +39,39 @@ func (bot *Bot) Ping(m *gateway.MessageCreateEvent) error { } // Say demonstrates how arguments.Flag could be used without the flag library. -func (bot *Bot) Say(m *gateway.MessageCreateEvent, f *arguments.Flag) error { +func (bot *Bot) Say( + m *gateway.MessageCreateEvent, f *arguments.Flag) (string, error) { + args := f.String() if args == "" { // Empty message, ignore - return nil + return "", nil } - _, err := bot.Ctx.SendMessage(m.ChannelID, args, nil) - return err + return args, nil } // GuildInfo demonstrates the use of command flags, in this case the GuildOnly // flag. -func (bot *Bot) GーGuildInfo(m *gateway.MessageCreateEvent) error { +func (bot *Bot) GーGuildInfo(m *gateway.MessageCreateEvent) (string, error) { g, err := bot.Ctx.Guild(m.GuildID) if err != nil { - return fmt.Errorf("Failed to get guild: %v", err) + return "", fmt.Errorf("Failed to get guild: %v", err) } - _, err = bot.Ctx.SendMessage(m.ChannelID, fmt.Sprintf( + return fmt.Sprintf( "Your guild is %s, and its maximum members is %d", g.Name, g.MaxMembers, - ), nil) - - return err + ), nil } // Repeat tells the bot to wait for the user's response, then repeat what they // said. -func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) error { +func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) (string, error) { _, err := bot.Ctx.SendMessage(m.ChannelID, "What do you want me to say?", nil) if err != nil { - return err + return "", err } ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -91,19 +89,17 @@ func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) error { }) if v == nil { - return errors.New("Timed out waiting for response.") + return "", errors.New("Timed out waiting for response.") } ev := v.(*gateway.MessageCreateEvent) - - _, err = bot.Ctx.SendMessage(m.ChannelID, ev.Content, nil) - return err + return ev.Content, nil } // Embed is a simple embed creator. Its purpose is to demonstrate the usage of // the ParseContent interface, as well as using the stdlib flag package. func (bot *Bot) Embed( - m *gateway.MessageCreateEvent, f *arguments.Flag) error { + m *gateway.MessageCreateEvent, f *arguments.Flag) (*discord.Embed, error) { fs := arguments.NewFlagSet() @@ -115,22 +111,22 @@ func (bot *Bot) Embed( ) if err := f.With(fs.FlagSet); err != nil { - return err + return nil, err } if len(fs.Args()) < 1 { - return fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage()) + return nil, fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage()) } // Check if the color string is valid. if !strings.HasPrefix(*color, "#") || len(*color) != 7 { - return errors.New("Invalid color, format must be #hhhhhh") + return nil, errors.New("Invalid color, format must be #hhhhhh") } // Parse the color into decimal numbers. colorHex, err := strconv.ParseInt((*color)[1:], 16, 64) if err != nil { - return err + return nil, err } // Make a new embed @@ -151,6 +147,5 @@ func (bot *Bot) Embed( } } - _, err = bot.Ctx.SendMessage(m.ChannelID, "", &embed) - return err + return &embed, err } diff --git a/bot/arguments.go b/bot/arguments.go index b55d890..540e494 100644 --- a/bot/arguments.go +++ b/bot/arguments.go @@ -132,7 +132,8 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) { v, reflect.ValueOf(input), }) - if err := errorReturns(ret); err != nil { + _, err := errorReturns(ret) + if err != nil { return nilV, err } diff --git a/bot/ctx.go b/bot/ctx.go index 938794c..63aa558 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -16,19 +16,36 @@ import ( // Context is the bot state for commands and subcommands. // +// Commands +// // A command can be created by making it a method of Commands, or whatever // struct was given to the constructor. This following example creates a command // with a single integer argument (which can be ran with "~example 123"): // -// func (c *Commands) Example(m *gateway.MessageCreateEvent, i int) error { -// _, err := c.Ctx.SendMessage(m.ChannelID, fmt.Sprintf("You sent: %d", i)) -// return err +// func (c *Commands) Example( +// m *gateway.MessageCreateEvent, i int) (string, error) { +// +// return fmt.Sprintf("You sent: %d", i) // } // // Commands' exported methods will all be used as commands. Messages are parsed // with its first argument (the command) mapped accordingly to c.MapName, which // capitalizes the first letter automatically to reflect the exported method // name. +// +// A command can either return either an error, or data and error. The only data +// types allowed are string, *discord.Embed, and *api.SendMessageData. Any other +// return types will invalidate the method. +// +// Events +// +// An event can only have one argument, which is the pointer to the event +// struct. It can also only return error. +// +// func (c *Commands) Example(o *gateway.TypingStartEvent) error { +// log.Println("Someone's typing!") +// return nil +// } type Context struct { *Subcommand *state.State diff --git a/bot/ctx_call.go b/bot/ctx_call.go index a86f056..50c46c5 100644 --- a/bot/ctx_call.go +++ b/bot/ctx_call.go @@ -4,6 +4,7 @@ import ( "reflect" "strings" + "github.com/diamondburned/arikawa/api" "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/gateway" "github.com/pkg/errors" @@ -101,7 +102,8 @@ func (ctx *Context) callCmd(ev interface{}) error { } for _, c := range filtered { - if err := callWith(c.value, ev); err != nil { + _, err := callWith(c.value, ev) + if err != nil { ctx.ErrorLogger(err) } } @@ -268,7 +270,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error { } // Check the returned error: - if err := errorReturns(ret); err != nil { + _, err := errorReturns(ret) + if err != nil { return err } @@ -319,13 +322,32 @@ Call: // Try calling all middlewares first. We don't need to stack middlewares, as // there will only be one command match. for _, mw := range sub.mwMethods { - if err := callWith(mw.value, mc); err != nil { + _, err := callWith(mw.value, mc) + if err != nil { return err } } // call the function and parse the error return value - return callWith(cmd.value, mc, argv...) + v, err := callWith(cmd.value, mc, argv...) + if err != nil { + return err + } + + switch v := v.(type) { + case string: + v = sub.SanitizeMessage(v) + _, err = ctx.SendMessage(mc.ChannelID, v, nil) + case *discord.Embed: + _, err = ctx.SendMessage(mc.ChannelID, "", v) + case *api.SendMessageData: + if v.Content != "" { + v.Content = sub.SanitizeMessage(v.Content) + } + _, err = ctx.SendMessageComplex(mc.ChannelID, *v) + } + + return err } func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool { @@ -374,22 +396,28 @@ func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool { return res } -func callWith(caller reflect.Value, ev interface{}, values ...reflect.Value) error { +func callWith( + caller reflect.Value, + ev interface{}, values ...reflect.Value) (interface{}, error) { + return errorReturns(caller.Call(append( []reflect.Value{reflect.ValueOf(ev)}, values..., ))) } -func errorReturns(returns []reflect.Value) error { +func errorReturns(returns []reflect.Value) (interface{}, error) { // assume first is always error, since we checked for this in parseCommands - v := returns[0].Interface() - + v := returns[len(returns)-1].Interface() if v == nil { - return nil + if len(returns) == 1 { + return nil, nil + } + + return returns[0].Interface(), nil } - return v.(error) + return nil, v.(error) } func reflectChannelID(_struct interface{}) discord.Snowflake { diff --git a/bot/subcommand.go b/bot/subcommand.go index e4c5d14..30acc80 100644 --- a/bot/subcommand.go +++ b/bot/subcommand.go @@ -4,6 +4,8 @@ import ( "reflect" "strings" + "github.com/diamondburned/arikawa/api" + "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/gateway" "github.com/pkg/errors" ) @@ -11,6 +13,10 @@ import ( var ( typeMessageCreate = reflect.TypeOf((*gateway.MessageCreateEvent)(nil)) + typeString = reflect.TypeOf("") + typeEmbed = reflect.TypeOf((*discord.Embed)(nil)) + typeSend = reflect.TypeOf((*api.SendMessageData)(nil)) + typeSubcmd = reflect.TypeOf((*Subcommand)(nil)) typeIError = reflect.TypeOf((*error)(nil)).Elem() @@ -34,6 +40,9 @@ type Subcommand struct { // Parsed command name: Command string + // Commands can actually return either a string, an embed, or a + // SendMessageData, with error as the second argument. + // All registered command contexts: Commands []*CommandContext Events []*CommandContext @@ -44,6 +53,10 @@ type Subcommand struct { // struct flags Flag NameFlag + // SanitizeMessage is executed on the message content if the method returns + // a string content or a SendMessageData. + SanitizeMessage func(content string) string + // Plumb nameflag, use Commands[0] if true. plumb bool @@ -73,6 +86,9 @@ type CommandContext struct { event reflect.Type // gateway.*Event method reflect.Method + // return type + retType reflect.Type + Arguments []Argument } @@ -100,6 +116,9 @@ func (cctx *CommandContext) Usage() []string { func NewSubcommand(cmd interface{}) (*Subcommand, error) { var sub = Subcommand{ command: cmd, + SanitizeMessage: func(c string) string { + return c + }, } if err := sub.reflectCommands(); err != nil { @@ -286,13 +305,14 @@ func (sub *Subcommand) parseCommands() error { } // Check number of returns: - if methodT.NumOut() != 1 { + numOut := methodT.NumOut() + if numOut == 0 || numOut > 2 { continue } - // Check return type - if err := methodT.Out(0); err == nil || !err.Implements(typeIError) { - // Invalid, skip + // Check the last return's type: + if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) { + // Invalid, skip. continue } @@ -327,6 +347,17 @@ func (sub *Subcommand) parseCommands() error { continue } + // See if we know the first return type, if error's return is the + // second: + if numOut > 1 { + switch t := methodT.Out(0); t { + case typeString, typeEmbed, typeSend: + // noop, passes + default: + continue + } + } + // If a plumb method has been found: if sub.plumb { continue diff --git a/discord/time.go b/discord/time.go index 788b021..65d483c 100644 --- a/discord/time.go +++ b/discord/time.go @@ -15,6 +15,14 @@ var ( _ json.Marshaler = (*Timestamp)(nil) ) +func NewTimestamp(t time.Time) Timestamp { + return Timestamp(t) +} + +func NowTimestamp() Timestamp { + return NewTimestamp(time.Now()) +} + func (t *Timestamp) UnmarshalJSON(v []byte) error { str := strings.Trim(string(v), `"`) if str == "null" {