From efde3f4ea6965b5848af31963e4f24d99c688420 Mon Sep 17 00:00:00 2001 From: diamondburned Date: Wed, 3 Nov 2021 15:16:02 -0700 Subject: [PATCH] state, handler: Refactor state storage and sync handlers This commit refactors a lot of packages. It refactors the handler package, removing the Synchronous field and replacing it the AddSyncHandler API, which allows each handler to control whether or not it should be ran synchronously independent of other handlers. This is useful for libraries that need to guarantee the incoming order of events. It also refactors the store interfaces to accept more interfaces. This is to make the API more consistent as well as reducing potential useless copies. The public-facing state API should still be the same, so this change will mostly concern users with their own store implementations. Several miscellaneous functions (such as a few in package gateway) were modified to be more suitable to other packages, but those functions should rarely ever be used, anyway. Several tests are also fixed within this commit, namely fixing state's intents bug. --- gateway/ready.go | 44 ++++--- state/state.go | 49 ++++---- state/state_events.go | 108 ++++++++++------- state/store/defaultstore/channel.go | 147 +++++++++++------------ state/store/defaultstore/guild.go | 15 +-- state/store/defaultstore/member.go | 14 +-- state/store/defaultstore/message.go | 24 ++-- state/store/defaultstore/message_test.go | 36 +++--- state/store/defaultstore/presence.go | 14 +-- state/store/defaultstore/role.go | 14 +-- state/store/defaultstore/voicestate.go | 14 +-- state/store/store.go | 38 +++--- utils/handler/handler.go | 99 +++++++++++---- utils/handler/handler_test.go | 10 +- utils/handler/slab.go | 8 +- voice/session_test.go | 2 +- 16 files changed, 362 insertions(+), 274 deletions(-) diff --git a/gateway/ready.go b/gateway/ready.go index b16df8f..a818fa7 100644 --- a/gateway/ready.go +++ b/gateway/ready.go @@ -233,27 +233,35 @@ type ( } ) -// ConvertSupplementalMember converts a SupplementalMember to a regular Member. -func ConvertSupplementalMember(sm SupplementalMember) discord.Member { - return discord.Member{ - User: discord.User{ID: sm.UserID}, - Nick: sm.Nick, - RoleIDs: sm.RoleIDs, - Joined: sm.Joined, - BoostedSince: sm.BoostedSince, - Deaf: sm.Deaf, - Mute: sm.Mute, - IsPending: sm.IsPending, +// ConvertSupplementalMembers converts a SupplementalMember to a regular Member. +func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member { + members := make([]discord.Member, len(sms)) + for i, sm := range sms { + members[i] = discord.Member{ + User: discord.User{ID: sm.UserID}, + Nick: sm.Nick, + RoleIDs: sm.RoleIDs, + Joined: sm.Joined, + BoostedSince: sm.BoostedSince, + Deaf: sm.Deaf, + Mute: sm.Mute, + IsPending: sm.IsPending, + } } + return members } -// ConvertSupplementalPresence converts a SupplementalPresence to a regular +// ConvertSupplementalPresences converts a SupplementalPresence to a regular // Presence with an empty GuildID. -func ConvertSupplementalPresence(sp SupplementalPresence) discord.Presence { - return discord.Presence{ - User: discord.User{ID: sp.UserID}, - Status: sp.Status, - Activities: sp.Activities, - ClientStatus: sp.ClientStatus, +func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence { + presences := make([]discord.Presence, len(sps)) + for i, sp := range sps { + presences[i] = discord.Presence{ + User: discord.User{ID: sp.UserID}, + Status: sp.Status, + Activities: sp.Activities, + ClientStatus: sp.ClientStatus, + } } + return presences } diff --git a/state/state.go b/state/state.go index d32d281..ce50443 100644 --- a/state/state.go +++ b/state/state.go @@ -27,7 +27,9 @@ var ( // The user should initialize handlers and intents in the opts function. func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc { return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { - return NewFromSession(session.NewCustomShard(m, id), defaultstore.New()), nil + state := NewFromSession(session.NewCustomShard(m, id), defaultstore.New()) + opts(m, state) + return state, nil } } @@ -363,7 +365,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { } if s.tracksChannel(c) { - err = s.Cabinet.ChannelSet(*c, false) + err = s.Cabinet.ChannelSet(c, false) } return @@ -383,8 +385,8 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err } if s.Gateway.HasIntents(gateway.IntentGuilds) { - for _, c := range cs { - if err = s.Cabinet.ChannelSet(c, false); err != nil { + for i := range cs { + if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil { return } } @@ -404,7 +406,7 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel return nil, err } - return c, s.Cabinet.ChannelSet(*c, false) + return c, s.Cabinet.ChannelSet(c, false) } // PrivateChannels gets the direct messages of the user. @@ -420,8 +422,8 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) { return nil, err } - for _, c := range cs { - if err := s.Cabinet.ChannelSet(c, false); err != nil { + for i := range cs { + if err := s.Cabinet.ChannelSet(&cs[i], false); err != nil { return nil, err } } @@ -509,8 +511,8 @@ func (s *State) Guilds() (gs []discord.Guild, err error) { } if s.Gateway.HasIntents(gateway.IntentGuilds) { - for _, g := range gs { - if err = s.Cabinet.GuildSet(g, false); err != nil { + for i := range gs { + if err = s.Cabinet.GuildSet(&gs[i], false); err != nil { return } } @@ -546,8 +548,8 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error } if s.Gateway.HasIntents(gateway.IntentGuildMembers) { - for _, m := range ms { - if err = s.Cabinet.MemberSet(guildID, m, false); err != nil { + for i := range ms { + if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil { return } } @@ -579,7 +581,7 @@ func (s *State) Message( go func() { c, cerr = s.Session.Channel(channelID) if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { - cerr = s.Cabinet.ChannelSet(*c, false) + cerr = s.Cabinet.ChannelSet(c, false) } wg.Done() @@ -688,8 +690,9 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes i = len(apiMessages) } - for _, m := range apiMessages[:i] { - if err := s.Cabinet.MessageSet(m, false); err != nil { + msgs := apiMessages[:i] + for i := range msgs { + if err := s.Cabinet.MessageSet(&msgs[i], false); err != nil { return nil, err } } @@ -745,14 +748,14 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di return } - for _, r := range rs { + for i, r := range rs { if r.ID == roleID { r := r // copy to prevent mem aliasing target = &r } if s.Gateway.HasIntents(gateway.IntentGuilds) { - if err = s.RoleSet(guildID, r, false); err != nil { + if err = s.RoleSet(guildID, &rs[i], false); err != nil { return } } @@ -777,8 +780,8 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { } if s.Gateway.HasIntents(gateway.IntentGuilds) { - for _, r := range rs { - if err := s.RoleSet(guildID, r, false); err != nil { + for i := range rs { + if err := s.RoleSet(guildID, &rs[i], false); err != nil { return rs, err } } @@ -790,7 +793,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { g, err = s.Session.Guild(id) if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { - err = s.Cabinet.GuildSet(*g, false) + err = s.Cabinet.GuildSet(g, false) } return @@ -799,7 +802,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) { m, err = s.Session.Member(gID, uID) if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) { - err = s.Cabinet.MemberSet(gID, *m, false) + err = s.Cabinet.MemberSet(gID, m, false) } return @@ -808,12 +811,14 @@ func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord // tracksMessage reports whether the state would track the passed message and // messages from the same channel. func (s *State) tracksMessage(m *discord.Message) bool { - return (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || + return s.Gateway.Identifier.Intents == nil || + (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || (!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages)) } // tracksChannel reports whether the state would track the passed channel. func (s *State) tracksChannel(c *discord.Channel) bool { - return (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || + return s.Gateway.Identifier.Intents == nil || + (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || !c.GuildID.IsValid() } diff --git a/state/state_events.go b/state/state_events.go index 5208275..036bb36 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -9,7 +9,7 @@ import ( ) func (s *State) hookSession() { - s.Session.AddHandler(func(event interface{}) { + s.Session.AddSyncHandler(func(event interface{}) { // Call the pre-handler before the state handler. if s.PreHandler != nil { s.PreHandler.Call(event) @@ -68,15 +68,15 @@ func (s *State) onEvent(iface interface{}) { } // Handle guild presences - for _, p := range ev.Presences { - if err := s.Cabinet.PresenceSet(p.GuildID, p, false); err != nil { + for i, presence := range ev.Presences { + if err := s.Cabinet.PresenceSet(presence.GuildID, &ev.Presences[i], false); err != nil { s.stateErr(err, "failed to set presence in Ready") } } // Handle private channels - for _, ch := range ev.PrivateChannels { - if err := s.Cabinet.ChannelSet(ch, false); err != nil { + for i := range ev.PrivateChannels { + if err := s.Cabinet.ChannelSet(&ev.PrivateChannels[i], false); err != nil { s.stateErr(err, "failed to set channel in Ready") } } @@ -90,16 +90,17 @@ func (s *State) onEvent(iface interface{}) { // Handle guilds for _, guild := range ev.Guilds { // Handle guild voice states - for _, v := range guild.VoiceStates { + for i := range guild.VoiceStates { + v := &guild.VoiceStates[i] if err := s.Cabinet.VoiceStateSet(guild.ID, v, false); err != nil { s.stateErr(err, "failed to set guild voice state in Ready Supplemental") } } } - for _, friend := range ev.MergedPresences.Friends { - sPresence := gateway.ConvertSupplementalPresence(friend) - if err := s.Cabinet.PresenceSet(0, sPresence, false); err != nil { + friendPresences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Friends) + for i := range friendPresences { + if err := s.Cabinet.PresenceSet(0, &friendPresences[i], false); err != nil { s.stateErr(err, "failed to set friend presence in Ready Supplemental") } } @@ -110,16 +111,16 @@ func (s *State) onEvent(iface interface{}) { for i := 0; i < len(ready.Guilds) && i < len(ev.MergedMembers); i++ { guild := ready.Guilds[i] - for _, member := range ev.MergedMembers[i] { - sMember := gateway.ConvertSupplementalMember(member) - if err := s.Cabinet.MemberSet(guild.ID, sMember, false); err != nil { + members := gateway.ConvertSupplementalMembers(ev.MergedMembers[i]) + for i := range members { + if err := s.Cabinet.MemberSet(guild.ID, &members[i], false); err != nil { s.stateErr(err, "failed to set friend presence in Ready Supplemental") } } - for _, member := range ev.MergedPresences.Guilds[i] { - sPresence := gateway.ConvertSupplementalPresence(member) - if err := s.Cabinet.PresenceSet(guild.ID, sPresence, false); err != nil { + presences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Guilds[i]) + for i := range presences { + if err := s.Cabinet.PresenceSet(guild.ID, &presences[i], false); err != nil { s.stateErr(err, "failed to set member presence in Ready Supplemental") } } @@ -129,7 +130,7 @@ func (s *State) onEvent(iface interface{}) { s.batchLog(storeGuildCreate(s.Cabinet, ev)) case *gateway.GuildUpdateEvent: - if err := s.Cabinet.GuildSet(ev.Guild, true); err != nil { + if err := s.Cabinet.GuildSet(&ev.Guild, true); err != nil { s.stateErr(err, "failed to update guild in state") } @@ -139,7 +140,7 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.GuildMemberAddEvent: - if err := s.Cabinet.MemberSet(ev.GuildID, ev.Member, false); err != nil { + if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Member, false); err != nil { s.stateErr(err, "failed to add a member in state") } @@ -153,7 +154,7 @@ func (s *State) onEvent(iface interface{}) { // Update available fields from ev into m ev.Update(m) - if err := s.Cabinet.MemberSet(ev.GuildID, *m, true); err != nil { + if err := s.Cabinet.MemberSet(ev.GuildID, m, true); err != nil { s.stateErr(err, "failed to update a member in state") } @@ -163,25 +164,25 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.GuildMembersChunkEvent: - for _, m := range ev.Members { - if err := s.Cabinet.MemberSet(ev.GuildID, m, false); err != nil { + for i := range ev.Members { + if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Members[i], false); err != nil { s.stateErr(err, "failed to add a member from chunk in state") } } - for _, p := range ev.Presences { - if err := s.Cabinet.PresenceSet(ev.GuildID, p, false); err != nil { + for i := range ev.Presences { + if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presences[i], false); err != nil { s.stateErr(err, "failed to add a presence from chunk in state") } } case *gateway.GuildRoleCreateEvent: - if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, false); err != nil { + if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, false); err != nil { s.stateErr(err, "failed to add a role in state") } case *gateway.GuildRoleUpdateEvent: - if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, true); err != nil { + if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, true); err != nil { s.stateErr(err, "failed to update a role in state") } @@ -196,17 +197,17 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.ChannelCreateEvent: - if err := s.Cabinet.ChannelSet(ev.Channel, false); err != nil { + if err := s.Cabinet.ChannelSet(&ev.Channel, false); err != nil { s.stateErr(err, "failed to create a channel in state") } case *gateway.ChannelUpdateEvent: - if err := s.Cabinet.ChannelSet(ev.Channel, true); err != nil { + if err := s.Cabinet.ChannelSet(&ev.Channel, true); err != nil { s.stateErr(err, "failed to update a channel in state") } case *gateway.ChannelDeleteEvent: - if err := s.Cabinet.ChannelRemove(ev.Channel); err != nil { + if err := s.Cabinet.ChannelRemove(&ev.Channel); err != nil { s.stateErr(err, "failed to remove a channel in state") } @@ -214,12 +215,12 @@ func (s *State) onEvent(iface interface{}) { // not tracked. case *gateway.MessageCreateEvent: - if err := s.Cabinet.MessageSet(ev.Message, false); err != nil { + if err := s.Cabinet.MessageSet(&ev.Message, false); err != nil { s.stateErr(err, "failed to add a message in state") } case *gateway.MessageUpdateEvent: - if err := s.Cabinet.MessageSet(ev.Message, true); err != nil { + if err := s.Cabinet.MessageSet(&ev.Message, true); err != nil { s.stateErr(err, "failed to update a message in state") } @@ -238,12 +239,17 @@ func (s *State) onEvent(iface interface{}) { case *gateway.MessageReactionAddEvent: s.editMessage(ev.ChannelID, ev.MessageID, func(m *discord.Message) bool { if i := findReaction(m.Reactions, ev.Emoji); i > -1 { + // Copy the reactions slice so it's not racy. + m.Reactions = append([]discord.Reaction(nil), m.Reactions...) m.Reactions[i].Count++ } else { var me bool if u, _ := s.Cabinet.Me(); u != nil { me = ev.UserID == u.ID } + old := m.Reactions + m.Reactions = make([]discord.Reaction, 0, len(old)+1) + m.Reactions = append(m.Reactions, old...) m.Reactions = append(m.Reactions, discord.Reaction{ Count: 1, Me: me, @@ -261,18 +267,21 @@ func (s *State) onEvent(iface interface{}) { } r := &m.Reactions[i] - r.Count-- + newCount := r.Count - 1 switch { - case r.Count < 1: // If the count is 0: - // Remove the reaction. - m.Reactions = append(m.Reactions[:i], m.Reactions[i+1:]...) + case newCount < 1: // If the count is 0: + old := m.Reactions + m.Reactions = make([]discord.Reaction, len(m.Reactions)-1) + copy(m.Reactions[0:], old[:i]) + copy(m.Reactions[i:], old[i+1:]) case r.Me: // If reaction removal is the user's u, err := s.Cabinet.Me() if err == nil && ev.UserID == u.ID { r.Me = false } + r.Count-- } return true @@ -295,13 +304,13 @@ func (s *State) onEvent(iface interface{}) { }) case *gateway.PresenceUpdateEvent: - if err := s.Cabinet.PresenceSet(ev.GuildID, ev.Presence, true); err != nil { + if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presence, true); err != nil { s.stateErr(err, "failed to update presence in state") } case *gateway.PresencesReplaceEvent: for _, p := range *ev { - if err := s.Cabinet.PresenceSet(p.GuildID, p.Presence, true); err != nil { + if err := s.Cabinet.PresenceSet(p.GuildID, &p.Presence, true); err != nil { s.stateErr(err, "failed to update presence in state") } } @@ -332,7 +341,7 @@ func (s *State) onEvent(iface interface{}) { s.stateErr(err, "failed to remove voice state from state") } } else { - if err := s.Cabinet.VoiceStateSet(vs.GuildID, *vs, true); err != nil { + if err := s.Cabinet.VoiceStateSet(vs.GuildID, vs, true); err != nil { s.stateErr(err, "failed to update voice state in state") } } @@ -342,6 +351,7 @@ func (s *State) onEvent(iface interface{}) { func (s *State) stateErr(err error, wrap string) { s.StateLog(errors.Wrap(err, wrap)) } + func (s *State) batchLog(errors []error) { for _, err := range errors { s.StateLog(err) @@ -355,10 +365,16 @@ func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func if err != nil { return } + + // Copy the messages. + cpy := *m + m = &cpy + if !fn(m) { return } - if err := s.Cabinet.MessageSet(*m, true); err != nil { + + if err := s.Cabinet.MessageSet(m, true); err != nil { s.stateErr(err, "failed to save message in reaction add") } } @@ -379,7 +395,7 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err stack, errs := newErrorStack() - if err := cab.GuildSet(guild.Guild, false); err != nil { + if err := cab.GuildSet(&guild.Guild, false); err != nil { errs(err, "failed to set guild in Ready") } @@ -391,8 +407,8 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err } // Handle guild member - for _, m := range guild.Members { - if err := cab.MemberSet(guild.ID, m, false); err != nil { + for i := range guild.Members { + if err := cab.MemberSet(guild.ID, &guild.Members[i], false); err != nil { errs(err, "failed to set guild member in Ready") } } @@ -400,36 +416,40 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err // Handle guild channels for _, ch := range guild.Channels { // I HATE Discord. + ch := ch ch.GuildID = guild.ID - if err := cab.ChannelSet(ch, false); err != nil { + if err := cab.ChannelSet(&ch, false); err != nil { errs(err, "failed to set guild channel in Ready") } } // Handle threads. for _, ch := range guild.Threads { + ch := ch ch.GuildID = guild.ID - if err := cab.ChannelSet(ch, false); err != nil { + if err := cab.ChannelSet(&ch, false); err != nil { errs(err, "failed to set guild thread in Ready") } } // Handle guild presences for _, p := range guild.Presences { + p := p p.GuildID = guild.ID - if err := cab.PresenceSet(guild.ID, p, false); err != nil { + if err := cab.PresenceSet(guild.ID, &p, false); err != nil { errs(err, "failed to set guild presence in Ready") } } // Handle guild voice states for _, v := range guild.VoiceStates { + v := v v.GuildID = guild.ID - if err := cab.VoiceStateSet(guild.ID, v, false); err != nil { + if err := cab.VoiceStateSet(guild.ID, &v, false); err != nil { errs(err, "failed to set guild voice state in Ready") } } diff --git a/state/store/defaultstore/channel.go b/state/store/defaultstore/channel.go index 80a1eed..bc33ed9 100644 --- a/state/store/defaultstore/channel.go +++ b/state/store/defaultstore/channel.go @@ -2,6 +2,7 @@ package defaultstore import ( "errors" + "fmt" "sync" "github.com/diamondburned/arikawa/v3/discord" @@ -13,20 +14,18 @@ type Channel struct { // Channel references must be protected under the same mutex. - privates map[discord.UserID]*discord.Channel - privateChs []*discord.Channel - - channels map[discord.ChannelID]*discord.Channel - guildChs map[discord.GuildID][]*discord.Channel + channels map[discord.ChannelID]discord.Channel + privates map[discord.UserID]discord.ChannelID + guildChs map[discord.GuildID][]discord.ChannelID } var _ store.ChannelStore = (*Channel)(nil) func NewChannel() *Channel { return &Channel{ - privates: map[discord.UserID]*discord.Channel{}, - channels: map[discord.ChannelID]*discord.Channel{}, - guildChs: map[discord.GuildID][]*discord.Channel{}, + channels: map[discord.ChannelID]discord.Channel{}, + privates: map[discord.UserID]discord.ChannelID{}, + guildChs: map[discord.GuildID][]discord.ChannelID{}, } } @@ -34,9 +33,9 @@ func (s *Channel) Reset() error { s.mut.Lock() defer s.mut.Unlock() - s.privates = map[discord.UserID]*discord.Channel{} - s.channels = map[discord.ChannelID]*discord.Channel{} - s.guildChs = map[discord.GuildID][]*discord.Channel{} + s.channels = map[discord.ChannelID]discord.Channel{} + s.privates = map[discord.UserID]discord.ChannelID{} + s.guildChs = map[discord.GuildID][]discord.ChannelID{} return nil } @@ -50,20 +49,19 @@ func (s *Channel) Channel(id discord.ChannelID) (*discord.Channel, error) { return nil, store.ErrNotFound } - cpy := *ch - return &cpy, nil + return &ch, nil } func (s *Channel) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { s.mut.RLock() defer s.mut.RUnlock() - ch, ok := s.privates[recipient] + id, ok := s.privates[recipient] if !ok { return nil, store.ErrNotFound } - cpy := *ch + cpy := s.channels[id] return &cpy, nil } @@ -72,16 +70,20 @@ func (s *Channel) Channels(guildID discord.GuildID) ([]discord.Channel, error) { s.mut.RLock() defer s.mut.RUnlock() - chRefs, ok := s.guildChs[guildID] + chIDs, ok := s.guildChs[guildID] if !ok { return nil, store.ErrNotFound } // Reading chRefs is also covered by the global mutex. - var channels = make([]discord.Channel, len(chRefs)) - for i, chRef := range chRefs { - channels[i] = *chRef + var channels = make([]discord.Channel, 0, len(chIDs)) + for _, chID := range chIDs { + ch, ok := s.channels[chID] + if !ok { + continue + } + channels = append(channels, ch) } return channels, nil @@ -92,42 +94,47 @@ func (s *Channel) PrivateChannels() ([]discord.Channel, error) { s.mut.RLock() defer s.mut.RUnlock() - if len(s.privateChs) == 0 { + groupDMs := s.guildChs[0] + + if len(s.privates) == 0 && len(groupDMs) == 0 { return nil, store.ErrNotFound } - var channels = make([]discord.Channel, len(s.privateChs)) - for i, ch := range s.privateChs { - channels[i] = *ch + var channels = make([]discord.Channel, 0, len(s.privates)+len(groupDMs)) + for _, chID := range s.privates { + if ch, ok := s.channels[chID]; ok { + channels = append(channels, ch) + } + } + for _, chID := range groupDMs { + if ch, ok := s.channels[chID]; ok { + channels = append(channels, ch) + } } return channels, nil } // ChannelSet sets the Direct Message or Guild channel into the state. -func (s *Channel) ChannelSet(channel discord.Channel, update bool) error { +func (s *Channel) ChannelSet(channel *discord.Channel, update bool) error { + cpy := *channel + s.mut.Lock() defer s.mut.Unlock() // Update the reference if we can. - if ch, ok := s.channels[channel.ID]; ok { - if update { - *ch = channel - } - return nil - } + s.channels[channel.ID] = cpy switch channel.Type { case discord.DirectMessage: // Safety bound check. if len(channel.DMRecipients) != 1 { - return errors.New("DirectMessage channel does not have 1 recipient") + return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID) } - s.privates[channel.DMRecipients[0].ID] = &channel - fallthrough + s.privates[channel.DMRecipients[0].ID] = channel.ID + return nil case discord.GroupDM: - s.privateChs = append(s.privateChs, &channel) - s.channels[channel.ID] = &channel + s.guildChs[0] = addChannelID(s.guildChs[0], channel.ID) return nil } @@ -137,16 +144,11 @@ func (s *Channel) ChannelSet(channel discord.Channel, update bool) error { return errors.New("invalid guildID for guild channel") } - s.channels[channel.ID] = &channel - - channels, _ := s.guildChs[channel.GuildID] - channels = append(channels, &channel) - s.guildChs[channel.GuildID] = channels - + s.guildChs[channel.GuildID] = addChannelID(s.guildChs[channel.GuildID], channel.ID) return nil } -func (s *Channel) ChannelRemove(channel discord.Channel) error { +func (s *Channel) ChannelRemove(channel *discord.Channel) error { s.mut.Lock() defer s.mut.Unlock() @@ -158,49 +160,42 @@ func (s *Channel) ChannelRemove(channel discord.Channel) error { case discord.DirectMessage: // Safety bound check. if len(channel.DMRecipients) != 1 { - return errors.New("DirectMessage channel does not have 1 recipient") + return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID) } delete(s.privates, channel.DMRecipients[0].ID) - fallthrough + return nil case discord.GroupDM: - for i, priv := range s.privateChs { - if priv.ID == channel.ID { - s.privateChs = removeChannel(s.privateChs, i) - break - } - } + s.guildChs[0] = removeChannelID(s.guildChs[0], channel.ID) return nil } - // Wipe the channel off the guilds index, if available. - channels, ok := s.guildChs[channel.GuildID] - if !ok { - return nil - } - - for i, ch := range channels { - if ch.ID == channel.ID { - s.guildChs[channel.GuildID] = removeChannel(channels, i) - break - } - } - + s.guildChs[channel.GuildID] = removeChannelID(s.guildChs[channel.GuildID], channel.ID) return nil } -// removeChannel removes the given channel with the index from the given +func addChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID { + for _, ch := range channels { + if ch == id { + return channels + } + } + if channels == nil { + channels = make([]discord.ChannelID, 0, 5) + } + return append(channels, id) +} + +// removeChannelID removes the given channel with the index from the given // channels slice in an unordered fashion. -func removeChannel(channels []*discord.Channel, i int) []*discord.Channel { - // Fast unordered delete. Not sure if there's a benefit in doing - // this over using a map, but I guess the memory usage is less and - // there's no copying. - - // Move the last channel to the current channel, set the last - // channel there to a nil value to unreference its children, then - // slice the last channel off. - channels[i] = channels[len(channels)-1] - channels[len(channels)-1] = nil - channels = channels[:len(channels)-1] - +func removeChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID { + for i, ch := range channels { + if ch == id { + // Move the last channel to the current channel, then slice the last + // channel off. + channels[i] = channels[len(channels)-1] + channels = channels[:len(channels)-1] + break + } + } return channels } diff --git a/state/store/defaultstore/guild.go b/state/store/defaultstore/guild.go index 2bae4e6..43868e8 100644 --- a/state/store/defaultstore/guild.go +++ b/state/store/defaultstore/guild.go @@ -33,13 +33,12 @@ func (s *Guild) Guild(id discord.GuildID) (*discord.Guild, error) { s.mut.RLock() defer s.mut.RUnlock() - ch, ok := s.guilds[id] - if !ok { - return nil, store.ErrNotFound + g, ok := s.guilds[id] + if ok { + return &g, nil } - // implicit copy - return &ch, nil + return nil, store.ErrNotFound } func (s *Guild) Guilds() ([]discord.Guild, error) { @@ -58,10 +57,12 @@ func (s *Guild) Guilds() ([]discord.Guild, error) { return gs, nil } -func (s *Guild) GuildSet(guild discord.Guild, update bool) error { +func (s *Guild) GuildSet(guild *discord.Guild, update bool) error { + cpy := *guild + s.mut.Lock() if _, ok := s.guilds[guild.ID]; !ok || update { - s.guilds[guild.ID] = guild + s.guilds[guild.ID] = cpy } s.mut.Unlock() diff --git a/state/store/defaultstore/member.go b/state/store/defaultstore/member.go index 7706e1e..7aa659f 100644 --- a/state/store/defaultstore/member.go +++ b/state/store/defaultstore/member.go @@ -13,7 +13,7 @@ type Member struct { } type guildMembers struct { - mut sync.Mutex + mut sync.RWMutex members map[discord.UserID]discord.Member } @@ -41,8 +41,8 @@ func (s *Member) Member(guildID discord.GuildID, userID discord.UserID) (*discor gm := iv.(*guildMembers) - gm.mut.Lock() - defer gm.mut.Unlock() + gm.mut.RLock() + defer gm.mut.RUnlock() m, ok := gm.members[userID] if ok { @@ -60,8 +60,8 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) { gm := iv.(*guildMembers) - gm.mut.Lock() - defer gm.mut.Unlock() + gm.mut.RLock() + defer gm.mut.RUnlock() var members = make([]discord.Member, 0, len(gm.members)) for _, m := range gm.members { @@ -71,13 +71,13 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) { return members, nil } -func (s *Member) MemberSet(guildID discord.GuildID, m discord.Member, update bool) error { +func (s *Member) MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) gm := iv.(*guildMembers) gm.mut.Lock() if _, ok := gm.members[m.User.ID]; !ok || update { - gm.members[m.User.ID] = m + gm.members[m.User.ID] = *m } gm.mut.Unlock() diff --git a/state/store/defaultstore/message.go b/state/store/defaultstore/message.go index d7d7ca3..c658722 100644 --- a/state/store/defaultstore/message.go +++ b/state/store/defaultstore/message.go @@ -16,7 +16,7 @@ type Message struct { var _ store.MessageStore = (*Message)(nil) type messages struct { - mut sync.Mutex + mut sync.RWMutex messages []discord.Message } @@ -43,8 +43,8 @@ func (s *Message) Message(chID discord.ChannelID, mID discord.MessageID) (*disco msgs := iv.(*messages) - msgs.mut.Lock() - defer msgs.mut.Unlock() + msgs.mut.RLock() + defer msgs.mut.RUnlock() for _, m := range msgs.messages { if m.ID == mID { @@ -63,8 +63,8 @@ func (s *Message) Messages(channelID discord.ChannelID) ([]discord.Message, erro msgs := iv.(*messages) - msgs.mut.Lock() - defer msgs.mut.Unlock() + msgs.mut.RLock() + defer msgs.mut.RUnlock() return append([]discord.Message(nil), msgs.messages...), nil } @@ -73,7 +73,7 @@ func (s *Message) MaxMessages() int { return s.maxMsgs } -func (s *Message) MessageSet(message discord.Message, update bool) error { +func (s *Message) MessageSet(message *discord.Message, update bool) error { if s.maxMsgs <= 0 { return nil } @@ -102,19 +102,19 @@ func (s *Message) MessageSet(message discord.Message, update bool) error { } if len(msgs.messages) == 0 { - msgs.messages = []discord.Message{message} + msgs.messages = []discord.Message{*message} } if pos := messageInsertPosition(message, msgs.messages); pos < 0 { // Messages are full, drop the oldest messages to make room. if len(msgs.messages) == s.maxMsgs { copy(msgs.messages[1:], msgs.messages) - msgs.messages[0] = message + msgs.messages[0] = *message } else { - msgs.messages = append([]discord.Message{message}, msgs.messages...) + msgs.messages = append([]discord.Message{*message}, msgs.messages...) } } else if pos > 0 && len(msgs.messages) < s.maxMsgs { - msgs.messages = append(msgs.messages, message) + msgs.messages = append(msgs.messages, *message) } // We already have this message or we can't append any more messages. @@ -131,7 +131,7 @@ func (s *Message) MessageSet(message discord.Message, update bool) error { // messageInsertPosition is biased as it will recommend adding the message even // if timestamps just match, even though the true order cannot be determined in // that case. -func messageInsertPosition(target discord.Message, messages []discord.Message) int8 { +func messageInsertPosition(target *discord.Message, messages []discord.Message) int8 { var ( targetTime = target.ID.Time() firstTime = messages[0].ID.Time() @@ -183,7 +183,7 @@ func messageInsertPosition(target discord.Message, messages []discord.Message) i } // DiffMessage fills non-empty fields from src to dst. -func DiffMessage(src discord.Message, dst *discord.Message) { +func DiffMessage(src, dst *discord.Message) { // Thanks, Discord. if src.Content != "" { dst.Content = src.Content diff --git a/state/store/defaultstore/message_test.go b/state/store/defaultstore/message_test.go index 80757f5..bb4e781 100644 --- a/state/store/defaultstore/message_test.go +++ b/state/store/defaultstore/message_test.go @@ -10,27 +10,27 @@ func populate12Store() *Message { store := NewMessage(10) // Insert a regular list of messages. - store.MessageSet(discord.Message{ID: 1 << 29, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 28, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 27, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 26, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 25, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 24, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 29, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 28, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 27, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 26, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 25, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 24, ChannelID: 1}, false) // Try to insert newer messages after inserting new messages. - store.MessageSet(discord.Message{ID: 1 << 30, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 31, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 32, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 33, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 34, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 30, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 31, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 32, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 33, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 34, ChannelID: 1}, false) // TThese messages should be discarded, due to age. - store.MessageSet(discord.Message{ID: 1 << 23, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 22, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 23, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 22, ChannelID: 1}, false) // These should be prepended. - store.MessageSet(discord.Message{ID: 1 << 35, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 36, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 35, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 36, ChannelID: 1}, false) return store } @@ -57,9 +57,9 @@ func TestMessageSet(t *testing.T) { func TestMessagesUpdate(t *testing.T) { store := populate12Store() - store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true) - store.MessageSet(discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true) - store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true) + store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true) + store.MessageSet(&discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true) + store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true) expect := map[discord.MessageID]string{ 5: "edited 3", diff --git a/state/store/defaultstore/presence.go b/state/store/defaultstore/presence.go index e3fb16b..dcc5c39 100644 --- a/state/store/defaultstore/presence.go +++ b/state/store/defaultstore/presence.go @@ -13,7 +13,7 @@ type Presence struct { } type presences struct { - mut sync.Mutex + mut sync.RWMutex presences map[discord.UserID]discord.Presence } @@ -41,8 +41,8 @@ func (s *Presence) Presence(gID discord.GuildID, uID discord.UserID) (*discord.P ps := iv.(*presences) - ps.mut.Lock() - defer ps.mut.Unlock() + ps.mut.RLock() + defer ps.mut.RUnlock() p, ok := ps.presences[uID] if ok { @@ -60,8 +60,8 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error ps := iv.(*presences) - ps.mut.Lock() - defer ps.mut.Unlock() + ps.mut.RLock() + defer ps.mut.RUnlock() var presences = make([]discord.Presence, 0, len(ps.presences)) for _, p := range ps.presences { @@ -71,7 +71,7 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error return presences, nil } -func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error { +func (s *Presence) PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) ps := iv.(*presences) @@ -85,7 +85,7 @@ func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, upda } if _, ok := ps.presences[p.User.ID]; !ok || update { - ps.presences[p.User.ID] = p + ps.presences[p.User.ID] = *p } return nil diff --git a/state/store/defaultstore/role.go b/state/store/defaultstore/role.go index 903290d..522b5fd 100644 --- a/state/store/defaultstore/role.go +++ b/state/store/defaultstore/role.go @@ -15,7 +15,7 @@ type Role struct { var _ store.RoleStore = (*Role)(nil) type roles struct { - mut sync.Mutex + mut sync.RWMutex roles map[discord.RoleID]discord.Role } @@ -41,8 +41,8 @@ func (s *Role) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Ro rs := iv.(*roles) - rs.mut.Lock() - defer rs.mut.Unlock() + rs.mut.RLock() + defer rs.mut.RUnlock() r, ok := rs.roles[roleID] if ok { @@ -60,8 +60,8 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) { rs := iv.(*roles) - rs.mut.Lock() - defer rs.mut.Unlock() + rs.mut.RLock() + defer rs.mut.RUnlock() var roles = make([]discord.Role, 0, len(rs.roles)) for _, role := range rs.roles { @@ -71,14 +71,14 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) { return roles, nil } -func (s *Role) RoleSet(guildID discord.GuildID, role discord.Role, update bool) error { +func (s *Role) RoleSet(guildID discord.GuildID, role *discord.Role, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) rs := iv.(*roles) rs.mut.Lock() if _, ok := rs.roles[role.ID]; !ok || update { - rs.roles[role.ID] = role + rs.roles[role.ID] = *role } rs.mut.Unlock() diff --git a/state/store/defaultstore/voicestate.go b/state/store/defaultstore/voicestate.go index e7531f5..09ad2ee 100644 --- a/state/store/defaultstore/voicestate.go +++ b/state/store/defaultstore/voicestate.go @@ -15,7 +15,7 @@ type VoiceState struct { var _ store.VoiceStateStore = (*VoiceState)(nil) type voiceStates struct { - mut sync.Mutex + mut sync.RWMutex voiceStates map[discord.UserID]discord.VoiceState } @@ -43,8 +43,8 @@ func (s *VoiceState) VoiceState( vs := iv.(*voiceStates) - vs.mut.Lock() - defer vs.mut.Unlock() + vs.mut.RLock() + defer vs.mut.RUnlock() v, ok := vs.voiceStates[userID] if ok { @@ -62,8 +62,8 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, vs := iv.(*voiceStates) - vs.mut.Lock() - defer vs.mut.Unlock() + vs.mut.RLock() + defer vs.mut.RUnlock() var states = make([]discord.VoiceState, 0, len(vs.voiceStates)) for _, state := range vs.voiceStates { @@ -74,7 +74,7 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, } func (s *VoiceState) VoiceStateSet( - guildID discord.GuildID, voiceState discord.VoiceState, update bool) error { + guildID discord.GuildID, voiceState *discord.VoiceState, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) @@ -82,7 +82,7 @@ func (s *VoiceState) VoiceStateSet( vs.mut.Lock() if _, ok := vs.voiceStates[voiceState.UserID]; !ok || update { - vs.voiceStates[voiceState.UserID] = voiceState + vs.voiceStates[voiceState.UserID] = *voiceState } vs.mut.Unlock() diff --git a/state/store/store.go b/state/store/store.go index d2c9550..09825c9 100644 --- a/state/store/store.go +++ b/state/store/store.go @@ -143,6 +143,12 @@ type Resetter interface { Reset() error } +type CoreStorer interface { + Resetter + Lock() + Unlock() +} + var _ Resetter = (*noop)(nil) func (noop) Reset() error { return nil } @@ -176,8 +182,8 @@ type ChannelStore interface { // Both ChannelSet and ChannelRemove should switch on Type to know if it's a // private channel or not. - ChannelSet(c discord.Channel, update bool) error - ChannelRemove(discord.Channel) error + ChannelSet(c *discord.Channel, update bool) error + ChannelRemove(*discord.Channel) error } var _ ChannelStore = (*noop)(nil) @@ -194,10 +200,10 @@ func (noop) Channels(discord.GuildID) ([]discord.Channel, error) { func (noop) PrivateChannels() ([]discord.Channel, error) { return nil, ErrNotFound } -func (noop) ChannelSet(discord.Channel, bool) error { +func (noop) ChannelSet(*discord.Channel, bool) error { return nil } -func (noop) ChannelRemove(discord.Channel) error { +func (noop) ChannelRemove(*discord.Channel) error { return nil } @@ -232,7 +238,7 @@ type GuildStore interface { Guild(discord.GuildID) (*discord.Guild, error) Guilds() ([]discord.Guild, error) - GuildSet(g discord.Guild, update bool) error + GuildSet(g *discord.Guild, update bool) error GuildRemove(id discord.GuildID) error } @@ -240,7 +246,7 @@ var _ GuildStore = (*noop)(nil) func (noop) Guild(discord.GuildID) (*discord.Guild, error) { return nil, ErrNotFound } func (noop) Guilds() ([]discord.Guild, error) { return nil, ErrNotFound } -func (noop) GuildSet(discord.Guild, bool) error { return nil } +func (noop) GuildSet(*discord.Guild, bool) error { return nil } func (noop) GuildRemove(discord.GuildID) error { return nil } // MemberStore is the store interface for all members. @@ -250,7 +256,7 @@ type MemberStore interface { Member(discord.GuildID, discord.UserID) (*discord.Member, error) Members(discord.GuildID) ([]discord.Member, error) - MemberSet(guildID discord.GuildID, m discord.Member, update bool) error + MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error MemberRemove(discord.GuildID, discord.UserID) error } @@ -262,7 +268,7 @@ func (noop) Member(discord.GuildID, discord.UserID) (*discord.Member, error) { func (noop) Members(discord.GuildID) ([]discord.Member, error) { return nil, ErrNotFound } -func (noop) MemberSet(discord.GuildID, discord.Member, bool) error { +func (noop) MemberSet(discord.GuildID, *discord.Member, bool) error { return nil } func (noop) MemberRemove(discord.GuildID, discord.UserID) error { @@ -289,7 +295,7 @@ type MessageStore interface { // If update is set to true, MessageSet will check if a message with the // id of the passed message is stored, and update it if so. Otherwise, if // there is no such message, it will be discarded. - MessageSet(m discord.Message, update bool) error + MessageSet(m *discord.Message, update bool) error MessageRemove(discord.ChannelID, discord.MessageID) error } @@ -304,7 +310,7 @@ func (noop) Message(discord.ChannelID, discord.MessageID) (*discord.Message, err func (noop) Messages(discord.ChannelID) ([]discord.Message, error) { return nil, ErrNotFound } -func (noop) MessageSet(discord.Message, bool) error { +func (noop) MessageSet(*discord.Message, bool) error { return nil } func (noop) MessageRemove(discord.ChannelID, discord.MessageID) error { @@ -319,7 +325,7 @@ type PresenceStore interface { Presence(discord.GuildID, discord.UserID) (*discord.Presence, error) Presences(discord.GuildID) ([]discord.Presence, error) - PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error + PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error PresenceRemove(discord.GuildID, discord.UserID) error } @@ -331,7 +337,7 @@ func (noop) Presence(discord.GuildID, discord.UserID) (*discord.Presence, error) func (noop) Presences(discord.GuildID) ([]discord.Presence, error) { return nil, ErrNotFound } -func (noop) PresenceSet(discord.GuildID, discord.Presence, bool) error { +func (noop) PresenceSet(discord.GuildID, *discord.Presence, bool) error { return nil } func (noop) PresenceRemove(discord.GuildID, discord.UserID) error { @@ -345,7 +351,7 @@ type RoleStore interface { Role(discord.GuildID, discord.RoleID) (*discord.Role, error) Roles(discord.GuildID) ([]discord.Role, error) - RoleSet(guildID discord.GuildID, r discord.Role, update bool) error + RoleSet(guildID discord.GuildID, r *discord.Role, update bool) error RoleRemove(discord.GuildID, discord.RoleID) error } @@ -353,7 +359,7 @@ var _ RoleStore = (*noop)(nil) func (noop) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) { return nil, ErrNotFound } func (noop) Roles(discord.GuildID) ([]discord.Role, error) { return nil, ErrNotFound } -func (noop) RoleSet(discord.GuildID, discord.Role, bool) error { return nil } +func (noop) RoleSet(discord.GuildID, *discord.Role, bool) error { return nil } func (noop) RoleRemove(discord.GuildID, discord.RoleID) error { return nil } // VoiceStateStore is the store interface for all voice states. @@ -363,7 +369,7 @@ type VoiceStateStore interface { VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) - VoiceStateSet(guildID discord.GuildID, s discord.VoiceState, update bool) error + VoiceStateSet(guildID discord.GuildID, s *discord.VoiceState, update bool) error VoiceStateRemove(discord.GuildID, discord.UserID) error } @@ -375,7 +381,7 @@ func (noop) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, er func (noop) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) { return nil, ErrNotFound } -func (noop) VoiceStateSet(discord.GuildID, discord.VoiceState, bool) error { +func (noop) VoiceStateSet(discord.GuildID, *discord.VoiceState, bool) error { return nil } func (noop) VoiceStateRemove(discord.GuildID, discord.UserID) error { diff --git a/utils/handler/handler.go b/utils/handler/handler.go index b2e8619..c56b17a 100644 --- a/utils/handler/handler.go +++ b/utils/handler/handler.go @@ -28,12 +28,8 @@ import ( // Handler is a container for command handlers. A zero-value instance is a valid // instance. type Handler struct { - // Synchronous controls whether to spawn each event handler in its own - // goroutine. Default false (meaning goroutines are spawned). - Synchronous bool - - mutex sync.RWMutex - slab slab + mutex sync.RWMutex + events map[reflect.Type]slab // nil type for interfaces } func New() *Handler { @@ -43,22 +39,24 @@ func New() *Handler { // Call calls all handlers with the given event. This is an internal method; use // with care. func (h *Handler) Call(ev interface{}) { - var evV = reflect.ValueOf(ev) - var evT = evV.Type() + v := reflect.ValueOf(ev) + t := v.Type() h.mutex.RLock() defer h.mutex.RUnlock() - for _, entry := range h.slab.Entries { - if entry.isInvalid() || entry.not(evT) { + for _, entry := range h.events[t].Entries { + if entry.isInvalid() { continue } + entry.Call(v) + } - if h.Synchronous { - entry.call(evV) - } else { - go entry.call(evV) + for _, entry := range h.events[nil].Entries { + if entry.isInvalid() || entry.not(t) { + continue } + entry.Call(v) } } @@ -145,7 +143,19 @@ func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, ca // h.AddHandler(ch) // func (h *Handler) AddHandler(handler interface{}) (rm func()) { - rm, err := h.addHandler(handler) + rm, err := h.addHandler(handler, false) + if err != nil { + panic(err) + } + return rm +} + +// AddSyncHandler is a synchronous variant of AddHandler. Handlers added using +// this method will block the Call method, which is helpful if the user needs to +// rely on the order of events arriving. Handlers added using this method should +// not block for very long, as it may clog up other handlers. +func (h *Handler) AddSyncHandler(handler interface{}) (rm func()) { + rm, err := h.addHandler(handler, true) if err != nil { panic(err) } @@ -167,23 +177,56 @@ func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) { } }() - return h.addHandler(handler) + return h.addHandler(handler, false) } -func (h *Handler) addHandler(fn interface{}) (rm func(), err error) { +// AddSyncHandlerCheck is the safe-guarded version of AddSyncHandler. It is +// similar to AddHandlerCheck. +func (h *Handler) AddSyncHandlerCheck(handler interface{}) (rm func(), err error) { + // Reflect would actually panic if anything goes wrong, so this is just in + // case. + defer func() { + if rec := recover(); rec != nil { + if recErr, ok := rec.(error); ok { + err = recErr + } else { + err = fmt.Errorf("%v", rec) + } + } + }() + + return h.addHandler(handler, true) +} + +func (h *Handler) addHandler(fn interface{}, sync bool) (rm func(), err error) { // Reflect the handler - r, err := newHandler(fn) + r, err := newHandler(fn, sync) if err != nil { return nil, errors.Wrap(err, "handler reflect failed") } + var id int + var t reflect.Type + if !r.isIface { + t = r.event + } + h.mutex.Lock() - id := h.slab.Put(r) + + if h.events == nil { + h.events = make(map[reflect.Type]slab, 10) + } + + slab := h.events[t] + id = slab.Put(r) + h.events[t] = slab + h.mutex.Unlock() return func() { h.mutex.Lock() - popped := h.slab.Pop(id) + slab := h.events[t] + popped := slab.Pop(id) h.mutex.Unlock() popped.cleanup() @@ -193,20 +236,22 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) { type handler struct { event reflect.Type // underlying type; arg0 or chan underlying type callback reflect.Value - isIface bool chanclose reflect.Value // IsValid() if chan + isIface bool + isSync bool } // newHandler reflects either a channel or a function into a handler. A function // must only have a single argument being the event and no return, and a channel // must have the event type as the underlying type. -func newHandler(unknown interface{}) (handler, error) { +func newHandler(unknown interface{}, sync bool) (handler, error) { fnV := reflect.ValueOf(unknown) fnT := fnV.Type() // underlying event type - var handler = handler{ + handler := handler{ callback: fnV, + isSync: sync, } switch fnT.Kind() { @@ -249,6 +294,14 @@ func (h handler) not(event reflect.Type) bool { return h.event != event } +func (h handler) Call(event reflect.Value) { + if h.isSync { + h.call(event) + } else { + go h.call(event) + } +} + func (h handler) call(event reflect.Value) { if h.chanclose.IsValid() { reflect.Select([]reflect.SelectCase{ diff --git a/utils/handler/handler_test.go b/utils/handler/handler_test.go index 3a6ba9f..6336664 100644 --- a/utils/handler/handler_test.go +++ b/utils/handler/handler_test.go @@ -63,7 +63,7 @@ func TestHandler(t *testing.T) { h, err := newHandler(func(m *gateway.MessageCreateEvent) { results <- m.Content - }) + }, false) if err != nil { t.Fatal(err) } @@ -88,7 +88,7 @@ func TestHandler(t *testing.T) { func TestHandlerChan(t *testing.T) { var results = make(chan *gateway.MessageCreateEvent) - h, err := newHandler(results) + h, err := newHandler(results, false) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestHandlerChanCancel(t *testing.T) { // unbuffered. var results = make(chan *gateway.MessageCreateEvent) - h, err := newHandler(results) + h, err := newHandler(results, false) if err != nil { t.Fatal(err) } @@ -161,7 +161,7 @@ func TestHandlerInterface(t *testing.T) { h, err := newHandler(func(m interface{}) { results <- m - }) + }, false) if err != nil { t.Fatal(err) } @@ -277,7 +277,7 @@ func TestHandlerChanFor(t *testing.T) { } func BenchmarkReflect(b *testing.B) { - h, err := newHandler(func(m *gateway.MessageCreateEvent) {}) + h, err := newHandler(func(m *gateway.MessageCreateEvent) {}, false) if err != nil { b.Fatal(err) } diff --git a/utils/handler/slab.go b/utils/handler/slab.go index a33dd25..abd0b95 100644 --- a/utils/handler/slab.go +++ b/utils/handler/slab.go @@ -1,8 +1,8 @@ package handler type slabEntry struct { - handler index int + handler } func (entry slabEntry) isInvalid() bool { @@ -18,13 +18,13 @@ type slab struct { func (s *slab) Put(entry handler) int { if s.free == len(s.Entries) { index := len(s.Entries) - s.Entries = append(s.Entries, slabEntry{entry, -1}) + s.Entries = append(s.Entries, slabEntry{-1, entry}) s.free++ return index } next := s.Entries[s.free].index - s.Entries[s.free] = slabEntry{entry, -1} + s.Entries[s.free] = slabEntry{-1, entry} i := s.free s.free = next @@ -38,7 +38,7 @@ func (s *slab) Get(i int) handler { func (s *slab) Pop(i int) handler { popped := s.Entries[i].handler - s.Entries[i] = slabEntry{handler{}, s.free} + s.Entries[i] = slabEntry{s.free, handler{}} s.free = i return popped } diff --git a/voice/session_test.go b/voice/session_test.go index e6e6b6c..7ebfbda 100644 --- a/voice/session_test.go +++ b/voice/session_test.go @@ -36,7 +36,7 @@ func TestIntegration(t *testing.T) { AddIntents(s.Gateway) func() { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() if err := s.Open(ctx); err != nil {