diff --git a/gateway/ready.go b/gateway/ready.go index b16df8f..a818fa7 100644 --- a/gateway/ready.go +++ b/gateway/ready.go @@ -233,27 +233,35 @@ type ( } ) -// ConvertSupplementalMember converts a SupplementalMember to a regular Member. -func ConvertSupplementalMember(sm SupplementalMember) discord.Member { - return discord.Member{ - User: discord.User{ID: sm.UserID}, - Nick: sm.Nick, - RoleIDs: sm.RoleIDs, - Joined: sm.Joined, - BoostedSince: sm.BoostedSince, - Deaf: sm.Deaf, - Mute: sm.Mute, - IsPending: sm.IsPending, +// ConvertSupplementalMembers converts a SupplementalMember to a regular Member. +func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member { + members := make([]discord.Member, len(sms)) + for i, sm := range sms { + members[i] = discord.Member{ + User: discord.User{ID: sm.UserID}, + Nick: sm.Nick, + RoleIDs: sm.RoleIDs, + Joined: sm.Joined, + BoostedSince: sm.BoostedSince, + Deaf: sm.Deaf, + Mute: sm.Mute, + IsPending: sm.IsPending, + } } + return members } -// ConvertSupplementalPresence converts a SupplementalPresence to a regular +// ConvertSupplementalPresences converts a SupplementalPresence to a regular // Presence with an empty GuildID. -func ConvertSupplementalPresence(sp SupplementalPresence) discord.Presence { - return discord.Presence{ - User: discord.User{ID: sp.UserID}, - Status: sp.Status, - Activities: sp.Activities, - ClientStatus: sp.ClientStatus, +func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence { + presences := make([]discord.Presence, len(sps)) + for i, sp := range sps { + presences[i] = discord.Presence{ + User: discord.User{ID: sp.UserID}, + Status: sp.Status, + Activities: sp.Activities, + ClientStatus: sp.ClientStatus, + } } + return presences } diff --git a/state/state.go b/state/state.go index d32d281..ce50443 100644 --- a/state/state.go +++ b/state/state.go @@ -27,7 +27,9 @@ var ( // The user should initialize handlers and intents in the opts function. func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc { return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { - return NewFromSession(session.NewCustomShard(m, id), defaultstore.New()), nil + state := NewFromSession(session.NewCustomShard(m, id), defaultstore.New()) + opts(m, state) + return state, nil } } @@ -363,7 +365,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { } if s.tracksChannel(c) { - err = s.Cabinet.ChannelSet(*c, false) + err = s.Cabinet.ChannelSet(c, false) } return @@ -383,8 +385,8 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err } if s.Gateway.HasIntents(gateway.IntentGuilds) { - for _, c := range cs { - if err = s.Cabinet.ChannelSet(c, false); err != nil { + for i := range cs { + if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil { return } } @@ -404,7 +406,7 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel return nil, err } - return c, s.Cabinet.ChannelSet(*c, false) + return c, s.Cabinet.ChannelSet(c, false) } // PrivateChannels gets the direct messages of the user. @@ -420,8 +422,8 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) { return nil, err } - for _, c := range cs { - if err := s.Cabinet.ChannelSet(c, false); err != nil { + for i := range cs { + if err := s.Cabinet.ChannelSet(&cs[i], false); err != nil { return nil, err } } @@ -509,8 +511,8 @@ func (s *State) Guilds() (gs []discord.Guild, err error) { } if s.Gateway.HasIntents(gateway.IntentGuilds) { - for _, g := range gs { - if err = s.Cabinet.GuildSet(g, false); err != nil { + for i := range gs { + if err = s.Cabinet.GuildSet(&gs[i], false); err != nil { return } } @@ -546,8 +548,8 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error } if s.Gateway.HasIntents(gateway.IntentGuildMembers) { - for _, m := range ms { - if err = s.Cabinet.MemberSet(guildID, m, false); err != nil { + for i := range ms { + if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil { return } } @@ -579,7 +581,7 @@ func (s *State) Message( go func() { c, cerr = s.Session.Channel(channelID) if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { - cerr = s.Cabinet.ChannelSet(*c, false) + cerr = s.Cabinet.ChannelSet(c, false) } wg.Done() @@ -688,8 +690,9 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes i = len(apiMessages) } - for _, m := range apiMessages[:i] { - if err := s.Cabinet.MessageSet(m, false); err != nil { + msgs := apiMessages[:i] + for i := range msgs { + if err := s.Cabinet.MessageSet(&msgs[i], false); err != nil { return nil, err } } @@ -745,14 +748,14 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di return } - for _, r := range rs { + for i, r := range rs { if r.ID == roleID { r := r // copy to prevent mem aliasing target = &r } if s.Gateway.HasIntents(gateway.IntentGuilds) { - if err = s.RoleSet(guildID, r, false); err != nil { + if err = s.RoleSet(guildID, &rs[i], false); err != nil { return } } @@ -777,8 +780,8 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { } if s.Gateway.HasIntents(gateway.IntentGuilds) { - for _, r := range rs { - if err := s.RoleSet(guildID, r, false); err != nil { + for i := range rs { + if err := s.RoleSet(guildID, &rs[i], false); err != nil { return rs, err } } @@ -790,7 +793,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { g, err = s.Session.Guild(id) if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { - err = s.Cabinet.GuildSet(*g, false) + err = s.Cabinet.GuildSet(g, false) } return @@ -799,7 +802,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) { m, err = s.Session.Member(gID, uID) if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) { - err = s.Cabinet.MemberSet(gID, *m, false) + err = s.Cabinet.MemberSet(gID, m, false) } return @@ -808,12 +811,14 @@ func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord // tracksMessage reports whether the state would track the passed message and // messages from the same channel. func (s *State) tracksMessage(m *discord.Message) bool { - return (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || + return s.Gateway.Identifier.Intents == nil || + (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || (!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages)) } // tracksChannel reports whether the state would track the passed channel. func (s *State) tracksChannel(c *discord.Channel) bool { - return (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || + return s.Gateway.Identifier.Intents == nil || + (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || !c.GuildID.IsValid() } diff --git a/state/state_events.go b/state/state_events.go index 5208275..036bb36 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -9,7 +9,7 @@ import ( ) func (s *State) hookSession() { - s.Session.AddHandler(func(event interface{}) { + s.Session.AddSyncHandler(func(event interface{}) { // Call the pre-handler before the state handler. if s.PreHandler != nil { s.PreHandler.Call(event) @@ -68,15 +68,15 @@ func (s *State) onEvent(iface interface{}) { } // Handle guild presences - for _, p := range ev.Presences { - if err := s.Cabinet.PresenceSet(p.GuildID, p, false); err != nil { + for i, presence := range ev.Presences { + if err := s.Cabinet.PresenceSet(presence.GuildID, &ev.Presences[i], false); err != nil { s.stateErr(err, "failed to set presence in Ready") } } // Handle private channels - for _, ch := range ev.PrivateChannels { - if err := s.Cabinet.ChannelSet(ch, false); err != nil { + for i := range ev.PrivateChannels { + if err := s.Cabinet.ChannelSet(&ev.PrivateChannels[i], false); err != nil { s.stateErr(err, "failed to set channel in Ready") } } @@ -90,16 +90,17 @@ func (s *State) onEvent(iface interface{}) { // Handle guilds for _, guild := range ev.Guilds { // Handle guild voice states - for _, v := range guild.VoiceStates { + for i := range guild.VoiceStates { + v := &guild.VoiceStates[i] if err := s.Cabinet.VoiceStateSet(guild.ID, v, false); err != nil { s.stateErr(err, "failed to set guild voice state in Ready Supplemental") } } } - for _, friend := range ev.MergedPresences.Friends { - sPresence := gateway.ConvertSupplementalPresence(friend) - if err := s.Cabinet.PresenceSet(0, sPresence, false); err != nil { + friendPresences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Friends) + for i := range friendPresences { + if err := s.Cabinet.PresenceSet(0, &friendPresences[i], false); err != nil { s.stateErr(err, "failed to set friend presence in Ready Supplemental") } } @@ -110,16 +111,16 @@ func (s *State) onEvent(iface interface{}) { for i := 0; i < len(ready.Guilds) && i < len(ev.MergedMembers); i++ { guild := ready.Guilds[i] - for _, member := range ev.MergedMembers[i] { - sMember := gateway.ConvertSupplementalMember(member) - if err := s.Cabinet.MemberSet(guild.ID, sMember, false); err != nil { + members := gateway.ConvertSupplementalMembers(ev.MergedMembers[i]) + for i := range members { + if err := s.Cabinet.MemberSet(guild.ID, &members[i], false); err != nil { s.stateErr(err, "failed to set friend presence in Ready Supplemental") } } - for _, member := range ev.MergedPresences.Guilds[i] { - sPresence := gateway.ConvertSupplementalPresence(member) - if err := s.Cabinet.PresenceSet(guild.ID, sPresence, false); err != nil { + presences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Guilds[i]) + for i := range presences { + if err := s.Cabinet.PresenceSet(guild.ID, &presences[i], false); err != nil { s.stateErr(err, "failed to set member presence in Ready Supplemental") } } @@ -129,7 +130,7 @@ func (s *State) onEvent(iface interface{}) { s.batchLog(storeGuildCreate(s.Cabinet, ev)) case *gateway.GuildUpdateEvent: - if err := s.Cabinet.GuildSet(ev.Guild, true); err != nil { + if err := s.Cabinet.GuildSet(&ev.Guild, true); err != nil { s.stateErr(err, "failed to update guild in state") } @@ -139,7 +140,7 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.GuildMemberAddEvent: - if err := s.Cabinet.MemberSet(ev.GuildID, ev.Member, false); err != nil { + if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Member, false); err != nil { s.stateErr(err, "failed to add a member in state") } @@ -153,7 +154,7 @@ func (s *State) onEvent(iface interface{}) { // Update available fields from ev into m ev.Update(m) - if err := s.Cabinet.MemberSet(ev.GuildID, *m, true); err != nil { + if err := s.Cabinet.MemberSet(ev.GuildID, m, true); err != nil { s.stateErr(err, "failed to update a member in state") } @@ -163,25 +164,25 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.GuildMembersChunkEvent: - for _, m := range ev.Members { - if err := s.Cabinet.MemberSet(ev.GuildID, m, false); err != nil { + for i := range ev.Members { + if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Members[i], false); err != nil { s.stateErr(err, "failed to add a member from chunk in state") } } - for _, p := range ev.Presences { - if err := s.Cabinet.PresenceSet(ev.GuildID, p, false); err != nil { + for i := range ev.Presences { + if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presences[i], false); err != nil { s.stateErr(err, "failed to add a presence from chunk in state") } } case *gateway.GuildRoleCreateEvent: - if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, false); err != nil { + if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, false); err != nil { s.stateErr(err, "failed to add a role in state") } case *gateway.GuildRoleUpdateEvent: - if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, true); err != nil { + if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, true); err != nil { s.stateErr(err, "failed to update a role in state") } @@ -196,17 +197,17 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.ChannelCreateEvent: - if err := s.Cabinet.ChannelSet(ev.Channel, false); err != nil { + if err := s.Cabinet.ChannelSet(&ev.Channel, false); err != nil { s.stateErr(err, "failed to create a channel in state") } case *gateway.ChannelUpdateEvent: - if err := s.Cabinet.ChannelSet(ev.Channel, true); err != nil { + if err := s.Cabinet.ChannelSet(&ev.Channel, true); err != nil { s.stateErr(err, "failed to update a channel in state") } case *gateway.ChannelDeleteEvent: - if err := s.Cabinet.ChannelRemove(ev.Channel); err != nil { + if err := s.Cabinet.ChannelRemove(&ev.Channel); err != nil { s.stateErr(err, "failed to remove a channel in state") } @@ -214,12 +215,12 @@ func (s *State) onEvent(iface interface{}) { // not tracked. case *gateway.MessageCreateEvent: - if err := s.Cabinet.MessageSet(ev.Message, false); err != nil { + if err := s.Cabinet.MessageSet(&ev.Message, false); err != nil { s.stateErr(err, "failed to add a message in state") } case *gateway.MessageUpdateEvent: - if err := s.Cabinet.MessageSet(ev.Message, true); err != nil { + if err := s.Cabinet.MessageSet(&ev.Message, true); err != nil { s.stateErr(err, "failed to update a message in state") } @@ -238,12 +239,17 @@ func (s *State) onEvent(iface interface{}) { case *gateway.MessageReactionAddEvent: s.editMessage(ev.ChannelID, ev.MessageID, func(m *discord.Message) bool { if i := findReaction(m.Reactions, ev.Emoji); i > -1 { + // Copy the reactions slice so it's not racy. + m.Reactions = append([]discord.Reaction(nil), m.Reactions...) m.Reactions[i].Count++ } else { var me bool if u, _ := s.Cabinet.Me(); u != nil { me = ev.UserID == u.ID } + old := m.Reactions + m.Reactions = make([]discord.Reaction, 0, len(old)+1) + m.Reactions = append(m.Reactions, old...) m.Reactions = append(m.Reactions, discord.Reaction{ Count: 1, Me: me, @@ -261,18 +267,21 @@ func (s *State) onEvent(iface interface{}) { } r := &m.Reactions[i] - r.Count-- + newCount := r.Count - 1 switch { - case r.Count < 1: // If the count is 0: - // Remove the reaction. - m.Reactions = append(m.Reactions[:i], m.Reactions[i+1:]...) + case newCount < 1: // If the count is 0: + old := m.Reactions + m.Reactions = make([]discord.Reaction, len(m.Reactions)-1) + copy(m.Reactions[0:], old[:i]) + copy(m.Reactions[i:], old[i+1:]) case r.Me: // If reaction removal is the user's u, err := s.Cabinet.Me() if err == nil && ev.UserID == u.ID { r.Me = false } + r.Count-- } return true @@ -295,13 +304,13 @@ func (s *State) onEvent(iface interface{}) { }) case *gateway.PresenceUpdateEvent: - if err := s.Cabinet.PresenceSet(ev.GuildID, ev.Presence, true); err != nil { + if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presence, true); err != nil { s.stateErr(err, "failed to update presence in state") } case *gateway.PresencesReplaceEvent: for _, p := range *ev { - if err := s.Cabinet.PresenceSet(p.GuildID, p.Presence, true); err != nil { + if err := s.Cabinet.PresenceSet(p.GuildID, &p.Presence, true); err != nil { s.stateErr(err, "failed to update presence in state") } } @@ -332,7 +341,7 @@ func (s *State) onEvent(iface interface{}) { s.stateErr(err, "failed to remove voice state from state") } } else { - if err := s.Cabinet.VoiceStateSet(vs.GuildID, *vs, true); err != nil { + if err := s.Cabinet.VoiceStateSet(vs.GuildID, vs, true); err != nil { s.stateErr(err, "failed to update voice state in state") } } @@ -342,6 +351,7 @@ func (s *State) onEvent(iface interface{}) { func (s *State) stateErr(err error, wrap string) { s.StateLog(errors.Wrap(err, wrap)) } + func (s *State) batchLog(errors []error) { for _, err := range errors { s.StateLog(err) @@ -355,10 +365,16 @@ func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func if err != nil { return } + + // Copy the messages. + cpy := *m + m = &cpy + if !fn(m) { return } - if err := s.Cabinet.MessageSet(*m, true); err != nil { + + if err := s.Cabinet.MessageSet(m, true); err != nil { s.stateErr(err, "failed to save message in reaction add") } } @@ -379,7 +395,7 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err stack, errs := newErrorStack() - if err := cab.GuildSet(guild.Guild, false); err != nil { + if err := cab.GuildSet(&guild.Guild, false); err != nil { errs(err, "failed to set guild in Ready") } @@ -391,8 +407,8 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err } // Handle guild member - for _, m := range guild.Members { - if err := cab.MemberSet(guild.ID, m, false); err != nil { + for i := range guild.Members { + if err := cab.MemberSet(guild.ID, &guild.Members[i], false); err != nil { errs(err, "failed to set guild member in Ready") } } @@ -400,36 +416,40 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err // Handle guild channels for _, ch := range guild.Channels { // I HATE Discord. + ch := ch ch.GuildID = guild.ID - if err := cab.ChannelSet(ch, false); err != nil { + if err := cab.ChannelSet(&ch, false); err != nil { errs(err, "failed to set guild channel in Ready") } } // Handle threads. for _, ch := range guild.Threads { + ch := ch ch.GuildID = guild.ID - if err := cab.ChannelSet(ch, false); err != nil { + if err := cab.ChannelSet(&ch, false); err != nil { errs(err, "failed to set guild thread in Ready") } } // Handle guild presences for _, p := range guild.Presences { + p := p p.GuildID = guild.ID - if err := cab.PresenceSet(guild.ID, p, false); err != nil { + if err := cab.PresenceSet(guild.ID, &p, false); err != nil { errs(err, "failed to set guild presence in Ready") } } // Handle guild voice states for _, v := range guild.VoiceStates { + v := v v.GuildID = guild.ID - if err := cab.VoiceStateSet(guild.ID, v, false); err != nil { + if err := cab.VoiceStateSet(guild.ID, &v, false); err != nil { errs(err, "failed to set guild voice state in Ready") } } diff --git a/state/store/defaultstore/channel.go b/state/store/defaultstore/channel.go index 80a1eed..bc33ed9 100644 --- a/state/store/defaultstore/channel.go +++ b/state/store/defaultstore/channel.go @@ -2,6 +2,7 @@ package defaultstore import ( "errors" + "fmt" "sync" "github.com/diamondburned/arikawa/v3/discord" @@ -13,20 +14,18 @@ type Channel struct { // Channel references must be protected under the same mutex. - privates map[discord.UserID]*discord.Channel - privateChs []*discord.Channel - - channels map[discord.ChannelID]*discord.Channel - guildChs map[discord.GuildID][]*discord.Channel + channels map[discord.ChannelID]discord.Channel + privates map[discord.UserID]discord.ChannelID + guildChs map[discord.GuildID][]discord.ChannelID } var _ store.ChannelStore = (*Channel)(nil) func NewChannel() *Channel { return &Channel{ - privates: map[discord.UserID]*discord.Channel{}, - channels: map[discord.ChannelID]*discord.Channel{}, - guildChs: map[discord.GuildID][]*discord.Channel{}, + channels: map[discord.ChannelID]discord.Channel{}, + privates: map[discord.UserID]discord.ChannelID{}, + guildChs: map[discord.GuildID][]discord.ChannelID{}, } } @@ -34,9 +33,9 @@ func (s *Channel) Reset() error { s.mut.Lock() defer s.mut.Unlock() - s.privates = map[discord.UserID]*discord.Channel{} - s.channels = map[discord.ChannelID]*discord.Channel{} - s.guildChs = map[discord.GuildID][]*discord.Channel{} + s.channels = map[discord.ChannelID]discord.Channel{} + s.privates = map[discord.UserID]discord.ChannelID{} + s.guildChs = map[discord.GuildID][]discord.ChannelID{} return nil } @@ -50,20 +49,19 @@ func (s *Channel) Channel(id discord.ChannelID) (*discord.Channel, error) { return nil, store.ErrNotFound } - cpy := *ch - return &cpy, nil + return &ch, nil } func (s *Channel) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { s.mut.RLock() defer s.mut.RUnlock() - ch, ok := s.privates[recipient] + id, ok := s.privates[recipient] if !ok { return nil, store.ErrNotFound } - cpy := *ch + cpy := s.channels[id] return &cpy, nil } @@ -72,16 +70,20 @@ func (s *Channel) Channels(guildID discord.GuildID) ([]discord.Channel, error) { s.mut.RLock() defer s.mut.RUnlock() - chRefs, ok := s.guildChs[guildID] + chIDs, ok := s.guildChs[guildID] if !ok { return nil, store.ErrNotFound } // Reading chRefs is also covered by the global mutex. - var channels = make([]discord.Channel, len(chRefs)) - for i, chRef := range chRefs { - channels[i] = *chRef + var channels = make([]discord.Channel, 0, len(chIDs)) + for _, chID := range chIDs { + ch, ok := s.channels[chID] + if !ok { + continue + } + channels = append(channels, ch) } return channels, nil @@ -92,42 +94,47 @@ func (s *Channel) PrivateChannels() ([]discord.Channel, error) { s.mut.RLock() defer s.mut.RUnlock() - if len(s.privateChs) == 0 { + groupDMs := s.guildChs[0] + + if len(s.privates) == 0 && len(groupDMs) == 0 { return nil, store.ErrNotFound } - var channels = make([]discord.Channel, len(s.privateChs)) - for i, ch := range s.privateChs { - channels[i] = *ch + var channels = make([]discord.Channel, 0, len(s.privates)+len(groupDMs)) + for _, chID := range s.privates { + if ch, ok := s.channels[chID]; ok { + channels = append(channels, ch) + } + } + for _, chID := range groupDMs { + if ch, ok := s.channels[chID]; ok { + channels = append(channels, ch) + } } return channels, nil } // ChannelSet sets the Direct Message or Guild channel into the state. -func (s *Channel) ChannelSet(channel discord.Channel, update bool) error { +func (s *Channel) ChannelSet(channel *discord.Channel, update bool) error { + cpy := *channel + s.mut.Lock() defer s.mut.Unlock() // Update the reference if we can. - if ch, ok := s.channels[channel.ID]; ok { - if update { - *ch = channel - } - return nil - } + s.channels[channel.ID] = cpy switch channel.Type { case discord.DirectMessage: // Safety bound check. if len(channel.DMRecipients) != 1 { - return errors.New("DirectMessage channel does not have 1 recipient") + return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID) } - s.privates[channel.DMRecipients[0].ID] = &channel - fallthrough + s.privates[channel.DMRecipients[0].ID] = channel.ID + return nil case discord.GroupDM: - s.privateChs = append(s.privateChs, &channel) - s.channels[channel.ID] = &channel + s.guildChs[0] = addChannelID(s.guildChs[0], channel.ID) return nil } @@ -137,16 +144,11 @@ func (s *Channel) ChannelSet(channel discord.Channel, update bool) error { return errors.New("invalid guildID for guild channel") } - s.channels[channel.ID] = &channel - - channels, _ := s.guildChs[channel.GuildID] - channels = append(channels, &channel) - s.guildChs[channel.GuildID] = channels - + s.guildChs[channel.GuildID] = addChannelID(s.guildChs[channel.GuildID], channel.ID) return nil } -func (s *Channel) ChannelRemove(channel discord.Channel) error { +func (s *Channel) ChannelRemove(channel *discord.Channel) error { s.mut.Lock() defer s.mut.Unlock() @@ -158,49 +160,42 @@ func (s *Channel) ChannelRemove(channel discord.Channel) error { case discord.DirectMessage: // Safety bound check. if len(channel.DMRecipients) != 1 { - return errors.New("DirectMessage channel does not have 1 recipient") + return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID) } delete(s.privates, channel.DMRecipients[0].ID) - fallthrough + return nil case discord.GroupDM: - for i, priv := range s.privateChs { - if priv.ID == channel.ID { - s.privateChs = removeChannel(s.privateChs, i) - break - } - } + s.guildChs[0] = removeChannelID(s.guildChs[0], channel.ID) return nil } - // Wipe the channel off the guilds index, if available. - channels, ok := s.guildChs[channel.GuildID] - if !ok { - return nil - } - - for i, ch := range channels { - if ch.ID == channel.ID { - s.guildChs[channel.GuildID] = removeChannel(channels, i) - break - } - } - + s.guildChs[channel.GuildID] = removeChannelID(s.guildChs[channel.GuildID], channel.ID) return nil } -// removeChannel removes the given channel with the index from the given +func addChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID { + for _, ch := range channels { + if ch == id { + return channels + } + } + if channels == nil { + channels = make([]discord.ChannelID, 0, 5) + } + return append(channels, id) +} + +// removeChannelID removes the given channel with the index from the given // channels slice in an unordered fashion. -func removeChannel(channels []*discord.Channel, i int) []*discord.Channel { - // Fast unordered delete. Not sure if there's a benefit in doing - // this over using a map, but I guess the memory usage is less and - // there's no copying. - - // Move the last channel to the current channel, set the last - // channel there to a nil value to unreference its children, then - // slice the last channel off. - channels[i] = channels[len(channels)-1] - channels[len(channels)-1] = nil - channels = channels[:len(channels)-1] - +func removeChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID { + for i, ch := range channels { + if ch == id { + // Move the last channel to the current channel, then slice the last + // channel off. + channels[i] = channels[len(channels)-1] + channels = channels[:len(channels)-1] + break + } + } return channels } diff --git a/state/store/defaultstore/guild.go b/state/store/defaultstore/guild.go index 2bae4e6..43868e8 100644 --- a/state/store/defaultstore/guild.go +++ b/state/store/defaultstore/guild.go @@ -33,13 +33,12 @@ func (s *Guild) Guild(id discord.GuildID) (*discord.Guild, error) { s.mut.RLock() defer s.mut.RUnlock() - ch, ok := s.guilds[id] - if !ok { - return nil, store.ErrNotFound + g, ok := s.guilds[id] + if ok { + return &g, nil } - // implicit copy - return &ch, nil + return nil, store.ErrNotFound } func (s *Guild) Guilds() ([]discord.Guild, error) { @@ -58,10 +57,12 @@ func (s *Guild) Guilds() ([]discord.Guild, error) { return gs, nil } -func (s *Guild) GuildSet(guild discord.Guild, update bool) error { +func (s *Guild) GuildSet(guild *discord.Guild, update bool) error { + cpy := *guild + s.mut.Lock() if _, ok := s.guilds[guild.ID]; !ok || update { - s.guilds[guild.ID] = guild + s.guilds[guild.ID] = cpy } s.mut.Unlock() diff --git a/state/store/defaultstore/member.go b/state/store/defaultstore/member.go index 7706e1e..7aa659f 100644 --- a/state/store/defaultstore/member.go +++ b/state/store/defaultstore/member.go @@ -13,7 +13,7 @@ type Member struct { } type guildMembers struct { - mut sync.Mutex + mut sync.RWMutex members map[discord.UserID]discord.Member } @@ -41,8 +41,8 @@ func (s *Member) Member(guildID discord.GuildID, userID discord.UserID) (*discor gm := iv.(*guildMembers) - gm.mut.Lock() - defer gm.mut.Unlock() + gm.mut.RLock() + defer gm.mut.RUnlock() m, ok := gm.members[userID] if ok { @@ -60,8 +60,8 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) { gm := iv.(*guildMembers) - gm.mut.Lock() - defer gm.mut.Unlock() + gm.mut.RLock() + defer gm.mut.RUnlock() var members = make([]discord.Member, 0, len(gm.members)) for _, m := range gm.members { @@ -71,13 +71,13 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) { return members, nil } -func (s *Member) MemberSet(guildID discord.GuildID, m discord.Member, update bool) error { +func (s *Member) MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) gm := iv.(*guildMembers) gm.mut.Lock() if _, ok := gm.members[m.User.ID]; !ok || update { - gm.members[m.User.ID] = m + gm.members[m.User.ID] = *m } gm.mut.Unlock() diff --git a/state/store/defaultstore/message.go b/state/store/defaultstore/message.go index d7d7ca3..c658722 100644 --- a/state/store/defaultstore/message.go +++ b/state/store/defaultstore/message.go @@ -16,7 +16,7 @@ type Message struct { var _ store.MessageStore = (*Message)(nil) type messages struct { - mut sync.Mutex + mut sync.RWMutex messages []discord.Message } @@ -43,8 +43,8 @@ func (s *Message) Message(chID discord.ChannelID, mID discord.MessageID) (*disco msgs := iv.(*messages) - msgs.mut.Lock() - defer msgs.mut.Unlock() + msgs.mut.RLock() + defer msgs.mut.RUnlock() for _, m := range msgs.messages { if m.ID == mID { @@ -63,8 +63,8 @@ func (s *Message) Messages(channelID discord.ChannelID) ([]discord.Message, erro msgs := iv.(*messages) - msgs.mut.Lock() - defer msgs.mut.Unlock() + msgs.mut.RLock() + defer msgs.mut.RUnlock() return append([]discord.Message(nil), msgs.messages...), nil } @@ -73,7 +73,7 @@ func (s *Message) MaxMessages() int { return s.maxMsgs } -func (s *Message) MessageSet(message discord.Message, update bool) error { +func (s *Message) MessageSet(message *discord.Message, update bool) error { if s.maxMsgs <= 0 { return nil } @@ -102,19 +102,19 @@ func (s *Message) MessageSet(message discord.Message, update bool) error { } if len(msgs.messages) == 0 { - msgs.messages = []discord.Message{message} + msgs.messages = []discord.Message{*message} } if pos := messageInsertPosition(message, msgs.messages); pos < 0 { // Messages are full, drop the oldest messages to make room. if len(msgs.messages) == s.maxMsgs { copy(msgs.messages[1:], msgs.messages) - msgs.messages[0] = message + msgs.messages[0] = *message } else { - msgs.messages = append([]discord.Message{message}, msgs.messages...) + msgs.messages = append([]discord.Message{*message}, msgs.messages...) } } else if pos > 0 && len(msgs.messages) < s.maxMsgs { - msgs.messages = append(msgs.messages, message) + msgs.messages = append(msgs.messages, *message) } // We already have this message or we can't append any more messages. @@ -131,7 +131,7 @@ func (s *Message) MessageSet(message discord.Message, update bool) error { // messageInsertPosition is biased as it will recommend adding the message even // if timestamps just match, even though the true order cannot be determined in // that case. -func messageInsertPosition(target discord.Message, messages []discord.Message) int8 { +func messageInsertPosition(target *discord.Message, messages []discord.Message) int8 { var ( targetTime = target.ID.Time() firstTime = messages[0].ID.Time() @@ -183,7 +183,7 @@ func messageInsertPosition(target discord.Message, messages []discord.Message) i } // DiffMessage fills non-empty fields from src to dst. -func DiffMessage(src discord.Message, dst *discord.Message) { +func DiffMessage(src, dst *discord.Message) { // Thanks, Discord. if src.Content != "" { dst.Content = src.Content diff --git a/state/store/defaultstore/message_test.go b/state/store/defaultstore/message_test.go index 80757f5..bb4e781 100644 --- a/state/store/defaultstore/message_test.go +++ b/state/store/defaultstore/message_test.go @@ -10,27 +10,27 @@ func populate12Store() *Message { store := NewMessage(10) // Insert a regular list of messages. - store.MessageSet(discord.Message{ID: 1 << 29, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 28, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 27, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 26, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 25, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 24, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 29, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 28, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 27, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 26, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 25, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 24, ChannelID: 1}, false) // Try to insert newer messages after inserting new messages. - store.MessageSet(discord.Message{ID: 1 << 30, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 31, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 32, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 33, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 34, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 30, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 31, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 32, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 33, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 34, ChannelID: 1}, false) // TThese messages should be discarded, due to age. - store.MessageSet(discord.Message{ID: 1 << 23, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 22, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 23, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 22, ChannelID: 1}, false) // These should be prepended. - store.MessageSet(discord.Message{ID: 1 << 35, ChannelID: 1}, false) - store.MessageSet(discord.Message{ID: 1 << 36, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 35, ChannelID: 1}, false) + store.MessageSet(&discord.Message{ID: 1 << 36, ChannelID: 1}, false) return store } @@ -57,9 +57,9 @@ func TestMessageSet(t *testing.T) { func TestMessagesUpdate(t *testing.T) { store := populate12Store() - store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true) - store.MessageSet(discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true) - store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true) + store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true) + store.MessageSet(&discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true) + store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true) expect := map[discord.MessageID]string{ 5: "edited 3", diff --git a/state/store/defaultstore/presence.go b/state/store/defaultstore/presence.go index e3fb16b..dcc5c39 100644 --- a/state/store/defaultstore/presence.go +++ b/state/store/defaultstore/presence.go @@ -13,7 +13,7 @@ type Presence struct { } type presences struct { - mut sync.Mutex + mut sync.RWMutex presences map[discord.UserID]discord.Presence } @@ -41,8 +41,8 @@ func (s *Presence) Presence(gID discord.GuildID, uID discord.UserID) (*discord.P ps := iv.(*presences) - ps.mut.Lock() - defer ps.mut.Unlock() + ps.mut.RLock() + defer ps.mut.RUnlock() p, ok := ps.presences[uID] if ok { @@ -60,8 +60,8 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error ps := iv.(*presences) - ps.mut.Lock() - defer ps.mut.Unlock() + ps.mut.RLock() + defer ps.mut.RUnlock() var presences = make([]discord.Presence, 0, len(ps.presences)) for _, p := range ps.presences { @@ -71,7 +71,7 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error return presences, nil } -func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error { +func (s *Presence) PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) ps := iv.(*presences) @@ -85,7 +85,7 @@ func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, upda } if _, ok := ps.presences[p.User.ID]; !ok || update { - ps.presences[p.User.ID] = p + ps.presences[p.User.ID] = *p } return nil diff --git a/state/store/defaultstore/role.go b/state/store/defaultstore/role.go index 903290d..522b5fd 100644 --- a/state/store/defaultstore/role.go +++ b/state/store/defaultstore/role.go @@ -15,7 +15,7 @@ type Role struct { var _ store.RoleStore = (*Role)(nil) type roles struct { - mut sync.Mutex + mut sync.RWMutex roles map[discord.RoleID]discord.Role } @@ -41,8 +41,8 @@ func (s *Role) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Ro rs := iv.(*roles) - rs.mut.Lock() - defer rs.mut.Unlock() + rs.mut.RLock() + defer rs.mut.RUnlock() r, ok := rs.roles[roleID] if ok { @@ -60,8 +60,8 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) { rs := iv.(*roles) - rs.mut.Lock() - defer rs.mut.Unlock() + rs.mut.RLock() + defer rs.mut.RUnlock() var roles = make([]discord.Role, 0, len(rs.roles)) for _, role := range rs.roles { @@ -71,14 +71,14 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) { return roles, nil } -func (s *Role) RoleSet(guildID discord.GuildID, role discord.Role, update bool) error { +func (s *Role) RoleSet(guildID discord.GuildID, role *discord.Role, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) rs := iv.(*roles) rs.mut.Lock() if _, ok := rs.roles[role.ID]; !ok || update { - rs.roles[role.ID] = role + rs.roles[role.ID] = *role } rs.mut.Unlock() diff --git a/state/store/defaultstore/voicestate.go b/state/store/defaultstore/voicestate.go index e7531f5..09ad2ee 100644 --- a/state/store/defaultstore/voicestate.go +++ b/state/store/defaultstore/voicestate.go @@ -15,7 +15,7 @@ type VoiceState struct { var _ store.VoiceStateStore = (*VoiceState)(nil) type voiceStates struct { - mut sync.Mutex + mut sync.RWMutex voiceStates map[discord.UserID]discord.VoiceState } @@ -43,8 +43,8 @@ func (s *VoiceState) VoiceState( vs := iv.(*voiceStates) - vs.mut.Lock() - defer vs.mut.Unlock() + vs.mut.RLock() + defer vs.mut.RUnlock() v, ok := vs.voiceStates[userID] if ok { @@ -62,8 +62,8 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, vs := iv.(*voiceStates) - vs.mut.Lock() - defer vs.mut.Unlock() + vs.mut.RLock() + defer vs.mut.RUnlock() var states = make([]discord.VoiceState, 0, len(vs.voiceStates)) for _, state := range vs.voiceStates { @@ -74,7 +74,7 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, } func (s *VoiceState) VoiceStateSet( - guildID discord.GuildID, voiceState discord.VoiceState, update bool) error { + guildID discord.GuildID, voiceState *discord.VoiceState, update bool) error { iv, _ := s.guilds.LoadOrStore(guildID) @@ -82,7 +82,7 @@ func (s *VoiceState) VoiceStateSet( vs.mut.Lock() if _, ok := vs.voiceStates[voiceState.UserID]; !ok || update { - vs.voiceStates[voiceState.UserID] = voiceState + vs.voiceStates[voiceState.UserID] = *voiceState } vs.mut.Unlock() diff --git a/state/store/store.go b/state/store/store.go index d2c9550..09825c9 100644 --- a/state/store/store.go +++ b/state/store/store.go @@ -143,6 +143,12 @@ type Resetter interface { Reset() error } +type CoreStorer interface { + Resetter + Lock() + Unlock() +} + var _ Resetter = (*noop)(nil) func (noop) Reset() error { return nil } @@ -176,8 +182,8 @@ type ChannelStore interface { // Both ChannelSet and ChannelRemove should switch on Type to know if it's a // private channel or not. - ChannelSet(c discord.Channel, update bool) error - ChannelRemove(discord.Channel) error + ChannelSet(c *discord.Channel, update bool) error + ChannelRemove(*discord.Channel) error } var _ ChannelStore = (*noop)(nil) @@ -194,10 +200,10 @@ func (noop) Channels(discord.GuildID) ([]discord.Channel, error) { func (noop) PrivateChannels() ([]discord.Channel, error) { return nil, ErrNotFound } -func (noop) ChannelSet(discord.Channel, bool) error { +func (noop) ChannelSet(*discord.Channel, bool) error { return nil } -func (noop) ChannelRemove(discord.Channel) error { +func (noop) ChannelRemove(*discord.Channel) error { return nil } @@ -232,7 +238,7 @@ type GuildStore interface { Guild(discord.GuildID) (*discord.Guild, error) Guilds() ([]discord.Guild, error) - GuildSet(g discord.Guild, update bool) error + GuildSet(g *discord.Guild, update bool) error GuildRemove(id discord.GuildID) error } @@ -240,7 +246,7 @@ var _ GuildStore = (*noop)(nil) func (noop) Guild(discord.GuildID) (*discord.Guild, error) { return nil, ErrNotFound } func (noop) Guilds() ([]discord.Guild, error) { return nil, ErrNotFound } -func (noop) GuildSet(discord.Guild, bool) error { return nil } +func (noop) GuildSet(*discord.Guild, bool) error { return nil } func (noop) GuildRemove(discord.GuildID) error { return nil } // MemberStore is the store interface for all members. @@ -250,7 +256,7 @@ type MemberStore interface { Member(discord.GuildID, discord.UserID) (*discord.Member, error) Members(discord.GuildID) ([]discord.Member, error) - MemberSet(guildID discord.GuildID, m discord.Member, update bool) error + MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error MemberRemove(discord.GuildID, discord.UserID) error } @@ -262,7 +268,7 @@ func (noop) Member(discord.GuildID, discord.UserID) (*discord.Member, error) { func (noop) Members(discord.GuildID) ([]discord.Member, error) { return nil, ErrNotFound } -func (noop) MemberSet(discord.GuildID, discord.Member, bool) error { +func (noop) MemberSet(discord.GuildID, *discord.Member, bool) error { return nil } func (noop) MemberRemove(discord.GuildID, discord.UserID) error { @@ -289,7 +295,7 @@ type MessageStore interface { // If update is set to true, MessageSet will check if a message with the // id of the passed message is stored, and update it if so. Otherwise, if // there is no such message, it will be discarded. - MessageSet(m discord.Message, update bool) error + MessageSet(m *discord.Message, update bool) error MessageRemove(discord.ChannelID, discord.MessageID) error } @@ -304,7 +310,7 @@ func (noop) Message(discord.ChannelID, discord.MessageID) (*discord.Message, err func (noop) Messages(discord.ChannelID) ([]discord.Message, error) { return nil, ErrNotFound } -func (noop) MessageSet(discord.Message, bool) error { +func (noop) MessageSet(*discord.Message, bool) error { return nil } func (noop) MessageRemove(discord.ChannelID, discord.MessageID) error { @@ -319,7 +325,7 @@ type PresenceStore interface { Presence(discord.GuildID, discord.UserID) (*discord.Presence, error) Presences(discord.GuildID) ([]discord.Presence, error) - PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error + PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error PresenceRemove(discord.GuildID, discord.UserID) error } @@ -331,7 +337,7 @@ func (noop) Presence(discord.GuildID, discord.UserID) (*discord.Presence, error) func (noop) Presences(discord.GuildID) ([]discord.Presence, error) { return nil, ErrNotFound } -func (noop) PresenceSet(discord.GuildID, discord.Presence, bool) error { +func (noop) PresenceSet(discord.GuildID, *discord.Presence, bool) error { return nil } func (noop) PresenceRemove(discord.GuildID, discord.UserID) error { @@ -345,7 +351,7 @@ type RoleStore interface { Role(discord.GuildID, discord.RoleID) (*discord.Role, error) Roles(discord.GuildID) ([]discord.Role, error) - RoleSet(guildID discord.GuildID, r discord.Role, update bool) error + RoleSet(guildID discord.GuildID, r *discord.Role, update bool) error RoleRemove(discord.GuildID, discord.RoleID) error } @@ -353,7 +359,7 @@ var _ RoleStore = (*noop)(nil) func (noop) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) { return nil, ErrNotFound } func (noop) Roles(discord.GuildID) ([]discord.Role, error) { return nil, ErrNotFound } -func (noop) RoleSet(discord.GuildID, discord.Role, bool) error { return nil } +func (noop) RoleSet(discord.GuildID, *discord.Role, bool) error { return nil } func (noop) RoleRemove(discord.GuildID, discord.RoleID) error { return nil } // VoiceStateStore is the store interface for all voice states. @@ -363,7 +369,7 @@ type VoiceStateStore interface { VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) - VoiceStateSet(guildID discord.GuildID, s discord.VoiceState, update bool) error + VoiceStateSet(guildID discord.GuildID, s *discord.VoiceState, update bool) error VoiceStateRemove(discord.GuildID, discord.UserID) error } @@ -375,7 +381,7 @@ func (noop) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, er func (noop) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) { return nil, ErrNotFound } -func (noop) VoiceStateSet(discord.GuildID, discord.VoiceState, bool) error { +func (noop) VoiceStateSet(discord.GuildID, *discord.VoiceState, bool) error { return nil } func (noop) VoiceStateRemove(discord.GuildID, discord.UserID) error { diff --git a/utils/handler/handler.go b/utils/handler/handler.go index b2e8619..c56b17a 100644 --- a/utils/handler/handler.go +++ b/utils/handler/handler.go @@ -28,12 +28,8 @@ import ( // Handler is a container for command handlers. A zero-value instance is a valid // instance. type Handler struct { - // Synchronous controls whether to spawn each event handler in its own - // goroutine. Default false (meaning goroutines are spawned). - Synchronous bool - - mutex sync.RWMutex - slab slab + mutex sync.RWMutex + events map[reflect.Type]slab // nil type for interfaces } func New() *Handler { @@ -43,22 +39,24 @@ func New() *Handler { // Call calls all handlers with the given event. This is an internal method; use // with care. func (h *Handler) Call(ev interface{}) { - var evV = reflect.ValueOf(ev) - var evT = evV.Type() + v := reflect.ValueOf(ev) + t := v.Type() h.mutex.RLock() defer h.mutex.RUnlock() - for _, entry := range h.slab.Entries { - if entry.isInvalid() || entry.not(evT) { + for _, entry := range h.events[t].Entries { + if entry.isInvalid() { continue } + entry.Call(v) + } - if h.Synchronous { - entry.call(evV) - } else { - go entry.call(evV) + for _, entry := range h.events[nil].Entries { + if entry.isInvalid() || entry.not(t) { + continue } + entry.Call(v) } } @@ -145,7 +143,19 @@ func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, ca // h.AddHandler(ch) // func (h *Handler) AddHandler(handler interface{}) (rm func()) { - rm, err := h.addHandler(handler) + rm, err := h.addHandler(handler, false) + if err != nil { + panic(err) + } + return rm +} + +// AddSyncHandler is a synchronous variant of AddHandler. Handlers added using +// this method will block the Call method, which is helpful if the user needs to +// rely on the order of events arriving. Handlers added using this method should +// not block for very long, as it may clog up other handlers. +func (h *Handler) AddSyncHandler(handler interface{}) (rm func()) { + rm, err := h.addHandler(handler, true) if err != nil { panic(err) } @@ -167,23 +177,56 @@ func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) { } }() - return h.addHandler(handler) + return h.addHandler(handler, false) } -func (h *Handler) addHandler(fn interface{}) (rm func(), err error) { +// AddSyncHandlerCheck is the safe-guarded version of AddSyncHandler. It is +// similar to AddHandlerCheck. +func (h *Handler) AddSyncHandlerCheck(handler interface{}) (rm func(), err error) { + // Reflect would actually panic if anything goes wrong, so this is just in + // case. + defer func() { + if rec := recover(); rec != nil { + if recErr, ok := rec.(error); ok { + err = recErr + } else { + err = fmt.Errorf("%v", rec) + } + } + }() + + return h.addHandler(handler, true) +} + +func (h *Handler) addHandler(fn interface{}, sync bool) (rm func(), err error) { // Reflect the handler - r, err := newHandler(fn) + r, err := newHandler(fn, sync) if err != nil { return nil, errors.Wrap(err, "handler reflect failed") } + var id int + var t reflect.Type + if !r.isIface { + t = r.event + } + h.mutex.Lock() - id := h.slab.Put(r) + + if h.events == nil { + h.events = make(map[reflect.Type]slab, 10) + } + + slab := h.events[t] + id = slab.Put(r) + h.events[t] = slab + h.mutex.Unlock() return func() { h.mutex.Lock() - popped := h.slab.Pop(id) + slab := h.events[t] + popped := slab.Pop(id) h.mutex.Unlock() popped.cleanup() @@ -193,20 +236,22 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) { type handler struct { event reflect.Type // underlying type; arg0 or chan underlying type callback reflect.Value - isIface bool chanclose reflect.Value // IsValid() if chan + isIface bool + isSync bool } // newHandler reflects either a channel or a function into a handler. A function // must only have a single argument being the event and no return, and a channel // must have the event type as the underlying type. -func newHandler(unknown interface{}) (handler, error) { +func newHandler(unknown interface{}, sync bool) (handler, error) { fnV := reflect.ValueOf(unknown) fnT := fnV.Type() // underlying event type - var handler = handler{ + handler := handler{ callback: fnV, + isSync: sync, } switch fnT.Kind() { @@ -249,6 +294,14 @@ func (h handler) not(event reflect.Type) bool { return h.event != event } +func (h handler) Call(event reflect.Value) { + if h.isSync { + h.call(event) + } else { + go h.call(event) + } +} + func (h handler) call(event reflect.Value) { if h.chanclose.IsValid() { reflect.Select([]reflect.SelectCase{ diff --git a/utils/handler/handler_test.go b/utils/handler/handler_test.go index 3a6ba9f..6336664 100644 --- a/utils/handler/handler_test.go +++ b/utils/handler/handler_test.go @@ -63,7 +63,7 @@ func TestHandler(t *testing.T) { h, err := newHandler(func(m *gateway.MessageCreateEvent) { results <- m.Content - }) + }, false) if err != nil { t.Fatal(err) } @@ -88,7 +88,7 @@ func TestHandler(t *testing.T) { func TestHandlerChan(t *testing.T) { var results = make(chan *gateway.MessageCreateEvent) - h, err := newHandler(results) + h, err := newHandler(results, false) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestHandlerChanCancel(t *testing.T) { // unbuffered. var results = make(chan *gateway.MessageCreateEvent) - h, err := newHandler(results) + h, err := newHandler(results, false) if err != nil { t.Fatal(err) } @@ -161,7 +161,7 @@ func TestHandlerInterface(t *testing.T) { h, err := newHandler(func(m interface{}) { results <- m - }) + }, false) if err != nil { t.Fatal(err) } @@ -277,7 +277,7 @@ func TestHandlerChanFor(t *testing.T) { } func BenchmarkReflect(b *testing.B) { - h, err := newHandler(func(m *gateway.MessageCreateEvent) {}) + h, err := newHandler(func(m *gateway.MessageCreateEvent) {}, false) if err != nil { b.Fatal(err) } diff --git a/utils/handler/slab.go b/utils/handler/slab.go index a33dd25..abd0b95 100644 --- a/utils/handler/slab.go +++ b/utils/handler/slab.go @@ -1,8 +1,8 @@ package handler type slabEntry struct { - handler index int + handler } func (entry slabEntry) isInvalid() bool { @@ -18,13 +18,13 @@ type slab struct { func (s *slab) Put(entry handler) int { if s.free == len(s.Entries) { index := len(s.Entries) - s.Entries = append(s.Entries, slabEntry{entry, -1}) + s.Entries = append(s.Entries, slabEntry{-1, entry}) s.free++ return index } next := s.Entries[s.free].index - s.Entries[s.free] = slabEntry{entry, -1} + s.Entries[s.free] = slabEntry{-1, entry} i := s.free s.free = next @@ -38,7 +38,7 @@ func (s *slab) Get(i int) handler { func (s *slab) Pop(i int) handler { popped := s.Entries[i].handler - s.Entries[i] = slabEntry{handler{}, s.free} + s.Entries[i] = slabEntry{s.free, handler{}} s.free = i return popped } diff --git a/voice/session_test.go b/voice/session_test.go index e6e6b6c..7ebfbda 100644 --- a/voice/session_test.go +++ b/voice/session_test.go @@ -36,7 +36,7 @@ func TestIntegration(t *testing.T) { AddIntents(s.Gateway) func() { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() if err := s.Open(ctx); err != nil {