diff --git a/bot/command.go b/bot/command.go index e6a8de6..04e5e52 100644 --- a/bot/command.go +++ b/bot/command.go @@ -2,8 +2,24 @@ package bot import ( "reflect" + + "github.com/diamondburned/arikawa/v2/gateway" ) +// eventIntents maps event pointer types to intents. +var eventIntents = map[reflect.Type]gateway.Intents{} + +func init() { + for event, intent := range gateway.EventIntents { + fn, ok := gateway.EventCreator[event] + if !ok { + continue + } + + eventIntents[reflect.TypeOf(fn())] = intent + } +} + type command struct { value reflect.Value // Func event reflect.Type @@ -26,6 +42,15 @@ func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, er return callWith(c.value, arg0, argv...) } +// intents returns the command's intents from the event. +func (c *command) intents() gateway.Intents { + intents, ok := eventIntents[c.event] + if !ok { + return 0 + } + return intents +} + func callWith(caller reflect.Value, arg0 interface{}, argv ...reflect.Value) (interface{}, error) { var callargs = make([]reflect.Value, 0, 1+len(argv)) diff --git a/bot/ctx.go b/bot/ctx.go index 0077cf7..e60d1eb 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -165,6 +165,8 @@ func Start( } } + s.Gateway.AddIntents(c.DeriveIntents()) + cancel := c.Start() if err := s.Open(); err != nil { @@ -229,10 +231,10 @@ func New(s *state.State, cmd interface{}) (*Context, error) { return ctx, nil } -// AddIntent adds the given Gateway Intent into the Gateway. This is a +// AddIntents adds the given Gateway Intent into the Gateway. This is a // convenient function that calls Gateway's AddIntent. -func (ctx *Context) AddIntent(i gateway.Intents) { - ctx.Gateway.AddIntent(i) +func (ctx *Context) AddIntents(i gateway.Intents) { + ctx.Gateway.AddIntents(i) } // Subcommands returns the slice of subcommands. To add subcommands, use @@ -444,3 +446,13 @@ func IndentLines(input string) string { } return strings.Join(lines, "\n") } + +// DeriveIntents derives all possible gateway intents from this context and all +// its subcommands' method handlers and middlewares. +func (ctx *Context) DeriveIntents() gateway.Intents { + var intents = ctx.Subcommand.DeriveIntents() + for _, subcmd := range ctx.subcommands { + intents |= subcmd.DeriveIntents() + } + return intents +} diff --git a/bot/ctx_test.go b/bot/ctx_test.go index daa215a..95c9fc0 100644 --- a/bot/ctx_test.go +++ b/bot/ctx_test.go @@ -149,6 +149,21 @@ func TestContext(t *testing.T) { } }) + t.Run("derive intents", func(t *testing.T) { + intents := ctx.DeriveIntents() + + assertIntents := func(target gateway.Intents, name string) { + if !intents.Has(target) { + t.Error("Derived intents do not have", name) + } + } + + assertIntents(gateway.IntentGuildMessages, "guild messages") + assertIntents(gateway.IntentDirectMessages, "direct messages") + assertIntents(gateway.IntentGuildMessageTyping, "guild typing") + assertIntents(gateway.IntentDirectMessageTyping, "direct message typing") + }) + t.Run("typing event", func(t *testing.T) { typing := &gateway.TypingStartEvent{} diff --git a/bot/subcommand.go b/bot/subcommand.go index 7ab0274..ad9d51a 100644 --- a/bot/subcommand.go +++ b/bot/subcommand.go @@ -143,6 +143,10 @@ func (sub *Subcommand) NeedsName() { sub.Command = lowerFirstLetter(sub.StructName) } +func lowerFirstLetter(name string) string { + return strings.ToLower(string(name[0])) + name[1:] +} + // FindCommand finds the MethodContext. It panics if methodName is not found. func (sub *Subcommand) FindCommand(methodName string) *MethodContext { for _, c := range sub.Commands { @@ -413,6 +417,23 @@ func (sub *Subcommand) AddAliases(commandName string, aliases ...string) { command.Aliases = append(command.Aliases, aliases...) } -func lowerFirstLetter(name string) string { - return strings.ToLower(string(name[0])) + name[1:] +// DeriveIntents derives all possible gateway intents from the method handlers +// and middlewares. +func (sub *Subcommand) DeriveIntents() gateway.Intents { + var intents gateway.Intents + + for _, event := range sub.Events { + intents |= event.intents() + } + for _, command := range sub.Commands { + intents |= command.intents() + } + if sub.plumbed != nil { + intents |= sub.plumbed.intents() + } + for _, middleware := range sub.globalmws { + intents |= middleware.intents() + } + + return intents } diff --git a/gateway/gateway.go b/gateway/gateway.go index 751de52..3c3664d 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -114,7 +114,7 @@ func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) { } for _, intent := range intents { - g.AddIntent(intent) + g.AddIntents(intent) } return g, nil @@ -154,9 +154,9 @@ func NewCustomGateway(gatewayURL, token string) *Gateway { } } -// AddIntent adds a Gateway Intent before connecting to the Gateway. As -// such, this function will only work before Open() is called. -func (g *Gateway) AddIntent(i Intents) { +// AddIntents adds a Gateway Intent before connecting to the Gateway. As such, +// this function will only work before Open() is called. +func (g *Gateway) AddIntents(i Intents) { g.Identifier.Intents |= i } diff --git a/gateway/intents.go b/gateway/intents.go index d858e3f..12fc96a 100644 --- a/gateway/intents.go +++ b/gateway/intents.go @@ -1,5 +1,7 @@ package gateway +import "github.com/diamondburned/arikawa/v2/discord" + // Intents for the new Discord API feature, documented at // https://discordapp.com/developers/docs/topics/gateway#gateway-intents. type Intents uint32 @@ -28,3 +30,15 @@ var PrivilegedIntents = []Intents{ IntentGuildPresences, IntentGuildMembers, } + +// Has returns true if i has the given intents. +func (i Intents) Has(intents Intents) bool { + return discord.HasFlag(uint64(i), uint64(intents)) +} + +// IsPrivileged returns true for each of the boolean that indicates the type of +// the privilege. +func (i Intents) IsPrivileged() (presences, member bool) { + // Keep this in sync with PrivilegedIntents. + return i.Has(IntentGuildPresences), i.Has(IntentGuildMembers) +}