1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-10-19 23:54:47 +00:00

Ported rfrouter over

This commit is contained in:
diamondburned (Forefront) 2020-01-18 22:06:00 -08:00
parent 189853de32
commit 28228a60f5
19 changed files with 1993 additions and 1 deletions

View file

@ -0,0 +1,92 @@
package main
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/bot/extras/arguments"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
)
type Bot struct {
// Context must not be embedded.
Ctx *bot.Context
}
func (bot *Bot) Help(m *gateway.MessageCreateEvent) error {
_, err := bot.Ctx.SendMessage(m.ChannelID, bot.Ctx.Help(), nil)
return err
}
func (bot *Bot) Ping(m *gateway.MessageCreateEvent) error {
_, err := bot.Ctx.SendMessage(m.ChannelID, "Pong!", nil)
return err
}
func (bot *Bot) Say(m *gateway.MessageCreateEvent, f *arguments.Flag) error {
args := f.String()
if args == "" {
// Empty message, ignore
return nil
}
_, err := bot.Ctx.SendMessage(m.ChannelID, args, nil)
return err
}
func (bot *Bot) Embed(
m *gateway.MessageCreateEvent, f *arguments.Flag) error {
fs := arguments.NewFlagSet()
var (
title = fs.String("title", "", "Title")
author = fs.String("author", "", "Author")
footer = fs.String("footer", "", "Footer")
color = fs.String("color", "#FFFFFF", "Color in hex format #hhhhhh")
)
if err := f.With(fs.FlagSet); err != nil {
return err
}
if len(fs.Args()) < 1 {
return fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage())
}
// Check if the color string is valid.
if !strings.HasPrefix(*color, "#") || len(*color) != 7 {
return errors.New("Invalid color, format must be #hhhhhh")
}
// Parse the color into decimal numbers.
colorHex, err := strconv.ParseInt((*color)[1:], 16, 64)
if err != nil {
return err
}
// Make a new embed
embed := discord.Embed{
Title: *title,
Description: strings.Join(fs.Args(), " "),
Color: discord.Color(colorHex),
}
if *author != "" {
embed.Author = &discord.EmbedAuthor{
Name: *author,
}
}
if *footer != "" {
embed.Footer = &discord.EmbedFooter{
Text: *footer,
}
}
_, err = bot.Ctx.SendMessage(m.ChannelID, "", &embed)
return err
}

View file

@ -0,0 +1,35 @@
package main
import (
"log"
"os"
"github.com/diamondburned/arikawa/bot"
)
// To run, do `BOT_TOKEN="TOKEN HERE" go run .`
func main() {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
log.Fatalln("No $BOT_TOKEN given.")
}
commands := &Bot{}
stop, err := bot.Start(token, commands, func(ctx *bot.Context) error {
ctx.Prefix = "!"
return nil
})
if err != nil {
log.Fatalln(err)
}
defer stop()
log.Println("Bot started")
// Automatically block until SIGINT.
bot.Wait()
}

122
bot/arguments.go Normal file
View file

@ -0,0 +1,122 @@
package bot
import (
"errors"
"reflect"
"strconv"
)
type argumentValueFn func(string) (reflect.Value, error)
// Parseable implements a Parse(string) method for data structures that can be
// used as arguments.
type Parseable interface {
Parse(string) error
}
// ManaulParseable implements a ParseContent(string) method. If the library sees
// this for an argument, it will send all of the arguments (including the
// command) into the method. If used, this should be the only argument followed
// after the Message Create event. Any more and the router will ignore.
type ManualParseable interface {
// $0 will have its prefix trimmed.
ParseContent([]string) error
}
type RawArguments struct {
Arguments []string
}
func (r *RawArguments) ParseContent(args []string) error {
r.Arguments = args
return nil
}
// nilV, only used to return an error
var nilV = reflect.Value{}
func getArgumentValueFn(t reflect.Type) (argumentValueFn, error) {
if t.Implements(typeIParser) {
mt, ok := t.MethodByName("Parse")
if !ok {
panic("BUG: type IParser does not implement Parse")
}
return func(input string) (reflect.Value, error) {
v := reflect.New(t.Elem())
ret := mt.Func.Call([]reflect.Value{
v, reflect.ValueOf(input),
})
if err := errorReturns(ret); err != nil {
return nilV, err
}
return v, nil
}, nil
}
var fn argumentValueFn
switch t.Kind() {
case reflect.String:
fn = func(s string) (reflect.Value, error) {
return reflect.ValueOf(s), nil
}
case reflect.Int, reflect.Int8,
reflect.Int16, reflect.Int32, reflect.Int64:
fn = func(s string) (reflect.Value, error) {
i, err := strconv.ParseInt(s, 10, 64)
return quickRet(i, err, t)
}
case reflect.Uint, reflect.Uint8,
reflect.Uint16, reflect.Uint32, reflect.Uint64:
fn = func(s string) (reflect.Value, error) {
u, err := strconv.ParseUint(s, 10, 64)
return quickRet(u, err, t)
}
case reflect.Float32, reflect.Float64:
fn = func(s string) (reflect.Value, error) {
f, err := strconv.ParseFloat(s, 64)
return quickRet(f, err, t)
}
case reflect.Bool:
fn = func(s string) (reflect.Value, error) {
switch s {
case "true", "yes", "y", "Y", "1":
return reflect.ValueOf(true), nil
case "false", "no", "n", "N", "0":
return reflect.ValueOf(false), nil
default:
return nilV, errors.New("invalid bool [true/false]")
}
}
}
if fn == nil {
return nil, errors.New("invalid type: " + t.String())
}
return fn, nil
}
func quickRet(v interface{}, err error, t reflect.Type) (reflect.Value, error) {
if err != nil {
return nilV, err
}
rv := reflect.ValueOf(v)
if t == nil {
return rv, nil
}
return rv.Convert(t), nil
}

114
bot/copied_from_d.go Normal file
View file

@ -0,0 +1,114 @@
package bot
/*
// UserPermissions but userID is after channelID.
func (ctx *Context) UserPermissions(channelID, userID string,
) (apermissions int, err error) {
// Try to just get permissions from state.
apermissions, err = ctx.Session.State.UserChannelPermissions(
userID, channelID)
if err == nil {
return
}
// Otherwise try get as much data from state as possible, falling back to the network.
channel, err := ctx.Channel(channelID)
if err != nil {
return
}
guild, err := ctx.Guild(channel.GuildID)
if err != nil {
return
}
if userID == guild.OwnerID {
apermissions = discordgo.PermissionAll
return
}
member, err := ctx.Member(guild.ID, userID)
if err != nil {
return
}
return MemberPermissions(guild, channel, member), nil
}
// Why this isn't exported, I have no idea.
func MemberPermissions(guild *discordgo.Guild, channel *discordgo.Channel,
member *discordgo.Member) (apermissions int) {
userID := member.User.ID
if userID == guild.OwnerID {
apermissions = discordgo.PermissionAll
return
}
for _, role := range guild.Roles {
if role.ID == guild.ID {
apermissions |= role.Permissions
break
}
}
for _, role := range guild.Roles {
for _, roleID := range member.Roles {
if role.ID == roleID {
apermissions |= role.Permissions
break
}
}
}
if apermissions&discordgo.PermissionAdministrator ==
discordgo.PermissionAdministrator {
apermissions |= discordgo.PermissionAll
}
// Apply @everyone overrides from the channel.
for _, overwrite := range channel.PermissionOverwrites {
if guild.ID == overwrite.ID {
apermissions &= ^overwrite.Deny
apermissions |= overwrite.Allow
break
}
}
denies := 0
allows := 0
// Member overwrites can override role overrides, so do two passes
for _, overwrite := range channel.PermissionOverwrites {
for _, roleID := range member.Roles {
if overwrite.Type == "role" && roleID == overwrite.ID {
denies |= overwrite.Deny
allows |= overwrite.Allow
break
}
}
}
apermissions &= ^denies
apermissions |= allows
for _, overwrite := range channel.PermissionOverwrites {
if overwrite.Type == "member" && overwrite.ID == userID {
apermissions &= ^overwrite.Deny
apermissions |= overwrite.Allow
break
}
}
if apermissions&discordgo.PermissionAdministrator ==
discordgo.PermissionAdministrator {
apermissions |= discordgo.PermissionAllChannel
}
return apermissions
}
*/

281
bot/ctx.go Normal file
View file

@ -0,0 +1,281 @@
package bot
import (
"log"
"os"
"os/signal"
"strings"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/state"
"github.com/pkg/errors"
)
// TODO: add variadic arguments
type Context struct {
*Subcommand
*state.State
// Descriptive (but optional) bot name
Name string
// Descriptive help body
Description string
// The prefix for commands
Prefix string
// FormatError formats any errors returned by anything, including the method
// commands or the reflect functions. This also includes invalid usage
// errors or unknown command errors. Returning an empty string means
// ignoring the error.
FormatError func(error) string
// ErrorLogger logs any error that anything makes and the library can't
// reply to the client. This includes any event callback errors that aren't
// Message Create.
ErrorLogger func(error)
// ReplyError when true replies to the user the error.
ReplyError bool
// Subcommands contains all the registered subcommands.
Subcommands []*Subcommand
}
// Start quickly starts a bot with the given command. It will prepend "Bot"
// into the token automatically. Refer to example/ for usage.
func Start(token string, cmd interface{},
opts func(*Context) error) (stop func() error, err error) {
s, err := state.New("Bot " + token)
if err != nil {
return nil, errors.Wrap(err, "Failed to create a dgo session")
}
c, err := New(s, cmd)
if err != nil {
return nil, errors.Wrap(err, "Failed to create rfrouter")
}
s.ErrorLog = func(err error) {
c.ErrorLogger(err)
}
if opts != nil {
if err := opts(c); err != nil {
return nil, err
}
}
cancel := c.Start()
if err := s.Open(); err != nil {
return nil, errors.Wrap(err, "Failed to connect to Discord")
}
return func() error {
cancel()
return s.Close()
}, nil
}
// Wait is a convenient function that blocks until a SIGINT is sent.
func Wait() {
sigs := make(chan os.Signal)
signal.Notify(sigs, os.Interrupt)
<-sigs
}
// New makes a new context with a "~" as the prefix. cmds must be a pointer to a
// struct with a *Context field. Example:
//
// type Commands struct {
// Ctx *Context
// }
//
// cmds := &Commands{}
// c, err := rfrouter.New(session, cmds)
//
// Commands' exported methods will all be used as commands. Messages are parsed
// with its first argument (the command) mapped accordingly to c.MapName, which
// capitalizes the first letter automatically to reflect the exported method
// name.
//
// The default prefix is "~", which means commands must start with "~" followed
// by the command name in the first argument, else it will be ignored.
//
// c.Start() should be called afterwards to actually handle incoming events.
func New(s *state.State, cmd interface{}) (*Context, error) {
c, err := NewSubcommand(cmd)
if err != nil {
return nil, err
}
ctx := &Context{
Subcommand: c,
State: s,
Prefix: "~",
FormatError: func(err error) string {
return err.Error()
},
ErrorLogger: func(err error) {
log.Println("Bot error:", err)
},
ReplyError: true,
}
if err := ctx.InitCommands(ctx); err != nil {
return nil, errors.Wrap(err, "Failed to initialize with given cmds")
}
return ctx, nil
}
func (ctx *Context) RegisterSubcommand(cmd interface{}) (*Subcommand, error) {
s, err := NewSubcommand(cmd)
if err != nil {
return nil, errors.Wrap(err, "Failed to add subcommand")
}
// Register the subcommand's name.
s.NeedsName()
if err := s.InitCommands(ctx); err != nil {
return nil, errors.Wrap(err, "Failed to initialize subcommand")
}
// Do a collision check
for _, sub := range ctx.Subcommands {
if sub.name == s.name {
return nil, errors.New(
"New subcommand has duplicate name: " + s.name)
}
}
ctx.Subcommands = append(ctx.Subcommands, s)
return s, nil
}
// Start adds itself into the discordgo Session handlers. This needs to be run.
// The returned function is a delete function, which removes itself from the
// Session handlers.
func (ctx *Context) Start() func() {
return ctx.Session.AddHandler(func(v interface{}) {
if err := ctx.callCmd(v); err != nil {
if str := ctx.FormatError(err); str != "" {
// Log the main error first
ctx.ErrorLogger(errors.Wrap(err, str))
mc, ok := v.(*gateway.MessageCreateEvent)
if !ok {
return
}
if ctx.ReplyError {
_, Merr := ctx.SendMessage(mc.ChannelID, str, nil)
if Merr != nil {
// Then the message error
ctx.ErrorLogger(Merr)
// TODO: there ought to be a better way lol
}
}
}
}
})
}
// Call should only be used if you know what you're doing.
func (ctx *Context) Call(event interface{}) error {
return ctx.callCmd(event)
}
// Help generates one. This function is used more for reference than an actual
// help message. As such, it only uses exported fields or methods.
func (ctx *Context) Help() string {
var help strings.Builder
// Generate the headers and descriptions
help.WriteString("__Help__")
if ctx.Name != "" {
help.WriteString(": " + ctx.Name)
}
if ctx.Description != "" {
help.WriteString("\n " + ctx.Description)
}
if ctx.Flag.Is(AdminOnly) {
// That's it.
return help.String()
}
// Separators
help.WriteString("\n---\n")
// Generate all commands
help.WriteString("__Commands__\n")
for _, cmd := range ctx.Commands {
if cmd.Flag.Is(AdminOnly) {
// Hidden
continue
}
help.WriteString(" " + ctx.Prefix + cmd.Name())
switch {
case len(cmd.Usage()) > 0:
help.WriteString(" " + strings.Join(cmd.Usage(), " "))
case cmd.Description != "":
help.WriteString(": " + cmd.Description)
}
help.WriteByte('\n')
}
var subHelp = strings.Builder{}
for _, sub := range ctx.Subcommands {
if sub.Flag.Is(AdminOnly) {
// Hidden
continue
}
subHelp.WriteString(" " + sub.Name())
if sub.Description != "" {
subHelp.WriteString(": " + sub.Description)
}
subHelp.WriteByte('\n')
for _, cmd := range sub.Commands {
if cmd.Flag.Is(AdminOnly) {
continue
}
subHelp.WriteString(" " +
ctx.Prefix + sub.Name() + " " + cmd.Name())
switch {
case len(cmd.Usage()) > 0:
subHelp.WriteString(" " + strings.Join(cmd.Usage(), " "))
case cmd.Description != "":
subHelp.WriteString(": " + cmd.Description)
}
subHelp.WriteByte('\n')
}
}
if sub := subHelp.String(); sub != "" {
help.WriteString("---\n")
help.WriteString("__Subcommands__\n")
help.WriteString(sub)
}
return help.String()
}

312
bot/ctx_call.go Normal file
View file

@ -0,0 +1,312 @@
package bot
import (
"encoding/csv"
"reflect"
"strings"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
)
func (ctx *Context) callCmd(ev interface{}) error {
evT := reflect.TypeOf(ev)
if evT != typeMessageCreate {
var callers []reflect.Value
var isAdmin *bool // i want to die
for _, cmd := range ctx.Commands {
if cmd.event == evT {
if cmd.Flag.Is(AdminOnly) &&
!ctx.eventIsAdmin(ev, &isAdmin) {
continue
}
callers = append(callers, cmd.value)
}
}
for _, sub := range ctx.Subcommands {
if sub.Flag.Is(AdminOnly) &&
!ctx.eventIsAdmin(ev, &isAdmin) {
continue
}
for _, cmd := range sub.Commands {
if cmd.event == evT {
if cmd.Flag.Is(AdminOnly) &&
!ctx.eventIsAdmin(ev, &isAdmin) {
continue
}
callers = append(callers, cmd.value)
}
}
}
for _, c := range callers {
if err := callWith(c, ev); err != nil {
ctx.ErrorLogger(err)
}
}
return nil
}
// safe assertion always
mc := ev.(*gateway.MessageCreateEvent)
// check if prefix
if !strings.HasPrefix(mc.Content, ctx.Prefix) {
// not a command, ignore
return nil
}
// trim the prefix before splitting, this way multi-words prefices work
content := mc.Content[len(ctx.Prefix):]
if content == "" {
return nil // just the prefix only
}
// parse arguments
args, err := ParseArgs(content)
if err != nil {
return err
}
if len(args) < 1 {
return nil // ???
}
var cmd *CommandContext
var start int // arg starts from $start
// Search for the command
for _, c := range ctx.Commands {
if c.name == args[0] {
cmd = c
start = 1
break
}
}
// Can't find command, look for subcommands of len(args) has a 2nd
// entry.
if cmd == nil && len(args) > 1 {
for _, s := range ctx.Subcommands {
if s.name != args[0] {
continue
}
for _, c := range s.Commands {
if c.name == args[1] {
cmd = c
start = 2
break
}
}
if cmd == nil {
return &ErrUnknownCommand{
Command: args[1],
Parent: args[0],
Prefix: ctx.Prefix,
ctx: s.Commands,
}
}
}
}
if cmd == nil || start == 0 {
return &ErrUnknownCommand{
Command: args[0],
Prefix: ctx.Prefix,
ctx: ctx.Commands,
}
}
// Start converting
var argv []reflect.Value
// Check manual parser
if cmd.parseType != nil {
// Create a zero value instance of this
v := reflect.New(cmd.parseType)
// Call the manual parse method
ret := cmd.parseMethod.Func.Call([]reflect.Value{
v, reflect.ValueOf(args),
})
// Check the method returns for error
if err := errorReturns(ret); err != nil {
// TODO: maybe wrap this?
return err
}
// Add the pointer to the argument into argv
argv = append(argv, v)
goto Call
}
// Here's an edge case: when the handler takes no arguments, we allow that
// anyway, as they might've used the raw content.
if len(cmd.arguments) == 0 {
goto Call
}
// Not enough arguments given
if len(args[start:]) != len(cmd.arguments) {
return &ErrInvalidUsage{
Args: args,
Prefix: ctx.Prefix,
Index: len(cmd.arguments) - start,
Err: "Not enough arguments given",
ctx: cmd,
}
}
argv = make([]reflect.Value, len(cmd.arguments))
for i := start; i < len(args); i++ {
v, err := cmd.arguments[i-start](args[i])
if err != nil {
return &ErrInvalidUsage{
Args: args,
Prefix: ctx.Prefix,
Index: i,
Err: err.Error(),
ctx: cmd,
}
}
argv[i-start] = v
}
Call:
// call the function and parse the error return value
return callWith(cmd.value, ev, argv...)
}
func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
if *is != nil {
return **is
}
var channelID = reflectChannelID(ev)
if !channelID.Valid() {
return false
}
var userID = reflectUserID(ev)
if !userID.Valid() {
return false
}
var res bool
p, err := ctx.State.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
res = true
}
*is = &res
return res
}
func callWith(caller reflect.Value, ev interface{}, values ...reflect.Value) error {
return errorReturns(caller.Call(append(
[]reflect.Value{reflect.ValueOf(ev)},
values...,
)))
}
var ParseArgs = func(args string) ([]string, error) {
// TODO: make modular
// TODO: actual tokenizer+parser
r := csv.NewReader(strings.NewReader(args))
r.Comma = ' '
return r.Read()
}
func errorReturns(returns []reflect.Value) error {
// assume first is always error, since we checked for this in parseCommands
v := returns[0].Interface()
if v == nil {
return nil
}
return v.(error)
}
func reflectChannelID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "Channel")
}
func reflectGuildID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "Guild")
}
func reflectUserID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "User")
}
func _reflectID(v reflect.Value, thing string) discord.Snowflake {
if !v.IsValid() {
return 0
}
t := v.Type()
if t.Kind() == reflect.Ptr {
v = v.Elem()
// Recheck after dereferring
if !v.IsValid() {
return 0
}
t = v.Type()
}
if t.Kind() != reflect.Struct {
return 0
}
numFields := t.NumField()
for i := 0; i < numFields; i++ {
field := t.Field(i)
fType := field.Type
if fType.Kind() == reflect.Ptr {
fType = fType.Elem()
}
switch fType.Kind() {
case reflect.Struct:
if chID := _reflectID(v.Field(i), thing); chID.Valid() {
return chID
}
case reflect.Int64:
if field.Name == thing+"ID" {
// grab value real quick
return discord.Snowflake(v.Field(i).Int())
}
// Special case where the struct name has Channel in it
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
return discord.Snowflake(v.Field(i).Int())
}
}
}
return 0
}

312
bot/ctx_test.go Normal file
View file

@ -0,0 +1,312 @@
package bot
import (
"reflect"
"strings"
"testing"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/state"
"github.com/pkg/errors"
)
type testCommands struct {
Ctx *Context
Return chan interface{}
}
func (t *testCommands) Send(_ *gateway.MessageCreateEvent, arg string) error {
t.Return <- arg
return errors.New("oh no")
}
func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *CustomParseable) error {
t.Return <- c.args
return nil
}
func (t *testCommands) NoArgs(_ *gateway.MessageCreateEvent) error {
return errors.New("passed")
}
func (t *testCommands) Noop(_ *gateway.MessageCreateEvent) error {
return nil
}
type CustomParseable struct {
args []string
}
func (c *CustomParseable) ParseContent(args []string) error {
c.args = args
return nil
}
func TestNewContext(t *testing.T) {
var state = &state.State{
Store: state.NewDefaultStore(nil),
}
_, err := New(state, &testCommands{})
if err != nil {
t.Fatal("Failed to create new context:", err)
}
}
func TestContext(t *testing.T) {
var given = &testCommands{}
var state = &state.State{
Store: state.NewDefaultStore(nil),
}
s, err := NewSubcommand(given)
if err != nil {
t.Fatal("Failed to create subcommand:", err)
}
var ctx = &Context{
Subcommand: s,
State: state,
}
t.Run("init commands", func(t *testing.T) {
if err := ctx.Subcommand.InitCommands(ctx); err != nil {
t.Fatal("Failed to init commands:", err)
}
if given.Ctx == nil {
t.Fatal("given's Context field is nil")
}
if given.Ctx.State.Store == nil {
t.Fatal("given's State is nil")
}
})
testReturn := func(expects interface{}, content string) (call error) {
// Return channel for testing
ret := make(chan interface{})
given.Return = ret
// Mock a messageCreate event
m := &gateway.MessageCreateEvent{
Content: content,
}
var (
callCh = make(chan error)
)
go func() {
callCh <- ctx.callCmd(m)
}()
select {
case arg := <-ret:
if !reflect.DeepEqual(arg, expects) {
t.Fatal("returned argument is invalid:", arg)
}
call = <-callCh
case call = <-callCh:
t.Fatal("expected return before error:", call)
}
return
}
t.Run("call command", func(t *testing.T) {
// Set a custom prefix
ctx.Prefix = "~"
if err := testReturn("test", "~send test"); err.Error() != "oh no" {
t.Fatal("unexpected error:", err)
}
})
t.Run("call command custom parser", func(t *testing.T) {
ctx.Prefix = "!"
expects := []string{"custom", "arg1", ":)"}
if err := testReturn(expects, "!custom arg1 :)"); err != nil {
t.Fatal("Unexpected call error:", err)
}
})
testMessage := func(content string) error {
// Mock a messageCreate event
m := &gateway.MessageCreateEvent{
Content: content,
}
return ctx.callCmd(m)
}
t.Run("call command without args", func(t *testing.T) {
ctx.Prefix = ""
if err := testMessage("noargs"); err.Error() != "passed" {
t.Fatal("unexpected error:", err)
}
})
// Test error cases
t.Run("call unknown command", func(t *testing.T) {
ctx.Prefix = "joe pls "
err := testMessage("joe pls no")
if err == nil || !strings.HasPrefix(err.Error(), "Unknown command:") {
t.Fatal("unexpected error:", err)
}
})
// Test subcommands
t.Run("register subcommand", func(t *testing.T) {
ctx.Prefix = "run "
_, err := ctx.RegisterSubcommand(&testCommands{})
if err != nil {
t.Fatal("Failed to register subcommand:", err)
}
if err := testMessage("run testcommands noop"); err != nil {
t.Fatal("unexpected error:", err)
}
})
}
func BenchmarkConstructor(b *testing.B) {
var state = &state.State{
Store: state.NewDefaultStore(nil),
}
for i := 0; i < b.N; i++ {
_, _ = New(state, &testCommands{})
}
}
func BenchmarkCall(b *testing.B) {
var given = &testCommands{}
var state = &state.State{
Store: state.NewDefaultStore(nil),
}
s, _ := NewSubcommand(given)
var ctx = &Context{
Subcommand: s,
State: state,
Prefix: "~",
}
m := &gateway.MessageCreateEvent{
Content: "~noop",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ctx.callCmd(m)
}
}
func BenchmarkHelp(b *testing.B) {
var given = &testCommands{}
var state = &state.State{
Store: state.NewDefaultStore(nil),
}
s, _ := NewSubcommand(given)
var ctx = &Context{
Subcommand: s,
State: state,
Prefix: "~",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ctx.Help()
}
}
type hasID struct {
ChannelID discord.Snowflake
}
type embedsID struct {
*hasID
*embedsID
}
type hasChannelInName struct {
ID discord.Snowflake
}
func TestReflectChannelID(t *testing.T) {
var s = &hasID{
ChannelID: 69420,
}
t.Run("hasID", func(t *testing.T) {
if id := reflectChannelID(s); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
t.Run("embedsID", func(t *testing.T) {
var e = &embedsID{
hasID: s,
}
if id := reflectChannelID(e); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
t.Run("hasChannelInName", func(t *testing.T) {
var s = &hasChannelInName{
ID: 69420,
}
if id := reflectChannelID(s); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
}
func BenchmarkReflectChannelID_1Level(b *testing.B) {
var s = &hasID{
ChannelID: 69420,
}
for i := 0; i < b.N; i++ {
_ = reflectChannelID(s)
}
}
func BenchmarkReflectChannelID_5Level(b *testing.B) {
var s = &embedsID{
nil,
&embedsID{
nil,
&embedsID{
nil,
&embedsID{
hasID: &hasID{
ChannelID: 69420,
},
},
},
},
}
for i := 0; i < b.N; i++ {
_ = reflectChannelID(s)
}
}

66
bot/error.go Normal file
View file

@ -0,0 +1,66 @@
package bot
import (
"strings"
)
type ErrUnknownCommand struct {
Command string
Parent string
Prefix string
// TODO: list available commands?
// Here, as a reminder
ctx []*CommandContext
}
func (err *ErrUnknownCommand) Error() string {
var header = "Unknown command: " + err.Prefix
if err.Parent != "" {
header += err.Parent + " " + err.Command
} else {
header += err.Command
}
return header
}
type ErrInvalidUsage struct {
Args []string
Prefix string
Index int
Err string
// TODO: usage generator?
// Here, as a reminder
ctx *CommandContext
}
func (err *ErrInvalidUsage) Error() string {
if err.Index == 0 {
return "Invalid usage"
}
if len(err.Args) == 0 {
return "Missing arguments. Refer to help."
}
body := "Invalid usage at " + err.Prefix
// Write the first part
body += strings.Join(err.Args[:err.Index], " ")
// Write the wrong part
body += " __" + err.Args[err.Index] + "__ "
// Write the last part
body += strings.Join(err.Args[err.Index+1:], " ")
if err.Err != "" {
body += "\nError: " + err.Err
}
return body
}

View file

@ -0,0 +1,51 @@
package arguments
import (
"errors"
"regexp"
)
var (
EmojiRegex = regexp.MustCompile(`<(a?):(.+?):(\d+)>`)
ErrInvalidEmoji = errors.New("Invalid emoji")
)
type Emoji struct {
ID string
Custom bool
Name string
Animated bool
}
func (e *Emoji) Parse(arg string) error {
// Check if Unicode
var unicode string
for _, r := range arg {
if r < '\U0001F600' && r > '\U0001F64F' {
unicode += string(r)
}
}
if unicode != "" {
e.ID = unicode
e.Custom = false
return nil
}
var matches = EmojiRegex.FindStringSubmatch(arg)
if len(matches) != 4 {
return ErrInvalidEmoji
}
e.Custom = true
e.Animated = matches[1] == "a"
e.Name = matches[2]
e.ID = matches[3]
return nil
}

View file

@ -0,0 +1,65 @@
package arguments
import (
"bytes"
"flag"
"io/ioutil"
"strings"
)
var FlagName = "command"
type FlagSet struct {
*flag.FlagSet
}
func NewFlagSet() *FlagSet {
fs := flag.NewFlagSet(FlagName, flag.ContinueOnError)
fs.SetOutput(ioutil.Discard)
return &FlagSet{fs}
}
func (fs *FlagSet) Usage() string {
var buf bytes.Buffer
fs.FlagSet.SetOutput(&buf)
fs.FlagSet.Usage()
fs.FlagSet.SetOutput(ioutil.Discard)
return buf.String()
}
type Flag struct {
arguments []string
}
func (f *Flag) ParseContent(arguments []string) error {
// trim the command out
f.arguments = arguments[1:]
return nil
}
func (f *Flag) Usage() string {
return "flags..."
}
func (f *Flag) Args() []string {
return f.arguments
}
func (f *Flag) Arg(n int) string {
if n < 0 || n >= len(f.arguments) {
return ""
}
return f.arguments[n]
}
func (f *Flag) String() string {
return strings.Join(f.arguments, " ")
}
func (f *Flag) With(fs *flag.FlagSet) error {
return fs.Parse(f.arguments)
}

View file

@ -0,0 +1,52 @@
package arguments
import (
"errors"
"regexp"
)
var (
ChannelRegex = regexp.MustCompile(`<#(\d+)>`)
UserRegex = regexp.MustCompile(`<@!?(\d+)>`)
RoleRegex = regexp.MustCompile(`<@&(\d+)>`)
)
type ChannelMention string
func (m *ChannelMention) Parse(arg string) error {
return grabFirst(ChannelRegex, "channel mention", arg, (*string)(m))
}
func (m *ChannelMention) Usage() string {
return "#channel"
}
type UserMention string
func (m *UserMention) Parse(arg string) error {
return grabFirst(UserRegex, "user mention", arg, (*string)(m))
}
func (m *UserMention) Usage() string {
return "@user"
}
type RoleMention string
func (m *RoleMention) Parse(arg string) error {
return grabFirst(RoleRegex, "role mention", arg, (*string)(m))
}
func (m *RoleMention) Usage() string {
return "@role"
}
func grabFirst(reg *regexp.Regexp, item, input string, output *string) error {
matches := reg.FindStringSubmatch(input)
if len(matches) < 2 {
return errors.New("Invalid " + item)
}
*output = matches[1]
return nil
}

40
bot/nameflag.go Normal file
View file

@ -0,0 +1,40 @@
package bot
import "strings"
type NameFlag uint64
const FlagSeparator = 'ー'
const (
None NameFlag = 1 << iota
// These flags only apply to messageCreate events.
Raw // R
AdminOnly // A
)
func ParseFlag(name string) (NameFlag, string) {
parts := strings.SplitN(name, string(FlagSeparator), 2)
if len(parts) != 2 {
return 0, name
}
var f NameFlag
for _, r := range parts[0] {
switch r {
case 'R':
f |= Raw
case 'A':
f |= AdminOnly
}
}
return f, parts[1]
}
func (f NameFlag) Is(flag NameFlag) bool {
return f&flag != 0
}

26
bot/nameflag_test.go Normal file
View file

@ -0,0 +1,26 @@
package bot
import "testing"
func TestNameFlag(t *testing.T) {
type entry struct {
Name string
Expect NameFlag
String string
}
var entries = []entry{{
Name: "AーEcho",
Expect: AdminOnly,
}, {
Name: "RAーGC",
Expect: Raw | AdminOnly,
}}
for _, entry := range entries {
var f, _ = ParseFlag(entry.Name)
if !f.Is(entry.Expect) {
t.Fatalf("unexpected expectation for %s: %v", entry.Name, f)
}
}
}

298
bot/subcommand.go Normal file
View file

@ -0,0 +1,298 @@
package bot
import (
"reflect"
"strings"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
)
var (
typeMessageCreate = reflect.TypeOf((*gateway.MessageCreateEvent)(nil))
// typeof.Implements(typeI*)
typeIError = reflect.TypeOf((*error)(nil)).Elem()
typeIManP = reflect.TypeOf((*ManualParseable)(nil)).Elem()
typeIParser = reflect.TypeOf((*Parseable)(nil)).Elem()
typeIUsager = reflect.TypeOf((*Usager)(nil)).Elem()
)
type Subcommand struct {
Description string
// Commands contains all the registered command contexts.
Commands []*CommandContext
// struct name
name string
// struct flags
Flag NameFlag
// Directly to struct
cmdValue reflect.Value
cmdType reflect.Type
// Pointer value
ptrValue reflect.Value
ptrType reflect.Type
// command interface as reference
command interface{}
}
// CommandContext is an internal struct containing fields to make this library
// work. As such, they're all unexported. Description, however, is exported for
// editing, and may be used to generate more informative help messages.
type CommandContext struct {
Description string
Flag NameFlag
name string // all lower-case
value reflect.Value // Func
event reflect.Type // gateway.*Event
method reflect.Method
// equal slices
argStrings []string
arguments []argumentValueFn
// only for ParseContent interface
parseMethod reflect.Method
parseType reflect.Type
parseUsage string
}
// Descriptor is optionally used to set the Description of a command context.
type Descriptor interface {
Description() string
}
// Namer is optionally used to override the command context's name.
type Namer interface {
Name() string
}
// Usager is optionally used to override the generated usage for either an
// argument, or multiple (using ManualParseable).
type Usager interface {
Usage() string
}
func (cctx *CommandContext) Name() string {
return cctx.name
}
func (cctx *CommandContext) Usage() []string {
if cctx.parseType != nil {
return []string{cctx.parseUsage}
}
if len(cctx.arguments) == 0 {
return nil
}
return cctx.argStrings
}
func NewSubcommand(cmd interface{}) (*Subcommand, error) {
var sub = Subcommand{
command: cmd,
}
// Set description
if d, ok := cmd.(Descriptor); ok {
sub.Description = d.Description()
}
if err := sub.reflectCommands(); err != nil {
return nil, errors.Wrap(err, "Failed to reflect commands")
}
if err := sub.parseCommands(); err != nil {
return nil, errors.Wrap(err, "Failed to parse commands")
}
return &sub, nil
}
// Name returns the command name in lower case. This only returns non-zero for
// subcommands.
func (sub *Subcommand) Name() string {
return sub.name
}
// NeedsName sets the name for this subcommand. Like InitCommands, this
// shouldn't be called at all, rather you should use RegisterSubcommand.
func (sub *Subcommand) NeedsName() {
flag, name := ParseFlag(sub.cmdType.Name())
// Check for interface
if n, ok := sub.command.(Namer); ok {
name = n.Name()
}
if !flag.Is(Raw) {
name = strings.ToLower(name)
}
sub.name = name
sub.Flag = flag
}
func (sub *Subcommand) reflectCommands() error {
t := reflect.TypeOf(sub.command)
v := reflect.ValueOf(sub.command)
if t.Kind() != reflect.Ptr {
return errors.New("sub is not a pointer")
}
// Set the pointer fields
sub.ptrValue = v
sub.ptrType = t
ts := t.Elem()
vs := v.Elem()
if ts.Kind() != reflect.Struct {
return errors.New("sub is not pointer to struct")
}
// Set the struct fields
sub.cmdValue = vs
sub.cmdType = ts
return nil
}
// InitCommands fills a Subcommand with a context. This shouldn't be called at
// all, rather you should use the RegisterSubcommand method of a Context.
func (sub *Subcommand) InitCommands(ctx *Context) error {
// Start filling up a *Context field
for i := 0; i < sub.cmdValue.NumField(); i++ {
field := sub.cmdValue.Field(i)
if !field.CanSet() || !field.CanInterface() {
continue
}
if _, ok := field.Interface().(*Context); !ok {
continue
}
field.Set(reflect.ValueOf(ctx))
return nil
}
return errors.New("No fields with *Command found")
}
func (sub *Subcommand) parseCommands() error {
var numMethods = sub.ptrValue.NumMethod()
var commands = make([]*CommandContext, 0, numMethods)
for i := 0; i < numMethods; i++ {
method := sub.ptrValue.Method(i)
if !method.CanInterface() {
continue
}
methodT := method.Type()
numArgs := methodT.NumIn()
// Doesn't meet requirement for an event
if numArgs == 0 {
continue
}
// Check return type
if err := methodT.Out(0); err == nil || !err.Implements(typeIError) {
// Invalid, skip
continue
}
var command = CommandContext{
method: sub.ptrType.Method(i),
value: method,
event: methodT.In(0), // parse event
}
// Parse the method name
flag, name := ParseFlag(command.method.Name)
if !flag.Is(Raw) {
name = strings.ToLower(name)
}
// Set the method name and flag
command.name = name
command.Flag = flag
// TODO: allow more flexibility
if command.event != typeMessageCreate {
goto Done
}
if numArgs == 1 {
// done
goto Done
}
// If the second argument implements ParseContent()
if t := methodT.In(1); t.Implements(typeIManP) {
mt, _ := t.MethodByName("ParseContent")
command.parseMethod = mt
command.parseType = t.Elem()
command.parseUsage = usager(t)
if command.parseUsage == "" {
command.parseUsage = t.String()
}
goto Done
}
command.arguments = make([]argumentValueFn, 0, numArgs)
// Fill up arguments
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
avfs, err := getArgumentValueFn(t)
if err != nil {
return errors.Wrap(err, "Error parsing argument "+t.String())
}
command.arguments = append(command.arguments, avfs)
var usage = usager(t)
if usage == "" {
usage = t.String()
}
command.argStrings = append(command.argStrings, usage)
}
Done:
// Append
commands = append(commands, &command)
}
sub.Commands = commands
return nil
}
func usager(t reflect.Type) string {
if !t.Implements(typeIUsager) {
return ""
}
usageFn, _ := t.MethodByName("Usage")
v := usageFn.Func.Call([]reflect.Value{
reflect.New(t.Elem()),
})
return v[0].String()
}

96
bot/subcommand_test.go Normal file
View file

@ -0,0 +1,96 @@
package bot
import "testing"
func TestNewSubcommand(t *testing.T) {
_, err := NewSubcommand(&testCommands{})
if err != nil {
t.Fatal("Failed to create new subcommand:", err)
}
}
func TestSubcommand(t *testing.T) {
var given = &testCommands{}
var sub = &Subcommand{
command: given,
}
t.Run("reflect commands", func(t *testing.T) {
if err := sub.reflectCommands(); err != nil {
t.Fatal("Failed to reflect commands:", err)
}
})
t.Run("parse commands", func(t *testing.T) {
if err := sub.parseCommands(); err != nil {
t.Fatal("Failed to parse commands:", err)
}
// !!! CHANGE ME
if len(sub.Commands) != 4 {
t.Fatal("invalid ctx.commands len", len(sub.Commands))
}
var (
foundSend bool
foundCustom bool
foundNoArgs bool
)
for _, this := range sub.Commands {
switch this.name {
case "send":
foundSend = true
if len(this.arguments) != 1 {
t.Fatal("invalid arguments len", len(this.arguments))
}
case "custom":
foundCustom = true
if len(this.arguments) > 0 {
t.Fatal("arguments should be 0 for custom")
}
if this.parseType == nil {
t.Fatal("custom has nil manualParse")
}
case "noargs":
foundNoArgs = true
if len(this.arguments) != 0 {
t.Fatal("expected 0 arguments, got non-zero")
}
if this.parseType != nil {
t.Fatal("unexpected parseType")
}
case "noop":
// Found, but whatever
default:
t.Fatal("Unexpected command:", this.name)
}
if this.event != typeMessageCreate {
t.Fatal("invalid event type:", this.event.String())
}
}
if !foundSend {
t.Fatal("missing send")
}
if !foundCustom {
t.Fatal("missing custom")
}
if !foundNoArgs {
t.Fatal("missing noargs")
}
})
}
func BenchmarkSubcommandConstructor(b *testing.B) {
for i := 0; i < b.N; i++ {
NewSubcommand(&testCommands{})
}
}

View file

@ -2,7 +2,7 @@ package discord
import "fmt"
type Color uint
type Color uint32
const DefaultColor Color = 0x303030

View file

@ -37,6 +37,8 @@ func (s *Snowflake) MarshalJSON() ([]byte, error) {
switch i := int64(*s); i {
case -1: // @me
id = "@me"
case 0:
return []byte("null"), nil
default:
id = strconv.FormatInt(i, 10)
}

View file

@ -31,6 +31,10 @@ func (t *Timestamp) UnmarshalJSON(v []byte) error {
}
func (t Timestamp) MarshalJSON() ([]byte, error) {
if time.Time(t).IsZero() {
return []byte("null"), nil
}
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
}

View file

@ -9,6 +9,7 @@ import (
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/handler"
"github.com/diamondburned/arikawa/session"
"github.com/pkg/errors"
)
var (
@ -76,6 +77,29 @@ func (s *State) Unhook() {
////
func (s *State) Permissions(
channelID, userID discord.Snowflake) (discord.Permissions, error) {
ch, err := s.Channel(channelID)
if err != nil {
return 0, errors.Wrap(err, "Failed to get channel")
}
g, err := s.Guild(ch.GuildID)
if err != nil {
return 0, errors.Wrap(err, "Failed to get guild")
}
m, err := s.Member(ch.GuildID, userID)
if err != nil {
return 0, errors.Wrap(err, "Failed to get member")
}
return discord.CalcOverwrites(*g, *ch, *m), nil
}
////
func (s *State) Self() (*discord.User, error) {
u, err := s.Store.Self()
if err == nil {