diff --git a/_example/advanced_bot/context.go b/_example/advanced_bot/context.go new file mode 100644 index 0000000..f556d6b --- /dev/null +++ b/_example/advanced_bot/context.go @@ -0,0 +1,92 @@ +package main + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "github.com/diamondburned/arikawa/bot" + "github.com/diamondburned/arikawa/bot/extras/arguments" + "github.com/diamondburned/arikawa/discord" + "github.com/diamondburned/arikawa/gateway" +) + +type Bot struct { + // Context must not be embedded. + Ctx *bot.Context +} + +func (bot *Bot) Help(m *gateway.MessageCreateEvent) error { + _, err := bot.Ctx.SendMessage(m.ChannelID, bot.Ctx.Help(), nil) + return err +} + +func (bot *Bot) Ping(m *gateway.MessageCreateEvent) error { + _, err := bot.Ctx.SendMessage(m.ChannelID, "Pong!", nil) + return err +} + +func (bot *Bot) Say(m *gateway.MessageCreateEvent, f *arguments.Flag) error { + args := f.String() + if args == "" { + // Empty message, ignore + return nil + } + + _, err := bot.Ctx.SendMessage(m.ChannelID, args, nil) + return err +} + +func (bot *Bot) Embed( + m *gateway.MessageCreateEvent, f *arguments.Flag) error { + + fs := arguments.NewFlagSet() + + var ( + title = fs.String("title", "", "Title") + author = fs.String("author", "", "Author") + footer = fs.String("footer", "", "Footer") + color = fs.String("color", "#FFFFFF", "Color in hex format #hhhhhh") + ) + + if err := f.With(fs.FlagSet); err != nil { + return err + } + + if len(fs.Args()) < 1 { + return fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage()) + } + + // Check if the color string is valid. + if !strings.HasPrefix(*color, "#") || len(*color) != 7 { + return errors.New("Invalid color, format must be #hhhhhh") + } + + // Parse the color into decimal numbers. + colorHex, err := strconv.ParseInt((*color)[1:], 16, 64) + if err != nil { + return err + } + + // Make a new embed + embed := discord.Embed{ + Title: *title, + Description: strings.Join(fs.Args(), " "), + Color: discord.Color(colorHex), + } + + if *author != "" { + embed.Author = &discord.EmbedAuthor{ + Name: *author, + } + } + if *footer != "" { + embed.Footer = &discord.EmbedFooter{ + Text: *footer, + } + } + + _, err = bot.Ctx.SendMessage(m.ChannelID, "", &embed) + return err +} diff --git a/_example/advanced_bot/main.go b/_example/advanced_bot/main.go new file mode 100644 index 0000000..4121e71 --- /dev/null +++ b/_example/advanced_bot/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "log" + "os" + + "github.com/diamondburned/arikawa/bot" +) + +// To run, do `BOT_TOKEN="TOKEN HERE" go run .` + +func main() { + var token = os.Getenv("BOT_TOKEN") + if token == "" { + log.Fatalln("No $BOT_TOKEN given.") + } + + commands := &Bot{} + + stop, err := bot.Start(token, commands, func(ctx *bot.Context) error { + ctx.Prefix = "!" + return nil + }) + + if err != nil { + log.Fatalln(err) + } + + defer stop() + + log.Println("Bot started") + + // Automatically block until SIGINT. + bot.Wait() +} diff --git a/bot/arguments.go b/bot/arguments.go new file mode 100644 index 0000000..364d4c1 --- /dev/null +++ b/bot/arguments.go @@ -0,0 +1,122 @@ +package bot + +import ( + "errors" + "reflect" + "strconv" +) + +type argumentValueFn func(string) (reflect.Value, error) + +// Parseable implements a Parse(string) method for data structures that can be +// used as arguments. +type Parseable interface { + Parse(string) error +} + +// ManaulParseable implements a ParseContent(string) method. If the library sees +// this for an argument, it will send all of the arguments (including the +// command) into the method. If used, this should be the only argument followed +// after the Message Create event. Any more and the router will ignore. +type ManualParseable interface { + // $0 will have its prefix trimmed. + ParseContent([]string) error +} + +type RawArguments struct { + Arguments []string +} + +func (r *RawArguments) ParseContent(args []string) error { + r.Arguments = args + return nil +} + +// nilV, only used to return an error +var nilV = reflect.Value{} + +func getArgumentValueFn(t reflect.Type) (argumentValueFn, error) { + if t.Implements(typeIParser) { + mt, ok := t.MethodByName("Parse") + if !ok { + panic("BUG: type IParser does not implement Parse") + } + + return func(input string) (reflect.Value, error) { + v := reflect.New(t.Elem()) + + ret := mt.Func.Call([]reflect.Value{ + v, reflect.ValueOf(input), + }) + + if err := errorReturns(ret); err != nil { + return nilV, err + } + + return v, nil + }, nil + } + + var fn argumentValueFn + + switch t.Kind() { + case reflect.String: + fn = func(s string) (reflect.Value, error) { + return reflect.ValueOf(s), nil + } + + case reflect.Int, reflect.Int8, + reflect.Int16, reflect.Int32, reflect.Int64: + + fn = func(s string) (reflect.Value, error) { + i, err := strconv.ParseInt(s, 10, 64) + return quickRet(i, err, t) + } + + case reflect.Uint, reflect.Uint8, + reflect.Uint16, reflect.Uint32, reflect.Uint64: + + fn = func(s string) (reflect.Value, error) { + u, err := strconv.ParseUint(s, 10, 64) + return quickRet(u, err, t) + } + + case reflect.Float32, reflect.Float64: + fn = func(s string) (reflect.Value, error) { + f, err := strconv.ParseFloat(s, 64) + return quickRet(f, err, t) + } + + case reflect.Bool: + fn = func(s string) (reflect.Value, error) { + switch s { + case "true", "yes", "y", "Y", "1": + return reflect.ValueOf(true), nil + case "false", "no", "n", "N", "0": + return reflect.ValueOf(false), nil + default: + return nilV, errors.New("invalid bool [true/false]") + } + } + } + + if fn == nil { + return nil, errors.New("invalid type: " + t.String()) + } + + return fn, nil +} + +func quickRet(v interface{}, err error, t reflect.Type) (reflect.Value, error) { + if err != nil { + return nilV, err + } + + rv := reflect.ValueOf(v) + + if t == nil { + return rv, nil + } + + return rv.Convert(t), nil +} diff --git a/bot/copied_from_d.go b/bot/copied_from_d.go new file mode 100644 index 0000000..5191940 --- /dev/null +++ b/bot/copied_from_d.go @@ -0,0 +1,114 @@ +package bot + +/* +// UserPermissions but userID is after channelID. +func (ctx *Context) UserPermissions(channelID, userID string, +) (apermissions int, err error) { + + // Try to just get permissions from state. + apermissions, err = ctx.Session.State.UserChannelPermissions( + userID, channelID) + if err == nil { + return + } + + // Otherwise try get as much data from state as possible, falling back to the network. + channel, err := ctx.Channel(channelID) + if err != nil { + return + } + + guild, err := ctx.Guild(channel.GuildID) + if err != nil { + return + } + + if userID == guild.OwnerID { + apermissions = discordgo.PermissionAll + return + } + + member, err := ctx.Member(guild.ID, userID) + if err != nil { + return + } + + return MemberPermissions(guild, channel, member), nil +} + +// Why this isn't exported, I have no idea. +func MemberPermissions(guild *discordgo.Guild, channel *discordgo.Channel, + member *discordgo.Member) (apermissions int) { + + userID := member.User.ID + + if userID == guild.OwnerID { + apermissions = discordgo.PermissionAll + return + } + + for _, role := range guild.Roles { + if role.ID == guild.ID { + apermissions |= role.Permissions + break + } + } + + for _, role := range guild.Roles { + for _, roleID := range member.Roles { + if role.ID == roleID { + apermissions |= role.Permissions + break + } + } + } + + if apermissions&discordgo.PermissionAdministrator == + discordgo.PermissionAdministrator { + + apermissions |= discordgo.PermissionAll + } + + // Apply @everyone overrides from the channel. + for _, overwrite := range channel.PermissionOverwrites { + if guild.ID == overwrite.ID { + apermissions &= ^overwrite.Deny + apermissions |= overwrite.Allow + break + } + } + + denies := 0 + allows := 0 + + // Member overwrites can override role overrides, so do two passes + for _, overwrite := range channel.PermissionOverwrites { + for _, roleID := range member.Roles { + if overwrite.Type == "role" && roleID == overwrite.ID { + denies |= overwrite.Deny + allows |= overwrite.Allow + break + } + } + } + + apermissions &= ^denies + apermissions |= allows + + for _, overwrite := range channel.PermissionOverwrites { + if overwrite.Type == "member" && overwrite.ID == userID { + apermissions &= ^overwrite.Deny + apermissions |= overwrite.Allow + break + } + } + + if apermissions&discordgo.PermissionAdministrator == + discordgo.PermissionAdministrator { + + apermissions |= discordgo.PermissionAllChannel + } + + return apermissions +} +*/ diff --git a/bot/ctx.go b/bot/ctx.go new file mode 100644 index 0000000..89641ed --- /dev/null +++ b/bot/ctx.go @@ -0,0 +1,281 @@ +package bot + +import ( + "log" + "os" + "os/signal" + "strings" + + "github.com/diamondburned/arikawa/gateway" + "github.com/diamondburned/arikawa/state" + "github.com/pkg/errors" +) + +// TODO: add variadic arguments + +type Context struct { + *Subcommand + *state.State + + // Descriptive (but optional) bot name + Name string + + // Descriptive help body + Description string + + // The prefix for commands + Prefix string + + // FormatError formats any errors returned by anything, including the method + // commands or the reflect functions. This also includes invalid usage + // errors or unknown command errors. Returning an empty string means + // ignoring the error. + FormatError func(error) string + + // ErrorLogger logs any error that anything makes and the library can't + // reply to the client. This includes any event callback errors that aren't + // Message Create. + ErrorLogger func(error) + + // ReplyError when true replies to the user the error. + ReplyError bool + + // Subcommands contains all the registered subcommands. + Subcommands []*Subcommand +} + +// Start quickly starts a bot with the given command. It will prepend "Bot" +// into the token automatically. Refer to example/ for usage. +func Start(token string, cmd interface{}, + opts func(*Context) error) (stop func() error, err error) { + + s, err := state.New("Bot " + token) + if err != nil { + return nil, errors.Wrap(err, "Failed to create a dgo session") + } + + c, err := New(s, cmd) + if err != nil { + return nil, errors.Wrap(err, "Failed to create rfrouter") + } + + s.ErrorLog = func(err error) { + c.ErrorLogger(err) + } + + if opts != nil { + if err := opts(c); err != nil { + return nil, err + } + } + + cancel := c.Start() + + if err := s.Open(); err != nil { + return nil, errors.Wrap(err, "Failed to connect to Discord") + } + + return func() error { + cancel() + return s.Close() + }, nil +} + +// Wait is a convenient function that blocks until a SIGINT is sent. +func Wait() { + sigs := make(chan os.Signal) + signal.Notify(sigs, os.Interrupt) + <-sigs +} + +// New makes a new context with a "~" as the prefix. cmds must be a pointer to a +// struct with a *Context field. Example: +// +// type Commands struct { +// Ctx *Context +// } +// +// cmds := &Commands{} +// c, err := rfrouter.New(session, cmds) +// +// Commands' exported methods will all be used as commands. Messages are parsed +// with its first argument (the command) mapped accordingly to c.MapName, which +// capitalizes the first letter automatically to reflect the exported method +// name. +// +// The default prefix is "~", which means commands must start with "~" followed +// by the command name in the first argument, else it will be ignored. +// +// c.Start() should be called afterwards to actually handle incoming events. +func New(s *state.State, cmd interface{}) (*Context, error) { + c, err := NewSubcommand(cmd) + if err != nil { + return nil, err + } + + ctx := &Context{ + Subcommand: c, + State: s, + Prefix: "~", + FormatError: func(err error) string { + return err.Error() + }, + ErrorLogger: func(err error) { + log.Println("Bot error:", err) + }, + ReplyError: true, + } + + if err := ctx.InitCommands(ctx); err != nil { + return nil, errors.Wrap(err, "Failed to initialize with given cmds") + } + + return ctx, nil +} + +func (ctx *Context) RegisterSubcommand(cmd interface{}) (*Subcommand, error) { + s, err := NewSubcommand(cmd) + if err != nil { + return nil, errors.Wrap(err, "Failed to add subcommand") + } + + // Register the subcommand's name. + s.NeedsName() + + if err := s.InitCommands(ctx); err != nil { + return nil, errors.Wrap(err, "Failed to initialize subcommand") + } + + // Do a collision check + for _, sub := range ctx.Subcommands { + if sub.name == s.name { + return nil, errors.New( + "New subcommand has duplicate name: " + s.name) + } + } + + ctx.Subcommands = append(ctx.Subcommands, s) + return s, nil +} + +// Start adds itself into the discordgo Session handlers. This needs to be run. +// The returned function is a delete function, which removes itself from the +// Session handlers. +func (ctx *Context) Start() func() { + return ctx.Session.AddHandler(func(v interface{}) { + if err := ctx.callCmd(v); err != nil { + if str := ctx.FormatError(err); str != "" { + // Log the main error first + ctx.ErrorLogger(errors.Wrap(err, str)) + + mc, ok := v.(*gateway.MessageCreateEvent) + if !ok { + return + } + + if ctx.ReplyError { + _, Merr := ctx.SendMessage(mc.ChannelID, str, nil) + if Merr != nil { + // Then the message error + ctx.ErrorLogger(Merr) + // TODO: there ought to be a better way lol + } + } + } + } + }) +} + +// Call should only be used if you know what you're doing. +func (ctx *Context) Call(event interface{}) error { + return ctx.callCmd(event) +} + +// Help generates one. This function is used more for reference than an actual +// help message. As such, it only uses exported fields or methods. +func (ctx *Context) Help() string { + var help strings.Builder + + // Generate the headers and descriptions + help.WriteString("__Help__") + + if ctx.Name != "" { + help.WriteString(": " + ctx.Name) + } + + if ctx.Description != "" { + help.WriteString("\n " + ctx.Description) + } + + if ctx.Flag.Is(AdminOnly) { + // That's it. + return help.String() + } + + // Separators + help.WriteString("\n---\n") + + // Generate all commands + help.WriteString("__Commands__\n") + + for _, cmd := range ctx.Commands { + if cmd.Flag.Is(AdminOnly) { + // Hidden + continue + } + + help.WriteString(" " + ctx.Prefix + cmd.Name()) + + switch { + case len(cmd.Usage()) > 0: + help.WriteString(" " + strings.Join(cmd.Usage(), " ")) + case cmd.Description != "": + help.WriteString(": " + cmd.Description) + } + + help.WriteByte('\n') + } + + var subHelp = strings.Builder{} + + for _, sub := range ctx.Subcommands { + if sub.Flag.Is(AdminOnly) { + // Hidden + continue + } + + subHelp.WriteString(" " + sub.Name()) + + if sub.Description != "" { + subHelp.WriteString(": " + sub.Description) + } + + subHelp.WriteByte('\n') + + for _, cmd := range sub.Commands { + if cmd.Flag.Is(AdminOnly) { + continue + } + + subHelp.WriteString(" " + + ctx.Prefix + sub.Name() + " " + cmd.Name()) + + switch { + case len(cmd.Usage()) > 0: + subHelp.WriteString(" " + strings.Join(cmd.Usage(), " ")) + case cmd.Description != "": + subHelp.WriteString(": " + cmd.Description) + } + + subHelp.WriteByte('\n') + } + } + + if sub := subHelp.String(); sub != "" { + help.WriteString("---\n") + help.WriteString("__Subcommands__\n") + help.WriteString(sub) + } + + return help.String() +} diff --git a/bot/ctx_call.go b/bot/ctx_call.go new file mode 100644 index 0000000..f4242c3 --- /dev/null +++ b/bot/ctx_call.go @@ -0,0 +1,312 @@ +package bot + +import ( + "encoding/csv" + "reflect" + "strings" + + "github.com/diamondburned/arikawa/discord" + "github.com/diamondburned/arikawa/gateway" +) + +func (ctx *Context) callCmd(ev interface{}) error { + evT := reflect.TypeOf(ev) + + if evT != typeMessageCreate { + var callers []reflect.Value + var isAdmin *bool // i want to die + + for _, cmd := range ctx.Commands { + if cmd.event == evT { + if cmd.Flag.Is(AdminOnly) && + !ctx.eventIsAdmin(ev, &isAdmin) { + + continue + } + + callers = append(callers, cmd.value) + } + } + + for _, sub := range ctx.Subcommands { + if sub.Flag.Is(AdminOnly) && + !ctx.eventIsAdmin(ev, &isAdmin) { + + continue + } + + for _, cmd := range sub.Commands { + if cmd.event == evT { + if cmd.Flag.Is(AdminOnly) && + !ctx.eventIsAdmin(ev, &isAdmin) { + + continue + } + + callers = append(callers, cmd.value) + } + } + } + + for _, c := range callers { + if err := callWith(c, ev); err != nil { + ctx.ErrorLogger(err) + } + } + + return nil + } + + // safe assertion always + mc := ev.(*gateway.MessageCreateEvent) + + // check if prefix + if !strings.HasPrefix(mc.Content, ctx.Prefix) { + // not a command, ignore + return nil + } + + // trim the prefix before splitting, this way multi-words prefices work + content := mc.Content[len(ctx.Prefix):] + + if content == "" { + return nil // just the prefix only + } + + // parse arguments + args, err := ParseArgs(content) + if err != nil { + return err + } + + if len(args) < 1 { + return nil // ??? + } + + var cmd *CommandContext + var start int // arg starts from $start + + // Search for the command + for _, c := range ctx.Commands { + if c.name == args[0] { + cmd = c + start = 1 + break + } + } + + // Can't find command, look for subcommands of len(args) has a 2nd + // entry. + if cmd == nil && len(args) > 1 { + for _, s := range ctx.Subcommands { + if s.name != args[0] { + continue + } + + for _, c := range s.Commands { + if c.name == args[1] { + cmd = c + start = 2 + break + } + } + + if cmd == nil { + return &ErrUnknownCommand{ + Command: args[1], + Parent: args[0], + Prefix: ctx.Prefix, + ctx: s.Commands, + } + } + } + } + + if cmd == nil || start == 0 { + return &ErrUnknownCommand{ + Command: args[0], + Prefix: ctx.Prefix, + ctx: ctx.Commands, + } + } + + // Start converting + var argv []reflect.Value + + // Check manual parser + if cmd.parseType != nil { + // Create a zero value instance of this + v := reflect.New(cmd.parseType) + + // Call the manual parse method + ret := cmd.parseMethod.Func.Call([]reflect.Value{ + v, reflect.ValueOf(args), + }) + + // Check the method returns for error + if err := errorReturns(ret); err != nil { + // TODO: maybe wrap this? + return err + } + + // Add the pointer to the argument into argv + argv = append(argv, v) + goto Call + } + + // Here's an edge case: when the handler takes no arguments, we allow that + // anyway, as they might've used the raw content. + if len(cmd.arguments) == 0 { + goto Call + } + + // Not enough arguments given + if len(args[start:]) != len(cmd.arguments) { + return &ErrInvalidUsage{ + Args: args, + Prefix: ctx.Prefix, + Index: len(cmd.arguments) - start, + Err: "Not enough arguments given", + ctx: cmd, + } + } + + argv = make([]reflect.Value, len(cmd.arguments)) + + for i := start; i < len(args); i++ { + v, err := cmd.arguments[i-start](args[i]) + if err != nil { + return &ErrInvalidUsage{ + Args: args, + Prefix: ctx.Prefix, + Index: i, + Err: err.Error(), + ctx: cmd, + } + } + + argv[i-start] = v + } + +Call: + // call the function and parse the error return value + return callWith(cmd.value, ev, argv...) +} + +func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool { + if *is != nil { + return **is + } + + var channelID = reflectChannelID(ev) + if !channelID.Valid() { + return false + } + + var userID = reflectUserID(ev) + if !userID.Valid() { + return false + } + + var res bool + + p, err := ctx.State.Permissions(channelID, userID) + if err == nil && p.Has(discord.PermissionAdministrator) { + res = true + } + + *is = &res + return res +} + +func callWith(caller reflect.Value, ev interface{}, values ...reflect.Value) error { + return errorReturns(caller.Call(append( + []reflect.Value{reflect.ValueOf(ev)}, + values..., + ))) +} + +var ParseArgs = func(args string) ([]string, error) { + // TODO: make modular + // TODO: actual tokenizer+parser + r := csv.NewReader(strings.NewReader(args)) + r.Comma = ' ' + + return r.Read() +} + +func errorReturns(returns []reflect.Value) error { + // assume first is always error, since we checked for this in parseCommands + v := returns[0].Interface() + + if v == nil { + return nil + } + + return v.(error) +} + +func reflectChannelID(_struct interface{}) discord.Snowflake { + return _reflectID(reflect.ValueOf(_struct), "Channel") +} + +func reflectGuildID(_struct interface{}) discord.Snowflake { + return _reflectID(reflect.ValueOf(_struct), "Guild") +} + +func reflectUserID(_struct interface{}) discord.Snowflake { + return _reflectID(reflect.ValueOf(_struct), "User") +} + +func _reflectID(v reflect.Value, thing string) discord.Snowflake { + if !v.IsValid() { + return 0 + } + + t := v.Type() + + if t.Kind() == reflect.Ptr { + v = v.Elem() + + // Recheck after dereferring + if !v.IsValid() { + return 0 + } + + t = v.Type() + } + + if t.Kind() != reflect.Struct { + return 0 + } + + numFields := t.NumField() + + for i := 0; i < numFields; i++ { + field := t.Field(i) + fType := field.Type + + if fType.Kind() == reflect.Ptr { + fType = fType.Elem() + } + + switch fType.Kind() { + case reflect.Struct: + if chID := _reflectID(v.Field(i), thing); chID.Valid() { + return chID + } + case reflect.Int64: + if field.Name == thing+"ID" { + // grab value real quick + return discord.Snowflake(v.Field(i).Int()) + } + + // Special case where the struct name has Channel in it + if field.Name == "ID" && strings.Contains(t.Name(), thing) { + return discord.Snowflake(v.Field(i).Int()) + } + } + } + + return 0 +} diff --git a/bot/ctx_test.go b/bot/ctx_test.go new file mode 100644 index 0000000..3164588 --- /dev/null +++ b/bot/ctx_test.go @@ -0,0 +1,312 @@ +package bot + +import ( + "reflect" + "strings" + "testing" + + "github.com/diamondburned/arikawa/discord" + "github.com/diamondburned/arikawa/gateway" + "github.com/diamondburned/arikawa/state" + "github.com/pkg/errors" +) + +type testCommands struct { + Ctx *Context + Return chan interface{} +} + +func (t *testCommands) Send(_ *gateway.MessageCreateEvent, arg string) error { + t.Return <- arg + return errors.New("oh no") +} + +func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *CustomParseable) error { + t.Return <- c.args + return nil +} + +func (t *testCommands) NoArgs(_ *gateway.MessageCreateEvent) error { + return errors.New("passed") +} + +func (t *testCommands) Noop(_ *gateway.MessageCreateEvent) error { + return nil +} + +type CustomParseable struct { + args []string +} + +func (c *CustomParseable) ParseContent(args []string) error { + c.args = args + return nil +} + +func TestNewContext(t *testing.T) { + var state = &state.State{ + Store: state.NewDefaultStore(nil), + } + + _, err := New(state, &testCommands{}) + if err != nil { + t.Fatal("Failed to create new context:", err) + } +} + +func TestContext(t *testing.T) { + var given = &testCommands{} + var state = &state.State{ + Store: state.NewDefaultStore(nil), + } + + s, err := NewSubcommand(given) + if err != nil { + t.Fatal("Failed to create subcommand:", err) + } + + var ctx = &Context{ + Subcommand: s, + State: state, + } + + t.Run("init commands", func(t *testing.T) { + if err := ctx.Subcommand.InitCommands(ctx); err != nil { + t.Fatal("Failed to init commands:", err) + } + + if given.Ctx == nil { + t.Fatal("given's Context field is nil") + } + + if given.Ctx.State.Store == nil { + t.Fatal("given's State is nil") + } + }) + + testReturn := func(expects interface{}, content string) (call error) { + // Return channel for testing + ret := make(chan interface{}) + given.Return = ret + + // Mock a messageCreate event + m := &gateway.MessageCreateEvent{ + Content: content, + } + + var ( + callCh = make(chan error) + ) + + go func() { + callCh <- ctx.callCmd(m) + }() + + select { + case arg := <-ret: + if !reflect.DeepEqual(arg, expects) { + t.Fatal("returned argument is invalid:", arg) + } + call = <-callCh + + case call = <-callCh: + t.Fatal("expected return before error:", call) + } + + return + } + + t.Run("call command", func(t *testing.T) { + // Set a custom prefix + ctx.Prefix = "~" + + if err := testReturn("test", "~send test"); err.Error() != "oh no" { + t.Fatal("unexpected error:", err) + } + }) + + t.Run("call command custom parser", func(t *testing.T) { + ctx.Prefix = "!" + expects := []string{"custom", "arg1", ":)"} + + if err := testReturn(expects, "!custom arg1 :)"); err != nil { + t.Fatal("Unexpected call error:", err) + } + }) + + testMessage := func(content string) error { + // Mock a messageCreate event + m := &gateway.MessageCreateEvent{ + Content: content, + } + + return ctx.callCmd(m) + } + + t.Run("call command without args", func(t *testing.T) { + ctx.Prefix = "" + + if err := testMessage("noargs"); err.Error() != "passed" { + t.Fatal("unexpected error:", err) + } + }) + + // Test error cases + + t.Run("call unknown command", func(t *testing.T) { + ctx.Prefix = "joe pls " + + err := testMessage("joe pls no") + + if err == nil || !strings.HasPrefix(err.Error(), "Unknown command:") { + t.Fatal("unexpected error:", err) + } + }) + + // Test subcommands + + t.Run("register subcommand", func(t *testing.T) { + ctx.Prefix = "run " + + _, err := ctx.RegisterSubcommand(&testCommands{}) + if err != nil { + t.Fatal("Failed to register subcommand:", err) + } + + if err := testMessage("run testcommands noop"); err != nil { + t.Fatal("unexpected error:", err) + } + }) +} + +func BenchmarkConstructor(b *testing.B) { + var state = &state.State{ + Store: state.NewDefaultStore(nil), + } + + for i := 0; i < b.N; i++ { + _, _ = New(state, &testCommands{}) + } +} + +func BenchmarkCall(b *testing.B) { + var given = &testCommands{} + var state = &state.State{ + Store: state.NewDefaultStore(nil), + } + + s, _ := NewSubcommand(given) + + var ctx = &Context{ + Subcommand: s, + State: state, + Prefix: "~", + } + + m := &gateway.MessageCreateEvent{ + Content: "~noop", + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ctx.callCmd(m) + } +} + +func BenchmarkHelp(b *testing.B) { + var given = &testCommands{} + var state = &state.State{ + Store: state.NewDefaultStore(nil), + } + + s, _ := NewSubcommand(given) + + var ctx = &Context{ + Subcommand: s, + State: state, + Prefix: "~", + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = ctx.Help() + } +} + +type hasID struct { + ChannelID discord.Snowflake +} + +type embedsID struct { + *hasID + *embedsID +} + +type hasChannelInName struct { + ID discord.Snowflake +} + +func TestReflectChannelID(t *testing.T) { + var s = &hasID{ + ChannelID: 69420, + } + + t.Run("hasID", func(t *testing.T) { + if id := reflectChannelID(s); id != 69420 { + t.Fatal("unexpected channelID:", id) + } + }) + + t.Run("embedsID", func(t *testing.T) { + var e = &embedsID{ + hasID: s, + } + + if id := reflectChannelID(e); id != 69420 { + t.Fatal("unexpected channelID:", id) + } + }) + + t.Run("hasChannelInName", func(t *testing.T) { + var s = &hasChannelInName{ + ID: 69420, + } + + if id := reflectChannelID(s); id != 69420 { + t.Fatal("unexpected channelID:", id) + } + }) +} + +func BenchmarkReflectChannelID_1Level(b *testing.B) { + var s = &hasID{ + ChannelID: 69420, + } + + for i := 0; i < b.N; i++ { + _ = reflectChannelID(s) + } +} + +func BenchmarkReflectChannelID_5Level(b *testing.B) { + var s = &embedsID{ + nil, + &embedsID{ + nil, + &embedsID{ + nil, + &embedsID{ + hasID: &hasID{ + ChannelID: 69420, + }, + }, + }, + }, + } + + for i := 0; i < b.N; i++ { + _ = reflectChannelID(s) + } +} diff --git a/bot/error.go b/bot/error.go new file mode 100644 index 0000000..163cf65 --- /dev/null +++ b/bot/error.go @@ -0,0 +1,66 @@ +package bot + +import ( + "strings" +) + +type ErrUnknownCommand struct { + Command string + Parent string + + Prefix string + + // TODO: list available commands? + // Here, as a reminder + ctx []*CommandContext +} + +func (err *ErrUnknownCommand) Error() string { + var header = "Unknown command: " + err.Prefix + if err.Parent != "" { + header += err.Parent + " " + err.Command + } else { + header += err.Command + } + + return header +} + +type ErrInvalidUsage struct { + Args []string + Prefix string + + Index int + Err string + + // TODO: usage generator? + // Here, as a reminder + ctx *CommandContext +} + +func (err *ErrInvalidUsage) Error() string { + if err.Index == 0 { + return "Invalid usage" + } + + if len(err.Args) == 0 { + return "Missing arguments. Refer to help." + } + + body := "Invalid usage at " + err.Prefix + + // Write the first part + body += strings.Join(err.Args[:err.Index], " ") + + // Write the wrong part + body += " __" + err.Args[err.Index] + "__ " + + // Write the last part + body += strings.Join(err.Args[err.Index+1:], " ") + + if err.Err != "" { + body += "\nError: " + err.Err + } + + return body +} diff --git a/bot/extras/arguments/emoji.go b/bot/extras/arguments/emoji.go new file mode 100644 index 0000000..cdbc19a --- /dev/null +++ b/bot/extras/arguments/emoji.go @@ -0,0 +1,51 @@ +package arguments + +import ( + "errors" + "regexp" +) + +var ( + EmojiRegex = regexp.MustCompile(`<(a?):(.+?):(\d+)>`) + + ErrInvalidEmoji = errors.New("Invalid emoji") +) + +type Emoji struct { + ID string + + Custom bool + Name string + Animated bool +} + +func (e *Emoji) Parse(arg string) error { + // Check if Unicode + var unicode string + + for _, r := range arg { + if r < '\U0001F600' && r > '\U0001F64F' { + unicode += string(r) + } + } + + if unicode != "" { + e.ID = unicode + e.Custom = false + + return nil + } + + var matches = EmojiRegex.FindStringSubmatch(arg) + + if len(matches) != 4 { + return ErrInvalidEmoji + } + + e.Custom = true + e.Animated = matches[1] == "a" + e.Name = matches[2] + e.ID = matches[3] + + return nil +} diff --git a/bot/extras/arguments/flag.go b/bot/extras/arguments/flag.go new file mode 100644 index 0000000..f99ab7b --- /dev/null +++ b/bot/extras/arguments/flag.go @@ -0,0 +1,65 @@ +package arguments + +import ( + "bytes" + "flag" + "io/ioutil" + "strings" +) + +var FlagName = "command" + +type FlagSet struct { + *flag.FlagSet +} + +func NewFlagSet() *FlagSet { + fs := flag.NewFlagSet(FlagName, flag.ContinueOnError) + fs.SetOutput(ioutil.Discard) + + return &FlagSet{fs} +} + +func (fs *FlagSet) Usage() string { + var buf bytes.Buffer + + fs.FlagSet.SetOutput(&buf) + fs.FlagSet.Usage() + fs.FlagSet.SetOutput(ioutil.Discard) + + return buf.String() +} + +type Flag struct { + arguments []string +} + +func (f *Flag) ParseContent(arguments []string) error { + // trim the command out + f.arguments = arguments[1:] + return nil +} + +func (f *Flag) Usage() string { + return "flags..." +} + +func (f *Flag) Args() []string { + return f.arguments +} + +func (f *Flag) Arg(n int) string { + if n < 0 || n >= len(f.arguments) { + return "" + } + + return f.arguments[n] +} + +func (f *Flag) String() string { + return strings.Join(f.arguments, " ") +} + +func (f *Flag) With(fs *flag.FlagSet) error { + return fs.Parse(f.arguments) +} diff --git a/bot/extras/arguments/mention.go b/bot/extras/arguments/mention.go new file mode 100644 index 0000000..e04261b --- /dev/null +++ b/bot/extras/arguments/mention.go @@ -0,0 +1,52 @@ +package arguments + +import ( + "errors" + "regexp" +) + +var ( + ChannelRegex = regexp.MustCompile(`<#(\d+)>`) + UserRegex = regexp.MustCompile(`<@!?(\d+)>`) + RoleRegex = regexp.MustCompile(`<@&(\d+)>`) +) + +type ChannelMention string + +func (m *ChannelMention) Parse(arg string) error { + return grabFirst(ChannelRegex, "channel mention", arg, (*string)(m)) +} + +func (m *ChannelMention) Usage() string { + return "#channel" +} + +type UserMention string + +func (m *UserMention) Parse(arg string) error { + return grabFirst(UserRegex, "user mention", arg, (*string)(m)) +} + +func (m *UserMention) Usage() string { + return "@user" +} + +type RoleMention string + +func (m *RoleMention) Parse(arg string) error { + return grabFirst(RoleRegex, "role mention", arg, (*string)(m)) +} + +func (m *RoleMention) Usage() string { + return "@role" +} + +func grabFirst(reg *regexp.Regexp, item, input string, output *string) error { + matches := reg.FindStringSubmatch(input) + if len(matches) < 2 { + return errors.New("Invalid " + item) + } + + *output = matches[1] + return nil +} diff --git a/bot/nameflag.go b/bot/nameflag.go new file mode 100644 index 0000000..9ab6ce3 --- /dev/null +++ b/bot/nameflag.go @@ -0,0 +1,40 @@ +package bot + +import "strings" + +type NameFlag uint64 + +const FlagSeparator = 'ー' + +const ( + None NameFlag = 1 << iota + + // These flags only apply to messageCreate events. + + Raw // R + AdminOnly // A +) + +func ParseFlag(name string) (NameFlag, string) { + parts := strings.SplitN(name, string(FlagSeparator), 2) + if len(parts) != 2 { + return 0, name + } + + var f NameFlag + + for _, r := range parts[0] { + switch r { + case 'R': + f |= Raw + case 'A': + f |= AdminOnly + } + } + + return f, parts[1] +} + +func (f NameFlag) Is(flag NameFlag) bool { + return f&flag != 0 +} diff --git a/bot/nameflag_test.go b/bot/nameflag_test.go new file mode 100644 index 0000000..915a0b9 --- /dev/null +++ b/bot/nameflag_test.go @@ -0,0 +1,26 @@ +package bot + +import "testing" + +func TestNameFlag(t *testing.T) { + type entry struct { + Name string + Expect NameFlag + String string + } + + var entries = []entry{{ + Name: "AーEcho", + Expect: AdminOnly, + }, { + Name: "RAーGC", + Expect: Raw | AdminOnly, + }} + + for _, entry := range entries { + var f, _ = ParseFlag(entry.Name) + if !f.Is(entry.Expect) { + t.Fatalf("unexpected expectation for %s: %v", entry.Name, f) + } + } +} diff --git a/bot/subcommand.go b/bot/subcommand.go new file mode 100644 index 0000000..8c6cf66 --- /dev/null +++ b/bot/subcommand.go @@ -0,0 +1,298 @@ +package bot + +import ( + "reflect" + "strings" + + "github.com/diamondburned/arikawa/gateway" + "github.com/pkg/errors" +) + +var ( + typeMessageCreate = reflect.TypeOf((*gateway.MessageCreateEvent)(nil)) + // typeof.Implements(typeI*) + typeIError = reflect.TypeOf((*error)(nil)).Elem() + typeIManP = reflect.TypeOf((*ManualParseable)(nil)).Elem() + typeIParser = reflect.TypeOf((*Parseable)(nil)).Elem() + typeIUsager = reflect.TypeOf((*Usager)(nil)).Elem() +) + +type Subcommand struct { + Description string + + // Commands contains all the registered command contexts. + Commands []*CommandContext + + // struct name + name string + + // struct flags + Flag NameFlag + + // Directly to struct + cmdValue reflect.Value + cmdType reflect.Type + + // Pointer value + ptrValue reflect.Value + ptrType reflect.Type + + // command interface as reference + command interface{} +} + +// CommandContext is an internal struct containing fields to make this library +// work. As such, they're all unexported. Description, however, is exported for +// editing, and may be used to generate more informative help messages. +type CommandContext struct { + Description string + Flag NameFlag + + name string // all lower-case + value reflect.Value // Func + event reflect.Type // gateway.*Event + method reflect.Method + + // equal slices + argStrings []string + arguments []argumentValueFn + + // only for ParseContent interface + parseMethod reflect.Method + parseType reflect.Type + parseUsage string +} + +// Descriptor is optionally used to set the Description of a command context. +type Descriptor interface { + Description() string +} + +// Namer is optionally used to override the command context's name. +type Namer interface { + Name() string +} + +// Usager is optionally used to override the generated usage for either an +// argument, or multiple (using ManualParseable). +type Usager interface { + Usage() string +} + +func (cctx *CommandContext) Name() string { + return cctx.name +} + +func (cctx *CommandContext) Usage() []string { + if cctx.parseType != nil { + return []string{cctx.parseUsage} + } + + if len(cctx.arguments) == 0 { + return nil + } + + return cctx.argStrings +} + +func NewSubcommand(cmd interface{}) (*Subcommand, error) { + var sub = Subcommand{ + command: cmd, + } + + // Set description + if d, ok := cmd.(Descriptor); ok { + sub.Description = d.Description() + } + + if err := sub.reflectCommands(); err != nil { + return nil, errors.Wrap(err, "Failed to reflect commands") + } + + if err := sub.parseCommands(); err != nil { + return nil, errors.Wrap(err, "Failed to parse commands") + } + + return &sub, nil +} + +// Name returns the command name in lower case. This only returns non-zero for +// subcommands. +func (sub *Subcommand) Name() string { + return sub.name +} + +// NeedsName sets the name for this subcommand. Like InitCommands, this +// shouldn't be called at all, rather you should use RegisterSubcommand. +func (sub *Subcommand) NeedsName() { + flag, name := ParseFlag(sub.cmdType.Name()) + + // Check for interface + if n, ok := sub.command.(Namer); ok { + name = n.Name() + } + + if !flag.Is(Raw) { + name = strings.ToLower(name) + } + + sub.name = name + sub.Flag = flag +} + +func (sub *Subcommand) reflectCommands() error { + t := reflect.TypeOf(sub.command) + v := reflect.ValueOf(sub.command) + + if t.Kind() != reflect.Ptr { + return errors.New("sub is not a pointer") + } + + // Set the pointer fields + sub.ptrValue = v + sub.ptrType = t + + ts := t.Elem() + vs := v.Elem() + + if ts.Kind() != reflect.Struct { + return errors.New("sub is not pointer to struct") + } + + // Set the struct fields + sub.cmdValue = vs + sub.cmdType = ts + + return nil +} + +// InitCommands fills a Subcommand with a context. This shouldn't be called at +// all, rather you should use the RegisterSubcommand method of a Context. +func (sub *Subcommand) InitCommands(ctx *Context) error { + // Start filling up a *Context field + for i := 0; i < sub.cmdValue.NumField(); i++ { + field := sub.cmdValue.Field(i) + + if !field.CanSet() || !field.CanInterface() { + continue + } + + if _, ok := field.Interface().(*Context); !ok { + continue + } + + field.Set(reflect.ValueOf(ctx)) + return nil + } + + return errors.New("No fields with *Command found") +} + +func (sub *Subcommand) parseCommands() error { + var numMethods = sub.ptrValue.NumMethod() + var commands = make([]*CommandContext, 0, numMethods) + + for i := 0; i < numMethods; i++ { + method := sub.ptrValue.Method(i) + + if !method.CanInterface() { + continue + } + + methodT := method.Type() + numArgs := methodT.NumIn() + + // Doesn't meet requirement for an event + if numArgs == 0 { + continue + } + + // Check return type + if err := methodT.Out(0); err == nil || !err.Implements(typeIError) { + // Invalid, skip + continue + } + + var command = CommandContext{ + method: sub.ptrType.Method(i), + value: method, + event: methodT.In(0), // parse event + } + + // Parse the method name + flag, name := ParseFlag(command.method.Name) + + if !flag.Is(Raw) { + name = strings.ToLower(name) + } + + // Set the method name and flag + command.name = name + command.Flag = flag + + // TODO: allow more flexibility + if command.event != typeMessageCreate { + goto Done + } + + if numArgs == 1 { + // done + goto Done + } + + // If the second argument implements ParseContent() + if t := methodT.In(1); t.Implements(typeIManP) { + mt, _ := t.MethodByName("ParseContent") + + command.parseMethod = mt + command.parseType = t.Elem() + + command.parseUsage = usager(t) + if command.parseUsage == "" { + command.parseUsage = t.String() + } + + goto Done + } + + command.arguments = make([]argumentValueFn, 0, numArgs) + + // Fill up arguments + for i := 1; i < numArgs; i++ { + t := methodT.In(i) + + avfs, err := getArgumentValueFn(t) + if err != nil { + return errors.Wrap(err, "Error parsing argument "+t.String()) + } + + command.arguments = append(command.arguments, avfs) + + var usage = usager(t) + if usage == "" { + usage = t.String() + } + + command.argStrings = append(command.argStrings, usage) + } + + Done: + // Append + commands = append(commands, &command) + } + + sub.Commands = commands + return nil +} + +func usager(t reflect.Type) string { + if !t.Implements(typeIUsager) { + return "" + } + + usageFn, _ := t.MethodByName("Usage") + v := usageFn.Func.Call([]reflect.Value{ + reflect.New(t.Elem()), + }) + return v[0].String() +} diff --git a/bot/subcommand_test.go b/bot/subcommand_test.go new file mode 100644 index 0000000..57ffe4e --- /dev/null +++ b/bot/subcommand_test.go @@ -0,0 +1,96 @@ +package bot + +import "testing" + +func TestNewSubcommand(t *testing.T) { + _, err := NewSubcommand(&testCommands{}) + if err != nil { + t.Fatal("Failed to create new subcommand:", err) + } +} + +func TestSubcommand(t *testing.T) { + var given = &testCommands{} + var sub = &Subcommand{ + command: given, + } + + t.Run("reflect commands", func(t *testing.T) { + if err := sub.reflectCommands(); err != nil { + t.Fatal("Failed to reflect commands:", err) + } + }) + + t.Run("parse commands", func(t *testing.T) { + if err := sub.parseCommands(); err != nil { + t.Fatal("Failed to parse commands:", err) + } + + // !!! CHANGE ME + if len(sub.Commands) != 4 { + t.Fatal("invalid ctx.commands len", len(sub.Commands)) + } + + var ( + foundSend bool + foundCustom bool + foundNoArgs bool + ) + + for _, this := range sub.Commands { + switch this.name { + case "send": + foundSend = true + if len(this.arguments) != 1 { + t.Fatal("invalid arguments len", len(this.arguments)) + } + + case "custom": + foundCustom = true + if len(this.arguments) > 0 { + t.Fatal("arguments should be 0 for custom") + } + if this.parseType == nil { + t.Fatal("custom has nil manualParse") + } + + case "noargs": + foundNoArgs = true + if len(this.arguments) != 0 { + t.Fatal("expected 0 arguments, got non-zero") + } + if this.parseType != nil { + t.Fatal("unexpected parseType") + } + + case "noop": + // Found, but whatever + + default: + t.Fatal("Unexpected command:", this.name) + } + + if this.event != typeMessageCreate { + t.Fatal("invalid event type:", this.event.String()) + } + } + + if !foundSend { + t.Fatal("missing send") + } + + if !foundCustom { + t.Fatal("missing custom") + } + + if !foundNoArgs { + t.Fatal("missing noargs") + } + }) +} + +func BenchmarkSubcommandConstructor(b *testing.B) { + for i := 0; i < b.N; i++ { + NewSubcommand(&testCommands{}) + } +} diff --git a/discord/message_embed.go b/discord/message_embed.go index c768aac..6ececbc 100644 --- a/discord/message_embed.go +++ b/discord/message_embed.go @@ -2,7 +2,7 @@ package discord import "fmt" -type Color uint +type Color uint32 const DefaultColor Color = 0x303030 diff --git a/discord/snowflake.go b/discord/snowflake.go index 24cc689..5857dfb 100644 --- a/discord/snowflake.go +++ b/discord/snowflake.go @@ -37,6 +37,8 @@ func (s *Snowflake) MarshalJSON() ([]byte, error) { switch i := int64(*s); i { case -1: // @me id = "@me" + case 0: + return []byte("null"), nil default: id = strconv.FormatInt(i, 10) } diff --git a/discord/time.go b/discord/time.go index 30fca0b..f965d38 100644 --- a/discord/time.go +++ b/discord/time.go @@ -31,6 +31,10 @@ func (t *Timestamp) UnmarshalJSON(v []byte) error { } func (t Timestamp) MarshalJSON() ([]byte, error) { + if time.Time(t).IsZero() { + return []byte("null"), nil + } + return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil } diff --git a/state/state.go b/state/state.go index 9138da9..21a31a6 100644 --- a/state/state.go +++ b/state/state.go @@ -9,6 +9,7 @@ import ( "github.com/diamondburned/arikawa/gateway" "github.com/diamondburned/arikawa/handler" "github.com/diamondburned/arikawa/session" + "github.com/pkg/errors" ) var ( @@ -76,6 +77,29 @@ func (s *State) Unhook() { //// +func (s *State) Permissions( + channelID, userID discord.Snowflake) (discord.Permissions, error) { + + ch, err := s.Channel(channelID) + if err != nil { + return 0, errors.Wrap(err, "Failed to get channel") + } + + g, err := s.Guild(ch.GuildID) + if err != nil { + return 0, errors.Wrap(err, "Failed to get guild") + } + + m, err := s.Member(ch.GuildID, userID) + if err != nil { + return 0, errors.Wrap(err, "Failed to get member") + } + + return discord.CalcOverwrites(*g, *ch, *m), nil +} + +//// + func (s *State) Self() (*discord.User, error) { u, err := s.Store.Self() if err == nil {