1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-11-30 16:17:57 +00:00

Bot: added data return types

This commit is contained in:
diamondburned (Forefront) 2020-01-26 01:06:54 -08:00
parent 0c57c6d127
commit 1c25ccbf8e
6 changed files with 124 additions and 44 deletions

View file

@ -20,9 +20,8 @@ type Bot struct {
}
// Help prints the default help message.
func (bot *Bot) Help(m *gateway.MessageCreateEvent) error {
_, err := bot.Ctx.SendMessage(m.ChannelID, bot.Ctx.Help(), nil)
return err
func (bot *Bot) Help(m *gateway.MessageCreateEvent) (string, error) {
return bot.Ctx.Help(), nil
}
// Add demonstrates the usage of typed arguments. Run it with "~add 1 2".
@ -40,40 +39,39 @@ func (bot *Bot) Ping(m *gateway.MessageCreateEvent) error {
}
// Say demonstrates how arguments.Flag could be used without the flag library.
func (bot *Bot) Say(m *gateway.MessageCreateEvent, f *arguments.Flag) error {
func (bot *Bot) Say(
m *gateway.MessageCreateEvent, f *arguments.Flag) (string, error) {
args := f.String()
if args == "" {
// Empty message, ignore
return nil
return "", nil
}
_, err := bot.Ctx.SendMessage(m.ChannelID, args, nil)
return err
return args, nil
}
// GuildInfo demonstrates the use of command flags, in this case the GuildOnly
// flag.
func (bot *Bot) GーGuildInfo(m *gateway.MessageCreateEvent) error {
func (bot *Bot) GーGuildInfo(m *gateway.MessageCreateEvent) (string, error) {
g, err := bot.Ctx.Guild(m.GuildID)
if err != nil {
return fmt.Errorf("Failed to get guild: %v", err)
return "", fmt.Errorf("Failed to get guild: %v", err)
}
_, err = bot.Ctx.SendMessage(m.ChannelID, fmt.Sprintf(
return fmt.Sprintf(
"Your guild is %s, and its maximum members is %d",
g.Name, g.MaxMembers,
), nil)
return err
), nil
}
// Repeat tells the bot to wait for the user's response, then repeat what they
// said.
func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) error {
func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) (string, error) {
_, err := bot.Ctx.SendMessage(m.ChannelID,
"What do you want me to say?", nil)
if err != nil {
return err
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
@ -91,19 +89,17 @@ func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) error {
})
if v == nil {
return errors.New("Timed out waiting for response.")
return "", errors.New("Timed out waiting for response.")
}
ev := v.(*gateway.MessageCreateEvent)
_, err = bot.Ctx.SendMessage(m.ChannelID, ev.Content, nil)
return err
return ev.Content, nil
}
// Embed is a simple embed creator. Its purpose is to demonstrate the usage of
// the ParseContent interface, as well as using the stdlib flag package.
func (bot *Bot) Embed(
m *gateway.MessageCreateEvent, f *arguments.Flag) error {
m *gateway.MessageCreateEvent, f *arguments.Flag) (*discord.Embed, error) {
fs := arguments.NewFlagSet()
@ -115,22 +111,22 @@ func (bot *Bot) Embed(
)
if err := f.With(fs.FlagSet); err != nil {
return err
return nil, err
}
if len(fs.Args()) < 1 {
return fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage())
return nil, fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage())
}
// Check if the color string is valid.
if !strings.HasPrefix(*color, "#") || len(*color) != 7 {
return errors.New("Invalid color, format must be #hhhhhh")
return nil, errors.New("Invalid color, format must be #hhhhhh")
}
// Parse the color into decimal numbers.
colorHex, err := strconv.ParseInt((*color)[1:], 16, 64)
if err != nil {
return err
return nil, err
}
// Make a new embed
@ -151,6 +147,5 @@ func (bot *Bot) Embed(
}
}
_, err = bot.Ctx.SendMessage(m.ChannelID, "", &embed)
return err
return &embed, err
}

View file

@ -132,7 +132,8 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) {
v, reflect.ValueOf(input),
})
if err := errorReturns(ret); err != nil {
_, err := errorReturns(ret)
if err != nil {
return nilV, err
}

View file

@ -16,19 +16,36 @@ import (
// Context is the bot state for commands and subcommands.
//
// Commands
//
// A command can be created by making it a method of Commands, or whatever
// struct was given to the constructor. This following example creates a command
// with a single integer argument (which can be ran with "~example 123"):
//
// func (c *Commands) Example(m *gateway.MessageCreateEvent, i int) error {
// _, err := c.Ctx.SendMessage(m.ChannelID, fmt.Sprintf("You sent: %d", i))
// return err
// func (c *Commands) Example(
// m *gateway.MessageCreateEvent, i int) (string, error) {
//
// return fmt.Sprintf("You sent: %d", i)
// }
//
// Commands' exported methods will all be used as commands. Messages are parsed
// with its first argument (the command) mapped accordingly to c.MapName, which
// capitalizes the first letter automatically to reflect the exported method
// name.
//
// A command can either return either an error, or data and error. The only data
// types allowed are string, *discord.Embed, and *api.SendMessageData. Any other
// return types will invalidate the method.
//
// Events
//
// An event can only have one argument, which is the pointer to the event
// struct. It can also only return error.
//
// func (c *Commands) Example(o *gateway.TypingStartEvent) error {
// log.Println("Someone's typing!")
// return nil
// }
type Context struct {
*Subcommand
*state.State

View file

@ -4,6 +4,7 @@ import (
"reflect"
"strings"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
@ -101,7 +102,8 @@ func (ctx *Context) callCmd(ev interface{}) error {
}
for _, c := range filtered {
if err := callWith(c.value, ev); err != nil {
_, err := callWith(c.value, ev)
if err != nil {
ctx.ErrorLogger(err)
}
}
@ -268,7 +270,8 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
}
// Check the returned error:
if err := errorReturns(ret); err != nil {
_, err := errorReturns(ret)
if err != nil {
return err
}
@ -319,13 +322,32 @@ Call:
// Try calling all middlewares first. We don't need to stack middlewares, as
// there will only be one command match.
for _, mw := range sub.mwMethods {
if err := callWith(mw.value, mc); err != nil {
_, err := callWith(mw.value, mc)
if err != nil {
return err
}
}
// call the function and parse the error return value
return callWith(cmd.value, mc, argv...)
v, err := callWith(cmd.value, mc, argv...)
if err != nil {
return err
}
switch v := v.(type) {
case string:
v = sub.SanitizeMessage(v)
_, err = ctx.SendMessage(mc.ChannelID, v, nil)
case *discord.Embed:
_, err = ctx.SendMessage(mc.ChannelID, "", v)
case *api.SendMessageData:
if v.Content != "" {
v.Content = sub.SanitizeMessage(v.Content)
}
_, err = ctx.SendMessageComplex(mc.ChannelID, *v)
}
return err
}
func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
@ -374,22 +396,28 @@ func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool {
return res
}
func callWith(caller reflect.Value, ev interface{}, values ...reflect.Value) error {
func callWith(
caller reflect.Value,
ev interface{}, values ...reflect.Value) (interface{}, error) {
return errorReturns(caller.Call(append(
[]reflect.Value{reflect.ValueOf(ev)},
values...,
)))
}
func errorReturns(returns []reflect.Value) error {
func errorReturns(returns []reflect.Value) (interface{}, error) {
// assume first is always error, since we checked for this in parseCommands
v := returns[0].Interface()
v := returns[len(returns)-1].Interface()
if v == nil {
return nil
if len(returns) == 1 {
return nil, nil
}
return returns[0].Interface(), nil
}
return v.(error)
return nil, v.(error)
}
func reflectChannelID(_struct interface{}) discord.Snowflake {

View file

@ -4,6 +4,8 @@ import (
"reflect"
"strings"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
)
@ -11,6 +13,10 @@ import (
var (
typeMessageCreate = reflect.TypeOf((*gateway.MessageCreateEvent)(nil))
typeString = reflect.TypeOf("")
typeEmbed = reflect.TypeOf((*discord.Embed)(nil))
typeSend = reflect.TypeOf((*api.SendMessageData)(nil))
typeSubcmd = reflect.TypeOf((*Subcommand)(nil))
typeIError = reflect.TypeOf((*error)(nil)).Elem()
@ -34,6 +40,9 @@ type Subcommand struct {
// Parsed command name:
Command string
// Commands can actually return either a string, an embed, or a
// SendMessageData, with error as the second argument.
// All registered command contexts:
Commands []*CommandContext
Events []*CommandContext
@ -44,6 +53,10 @@ type Subcommand struct {
// struct flags
Flag NameFlag
// SanitizeMessage is executed on the message content if the method returns
// a string content or a SendMessageData.
SanitizeMessage func(content string) string
// Plumb nameflag, use Commands[0] if true.
plumb bool
@ -73,6 +86,9 @@ type CommandContext struct {
event reflect.Type // gateway.*Event
method reflect.Method
// return type
retType reflect.Type
Arguments []Argument
}
@ -100,6 +116,9 @@ func (cctx *CommandContext) Usage() []string {
func NewSubcommand(cmd interface{}) (*Subcommand, error) {
var sub = Subcommand{
command: cmd,
SanitizeMessage: func(c string) string {
return c
},
}
if err := sub.reflectCommands(); err != nil {
@ -286,13 +305,14 @@ func (sub *Subcommand) parseCommands() error {
}
// Check number of returns:
if methodT.NumOut() != 1 {
numOut := methodT.NumOut()
if numOut == 0 || numOut > 2 {
continue
}
// Check return type
if err := methodT.Out(0); err == nil || !err.Implements(typeIError) {
// Invalid, skip
// Check the last return's type:
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
// Invalid, skip.
continue
}
@ -327,6 +347,17 @@ func (sub *Subcommand) parseCommands() error {
continue
}
// See if we know the first return type, if error's return is the
// second:
if numOut > 1 {
switch t := methodT.Out(0); t {
case typeString, typeEmbed, typeSend:
// noop, passes
default:
continue
}
}
// If a plumb method has been found:
if sub.plumb {
continue

View file

@ -15,6 +15,14 @@ var (
_ json.Marshaler = (*Timestamp)(nil)
)
func NewTimestamp(t time.Time) Timestamp {
return Timestamp(t)
}
func NowTimestamp() Timestamp {
return NewTimestamp(time.Now())
}
func (t *Timestamp) UnmarshalJSON(v []byte) error {
str := strings.Trim(string(v), `"`)
if str == "null" {