mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-10-19 23:54:47 +00:00
Ported rfrouter over
This commit is contained in:
parent
189853de32
commit
28228a60f5
92
_example/advanced_bot/context.go
Normal file
92
_example/advanced_bot/context.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/diamondburned/arikawa/bot"
|
||||
"github.com/diamondburned/arikawa/bot/extras/arguments"
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
)
|
||||
|
||||
type Bot struct {
|
||||
// Context must not be embedded.
|
||||
Ctx *bot.Context
|
||||
}
|
||||
|
||||
func (bot *Bot) Help(m *gateway.MessageCreateEvent) error {
|
||||
_, err := bot.Ctx.SendMessage(m.ChannelID, bot.Ctx.Help(), nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (bot *Bot) Ping(m *gateway.MessageCreateEvent) error {
|
||||
_, err := bot.Ctx.SendMessage(m.ChannelID, "Pong!", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (bot *Bot) Say(m *gateway.MessageCreateEvent, f *arguments.Flag) error {
|
||||
args := f.String()
|
||||
if args == "" {
|
||||
// Empty message, ignore
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := bot.Ctx.SendMessage(m.ChannelID, args, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (bot *Bot) Embed(
|
||||
m *gateway.MessageCreateEvent, f *arguments.Flag) error {
|
||||
|
||||
fs := arguments.NewFlagSet()
|
||||
|
||||
var (
|
||||
title = fs.String("title", "", "Title")
|
||||
author = fs.String("author", "", "Author")
|
||||
footer = fs.String("footer", "", "Footer")
|
||||
color = fs.String("color", "#FFFFFF", "Color in hex format #hhhhhh")
|
||||
)
|
||||
|
||||
if err := f.With(fs.FlagSet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(fs.Args()) < 1 {
|
||||
return fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage())
|
||||
}
|
||||
|
||||
// Check if the color string is valid.
|
||||
if !strings.HasPrefix(*color, "#") || len(*color) != 7 {
|
||||
return errors.New("Invalid color, format must be #hhhhhh")
|
||||
}
|
||||
|
||||
// Parse the color into decimal numbers.
|
||||
colorHex, err := strconv.ParseInt((*color)[1:], 16, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make a new embed
|
||||
embed := discord.Embed{
|
||||
Title: *title,
|
||||
Description: strings.Join(fs.Args(), " "),
|
||||
Color: discord.Color(colorHex),
|
||||
}
|
||||
|
||||
if *author != "" {
|
||||
embed.Author = &discord.EmbedAuthor{
|
||||
Name: *author,
|
||||
}
|
||||
}
|
||||
if *footer != "" {
|
||||
embed.Footer = &discord.EmbedFooter{
|
||||
Text: *footer,
|
||||
}
|
||||
}
|
||||
|
||||
_, err = bot.Ctx.SendMessage(m.ChannelID, "", &embed)
|
||||
return err
|
||||
}
|
35
_example/advanced_bot/main.go
Normal file
35
_example/advanced_bot/main.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/diamondburned/arikawa/bot"
|
||||
)
|
||||
|
||||
// To run, do `BOT_TOKEN="TOKEN HERE" go run .`
|
||||
|
||||
func main() {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
log.Fatalln("No $BOT_TOKEN given.")
|
||||
}
|
||||
|
||||
commands := &Bot{}
|
||||
|
||||
stop, err := bot.Start(token, commands, func(ctx *bot.Context) error {
|
||||
ctx.Prefix = "!"
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
defer stop()
|
||||
|
||||
log.Println("Bot started")
|
||||
|
||||
// Automatically block until SIGINT.
|
||||
bot.Wait()
|
||||
}
|
122
bot/arguments.go
Normal file
122
bot/arguments.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type argumentValueFn func(string) (reflect.Value, error)
|
||||
|
||||
// Parseable implements a Parse(string) method for data structures that can be
|
||||
// used as arguments.
|
||||
type Parseable interface {
|
||||
Parse(string) error
|
||||
}
|
||||
|
||||
// ManaulParseable implements a ParseContent(string) method. If the library sees
|
||||
// this for an argument, it will send all of the arguments (including the
|
||||
// command) into the method. If used, this should be the only argument followed
|
||||
// after the Message Create event. Any more and the router will ignore.
|
||||
type ManualParseable interface {
|
||||
// $0 will have its prefix trimmed.
|
||||
ParseContent([]string) error
|
||||
}
|
||||
|
||||
type RawArguments struct {
|
||||
Arguments []string
|
||||
}
|
||||
|
||||
func (r *RawArguments) ParseContent(args []string) error {
|
||||
r.Arguments = args
|
||||
return nil
|
||||
}
|
||||
|
||||
// nilV, only used to return an error
|
||||
var nilV = reflect.Value{}
|
||||
|
||||
func getArgumentValueFn(t reflect.Type) (argumentValueFn, error) {
|
||||
if t.Implements(typeIParser) {
|
||||
mt, ok := t.MethodByName("Parse")
|
||||
if !ok {
|
||||
panic("BUG: type IParser does not implement Parse")
|
||||
}
|
||||
|
||||
return func(input string) (reflect.Value, error) {
|
||||
v := reflect.New(t.Elem())
|
||||
|
||||
ret := mt.Func.Call([]reflect.Value{
|
||||
v, reflect.ValueOf(input),
|
||||
})
|
||||
|
||||
if err := errorReturns(ret); err != nil {
|
||||
return nilV, err
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
var fn argumentValueFn
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
fn = func(s string) (reflect.Value, error) {
|
||||
return reflect.ValueOf(s), nil
|
||||
}
|
||||
|
||||
case reflect.Int, reflect.Int8,
|
||||
reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
|
||||
fn = func(s string) (reflect.Value, error) {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
return quickRet(i, err, t)
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8,
|
||||
reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
|
||||
fn = func(s string) (reflect.Value, error) {
|
||||
u, err := strconv.ParseUint(s, 10, 64)
|
||||
return quickRet(u, err, t)
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fn = func(s string) (reflect.Value, error) {
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
return quickRet(f, err, t)
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
fn = func(s string) (reflect.Value, error) {
|
||||
switch s {
|
||||
case "true", "yes", "y", "Y", "1":
|
||||
return reflect.ValueOf(true), nil
|
||||
case "false", "no", "n", "N", "0":
|
||||
return reflect.ValueOf(false), nil
|
||||
default:
|
||||
return nilV, errors.New("invalid bool [true/false]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fn == nil {
|
||||
return nil, errors.New("invalid type: " + t.String())
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
func quickRet(v interface{}, err error, t reflect.Type) (reflect.Value, error) {
|
||||
if err != nil {
|
||||
return nilV, err
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
if t == nil {
|
||||
return rv, nil
|
||||
}
|
||||
|
||||
return rv.Convert(t), nil
|
||||
}
|
114
bot/copied_from_d.go
Normal file
114
bot/copied_from_d.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package bot
|
||||
|
||||
/*
|
||||
// UserPermissions but userID is after channelID.
|
||||
func (ctx *Context) UserPermissions(channelID, userID string,
|
||||
) (apermissions int, err error) {
|
||||
|
||||
// Try to just get permissions from state.
|
||||
apermissions, err = ctx.Session.State.UserChannelPermissions(
|
||||
userID, channelID)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise try get as much data from state as possible, falling back to the network.
|
||||
channel, err := ctx.Channel(channelID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
guild, err := ctx.Guild(channel.GuildID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if userID == guild.OwnerID {
|
||||
apermissions = discordgo.PermissionAll
|
||||
return
|
||||
}
|
||||
|
||||
member, err := ctx.Member(guild.ID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return MemberPermissions(guild, channel, member), nil
|
||||
}
|
||||
|
||||
// Why this isn't exported, I have no idea.
|
||||
func MemberPermissions(guild *discordgo.Guild, channel *discordgo.Channel,
|
||||
member *discordgo.Member) (apermissions int) {
|
||||
|
||||
userID := member.User.ID
|
||||
|
||||
if userID == guild.OwnerID {
|
||||
apermissions = discordgo.PermissionAll
|
||||
return
|
||||
}
|
||||
|
||||
for _, role := range guild.Roles {
|
||||
if role.ID == guild.ID {
|
||||
apermissions |= role.Permissions
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for _, role := range guild.Roles {
|
||||
for _, roleID := range member.Roles {
|
||||
if role.ID == roleID {
|
||||
apermissions |= role.Permissions
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apermissions&discordgo.PermissionAdministrator ==
|
||||
discordgo.PermissionAdministrator {
|
||||
|
||||
apermissions |= discordgo.PermissionAll
|
||||
}
|
||||
|
||||
// Apply @everyone overrides from the channel.
|
||||
for _, overwrite := range channel.PermissionOverwrites {
|
||||
if guild.ID == overwrite.ID {
|
||||
apermissions &= ^overwrite.Deny
|
||||
apermissions |= overwrite.Allow
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
denies := 0
|
||||
allows := 0
|
||||
|
||||
// Member overwrites can override role overrides, so do two passes
|
||||
for _, overwrite := range channel.PermissionOverwrites {
|
||||
for _, roleID := range member.Roles {
|
||||
if overwrite.Type == "role" && roleID == overwrite.ID {
|
||||
denies |= overwrite.Deny
|
||||
allows |= overwrite.Allow
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
apermissions &= ^denies
|
||||
apermissions |= allows
|
||||
|
||||
for _, overwrite := range channel.PermissionOverwrites {
|
||||
if overwrite.Type == "member" && overwrite.ID == userID {
|
||||
apermissions &= ^overwrite.Deny
|
||||
apermissions |= overwrite.Allow
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if apermissions&discordgo.PermissionAdministrator ==
|
||||
discordgo.PermissionAdministrator {
|
||||
|
||||
apermissions |= discordgo.PermissionAllChannel
|
||||
}
|
||||
|
||||
return apermissions
|
||||
}
|
||||
*/
|
281
bot/ctx.go
Normal file
281
bot/ctx.go
Normal file
|
@ -0,0 +1,281 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/state"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// TODO: add variadic arguments
|
||||
|
||||
type Context struct {
|
||||
*Subcommand
|
||||
*state.State
|
||||
|
||||
// Descriptive (but optional) bot name
|
||||
Name string
|
||||
|
||||
// Descriptive help body
|
||||
Description string
|
||||
|
||||
// The prefix for commands
|
||||
Prefix string
|
||||
|
||||
// FormatError formats any errors returned by anything, including the method
|
||||
// commands or the reflect functions. This also includes invalid usage
|
||||
// errors or unknown command errors. Returning an empty string means
|
||||
// ignoring the error.
|
||||
FormatError func(error) string
|
||||
|
||||
// ErrorLogger logs any error that anything makes and the library can't
|
||||
// reply to the client. This includes any event callback errors that aren't
|
||||
// Message Create.
|
||||
ErrorLogger func(error)
|
||||
|
||||
// ReplyError when true replies to the user the error.
|
||||
ReplyError bool
|
||||
|
||||
// Subcommands contains all the registered subcommands.
|
||||
Subcommands []*Subcommand
|
||||
}
|
||||
|
||||
// Start quickly starts a bot with the given command. It will prepend "Bot"
|
||||
// into the token automatically. Refer to example/ for usage.
|
||||
func Start(token string, cmd interface{},
|
||||
opts func(*Context) error) (stop func() error, err error) {
|
||||
|
||||
s, err := state.New("Bot " + token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to create a dgo session")
|
||||
}
|
||||
|
||||
c, err := New(s, cmd)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to create rfrouter")
|
||||
}
|
||||
|
||||
s.ErrorLog = func(err error) {
|
||||
c.ErrorLogger(err)
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
if err := opts(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cancel := c.Start()
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to connect to Discord")
|
||||
}
|
||||
|
||||
return func() error {
|
||||
cancel()
|
||||
return s.Close()
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Wait is a convenient function that blocks until a SIGINT is sent.
|
||||
func Wait() {
|
||||
sigs := make(chan os.Signal)
|
||||
signal.Notify(sigs, os.Interrupt)
|
||||
<-sigs
|
||||
}
|
||||
|
||||
// New makes a new context with a "~" as the prefix. cmds must be a pointer to a
|
||||
// struct with a *Context field. Example:
|
||||
//
|
||||
// type Commands struct {
|
||||
// Ctx *Context
|
||||
// }
|
||||
//
|
||||
// cmds := &Commands{}
|
||||
// c, err := rfrouter.New(session, cmds)
|
||||
//
|
||||
// Commands' exported methods will all be used as commands. Messages are parsed
|
||||
// with its first argument (the command) mapped accordingly to c.MapName, which
|
||||
// capitalizes the first letter automatically to reflect the exported method
|
||||
// name.
|
||||
//
|
||||
// The default prefix is "~", which means commands must start with "~" followed
|
||||
// by the command name in the first argument, else it will be ignored.
|
||||
//
|
||||
// c.Start() should be called afterwards to actually handle incoming events.
|
||||
func New(s *state.State, cmd interface{}) (*Context, error) {
|
||||
c, err := NewSubcommand(cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := &Context{
|
||||
Subcommand: c,
|
||||
State: s,
|
||||
Prefix: "~",
|
||||
FormatError: func(err error) string {
|
||||
return err.Error()
|
||||
},
|
||||
ErrorLogger: func(err error) {
|
||||
log.Println("Bot error:", err)
|
||||
},
|
||||
ReplyError: true,
|
||||
}
|
||||
|
||||
if err := ctx.InitCommands(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to initialize with given cmds")
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (ctx *Context) RegisterSubcommand(cmd interface{}) (*Subcommand, error) {
|
||||
s, err := NewSubcommand(cmd)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to add subcommand")
|
||||
}
|
||||
|
||||
// Register the subcommand's name.
|
||||
s.NeedsName()
|
||||
|
||||
if err := s.InitCommands(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to initialize subcommand")
|
||||
}
|
||||
|
||||
// Do a collision check
|
||||
for _, sub := range ctx.Subcommands {
|
||||
if sub.name == s.name {
|
||||
return nil, errors.New(
|
||||
"New subcommand has duplicate name: " + s.name)
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Subcommands = append(ctx.Subcommands, s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Start adds itself into the discordgo Session handlers. This needs to be run.
|
||||
// The returned function is a delete function, which removes itself from the
|
||||
// Session handlers.
|
||||
func (ctx *Context) Start() func() {
|
||||
return ctx.Session.AddHandler(func(v interface{}) {
|
||||
if err := ctx.callCmd(v); err != nil {
|
||||
if str := ctx.FormatError(err); str != "" {
|
||||
// Log the main error first
|
||||
ctx.ErrorLogger(errors.Wrap(err, str))
|
||||
|
||||
mc, ok := v.(*gateway.MessageCreateEvent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.ReplyError {
|
||||
_, Merr := ctx.SendMessage(mc.ChannelID, str, nil)
|
||||
if Merr != nil {
|
||||
// Then the message error
|
||||
ctx.ErrorLogger(Merr)
|
||||
// TODO: there ought to be a better way lol
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Call should only be used if you know what you're doing.
|
||||
func (ctx *Context) Call(event interface{}) error {
|
||||
return ctx.callCmd(event)
|
||||
}
|
||||
|
||||
// Help generates one. This function is used more for reference than an actual
|
||||
// help message. As such, it only uses exported fields or methods.
|
||||
func (ctx *Context) Help() string {
|
||||
var help strings.Builder
|
||||
|
||||
// Generate the headers and descriptions
|
||||
help.WriteString("__Help__")
|
||||
|
||||
if ctx.Name != "" {
|
||||
help.WriteString(": " + ctx.Name)
|
||||
}
|
||||
|
||||
if ctx.Description != "" {
|
||||
help.WriteString("\n " + ctx.Description)
|
||||
}
|
||||
|
||||
if ctx.Flag.Is(AdminOnly) {
|
||||
// That's it.
|
||||
return help.String()
|
||||
}
|
||||
|
||||
// Separators
|
||||
help.WriteString("\n---\n")
|
||||
|
||||
// Generate all commands
|
||||
help.WriteString("__Commands__\n")
|
||||
|
||||
for _, cmd := range ctx.Commands {
|
||||
if cmd.Flag.Is(AdminOnly) {
|
||||
// Hidden
|
||||
continue
|
||||
}
|
||||
|
||||
help.WriteString(" " + ctx.Prefix + cmd.Name())
|
||||
|
||||
switch {
|
||||
case len(cmd.Usage()) > 0:
|
||||
help.WriteString(" " + strings.Join(cmd.Usage(), " "))
|
||||
case cmd.Description != "":
|
||||
help.WriteString(": " + cmd.Description)
|
||||
}
|
||||
|
||||
help.WriteByte('\n')
|
||||
}
|
||||
|
||||
var subHelp = strings.Builder{}
|
||||
|
||||
for _, sub := range ctx.Subcommands {
|
||||
if sub.Flag.Is(AdminOnly) {
|
||||
// Hidden
|
||||
continue
|
||||
}
|
||||
|
||||
subHelp.WriteString(" " + sub.Name())
|
||||
|
||||
if sub.Description != "" {
|
||||
subHelp.WriteString(": " + sub.Description)
|
||||
}
|
||||
|
||||
subHelp.WriteByte('\n')
|
||||
|
||||
for _, cmd := range sub.Commands {
|
||||
if cmd.Flag.Is(AdminOnly) {
|
||||
continue
|
||||
}
|
||||
|
||||
subHelp.WriteString(" " +
|
||||
ctx.Prefix + sub.Name() + " " + cmd.Name())
|
||||
|
||||
switch {
|
||||
case len(cmd.Usage()) > 0:
|
||||
subHelp.WriteString(" " + strings.Join(cmd.Usage(), " "))
|
||||
case cmd.Description != "":
|
||||
subHelp.WriteString(": " + cmd.Description)
|
||||
}
|
||||
|
||||
subHelp.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
if sub := subHelp.String(); sub != "" {
|
||||
help.WriteString("---\n")
|
||||
help.WriteString("__Subcommands__\n")
|
||||
help.WriteString(sub)
|
||||
}
|
||||
|
||||
return help.String()
|
||||
}
|
312
bot/ctx_call.go
Normal file
312
bot/ctx_call.go
Normal file
|
@ -0,0 +1,312 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
)
|
||||
|
||||
func (ctx *Context) callCmd(ev interface{}) error {
|
||||
evT := reflect.TypeOf(ev)
|
||||
|
||||
if evT != typeMessageCreate {
|
||||
var callers []reflect.Value
|
||||
var isAdmin *bool // i want to die
|
||||
|
||||
for _, cmd := range ctx.Commands {
|
||||
if cmd.event == evT {
|
||||
if cmd.Flag.Is(AdminOnly) &&
|
||||
!ctx.eventIsAdmin(ev, &isAdmin) {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
callers = append(callers, cmd.value)
|
||||
}
|
||||
}
|
||||
|
||||
for _, sub := range ctx.Subcommands {
|
||||
if sub.Flag.Is(AdminOnly) &&
|
||||
!ctx.eventIsAdmin(ev, &isAdmin) {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, cmd := range sub.Commands {
|
||||
if cmd.event == evT {
|
||||
if cmd.Flag.Is(AdminOnly) &&
|
||||
!ctx.eventIsAdmin(ev, &isAdmin) {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
callers = append(callers, cmd.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, c := range callers {
|
||||
if err := callWith(c, ev); err != nil {
|
||||
ctx.ErrorLogger(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// safe assertion always
|
||||
mc := ev.(*gateway.MessageCreateEvent)
|
||||
|
||||
// check if prefix
|
||||
if !strings.HasPrefix(mc.Content, ctx.Prefix) {
|
||||
// not a command, ignore
|
||||
return nil
|
||||
}
|
||||
|
||||
// trim the prefix before splitting, this way multi-words prefices work
|
||||
content := mc.Content[len(ctx.Prefix):]
|
||||
|
||||
if content == "" {
|
||||
return nil // just the prefix only
|
||||
}
|
||||
|
||||
// parse arguments
|
||||
args, err := ParseArgs(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(args) < 1 {
|
||||
return nil // ???
|
||||
}
|
||||
|
||||
var cmd *CommandContext
|
||||
var start int // arg starts from $start
|
||||
|
||||
// Search for the command
|
||||
for _, c := range ctx.Commands {
|
||||
if c.name == args[0] {
|
||||
cmd = c
|
||||
start = 1
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Can't find command, look for subcommands of len(args) has a 2nd
|
||||
// entry.
|
||||
if cmd == nil && len(args) > 1 {
|
||||
for _, s := range ctx.Subcommands {
|
||||
if s.name != args[0] {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, c := range s.Commands {
|
||||
if c.name == args[1] {
|
||||
cmd = c
|
||||
start = 2
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if cmd == nil {
|
||||
return &ErrUnknownCommand{
|
||||
Command: args[1],
|
||||
Parent: args[0],
|
||||
Prefix: ctx.Prefix,
|
||||
ctx: s.Commands,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cmd == nil || start == 0 {
|
||||
return &ErrUnknownCommand{
|
||||
Command: args[0],
|
||||
Prefix: ctx.Prefix,
|
||||
ctx: ctx.Commands,
|
||||
}
|
||||
}
|
||||
|
||||
// Start converting
|
||||
var argv []reflect.Value
|
||||
|
||||
// Check manual parser
|
||||
if cmd.parseType != nil {
|
||||
// Create a zero value instance of this
|
||||
v := reflect.New(cmd.parseType)
|
||||
|
||||
// Call the manual parse method
|
||||
ret := cmd.parseMethod.Func.Call([]reflect.Value{
|
||||
v, reflect.ValueOf(args),
|
||||
})
|
||||
|
||||
// Check the method returns for error
|
||||
if err := errorReturns(ret); err != nil {
|
||||
// TODO: maybe wrap this?
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the pointer to the argument into argv
|
||||
argv = append(argv, v)
|
||||
goto Call
|
||||
}
|
||||
|
||||
// Here's an edge case: when the handler takes no arguments, we allow that
|
||||
// anyway, as they might've used the raw content.
|
||||
if len(cmd.arguments) == 0 {
|
||||
goto Call
|
||||
}
|
||||
|
||||
// Not enough arguments given
|
||||
if len(args[start:]) != len(cmd.arguments) {
|
||||
return &ErrInvalidUsage{
|
||||
Args: args,
|
||||
Prefix: ctx.Prefix,
|
||||
Index: len(cmd.arguments) - start,
|
||||
Err: "Not enough arguments given",
|
||||
ctx: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
argv = make([]reflect.Value, len(cmd.arguments))
|
||||
|
||||
for i := start; i < len(args); i++ {
|
||||
v, err := cmd.arguments[i-start](args[i])
|
||||
if err != nil {
|
||||
return &ErrInvalidUsage{
|
||||
Args: args,
|
||||
Prefix: ctx.Prefix,
|
||||
Index: i,
|
||||
Err: err.Error(),
|
||||
ctx: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
argv[i-start] = v
|
||||
}
|
||||
|
||||
Call:
|
||||
// call the function and parse the error return value
|
||||
return callWith(cmd.value, ev, argv...)
|
||||
}
|
||||
|
||||
func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
|
||||
if *is != nil {
|
||||
return **is
|
||||
}
|
||||
|
||||
var channelID = reflectChannelID(ev)
|
||||
if !channelID.Valid() {
|
||||
return false
|
||||
}
|
||||
|
||||
var userID = reflectUserID(ev)
|
||||
if !userID.Valid() {
|
||||
return false
|
||||
}
|
||||
|
||||
var res bool
|
||||
|
||||
p, err := ctx.State.Permissions(channelID, userID)
|
||||
if err == nil && p.Has(discord.PermissionAdministrator) {
|
||||
res = true
|
||||
}
|
||||
|
||||
*is = &res
|
||||
return res
|
||||
}
|
||||
|
||||
func callWith(caller reflect.Value, ev interface{}, values ...reflect.Value) error {
|
||||
return errorReturns(caller.Call(append(
|
||||
[]reflect.Value{reflect.ValueOf(ev)},
|
||||
values...,
|
||||
)))
|
||||
}
|
||||
|
||||
var ParseArgs = func(args string) ([]string, error) {
|
||||
// TODO: make modular
|
||||
// TODO: actual tokenizer+parser
|
||||
r := csv.NewReader(strings.NewReader(args))
|
||||
r.Comma = ' '
|
||||
|
||||
return r.Read()
|
||||
}
|
||||
|
||||
func errorReturns(returns []reflect.Value) error {
|
||||
// assume first is always error, since we checked for this in parseCommands
|
||||
v := returns[0].Interface()
|
||||
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v.(error)
|
||||
}
|
||||
|
||||
func reflectChannelID(_struct interface{}) discord.Snowflake {
|
||||
return _reflectID(reflect.ValueOf(_struct), "Channel")
|
||||
}
|
||||
|
||||
func reflectGuildID(_struct interface{}) discord.Snowflake {
|
||||
return _reflectID(reflect.ValueOf(_struct), "Guild")
|
||||
}
|
||||
|
||||
func reflectUserID(_struct interface{}) discord.Snowflake {
|
||||
return _reflectID(reflect.ValueOf(_struct), "User")
|
||||
}
|
||||
|
||||
func _reflectID(v reflect.Value, thing string) discord.Snowflake {
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
|
||||
if t.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
|
||||
// Recheck after dereferring
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t = v.Type()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
return 0
|
||||
}
|
||||
|
||||
numFields := t.NumField()
|
||||
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := t.Field(i)
|
||||
fType := field.Type
|
||||
|
||||
if fType.Kind() == reflect.Ptr {
|
||||
fType = fType.Elem()
|
||||
}
|
||||
|
||||
switch fType.Kind() {
|
||||
case reflect.Struct:
|
||||
if chID := _reflectID(v.Field(i), thing); chID.Valid() {
|
||||
return chID
|
||||
}
|
||||
case reflect.Int64:
|
||||
if field.Name == thing+"ID" {
|
||||
// grab value real quick
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
|
||||
// Special case where the struct name has Channel in it
|
||||
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
312
bot/ctx_test.go
Normal file
312
bot/ctx_test.go
Normal file
|
@ -0,0 +1,312 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/state"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type testCommands struct {
|
||||
Ctx *Context
|
||||
Return chan interface{}
|
||||
}
|
||||
|
||||
func (t *testCommands) Send(_ *gateway.MessageCreateEvent, arg string) error {
|
||||
t.Return <- arg
|
||||
return errors.New("oh no")
|
||||
}
|
||||
|
||||
func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *CustomParseable) error {
|
||||
t.Return <- c.args
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testCommands) NoArgs(_ *gateway.MessageCreateEvent) error {
|
||||
return errors.New("passed")
|
||||
}
|
||||
|
||||
func (t *testCommands) Noop(_ *gateway.MessageCreateEvent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type CustomParseable struct {
|
||||
args []string
|
||||
}
|
||||
|
||||
func (c *CustomParseable) ParseContent(args []string) error {
|
||||
c.args = args
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
var state = &state.State{
|
||||
Store: state.NewDefaultStore(nil),
|
||||
}
|
||||
|
||||
_, err := New(state, &testCommands{})
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create new context:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
var given = &testCommands{}
|
||||
var state = &state.State{
|
||||
Store: state.NewDefaultStore(nil),
|
||||
}
|
||||
|
||||
s, err := NewSubcommand(given)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create subcommand:", err)
|
||||
}
|
||||
|
||||
var ctx = &Context{
|
||||
Subcommand: s,
|
||||
State: state,
|
||||
}
|
||||
|
||||
t.Run("init commands", func(t *testing.T) {
|
||||
if err := ctx.Subcommand.InitCommands(ctx); err != nil {
|
||||
t.Fatal("Failed to init commands:", err)
|
||||
}
|
||||
|
||||
if given.Ctx == nil {
|
||||
t.Fatal("given's Context field is nil")
|
||||
}
|
||||
|
||||
if given.Ctx.State.Store == nil {
|
||||
t.Fatal("given's State is nil")
|
||||
}
|
||||
})
|
||||
|
||||
testReturn := func(expects interface{}, content string) (call error) {
|
||||
// Return channel for testing
|
||||
ret := make(chan interface{})
|
||||
given.Return = ret
|
||||
|
||||
// Mock a messageCreate event
|
||||
m := &gateway.MessageCreateEvent{
|
||||
Content: content,
|
||||
}
|
||||
|
||||
var (
|
||||
callCh = make(chan error)
|
||||
)
|
||||
|
||||
go func() {
|
||||
callCh <- ctx.callCmd(m)
|
||||
}()
|
||||
|
||||
select {
|
||||
case arg := <-ret:
|
||||
if !reflect.DeepEqual(arg, expects) {
|
||||
t.Fatal("returned argument is invalid:", arg)
|
||||
}
|
||||
call = <-callCh
|
||||
|
||||
case call = <-callCh:
|
||||
t.Fatal("expected return before error:", call)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("call command", func(t *testing.T) {
|
||||
// Set a custom prefix
|
||||
ctx.Prefix = "~"
|
||||
|
||||
if err := testReturn("test", "~send test"); err.Error() != "oh no" {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("call command custom parser", func(t *testing.T) {
|
||||
ctx.Prefix = "!"
|
||||
expects := []string{"custom", "arg1", ":)"}
|
||||
|
||||
if err := testReturn(expects, "!custom arg1 :)"); err != nil {
|
||||
t.Fatal("Unexpected call error:", err)
|
||||
}
|
||||
})
|
||||
|
||||
testMessage := func(content string) error {
|
||||
// Mock a messageCreate event
|
||||
m := &gateway.MessageCreateEvent{
|
||||
Content: content,
|
||||
}
|
||||
|
||||
return ctx.callCmd(m)
|
||||
}
|
||||
|
||||
t.Run("call command without args", func(t *testing.T) {
|
||||
ctx.Prefix = ""
|
||||
|
||||
if err := testMessage("noargs"); err.Error() != "passed" {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test error cases
|
||||
|
||||
t.Run("call unknown command", func(t *testing.T) {
|
||||
ctx.Prefix = "joe pls "
|
||||
|
||||
err := testMessage("joe pls no")
|
||||
|
||||
if err == nil || !strings.HasPrefix(err.Error(), "Unknown command:") {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test subcommands
|
||||
|
||||
t.Run("register subcommand", func(t *testing.T) {
|
||||
ctx.Prefix = "run "
|
||||
|
||||
_, err := ctx.RegisterSubcommand(&testCommands{})
|
||||
if err != nil {
|
||||
t.Fatal("Failed to register subcommand:", err)
|
||||
}
|
||||
|
||||
if err := testMessage("run testcommands noop"); err != nil {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkConstructor(b *testing.B) {
|
||||
var state = &state.State{
|
||||
Store: state.NewDefaultStore(nil),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = New(state, &testCommands{})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCall(b *testing.B) {
|
||||
var given = &testCommands{}
|
||||
var state = &state.State{
|
||||
Store: state.NewDefaultStore(nil),
|
||||
}
|
||||
|
||||
s, _ := NewSubcommand(given)
|
||||
|
||||
var ctx = &Context{
|
||||
Subcommand: s,
|
||||
State: state,
|
||||
Prefix: "~",
|
||||
}
|
||||
|
||||
m := &gateway.MessageCreateEvent{
|
||||
Content: "~noop",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx.callCmd(m)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHelp(b *testing.B) {
|
||||
var given = &testCommands{}
|
||||
var state = &state.State{
|
||||
Store: state.NewDefaultStore(nil),
|
||||
}
|
||||
|
||||
s, _ := NewSubcommand(given)
|
||||
|
||||
var ctx = &Context{
|
||||
Subcommand: s,
|
||||
State: state,
|
||||
Prefix: "~",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ctx.Help()
|
||||
}
|
||||
}
|
||||
|
||||
type hasID struct {
|
||||
ChannelID discord.Snowflake
|
||||
}
|
||||
|
||||
type embedsID struct {
|
||||
*hasID
|
||||
*embedsID
|
||||
}
|
||||
|
||||
type hasChannelInName struct {
|
||||
ID discord.Snowflake
|
||||
}
|
||||
|
||||
func TestReflectChannelID(t *testing.T) {
|
||||
var s = &hasID{
|
||||
ChannelID: 69420,
|
||||
}
|
||||
|
||||
t.Run("hasID", func(t *testing.T) {
|
||||
if id := reflectChannelID(s); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("embedsID", func(t *testing.T) {
|
||||
var e = &embedsID{
|
||||
hasID: s,
|
||||
}
|
||||
|
||||
if id := reflectChannelID(e); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hasChannelInName", func(t *testing.T) {
|
||||
var s = &hasChannelInName{
|
||||
ID: 69420,
|
||||
}
|
||||
|
||||
if id := reflectChannelID(s); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkReflectChannelID_1Level(b *testing.B) {
|
||||
var s = &hasID{
|
||||
ChannelID: 69420,
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = reflectChannelID(s)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReflectChannelID_5Level(b *testing.B) {
|
||||
var s = &embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
hasID: &hasID{
|
||||
ChannelID: 69420,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = reflectChannelID(s)
|
||||
}
|
||||
}
|
66
bot/error.go
Normal file
66
bot/error.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ErrUnknownCommand struct {
|
||||
Command string
|
||||
Parent string
|
||||
|
||||
Prefix string
|
||||
|
||||
// TODO: list available commands?
|
||||
// Here, as a reminder
|
||||
ctx []*CommandContext
|
||||
}
|
||||
|
||||
func (err *ErrUnknownCommand) Error() string {
|
||||
var header = "Unknown command: " + err.Prefix
|
||||
if err.Parent != "" {
|
||||
header += err.Parent + " " + err.Command
|
||||
} else {
|
||||
header += err.Command
|
||||
}
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
type ErrInvalidUsage struct {
|
||||
Args []string
|
||||
Prefix string
|
||||
|
||||
Index int
|
||||
Err string
|
||||
|
||||
// TODO: usage generator?
|
||||
// Here, as a reminder
|
||||
ctx *CommandContext
|
||||
}
|
||||
|
||||
func (err *ErrInvalidUsage) Error() string {
|
||||
if err.Index == 0 {
|
||||
return "Invalid usage"
|
||||
}
|
||||
|
||||
if len(err.Args) == 0 {
|
||||
return "Missing arguments. Refer to help."
|
||||
}
|
||||
|
||||
body := "Invalid usage at " + err.Prefix
|
||||
|
||||
// Write the first part
|
||||
body += strings.Join(err.Args[:err.Index], " ")
|
||||
|
||||
// Write the wrong part
|
||||
body += " __" + err.Args[err.Index] + "__ "
|
||||
|
||||
// Write the last part
|
||||
body += strings.Join(err.Args[err.Index+1:], " ")
|
||||
|
||||
if err.Err != "" {
|
||||
body += "\nError: " + err.Err
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
51
bot/extras/arguments/emoji.go
Normal file
51
bot/extras/arguments/emoji.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package arguments
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var (
|
||||
EmojiRegex = regexp.MustCompile(`<(a?):(.+?):(\d+)>`)
|
||||
|
||||
ErrInvalidEmoji = errors.New("Invalid emoji")
|
||||
)
|
||||
|
||||
type Emoji struct {
|
||||
ID string
|
||||
|
||||
Custom bool
|
||||
Name string
|
||||
Animated bool
|
||||
}
|
||||
|
||||
func (e *Emoji) Parse(arg string) error {
|
||||
// Check if Unicode
|
||||
var unicode string
|
||||
|
||||
for _, r := range arg {
|
||||
if r < '\U0001F600' && r > '\U0001F64F' {
|
||||
unicode += string(r)
|
||||
}
|
||||
}
|
||||
|
||||
if unicode != "" {
|
||||
e.ID = unicode
|
||||
e.Custom = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var matches = EmojiRegex.FindStringSubmatch(arg)
|
||||
|
||||
if len(matches) != 4 {
|
||||
return ErrInvalidEmoji
|
||||
}
|
||||
|
||||
e.Custom = true
|
||||
e.Animated = matches[1] == "a"
|
||||
e.Name = matches[2]
|
||||
e.ID = matches[3]
|
||||
|
||||
return nil
|
||||
}
|
65
bot/extras/arguments/flag.go
Normal file
65
bot/extras/arguments/flag.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package arguments
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var FlagName = "command"
|
||||
|
||||
type FlagSet struct {
|
||||
*flag.FlagSet
|
||||
}
|
||||
|
||||
func NewFlagSet() *FlagSet {
|
||||
fs := flag.NewFlagSet(FlagName, flag.ContinueOnError)
|
||||
fs.SetOutput(ioutil.Discard)
|
||||
|
||||
return &FlagSet{fs}
|
||||
}
|
||||
|
||||
func (fs *FlagSet) Usage() string {
|
||||
var buf bytes.Buffer
|
||||
|
||||
fs.FlagSet.SetOutput(&buf)
|
||||
fs.FlagSet.Usage()
|
||||
fs.FlagSet.SetOutput(ioutil.Discard)
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
type Flag struct {
|
||||
arguments []string
|
||||
}
|
||||
|
||||
func (f *Flag) ParseContent(arguments []string) error {
|
||||
// trim the command out
|
||||
f.arguments = arguments[1:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Flag) Usage() string {
|
||||
return "flags..."
|
||||
}
|
||||
|
||||
func (f *Flag) Args() []string {
|
||||
return f.arguments
|
||||
}
|
||||
|
||||
func (f *Flag) Arg(n int) string {
|
||||
if n < 0 || n >= len(f.arguments) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return f.arguments[n]
|
||||
}
|
||||
|
||||
func (f *Flag) String() string {
|
||||
return strings.Join(f.arguments, " ")
|
||||
}
|
||||
|
||||
func (f *Flag) With(fs *flag.FlagSet) error {
|
||||
return fs.Parse(f.arguments)
|
||||
}
|
52
bot/extras/arguments/mention.go
Normal file
52
bot/extras/arguments/mention.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package arguments
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var (
|
||||
ChannelRegex = regexp.MustCompile(`<#(\d+)>`)
|
||||
UserRegex = regexp.MustCompile(`<@!?(\d+)>`)
|
||||
RoleRegex = regexp.MustCompile(`<@&(\d+)>`)
|
||||
)
|
||||
|
||||
type ChannelMention string
|
||||
|
||||
func (m *ChannelMention) Parse(arg string) error {
|
||||
return grabFirst(ChannelRegex, "channel mention", arg, (*string)(m))
|
||||
}
|
||||
|
||||
func (m *ChannelMention) Usage() string {
|
||||
return "#channel"
|
||||
}
|
||||
|
||||
type UserMention string
|
||||
|
||||
func (m *UserMention) Parse(arg string) error {
|
||||
return grabFirst(UserRegex, "user mention", arg, (*string)(m))
|
||||
}
|
||||
|
||||
func (m *UserMention) Usage() string {
|
||||
return "@user"
|
||||
}
|
||||
|
||||
type RoleMention string
|
||||
|
||||
func (m *RoleMention) Parse(arg string) error {
|
||||
return grabFirst(RoleRegex, "role mention", arg, (*string)(m))
|
||||
}
|
||||
|
||||
func (m *RoleMention) Usage() string {
|
||||
return "@role"
|
||||
}
|
||||
|
||||
func grabFirst(reg *regexp.Regexp, item, input string, output *string) error {
|
||||
matches := reg.FindStringSubmatch(input)
|
||||
if len(matches) < 2 {
|
||||
return errors.New("Invalid " + item)
|
||||
}
|
||||
|
||||
*output = matches[1]
|
||||
return nil
|
||||
}
|
40
bot/nameflag.go
Normal file
40
bot/nameflag.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package bot
|
||||
|
||||
import "strings"
|
||||
|
||||
type NameFlag uint64
|
||||
|
||||
const FlagSeparator = 'ー'
|
||||
|
||||
const (
|
||||
None NameFlag = 1 << iota
|
||||
|
||||
// These flags only apply to messageCreate events.
|
||||
|
||||
Raw // R
|
||||
AdminOnly // A
|
||||
)
|
||||
|
||||
func ParseFlag(name string) (NameFlag, string) {
|
||||
parts := strings.SplitN(name, string(FlagSeparator), 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, name
|
||||
}
|
||||
|
||||
var f NameFlag
|
||||
|
||||
for _, r := range parts[0] {
|
||||
switch r {
|
||||
case 'R':
|
||||
f |= Raw
|
||||
case 'A':
|
||||
f |= AdminOnly
|
||||
}
|
||||
}
|
||||
|
||||
return f, parts[1]
|
||||
}
|
||||
|
||||
func (f NameFlag) Is(flag NameFlag) bool {
|
||||
return f&flag != 0
|
||||
}
|
26
bot/nameflag_test.go
Normal file
26
bot/nameflag_test.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package bot
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNameFlag(t *testing.T) {
|
||||
type entry struct {
|
||||
Name string
|
||||
Expect NameFlag
|
||||
String string
|
||||
}
|
||||
|
||||
var entries = []entry{{
|
||||
Name: "AーEcho",
|
||||
Expect: AdminOnly,
|
||||
}, {
|
||||
Name: "RAーGC",
|
||||
Expect: Raw | AdminOnly,
|
||||
}}
|
||||
|
||||
for _, entry := range entries {
|
||||
var f, _ = ParseFlag(entry.Name)
|
||||
if !f.Is(entry.Expect) {
|
||||
t.Fatalf("unexpected expectation for %s: %v", entry.Name, f)
|
||||
}
|
||||
}
|
||||
}
|
298
bot/subcommand.go
Normal file
298
bot/subcommand.go
Normal file
|
@ -0,0 +1,298 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
typeMessageCreate = reflect.TypeOf((*gateway.MessageCreateEvent)(nil))
|
||||
// typeof.Implements(typeI*)
|
||||
typeIError = reflect.TypeOf((*error)(nil)).Elem()
|
||||
typeIManP = reflect.TypeOf((*ManualParseable)(nil)).Elem()
|
||||
typeIParser = reflect.TypeOf((*Parseable)(nil)).Elem()
|
||||
typeIUsager = reflect.TypeOf((*Usager)(nil)).Elem()
|
||||
)
|
||||
|
||||
type Subcommand struct {
|
||||
Description string
|
||||
|
||||
// Commands contains all the registered command contexts.
|
||||
Commands []*CommandContext
|
||||
|
||||
// struct name
|
||||
name string
|
||||
|
||||
// struct flags
|
||||
Flag NameFlag
|
||||
|
||||
// Directly to struct
|
||||
cmdValue reflect.Value
|
||||
cmdType reflect.Type
|
||||
|
||||
// Pointer value
|
||||
ptrValue reflect.Value
|
||||
ptrType reflect.Type
|
||||
|
||||
// command interface as reference
|
||||
command interface{}
|
||||
}
|
||||
|
||||
// CommandContext is an internal struct containing fields to make this library
|
||||
// work. As such, they're all unexported. Description, however, is exported for
|
||||
// editing, and may be used to generate more informative help messages.
|
||||
type CommandContext struct {
|
||||
Description string
|
||||
Flag NameFlag
|
||||
|
||||
name string // all lower-case
|
||||
value reflect.Value // Func
|
||||
event reflect.Type // gateway.*Event
|
||||
method reflect.Method
|
||||
|
||||
// equal slices
|
||||
argStrings []string
|
||||
arguments []argumentValueFn
|
||||
|
||||
// only for ParseContent interface
|
||||
parseMethod reflect.Method
|
||||
parseType reflect.Type
|
||||
parseUsage string
|
||||
}
|
||||
|
||||
// Descriptor is optionally used to set the Description of a command context.
|
||||
type Descriptor interface {
|
||||
Description() string
|
||||
}
|
||||
|
||||
// Namer is optionally used to override the command context's name.
|
||||
type Namer interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
// Usager is optionally used to override the generated usage for either an
|
||||
// argument, or multiple (using ManualParseable).
|
||||
type Usager interface {
|
||||
Usage() string
|
||||
}
|
||||
|
||||
func (cctx *CommandContext) Name() string {
|
||||
return cctx.name
|
||||
}
|
||||
|
||||
func (cctx *CommandContext) Usage() []string {
|
||||
if cctx.parseType != nil {
|
||||
return []string{cctx.parseUsage}
|
||||
}
|
||||
|
||||
if len(cctx.arguments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cctx.argStrings
|
||||
}
|
||||
|
||||
func NewSubcommand(cmd interface{}) (*Subcommand, error) {
|
||||
var sub = Subcommand{
|
||||
command: cmd,
|
||||
}
|
||||
|
||||
// Set description
|
||||
if d, ok := cmd.(Descriptor); ok {
|
||||
sub.Description = d.Description()
|
||||
}
|
||||
|
||||
if err := sub.reflectCommands(); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to reflect commands")
|
||||
}
|
||||
|
||||
if err := sub.parseCommands(); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to parse commands")
|
||||
}
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// Name returns the command name in lower case. This only returns non-zero for
|
||||
// subcommands.
|
||||
func (sub *Subcommand) Name() string {
|
||||
return sub.name
|
||||
}
|
||||
|
||||
// NeedsName sets the name for this subcommand. Like InitCommands, this
|
||||
// shouldn't be called at all, rather you should use RegisterSubcommand.
|
||||
func (sub *Subcommand) NeedsName() {
|
||||
flag, name := ParseFlag(sub.cmdType.Name())
|
||||
|
||||
// Check for interface
|
||||
if n, ok := sub.command.(Namer); ok {
|
||||
name = n.Name()
|
||||
}
|
||||
|
||||
if !flag.Is(Raw) {
|
||||
name = strings.ToLower(name)
|
||||
}
|
||||
|
||||
sub.name = name
|
||||
sub.Flag = flag
|
||||
}
|
||||
|
||||
func (sub *Subcommand) reflectCommands() error {
|
||||
t := reflect.TypeOf(sub.command)
|
||||
v := reflect.ValueOf(sub.command)
|
||||
|
||||
if t.Kind() != reflect.Ptr {
|
||||
return errors.New("sub is not a pointer")
|
||||
}
|
||||
|
||||
// Set the pointer fields
|
||||
sub.ptrValue = v
|
||||
sub.ptrType = t
|
||||
|
||||
ts := t.Elem()
|
||||
vs := v.Elem()
|
||||
|
||||
if ts.Kind() != reflect.Struct {
|
||||
return errors.New("sub is not pointer to struct")
|
||||
}
|
||||
|
||||
// Set the struct fields
|
||||
sub.cmdValue = vs
|
||||
sub.cmdType = ts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitCommands fills a Subcommand with a context. This shouldn't be called at
|
||||
// all, rather you should use the RegisterSubcommand method of a Context.
|
||||
func (sub *Subcommand) InitCommands(ctx *Context) error {
|
||||
// Start filling up a *Context field
|
||||
for i := 0; i < sub.cmdValue.NumField(); i++ {
|
||||
field := sub.cmdValue.Field(i)
|
||||
|
||||
if !field.CanSet() || !field.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := field.Interface().(*Context); !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
field.Set(reflect.ValueOf(ctx))
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("No fields with *Command found")
|
||||
}
|
||||
|
||||
func (sub *Subcommand) parseCommands() error {
|
||||
var numMethods = sub.ptrValue.NumMethod()
|
||||
var commands = make([]*CommandContext, 0, numMethods)
|
||||
|
||||
for i := 0; i < numMethods; i++ {
|
||||
method := sub.ptrValue.Method(i)
|
||||
|
||||
if !method.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
methodT := method.Type()
|
||||
numArgs := methodT.NumIn()
|
||||
|
||||
// Doesn't meet requirement for an event
|
||||
if numArgs == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check return type
|
||||
if err := methodT.Out(0); err == nil || !err.Implements(typeIError) {
|
||||
// Invalid, skip
|
||||
continue
|
||||
}
|
||||
|
||||
var command = CommandContext{
|
||||
method: sub.ptrType.Method(i),
|
||||
value: method,
|
||||
event: methodT.In(0), // parse event
|
||||
}
|
||||
|
||||
// Parse the method name
|
||||
flag, name := ParseFlag(command.method.Name)
|
||||
|
||||
if !flag.Is(Raw) {
|
||||
name = strings.ToLower(name)
|
||||
}
|
||||
|
||||
// Set the method name and flag
|
||||
command.name = name
|
||||
command.Flag = flag
|
||||
|
||||
// TODO: allow more flexibility
|
||||
if command.event != typeMessageCreate {
|
||||
goto Done
|
||||
}
|
||||
|
||||
if numArgs == 1 {
|
||||
// done
|
||||
goto Done
|
||||
}
|
||||
|
||||
// If the second argument implements ParseContent()
|
||||
if t := methodT.In(1); t.Implements(typeIManP) {
|
||||
mt, _ := t.MethodByName("ParseContent")
|
||||
|
||||
command.parseMethod = mt
|
||||
command.parseType = t.Elem()
|
||||
|
||||
command.parseUsage = usager(t)
|
||||
if command.parseUsage == "" {
|
||||
command.parseUsage = t.String()
|
||||
}
|
||||
|
||||
goto Done
|
||||
}
|
||||
|
||||
command.arguments = make([]argumentValueFn, 0, numArgs)
|
||||
|
||||
// Fill up arguments
|
||||
for i := 1; i < numArgs; i++ {
|
||||
t := methodT.In(i)
|
||||
|
||||
avfs, err := getArgumentValueFn(t)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error parsing argument "+t.String())
|
||||
}
|
||||
|
||||
command.arguments = append(command.arguments, avfs)
|
||||
|
||||
var usage = usager(t)
|
||||
if usage == "" {
|
||||
usage = t.String()
|
||||
}
|
||||
|
||||
command.argStrings = append(command.argStrings, usage)
|
||||
}
|
||||
|
||||
Done:
|
||||
// Append
|
||||
commands = append(commands, &command)
|
||||
}
|
||||
|
||||
sub.Commands = commands
|
||||
return nil
|
||||
}
|
||||
|
||||
func usager(t reflect.Type) string {
|
||||
if !t.Implements(typeIUsager) {
|
||||
return ""
|
||||
}
|
||||
|
||||
usageFn, _ := t.MethodByName("Usage")
|
||||
v := usageFn.Func.Call([]reflect.Value{
|
||||
reflect.New(t.Elem()),
|
||||
})
|
||||
return v[0].String()
|
||||
}
|
96
bot/subcommand_test.go
Normal file
96
bot/subcommand_test.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package bot
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNewSubcommand(t *testing.T) {
|
||||
_, err := NewSubcommand(&testCommands{})
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create new subcommand:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubcommand(t *testing.T) {
|
||||
var given = &testCommands{}
|
||||
var sub = &Subcommand{
|
||||
command: given,
|
||||
}
|
||||
|
||||
t.Run("reflect commands", func(t *testing.T) {
|
||||
if err := sub.reflectCommands(); err != nil {
|
||||
t.Fatal("Failed to reflect commands:", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse commands", func(t *testing.T) {
|
||||
if err := sub.parseCommands(); err != nil {
|
||||
t.Fatal("Failed to parse commands:", err)
|
||||
}
|
||||
|
||||
// !!! CHANGE ME
|
||||
if len(sub.Commands) != 4 {
|
||||
t.Fatal("invalid ctx.commands len", len(sub.Commands))
|
||||
}
|
||||
|
||||
var (
|
||||
foundSend bool
|
||||
foundCustom bool
|
||||
foundNoArgs bool
|
||||
)
|
||||
|
||||
for _, this := range sub.Commands {
|
||||
switch this.name {
|
||||
case "send":
|
||||
foundSend = true
|
||||
if len(this.arguments) != 1 {
|
||||
t.Fatal("invalid arguments len", len(this.arguments))
|
||||
}
|
||||
|
||||
case "custom":
|
||||
foundCustom = true
|
||||
if len(this.arguments) > 0 {
|
||||
t.Fatal("arguments should be 0 for custom")
|
||||
}
|
||||
if this.parseType == nil {
|
||||
t.Fatal("custom has nil manualParse")
|
||||
}
|
||||
|
||||
case "noargs":
|
||||
foundNoArgs = true
|
||||
if len(this.arguments) != 0 {
|
||||
t.Fatal("expected 0 arguments, got non-zero")
|
||||
}
|
||||
if this.parseType != nil {
|
||||
t.Fatal("unexpected parseType")
|
||||
}
|
||||
|
||||
case "noop":
|
||||
// Found, but whatever
|
||||
|
||||
default:
|
||||
t.Fatal("Unexpected command:", this.name)
|
||||
}
|
||||
|
||||
if this.event != typeMessageCreate {
|
||||
t.Fatal("invalid event type:", this.event.String())
|
||||
}
|
||||
}
|
||||
|
||||
if !foundSend {
|
||||
t.Fatal("missing send")
|
||||
}
|
||||
|
||||
if !foundCustom {
|
||||
t.Fatal("missing custom")
|
||||
}
|
||||
|
||||
if !foundNoArgs {
|
||||
t.Fatal("missing noargs")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkSubcommandConstructor(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewSubcommand(&testCommands{})
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@ package discord
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Color uint
|
||||
type Color uint32
|
||||
|
||||
const DefaultColor Color = 0x303030
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ func (s *Snowflake) MarshalJSON() ([]byte, error) {
|
|||
switch i := int64(*s); i {
|
||||
case -1: // @me
|
||||
id = "@me"
|
||||
case 0:
|
||||
return []byte("null"), nil
|
||||
default:
|
||||
id = strconv.FormatInt(i, 10)
|
||||
}
|
||||
|
|
|
@ -31,6 +31,10 @@ func (t *Timestamp) UnmarshalJSON(v []byte) error {
|
|||
}
|
||||
|
||||
func (t Timestamp) MarshalJSON() ([]byte, error) {
|
||||
if time.Time(t).IsZero() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/handler"
|
||||
"github.com/diamondburned/arikawa/session"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -76,6 +77,29 @@ func (s *State) Unhook() {
|
|||
|
||||
////
|
||||
|
||||
func (s *State) Permissions(
|
||||
channelID, userID discord.Snowflake) (discord.Permissions, error) {
|
||||
|
||||
ch, err := s.Channel(channelID)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "Failed to get channel")
|
||||
}
|
||||
|
||||
g, err := s.Guild(ch.GuildID)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "Failed to get guild")
|
||||
}
|
||||
|
||||
m, err := s.Member(ch.GuildID, userID)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "Failed to get member")
|
||||
}
|
||||
|
||||
return discord.CalcOverwrites(*g, *ch, *m), nil
|
||||
}
|
||||
|
||||
////
|
||||
|
||||
func (s *State) Self() (*discord.User, error) {
|
||||
u, err := s.Store.Self()
|
||||
if err == nil {
|
||||
|
|
Loading…
Reference in a new issue