1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-07-29 00:42:16 +00:00

Compare commits

...

9 commits

Author SHA1 Message Date
diamondburned 9b4b707070 Handler: Use a free list over a linked list 2020-10-28 23:55:39 -07:00
diamondburned 9899f6073b Bot: Added automatic Intents detection from handlers
This commit adds automatic Intents detection into package bot. When the
Start function is used, the intents will be OR'd after running the
options callback.

This commit also breaks the old "AddIntent" methods to rename them to
"AddIntents" for correctness.
2020-10-28 22:49:18 -07:00
diamondburned ef48d686cd Handler: Changed to a linked list instead of a slice-backed map
This change should slightly improve the performance of the handler
container.

A rough benchmark was written and tested; the source code is at
https://gist.github.com/diamondburned/c369d13efda5c702a0e59874deee64bd.
2020-10-28 22:37:38 -07:00
diamondburned b7b8118d0b State: Fixed breaking change from previous Gateway change 2020-10-28 19:47:22 -07:00
diamondburned b8e4b2cf56 Gateway: Added an Event to Intents map for convenience 2020-10-28 19:44:04 -07:00
diamondburned c00d31ce30 Gateway: Added missing MessageReactionRemoveEmojiEvent constructor 2020-10-28 19:39:18 -07:00
diamondburned fd16db1385 Gateway: Fixed MessageReactionRemoveEmoji not having Event suffix 2020-10-28 19:31:43 -07:00
diamondburned 33e7abd4db Merge wsutil fix from v1 into v2 2020-10-28 19:03:10 -07:00
diamondburned 160a4e6606 wsutil: Fixed data races involving getters 2020-10-28 19:00:59 -07:00
15 changed files with 262 additions and 141 deletions

View file

@ -2,8 +2,24 @@ package bot
import (
"reflect"
"github.com/diamondburned/arikawa/v2/gateway"
)
// eventIntents maps event pointer types to intents.
var eventIntents = map[reflect.Type]gateway.Intents{}
func init() {
for event, intent := range gateway.EventIntents {
fn, ok := gateway.EventCreator[event]
if !ok {
continue
}
eventIntents[reflect.TypeOf(fn())] = intent
}
}
type command struct {
value reflect.Value // Func
event reflect.Type
@ -26,6 +42,15 @@ func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, er
return callWith(c.value, arg0, argv...)
}
// intents returns the command's intents from the event.
func (c *command) intents() gateway.Intents {
intents, ok := eventIntents[c.event]
if !ok {
return 0
}
return intents
}
func callWith(caller reflect.Value, arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
var callargs = make([]reflect.Value, 0, 1+len(argv))

View file

@ -165,6 +165,8 @@ func Start(
}
}
s.Gateway.AddIntents(c.DeriveIntents())
cancel := c.Start()
if err := s.Open(); err != nil {
@ -229,10 +231,10 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
return ctx, nil
}
// AddIntent adds the given Gateway Intent into the Gateway. This is a
// AddIntents adds the given Gateway Intent into the Gateway. This is a
// convenient function that calls Gateway's AddIntent.
func (ctx *Context) AddIntent(i gateway.Intents) {
ctx.Gateway.AddIntent(i)
func (ctx *Context) AddIntents(i gateway.Intents) {
ctx.Gateway.AddIntents(i)
}
// Subcommands returns the slice of subcommands. To add subcommands, use
@ -444,3 +446,13 @@ func IndentLines(input string) string {
}
return strings.Join(lines, "\n")
}
// DeriveIntents derives all possible gateway intents from this context and all
// its subcommands' method handlers and middlewares.
func (ctx *Context) DeriveIntents() gateway.Intents {
var intents = ctx.Subcommand.DeriveIntents()
for _, subcmd := range ctx.subcommands {
intents |= subcmd.DeriveIntents()
}
return intents
}

View file

@ -149,6 +149,21 @@ func TestContext(t *testing.T) {
}
})
t.Run("derive intents", func(t *testing.T) {
intents := ctx.DeriveIntents()
assertIntents := func(target gateway.Intents, name string) {
if !intents.Has(target) {
t.Error("Derived intents do not have", name)
}
}
assertIntents(gateway.IntentGuildMessages, "guild messages")
assertIntents(gateway.IntentDirectMessages, "direct messages")
assertIntents(gateway.IntentGuildMessageTyping, "guild typing")
assertIntents(gateway.IntentDirectMessageTyping, "direct message typing")
})
t.Run("typing event", func(t *testing.T) {
typing := &gateway.TypingStartEvent{}

View file

@ -143,6 +143,10 @@ func (sub *Subcommand) NeedsName() {
sub.Command = lowerFirstLetter(sub.StructName)
}
func lowerFirstLetter(name string) string {
return strings.ToLower(string(name[0])) + name[1:]
}
// FindCommand finds the MethodContext. It panics if methodName is not found.
func (sub *Subcommand) FindCommand(methodName string) *MethodContext {
for _, c := range sub.Commands {
@ -413,6 +417,23 @@ func (sub *Subcommand) AddAliases(commandName string, aliases ...string) {
command.Aliases = append(command.Aliases, aliases...)
}
func lowerFirstLetter(name string) string {
return strings.ToLower(string(name[0])) + name[1:]
// DeriveIntents derives all possible gateway intents from the method handlers
// and middlewares.
func (sub *Subcommand) DeriveIntents() gateway.Intents {
var intents gateway.Intents
for _, event := range sub.Events {
intents |= event.intents()
}
for _, command := range sub.Commands {
intents |= command.intents()
}
if sub.plumbed != nil {
intents |= sub.plumbed.intents()
}
for _, middleware := range sub.globalmws {
intents |= middleware.intents()
}
return intents
}

View file

@ -248,7 +248,7 @@ type (
MessageID discord.MessageID `json:"message_id"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
}
MessageReactionRemoveEmoji struct {
MessageReactionRemoveEmojiEvent struct {
ChannelID discord.ChannelID `json:"channel_id"`
MessageID discord.MessageID `json:"message_id"`
Emoji discord.Emoji `json:"emoji"`

View file

@ -3,6 +3,7 @@ package gateway
// Event is any event struct. They have an "Event" suffixed to them.
type Event = interface{}
// EventCreator maps an event type string to a constructor.
var EventCreator = map[string]func() Event{
"HELLO": func() Event { return new(HelloEvent) },
"READY": func() Event { return new(ReadyEvent) },
@ -44,9 +45,10 @@ var EventCreator = map[string]func() Event{
"MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) },
"MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) },
"MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) },
"MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) },
"MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) },
"MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) },
"MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) },
"MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) },
"MESSAGE_REACTION_REMOVE_EMOJI": func() Event { return new(MessageReactionRemoveEmojiEvent) },
"MESSAGE_ACK": func() Event { return new(MessageAckEvent) },

View file

@ -114,7 +114,7 @@ func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
}
for _, intent := range intents {
g.AddIntent(intent)
g.AddIntents(intent)
}
return g, nil
@ -154,9 +154,9 @@ func NewCustomGateway(gatewayURL, token string) *Gateway {
}
}
// AddIntent adds a Gateway Intent before connecting to the Gateway. As
// such, this function will only work before Open() is called.
func (g *Gateway) AddIntent(i Intents) {
// AddIntents adds a Gateway Intent before connecting to the Gateway. As such,
// this function will only work before Open() is called.
func (g *Gateway) AddIntents(i Intents) {
g.Identifier.Intents |= i
}

View file

@ -55,28 +55,6 @@ func (i *IdentifyData) SetShard(id, num int) {
i.Shard[0], i.Shard[1] = id, num
}
// Intents for the new Discord API feature, documented at
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
type Intents uint32
const (
IntentGuilds Intents = 1 << iota
IntentGuildMembers
IntentGuildBans
IntentGuildEmojis
IntentGuildIntegrations
IntentGuildWebhooks
IntentGuildInvites
IntentGuildVoiceStates
IntentGuildPresences
IntentGuildMessages
IntentGuildMessageReactions
IntentGuildMessageTyping
IntentDirectMessages
IntentDirectMessageReactions
IntentDirectMessageTyping
)
type Identifier struct {
IdentifyData

44
gateway/intents.go Normal file
View file

@ -0,0 +1,44 @@
package gateway
import "github.com/diamondburned/arikawa/v2/discord"
// Intents for the new Discord API feature, documented at
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
type Intents uint32
const (
IntentGuilds Intents = 1 << iota
IntentGuildMembers
IntentGuildBans
IntentGuildEmojis
IntentGuildIntegrations
IntentGuildWebhooks
IntentGuildInvites
IntentGuildVoiceStates
IntentGuildPresences
IntentGuildMessages
IntentGuildMessageReactions
IntentGuildMessageTyping
IntentDirectMessages
IntentDirectMessageReactions
IntentDirectMessageTyping
)
// PrivilegedIntents contains a list of privileged intents that Discord requires
// bots to have these intents explicitly enabled in the Developer Portal.
var PrivilegedIntents = []Intents{
IntentGuildPresences,
IntentGuildMembers,
}
// Has returns true if i has the given intents.
func (i Intents) Has(intents Intents) bool {
return discord.HasFlag(uint64(i), uint64(intents))
}
// IsPrivileged returns true for each of the boolean that indicates the type of
// the privilege.
func (i Intents) IsPrivileged() (presences, member bool) {
// Keep this in sync with PrivilegedIntents.
return i.Has(IntentGuildPresences), i.Has(IntentGuildMembers)
}

47
gateway/intents_map.go Normal file
View file

@ -0,0 +1,47 @@
package gateway
// EventIntents maps event types to intents.
var EventIntents = map[string]Intents{
"GUILD_CREATE": IntentGuilds,
"GUILD_UPDATE": IntentGuilds,
"GUILD_DELETE": IntentGuilds,
"GUILD_ROLE_CREATE": IntentGuilds,
"GUILD_ROLE_UPDATE": IntentGuilds,
"GUILD_ROLE_DELETE": IntentGuilds,
"CHANNEL_CREATE": IntentGuilds,
"CHANNEL_UPDATE": IntentGuilds,
"CHANNEL_DELETE": IntentGuilds,
"CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages,
"GUILD_MEMBER_ADD": IntentGuildMembers,
"GUILD_MEMBER_REMOVE": IntentGuildMembers,
"GUILD_MEMBER_UPDATE": IntentGuildMembers,
"GUILD_BAN_ADD": IntentGuildBans,
"GUILD_BAN_REMOVE": IntentGuildBans,
"GUILD_EMOJIS_UPDATE": IntentGuildEmojis,
"GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations,
"WEBHOOKS_UPDATE": IntentGuildWebhooks,
"INVITE_CREATE": IntentGuildInvites,
"INVITE_DELETE": IntentGuildInvites,
"VOICE_STATE_UPDATE": IntentGuildVoiceStates,
"PRESENCE_UPDATE": IntentGuildPresences,
"MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_DELETE_BULK": IntentGuildMessages,
"MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions,
"TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping,
}

View file

@ -240,7 +240,7 @@ func (s *State) onEvent(iface interface{}) {
return true
})
case *gateway.MessageReactionRemoveEmoji:
case *gateway.MessageReactionRemoveEmojiEvent:
s.editMessage(ev.ChannelID, ev.MessageID, func(m *discord.Message) bool {
var i = findReaction(m.Reactions, ev.Emoji)
if i < 0 {

View file

@ -25,21 +25,19 @@ import (
"github.com/pkg/errors"
)
// Handler is a container for command handlers. A zero-value instance is a valid
// instance.
type Handler struct {
// Synchronous controls whether to spawn each event handler in its own
// goroutine. Default false (meaning goroutines are spawned).
Synchronous bool
handlers map[uint64]handler
horders []uint64
hserial uint64
hmutex sync.RWMutex
mutex sync.RWMutex
slab slab
}
func New() *Handler {
return &Handler{
handlers: map[uint64]handler{},
}
return &Handler{}
}
// Call calls all handlers with the given event. This is an internal method; use
@ -48,52 +46,18 @@ func (h *Handler) Call(ev interface{}) {
var evV = reflect.ValueOf(ev)
var evT = evV.Type()
h.hmutex.RLock()
defer h.hmutex.RUnlock()
h.mutex.RLock()
defer h.mutex.RUnlock()
for _, order := range h.horders {
handler, ok := h.handlers[order]
if !ok {
// This shouldn't ever happen, but we're adding this just in case.
continue
}
if handler.not(evT) {
for _, entry := range h.slab.Entries {
if entry.isInvalid() || entry.not(evT) {
continue
}
if h.Synchronous {
handler.call(evV)
entry.call(evV)
} else {
go handler.call(evV)
}
}
}
// CallDirect is the same as Call, but only calls those event handlers that
// listen for this specific event, i.e. that aren't interface handlers.
func (h *Handler) CallDirect(ev interface{}) {
var evV = reflect.ValueOf(ev)
var evT = evV.Type()
h.hmutex.RLock()
defer h.hmutex.RUnlock()
for _, order := range h.horders {
handler, ok := h.handlers[order]
if !ok {
// This shouldn't ever happen, but we're adding this just in case.
continue
}
if evT != handler.event {
continue
}
if h.Synchronous {
handler.call(evV)
} else {
go handler.call(evV)
go entry.call(evV)
}
}
}
@ -213,47 +177,16 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
return nil, errors.Wrap(err, "handler reflect failed")
}
h.hmutex.Lock()
defer h.hmutex.Unlock()
// Get the current counter value and increment the counter:
serial := h.hserial
h.hserial++
// Create a map if there's none:
if h.handlers == nil {
h.handlers = map[uint64]handler{}
}
// Use the serial for the map:
h.handlers[serial] = *r
// Append the serial into the list of keys:
h.horders = append(h.horders, serial)
h.mutex.Lock()
id := h.slab.Put(r)
h.mutex.Unlock()
return func() {
h.hmutex.Lock()
defer h.hmutex.Unlock()
h.mutex.Lock()
popped := h.slab.Pop(id)
h.mutex.Unlock()
// Take and delete the handler from the map, but return if we can't find
// it.
hd, ok := h.handlers[serial]
if !ok {
return
}
delete(h.handlers, serial)
// Delete the key from the orders slice:
for i, order := range h.horders {
if order == serial {
h.horders = append(h.horders[:i], h.horders[i+1:]...)
break
}
}
// Clean up the handler.
hd.cleanup()
popped.cleanup()
}, nil
}
@ -267,7 +200,7 @@ type handler struct {
// newHandler reflects either a channel or a function into a handler. A function
// must only have a single argument being the event and no return, and a channel
// must have the event type as the underlying type.
func newHandler(unknown interface{}) (*handler, error) {
func newHandler(unknown interface{}) (handler, error) {
fnV := reflect.ValueOf(unknown)
fnT := fnV.Type()
@ -279,11 +212,11 @@ func newHandler(unknown interface{}) (*handler, error) {
switch fnT.Kind() {
case reflect.Func:
if fnT.NumIn() != 1 {
return nil, errors.New("function can only accept 1 event as argument")
return handler, errors.New("function can only accept 1 event as argument")
}
if fnT.NumOut() > 0 {
return nil, errors.New("function can't accept returns")
return handler, errors.New("function can't accept returns")
}
handler.event = fnT.In(0)
@ -293,19 +226,19 @@ func newHandler(unknown interface{}) (*handler, error) {
handler.chanclose = reflect.ValueOf(make(chan struct{}))
default:
return nil, errors.New("given interface is not a function or channel")
return handler, errors.New("given interface is not a function or channel")
}
var kind = handler.event.Kind()
// Accept either pointer type or interface{} type
if kind != reflect.Ptr && kind != reflect.Interface {
return nil, errors.New("first argument is not pointer")
return handler, errors.New("first argument is not pointer")
}
handler.isIface = kind == reflect.Interface
return &handler, nil
return handler, nil
}
func (h handler) not(event reflect.Type) bool {
@ -316,7 +249,7 @@ func (h handler) not(event reflect.Type) bool {
return h.event != event
}
func (h *handler) call(event reflect.Value) {
func (h handler) call(event reflect.Value) {
if h.chanclose.IsValid() {
reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: h.callback, Send: event},
@ -327,7 +260,7 @@ func (h *handler) call(event reflect.Value) {
}
}
func (h *handler) cleanup() {
func (h handler) cleanup() {
if h.chanclose.IsValid() {
// Closing this channel will force all ongoing selects to return
// immediately.

View file

@ -20,9 +20,7 @@ func newMessage(content string) *gateway.MessageCreateEvent {
func TestCall(t *testing.T) {
var results = make(chan string)
h := &Handler{
handlers: map[uint64]handler{},
}
h := &Handler{}
// Add handler test
rm := h.AddHandler(func(m *gateway.MessageCreateEvent) {

44
utils/handler/slab.go Normal file
View file

@ -0,0 +1,44 @@
package handler
type slabEntry struct {
handler
index int
}
func (entry slabEntry) isInvalid() bool {
return entry.index != -1
}
// slab is an implementation of the internal handler free list.
type slab struct {
Entries []slabEntry
free int
}
func (s *slab) Put(entry handler) int {
if s.free == len(s.Entries) {
index := len(s.Entries)
s.Entries = append(s.Entries, slabEntry{entry, -1})
s.free++
return index
}
next := s.Entries[s.free].index
s.Entries[s.free] = slabEntry{entry, -1}
i := s.free
s.free = next
return i
}
func (s *slab) Get(i int) handler {
return s.Entries[i].handler
}
func (s *slab) Pop(i int) handler {
popped := s.Entries[i].handler
s.Entries[i] = slabEntry{handler{}, s.free}
s.free = i
return popped
}

View file

@ -85,24 +85,26 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
conn, _, err := c.dialer.DialContext(ctx, addr, headers)
if err != nil {
return errors.Wrap(err, "failed to dial WS")
}
events := make(chan Event, WSBuffer)
go startReadLoop(conn, events)
var err error
c.mutex.Lock()
defer c.mutex.Unlock()
c.Conn = conn
c.events = events
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
if err != nil {
return errors.Wrap(err, "failed to dial WS")
}
c.events = make(chan Event, WSBuffer)
go startReadLoop(c.Conn, c.events)
return err
}
func (c *Conn) Listen() <-chan Event {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.events
}