From e79132f2c5167fe78687eb406767fba303527588 Mon Sep 17 00:00:00 2001 From: diamondburned Date: Tue, 28 Jul 2020 12:00:01 -0700 Subject: [PATCH] State: Breaking API to fix race conditions in store --- state/state.go | 52 ++++--- state/state_events.go | 73 ++++----- state/store.go | 21 +-- state/store_default.go | 344 +++++++++++++++++++---------------------- state/store_noop.go | 18 +-- 5 files changed, 243 insertions(+), 265 deletions(-) diff --git a/state/state.go b/state/state.go index f363dd4..64fec7c 100644 --- a/state/state.go +++ b/state/state.go @@ -17,7 +17,7 @@ import ( var ( MaxFetchMembers uint = 1000 - MaxFetchGuilds uint = 100 + MaxFetchGuilds uint = 10 ) // State is the cache to store events coming from Discord as well as data from @@ -80,12 +80,10 @@ type State struct { // with the State. *handler.Handler - unhooker func() - // List of channels with few messages, so it doesn't bother hitting the API // again. fewMessages map[discord.ChannelID]struct{} - fewMutex *sync.Mutex + fewMutex sync.Mutex // unavailableGuilds is a set of discord.GuildIDs of guilds that became // unavailable when already connected to the gateway, i.e. sent in a @@ -131,7 +129,7 @@ func NewFromSession(s *session.Session, store Store) (*State, error) { Handler: handler.New(), StateLog: func(err error) {}, fewMessages: map[discord.ChannelID]struct{}{}, - fewMutex: new(sync.Mutex), + fewMutex: sync.Mutex{}, unavailableGuilds: moreatomic.NewGuildIDSet(), unreadyGuilds: moreatomic.NewGuildIDSet(), } @@ -235,7 +233,9 @@ func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (dis //// -func (s *State) Permissions(channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) { +func (s *State) Permissions( + channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) { + ch, err := s.Channel(channelID) if err != nil { return 0, errors.Wrap(err, "failed to get channel") @@ -286,7 +286,7 @@ func (s *State) Me() (*discord.User, error) { return nil, err } - return u, s.Store.MyselfSet(u) + return u, s.Store.MyselfSet(*u) } //// @@ -302,7 +302,7 @@ func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) { return nil, err } - return c, s.Store.ChannelSet(c) + return c, s.Store.ChannelSet(*c) } func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) { @@ -319,7 +319,7 @@ func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) { for _, ch := range c { ch := ch - if err := s.Store.ChannelSet(&ch); err != nil { + if err := s.Store.ChannelSet(ch); err != nil { return nil, err } } @@ -338,7 +338,7 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel return nil, err } - return c, s.Store.ChannelSet(c) + return c, s.Store.ChannelSet(*c) } func (s *State) PrivateChannels() ([]discord.Channel, error) { @@ -355,7 +355,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) { for _, ch := range c { ch := ch - if err := s.Store.ChannelSet(&ch); err != nil { + if err := s.Store.ChannelSet(ch); err != nil { return nil, err } } @@ -431,7 +431,7 @@ func (s *State) Guilds() ([]discord.Guild, error) { for _, ch := range c { ch := ch - if err := s.Store.GuildSet(&ch); err != nil { + if err := s.Store.GuildSet(ch); err != nil { return nil, err } } @@ -462,7 +462,7 @@ func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) { } for _, m := range ms { - if err := s.Store.MemberSet(guildID, &m); err != nil { + if err := s.Store.MemberSet(guildID, m); err != nil { return nil, err } } @@ -475,7 +475,9 @@ func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) { //// -func (s *State) Message(channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { +func (s *State) Message( + channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { + m, err := s.Store.Message(channelID, messageID) if err == nil { return m, nil @@ -489,7 +491,7 @@ func (s *State) Message(channelID discord.ChannelID, messageID discord.MessageID go func() { c, cerr = s.Session.Channel(channelID) if cerr == nil { - cerr = s.Store.ChannelSet(c) + cerr = s.Store.ChannelSet(*c) } wg.Done() @@ -510,7 +512,7 @@ func (s *State) Message(channelID discord.ChannelID, messageID discord.MessageID m.ChannelID = c.ID m.GuildID = c.GuildID - return m, s.Store.MessageSet(m) + return m, s.Store.MessageSet(*m) } // Messages fetches maximum 100 messages from the API, if it has to. There is no @@ -559,7 +561,7 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) // Set the guild ID, fine if it's 0 (it's already 0 anyway). ms[i].GuildID = guildID - if err := s.Store.MessageSet(&ms[i]); err != nil { + if err := s.Store.MessageSet(ms[i]); err != nil { return nil, err } } @@ -582,7 +584,9 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) // Presence checks the state for user presences. If no guildID is given, it will // look for the presence in all guilds. -func (s *State) Presence(guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { +func (s *State) Presence( + guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { + p, err := s.Store.Presence(guildID, userID) if err == nil { return p, nil @@ -627,7 +631,7 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.R role = &r } - if err := s.RoleSet(guildID, &r); err != nil { + if err := s.RoleSet(guildID, r); err != nil { return role, err } } @@ -649,7 +653,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { for _, r := range rs { r := r - if err := s.RoleSet(guildID, &r); err != nil { + if err := s.RoleSet(guildID, r); err != nil { return rs, err } } @@ -660,16 +664,18 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { g, err = s.Session.Guild(id) if err == nil { - err = s.Store.GuildSet(g) + err = s.Store.GuildSet(*g) } return } -func (s *State) fetchMember(guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) { +func (s *State) fetchMember( + guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) { + m, err = s.Session.Member(guildID, userID) if err == nil { - err = s.Store.MemberSet(guildID, m) + err = s.Store.MemberSet(guildID, *m) } return diff --git a/state/state_events.go b/state/state_events.go index 4a27bbe..f638930 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -8,7 +8,7 @@ import ( ) func (s *State) hookSession() { - s.unhooker = s.Session.AddHandler(func(event interface{}) { + s.Session.AddHandler(func(event interface{}) { // Call the pre-handler before the state handler. if s.PreHandler != nil { s.PreHandler.Call(event) @@ -55,9 +55,7 @@ func (s *State) onEvent(iface interface{}) { // Handle presences for _, p := range ev.Presences { - p := p - - if err := s.Store.PresenceSet(0, &p); err != nil { + if err := s.Store.PresenceSet(0, p); err != nil { s.stateErr(err, "failed to set global presence") } } @@ -68,19 +66,19 @@ func (s *State) onEvent(iface interface{}) { } // Handle private channels - for i := range ev.PrivateChannels { - if err := s.Store.ChannelSet(&ev.PrivateChannels[i]); err != nil { + for _, ch := range ev.PrivateChannels { + if err := s.Store.ChannelSet(ch); err != nil { s.stateErr(err, "failed to set channel in state") } } // Handle user - if err := s.Store.MyselfSet(&ev.User); err != nil { + if err := s.Store.MyselfSet(ev.User); err != nil { s.stateErr(err, "failed to set self in state") } case *gateway.GuildUpdateEvent: - if err := s.Store.GuildSet(&ev.Guild); err != nil { + if err := s.Store.GuildSet(ev.Guild); err != nil { s.stateErr(err, "failed to update guild in state") } @@ -90,7 +88,7 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.GuildMemberAddEvent: - if err := s.Store.MemberSet(ev.GuildID, &ev.Member); err != nil { + if err := s.Store.MemberSet(ev.GuildID, ev.Member); err != nil { s.stateErr(err, "failed to add a member in state") } @@ -104,7 +102,7 @@ func (s *State) onEvent(iface interface{}) { // Update available fields from ev into m ev.Update(m) - if err := s.Store.MemberSet(ev.GuildID, m); err != nil { + if err := s.Store.MemberSet(ev.GuildID, *m); err != nil { s.stateErr(err, "failed to update a member in state") } @@ -115,28 +113,24 @@ func (s *State) onEvent(iface interface{}) { case *gateway.GuildMembersChunkEvent: for _, m := range ev.Members { - m := m - - if err := s.Store.MemberSet(ev.GuildID, &m); err != nil { + if err := s.Store.MemberSet(ev.GuildID, m); err != nil { s.stateErr(err, "failed to add a member from chunk in state") } } for _, p := range ev.Presences { - p := p - - if err := s.Store.PresenceSet(ev.GuildID, &p); err != nil { + if err := s.Store.PresenceSet(ev.GuildID, p); err != nil { s.stateErr(err, "failed to add a presence from chunk in state") } } case *gateway.GuildRoleCreateEvent: - if err := s.Store.RoleSet(ev.GuildID, &ev.Role); err != nil { + if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil { s.stateErr(err, "failed to add a role in state") } case *gateway.GuildRoleUpdateEvent: - if err := s.Store.RoleSet(ev.GuildID, &ev.Role); err != nil { + if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil { s.stateErr(err, "failed to update a role in state") } @@ -151,17 +145,17 @@ func (s *State) onEvent(iface interface{}) { } case *gateway.ChannelCreateEvent: - if err := s.Store.ChannelSet(&ev.Channel); err != nil { + if err := s.Store.ChannelSet(ev.Channel); err != nil { s.stateErr(err, "failed to create a channel in state") } case *gateway.ChannelUpdateEvent: - if err := s.Store.ChannelSet(&ev.Channel); err != nil { + if err := s.Store.ChannelSet(ev.Channel); err != nil { s.stateErr(err, "failed to update a channel in state") } case *gateway.ChannelDeleteEvent: - if err := s.Store.ChannelRemove(&ev.Channel); err != nil { + if err := s.Store.ChannelRemove(ev.Channel); err != nil { s.stateErr(err, "failed to remove a channel in state") } @@ -169,12 +163,12 @@ func (s *State) onEvent(iface interface{}) { // not tracked. case *gateway.MessageCreateEvent: - if err := s.Store.MessageSet(&ev.Message); err != nil { + if err := s.Store.MessageSet(ev.Message); err != nil { s.stateErr(err, "failed to add a message in state") } case *gateway.MessageUpdateEvent: - if err := s.Store.MessageSet(&ev.Message); err != nil { + if err := s.Store.MessageSet(ev.Message); err != nil { s.stateErr(err, "failed to update a message in state") } @@ -250,15 +244,13 @@ func (s *State) onEvent(iface interface{}) { }) case *gateway.PresenceUpdateEvent: - if err := s.Store.PresenceSet(ev.GuildID, &ev.Presence); err != nil { + if err := s.Store.PresenceSet(ev.GuildID, ev.Presence); err != nil { s.stateErr(err, "failed to update presence in state") } case *gateway.PresencesReplaceEvent: - for i := range *ev { - p := (*ev)[i] - - if err := s.Store.PresenceSet(p.GuildID, &p); err != nil { + for _, p := range *ev { + if err := s.Store.PresenceSet(p.GuildID, p); err != nil { s.stateErr(err, "failed to update presence in state") } } @@ -279,7 +271,7 @@ func (s *State) onEvent(iface interface{}) { s.Ready.Notes[ev.ID] = ev.Note case *gateway.UserUpdateEvent: - if err := s.Store.MyselfSet(&ev.User); err != nil { + if err := s.Store.MyselfSet(ev.User); err != nil { s.stateErr(err, "failed to update myself from USER_UPDATE") } @@ -290,7 +282,7 @@ func (s *State) onEvent(iface interface{}) { s.stateErr(err, "failed to remove voice state from state") } } else { - if err := s.Store.VoiceStateSet(vs.GuildID, vs); err != nil { + if err := s.Store.VoiceStateSet(vs.GuildID, *vs); err != nil { s.stateErr(err, "failed to update voice state in state") } } @@ -316,7 +308,7 @@ func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func if !fn(m) { return } - if err := s.Store.MessageSet(m); err != nil { + if err := s.Store.MessageSet(*m); err != nil { s.stateErr(err, "failed to save message in reaction add") } } @@ -337,7 +329,7 @@ func storeGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error { stack, errs := newErrorStack() - if err := store.GuildSet(&guild.Guild); err != nil { + if err := store.GuildSet(guild.Guild); err != nil { errs(err, "failed to set guild in Ready") } @@ -349,33 +341,32 @@ func storeGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error { } // Handle guild member - for i := range guild.Members { - if err := store.MemberSet(guild.ID, &guild.Members[i]); err != nil { + for _, m := range guild.Members { + if err := store.MemberSet(guild.ID, m); err != nil { errs(err, "failed to set guild member in Ready") } } // Handle guild channels - for i := range guild.Channels { + for _, ch := range guild.Channels { // I HATE Discord. - ch := guild.Channels[i] ch.GuildID = guild.ID - if err := store.ChannelSet(&ch); err != nil { + if err := store.ChannelSet(ch); err != nil { errs(err, "failed to set guild channel in Ready") } } // Handle guild presences - for i := range guild.Presences { - if err := store.PresenceSet(guild.ID, &guild.Presences[i]); err != nil { + for _, p := range guild.Presences { + if err := store.PresenceSet(guild.ID, p); err != nil { errs(err, "failed to set guild presence in Ready") } } // Handle guild voice states - for i := range guild.VoiceStates { - if err := store.VoiceStateSet(guild.ID, &guild.VoiceStates[i]); err != nil { + for _, v := range guild.VoiceStates { + if err := store.VoiceStateSet(guild.ID, v); err != nil { errs(err, "failed to set guild voice state in Ready") } } diff --git a/state/store.go b/state/store.go index 4d4c91c..1d8ccb5 100644 --- a/state/store.go +++ b/state/store.go @@ -21,6 +21,9 @@ type Store interface { // would mutate the underlying slice (and as a result the returned slice as // well). The best way to avoid this is to copy the whole slice, like // DefaultStore does. +// +// These methods should not care about returning slices in order, unless +// explicitly stated against. type StoreGetter interface { Me() (*discord.User, error) @@ -58,34 +61,34 @@ type StoreGetter interface { } type StoreModifier interface { - MyselfSet(me *discord.User) error + MyselfSet(me discord.User) error // ChannelSet should switch on Type to know if it's a private channel or // not. - ChannelSet(*discord.Channel) error - ChannelRemove(*discord.Channel) error + ChannelSet(discord.Channel) error + ChannelRemove(discord.Channel) error // EmojiSet should delete all old emojis before setting new ones. EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error - GuildSet(*discord.Guild) error + GuildSet(discord.Guild) error GuildRemove(id discord.GuildID) error - MemberSet(guildID discord.GuildID, member *discord.Member) error + MemberSet(guildID discord.GuildID, member discord.Member) error MemberRemove(guildID discord.GuildID, userID discord.UserID) error // MessageSet should prepend messages into the slice, the latest being in // front. - MessageSet(*discord.Message) error + MessageSet(discord.Message) error MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error - PresenceSet(guildID discord.GuildID, presence *discord.Presence) error + PresenceSet(guildID discord.GuildID, presence discord.Presence) error PresenceRemove(guildID discord.GuildID, userID discord.UserID) error - RoleSet(guildID discord.GuildID, role *discord.Role) error + RoleSet(guildID discord.GuildID, role discord.Role) error RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error - VoiceStateSet(guildID discord.GuildID, voiceState *discord.VoiceState) error + VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error } diff --git a/state/store_default.go b/state/store_default.go index bdef92e..bbb9ef6 100644 --- a/state/store_default.go +++ b/state/store_default.go @@ -1,7 +1,6 @@ package state import ( - "sort" "sync" "github.com/diamondburned/arikawa/discord" @@ -10,21 +9,25 @@ import ( // TODO: make an ExpiryStore type DefaultStore struct { - *DefaultStoreOptions + DefaultStoreOptions self discord.User // includes normal and private - privates map[discord.ChannelID]*discord.Channel - guilds map[discord.GuildID]*discord.Guild + privates map[discord.ChannelID]discord.Channel + guilds map[discord.GuildID]discord.Guild + roles map[discord.GuildID][]discord.Role + emojis map[discord.GuildID][]discord.Emoji channels map[discord.GuildID][]discord.Channel - members map[discord.GuildID][]discord.Member presences map[discord.GuildID][]discord.Presence - messages map[discord.ChannelID][]discord.Message voiceStates map[discord.GuildID][]discord.VoiceState + messages map[discord.ChannelID][]discord.Message - mut sync.Mutex + // special case; optimize for lots of members + members map[discord.GuildID]map[discord.UserID]discord.Member + + mut sync.RWMutex } type DefaultStoreOptions struct { @@ -40,9 +43,7 @@ func NewDefaultStore(opts *DefaultStoreOptions) *DefaultStore { } } - ds := &DefaultStore{ - DefaultStoreOptions: opts, - } + ds := &DefaultStore{DefaultStoreOptions: *opts} ds.Reset() return ds @@ -54,14 +55,17 @@ func (s *DefaultStore) Reset() error { s.self = discord.User{} - s.privates = map[discord.ChannelID]*discord.Channel{} - s.guilds = map[discord.GuildID]*discord.Guild{} + s.privates = map[discord.ChannelID]discord.Channel{} + s.guilds = map[discord.GuildID]discord.Guild{} + s.roles = map[discord.GuildID][]discord.Role{} + s.emojis = map[discord.GuildID][]discord.Emoji{} s.channels = map[discord.GuildID][]discord.Channel{} - s.members = map[discord.GuildID][]discord.Member{} s.presences = map[discord.GuildID][]discord.Presence{} - s.messages = map[discord.ChannelID][]discord.Message{} s.voiceStates = map[discord.GuildID][]discord.VoiceState{} + s.messages = map[discord.ChannelID][]discord.Message{} + + s.members = map[discord.GuildID]map[discord.UserID]discord.Member{} return nil } @@ -69,8 +73,8 @@ func (s *DefaultStore) Reset() error { //// func (s *DefaultStore) Me() (*discord.User, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() if !s.self.ID.Valid() { return nil, ErrStoreNotFound @@ -79,9 +83,9 @@ func (s *DefaultStore) Me() (*discord.User, error) { return &s.self, nil } -func (s *DefaultStore) MyselfSet(me *discord.User) error { +func (s *DefaultStore) MyselfSet(me discord.User) error { s.mut.Lock() - s.self = *me + s.self = me s.mut.Unlock() return nil @@ -90,11 +94,12 @@ func (s *DefaultStore) MyselfSet(me *discord.User) error { //// func (s *DefaultStore) Channel(id discord.ChannelID) (*discord.Channel, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() if ch, ok := s.privates[id]; ok { - return ch, nil + // implicit copy + return &ch, nil } for _, chs := range s.channels { @@ -109,8 +114,8 @@ func (s *DefaultStore) Channel(id discord.ChannelID) (*discord.Channel, error) { } func (s *DefaultStore) Channels(guildID discord.GuildID) ([]discord.Channel, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() chs, ok := s.channels[guildID] if !ok { @@ -123,16 +128,17 @@ func (s *DefaultStore) Channels(guildID discord.GuildID) ([]discord.Channel, err // CreatePrivateChannel searches in the cache for a private channel. It makes no // API calls. func (s *DefaultStore) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() // slow way for _, ch := range s.privates { - if ch.Type != discord.DirectMessage || len(ch.DMRecipients) < 1 { + if ch.Type != discord.DirectMessage || len(ch.DMRecipients) == 0 { continue } if ch.DMRecipients[0].ID == recipient { - return &(*ch), nil + // Return an implicit copy made by range. + return &ch, nil } } return nil, ErrStoreNotFound @@ -140,18 +146,18 @@ func (s *DefaultStore) CreatePrivateChannel(recipient discord.UserID) (*discord. // PrivateChannels returns a list of Direct Message channels randomly ordered. func (s *DefaultStore) PrivateChannels() ([]discord.Channel, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() var chs = make([]discord.Channel, 0, len(s.privates)) - for _, ch := range s.privates { - chs = append(chs, *ch) + for i := range s.privates { + chs = append(chs, s.privates[i]) } return chs, nil } -func (s *DefaultStore) ChannelSet(channel *discord.Channel) error { +func (s *DefaultStore) ChannelSet(channel discord.Channel) error { s.mut.Lock() defer s.mut.Unlock() @@ -169,20 +175,20 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error { } // Found, just edit - chs[i] = *channel + chs[i] = channel return nil } } - chs = append(chs, *channel) + chs = append(chs, channel) s.channels[channel.GuildID] = chs } return nil } -func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error { +func (s *DefaultStore) ChannelRemove(channel discord.Channel) error { s.mut.Lock() defer s.mut.Unlock() @@ -193,9 +199,11 @@ func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error { for i, ch := range chs { if ch.ID == channel.ID { - chs = append(chs[:i], chs[i+1:]...) - s.channels[channel.GuildID] = chs + // Fast unordered delete. + chs[i] = chs[len(chs)-1] + chs = chs[:len(chs)-1] + s.channels[channel.GuildID] = chs return nil } } @@ -206,16 +214,17 @@ func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error { //// func (s *DefaultStore) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() - gd, ok := s.guilds[guildID] + emojis, ok := s.emojis[guildID] if !ok { return nil, ErrStoreNotFound } - for _, emoji := range gd.Emojis { + for _, emoji := range emojis { if emoji.ID == emojiID { + // Emoji is an implicit copy, so we could do this safely. return &emoji, nil } } @@ -224,162 +233,126 @@ func (s *DefaultStore) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) ( } func (s *DefaultStore) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() - gd, ok := s.guilds[guildID] + emojis, ok := s.emojis[guildID] if !ok { return nil, ErrStoreNotFound } - return append([]discord.Emoji{}, gd.Emojis...), nil + return append([]discord.Emoji{}, emojis...), nil } func (s *DefaultStore) EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error { s.mut.Lock() defer s.mut.Unlock() - gd, ok := s.guilds[guildID] - if !ok { - return ErrStoreNotFound - } + // A nil slice is acceptable, as we'll make a new slice later on and set it. + s.emojis[guildID] = emojis - filtered := emojis[:0] - -Main: - for _, enew := range emojis { - // Try and see if this emoji is already in the slice - for i, emoji := range gd.Emojis { - if emoji.ID == enew.ID { - // If it is, we simply replace it - gd.Emojis[i] = enew - - continue Main - } - } - - // If not, we add it to the slice that's to be appended. - filtered = append(filtered, enew) - } - - // Append the new emojis - gd.Emojis = append(gd.Emojis, filtered...) return nil } //// func (s *DefaultStore) Guild(id discord.GuildID) (*discord.Guild, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() ch, ok := s.guilds[id] if !ok { return nil, ErrStoreNotFound } - return ch, nil + // implicit copy + return &ch, nil } func (s *DefaultStore) Guilds() ([]discord.Guild, error) { - s.mut.Lock() + s.mut.RLock() + defer s.mut.RUnlock() if len(s.guilds) == 0 { - s.mut.Unlock() return nil, ErrStoreNotFound } var gs = make([]discord.Guild, 0, len(s.guilds)) for _, g := range s.guilds { - gs = append(gs, *g) + gs = append(gs, g) } - s.mut.Unlock() - - sort.Slice(gs, func(i, j int) bool { - return gs[i].ID > gs[j].ID - }) - return gs, nil } -func (s *DefaultStore) GuildSet(guild *discord.Guild) error { +func (s *DefaultStore) GuildSet(guild discord.Guild) error { s.mut.Lock() defer s.mut.Unlock() - if g, ok := s.guilds[guild.ID]; ok { - // preserve state stuff - if guild.Roles == nil { - guild.Roles = g.Roles - } - if guild.Emojis == nil { - guild.Emojis = g.Emojis - } - } - s.guilds[guild.ID] = guild return nil } func (s *DefaultStore) GuildRemove(id discord.GuildID) error { s.mut.Lock() - delete(s.guilds, id) - s.mut.Unlock() + defer s.mut.Unlock() + if _, ok := s.guilds[id]; !ok { + return ErrStoreNotFound + } + + delete(s.guilds, id) return nil } //// -func (s *DefaultStore) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { - s.mut.Lock() - defer s.mut.Unlock() +func (s *DefaultStore) Member( + guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { + + s.mut.RLock() + defer s.mut.RUnlock() ms, ok := s.members[guildID] if !ok { return nil, ErrStoreNotFound } - for _, m := range ms { - if m.User.ID == userID { - return &m, nil - } + m, ok := ms[userID] + if ok { + return &m, nil } return nil, ErrStoreNotFound } func (s *DefaultStore) Members(guildID discord.GuildID) ([]discord.Member, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() ms, ok := s.members[guildID] if !ok { return nil, ErrStoreNotFound } - return append([]discord.Member{}, ms...), nil + var members = make([]discord.Member, 0, len(ms)) + for _, m := range ms { + members = append(members, m) + } + + return members, nil } -func (s *DefaultStore) MemberSet(guildID discord.GuildID, member *discord.Member) error { +func (s *DefaultStore) MemberSet(guildID discord.GuildID, member discord.Member) error { s.mut.Lock() defer s.mut.Unlock() - ms := s.members[guildID] - - // Try and see if this member is already in the slice - for i, m := range ms { - if m.User.ID == member.User.ID { - // If it is, we simply replace it - ms[i] = *member - s.members[guildID] = ms - - return nil - } + ms, ok := s.members[guildID] + if !ok { + ms = make(map[discord.UserID]discord.Member, 1) } - // Append the new member - ms = append(ms, *member) + ms[member.User.ID] = member s.members[guildID] = ms return nil @@ -394,24 +367,21 @@ func (s *DefaultStore) MemberRemove(guildID discord.GuildID, userID discord.User return ErrStoreNotFound } - // Try and see if this member is already in the slice - for i, m := range ms { - if m.User.ID == userID { - ms = append(ms, ms[i+1:]...) - s.members[guildID] = ms - - return nil - } + if _, ok := ms[userID]; !ok { + return ErrStoreNotFound } - return ErrStoreNotFound + delete(ms, userID) + return nil } //// -func (s *DefaultStore) Message(channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { - s.mut.Lock() - defer s.mut.Unlock() +func (s *DefaultStore) Message( + channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { + + s.mut.RLock() + defer s.mut.RUnlock() ms, ok := s.messages[channelID] if !ok { @@ -428,24 +398,22 @@ func (s *DefaultStore) Message(channelID discord.ChannelID, messageID discord.Me } func (s *DefaultStore) Messages(channelID discord.ChannelID) ([]discord.Message, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() ms, ok := s.messages[channelID] if !ok { return nil, ErrStoreNotFound } - cp := make([]discord.Message, len(ms)) - copy(cp, ms) - return cp, nil + return append([]discord.Message{}, ms...), nil } func (s *DefaultStore) MaxMessages() int { return int(s.DefaultStoreOptions.MaxMessages) } -func (s *DefaultStore) MessageSet(message *discord.Message) error { +func (s *DefaultStore) MessageSet(message discord.Message) error { s.mut.Lock() defer s.mut.Unlock() @@ -457,7 +425,7 @@ func (s *DefaultStore) MessageSet(message *discord.Message) error { // Check if we already have the message. for i, m := range ms { if m.ID == message.ID { - DiffMessage(*message, &m) + DiffMessage(message, &m) ms[i] = m return nil } @@ -480,13 +448,15 @@ func (s *DefaultStore) MessageSet(message *discord.Message) error { // 1st-endth. copy(ms[1:end], ms[0:end-1]) // Then, set the 0th entry. - ms[0] = *message + ms[0] = message s.messages[message.ChannelID] = ms return nil } -func (s *DefaultStore) MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error { +func (s *DefaultStore) MessageRemove( + channelID discord.ChannelID, messageID discord.MessageID) error { + s.mut.Lock() defer s.mut.Unlock() @@ -508,9 +478,11 @@ func (s *DefaultStore) MessageRemove(channelID discord.ChannelID, messageID disc //// -func (s *DefaultStore) Presence(guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { - s.mut.Lock() - defer s.mut.Unlock() +func (s *DefaultStore) Presence( + guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { + + s.mut.RLock() + defer s.mut.RUnlock() ps, ok := s.presences[guildID] if !ok { @@ -527,33 +499,32 @@ func (s *DefaultStore) Presence(guildID discord.GuildID, userID discord.UserID) } func (s *DefaultStore) Presences(guildID discord.GuildID) ([]discord.Presence, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() ps, ok := s.presences[guildID] if !ok { return nil, ErrStoreNotFound } - return ps, nil + return append([]discord.Presence{}, ps...), nil } -func (s *DefaultStore) PresenceSet(guildID discord.GuildID, presence *discord.Presence) error { +func (s *DefaultStore) PresenceSet(guildID discord.GuildID, presence discord.Presence) error { s.mut.Lock() defer s.mut.Unlock() - ps := s.presences[guildID] + ps, _ := s.presences[guildID] for i, p := range ps { if p.User.ID == presence.User.ID { - ps[i] = *presence - s.presences[guildID] = ps - + // Change the backing array. + ps[i] = presence return nil } } - ps = append(ps, *presence) + ps = append(ps, presence) s.presences[guildID] = ps return nil } @@ -569,9 +540,10 @@ func (s *DefaultStore) PresenceRemove(guildID discord.GuildID, userID discord.Us for i, p := range ps { if p.User.ID == userID { - ps = append(ps[:i], ps[i+1:]...) - s.presences[guildID] = ps + ps[i] = ps[len(ps)-1] + ps = ps[:len(ps)-1] + s.presences[guildID] = ps return nil } } @@ -582,15 +554,15 @@ func (s *DefaultStore) PresenceRemove(guildID discord.GuildID, userID discord.Us //// func (s *DefaultStore) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() - gd, ok := s.guilds[guildID] + rs, ok := s.roles[guildID] if !ok { return nil, ErrStoreNotFound } - for _, r := range gd.Roles { + for _, r := range rs { if r.ID == roleID { return &r, nil } @@ -600,34 +572,35 @@ func (s *DefaultStore) Role(guildID discord.GuildID, roleID discord.RoleID) (*di } func (s *DefaultStore) Roles(guildID discord.GuildID) ([]discord.Role, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() - gd, ok := s.guilds[guildID] + rs, ok := s.roles[guildID] if !ok { return nil, ErrStoreNotFound } - return append([]discord.Role{}, gd.Roles...), nil + return append([]discord.Role{}, rs...), nil } -func (s *DefaultStore) RoleSet(guildID discord.GuildID, role *discord.Role) error { +func (s *DefaultStore) RoleSet(guildID discord.GuildID, role discord.Role) error { s.mut.Lock() defer s.mut.Unlock() - gd, ok := s.guilds[guildID] - if !ok { - return ErrStoreNotFound - } + // A nil slice is fine, since we can just append the role. + rs, _ := s.roles[guildID] - for i, r := range gd.Roles { + for i, r := range rs { if r.ID == role.ID { - gd.Roles[i] = *role + // This changes the backing array, so we don't need to reset the + // slice. + rs[i] = role return nil } } - gd.Roles = append(gd.Roles, *role) + rs = append(rs, role) + s.roles[guildID] = rs return nil } @@ -635,14 +608,18 @@ func (s *DefaultStore) RoleRemove(guildID discord.GuildID, roleID discord.RoleID s.mut.Lock() defer s.mut.Unlock() - gd, ok := s.guilds[guildID] + rs, ok := s.roles[guildID] if !ok { return ErrStoreNotFound } - for i, r := range gd.Roles { + for i, r := range rs { if r.ID == roleID { - gd.Roles = append(gd.Roles[:i], gd.Roles[i+1:]...) + // Fast delete. + rs[i] = rs[len(rs)-1] + rs = rs[:len(rs)-1] + + s.roles[guildID] = rs return nil } } @@ -652,9 +629,11 @@ func (s *DefaultStore) RoleRemove(guildID discord.GuildID, roleID discord.RoleID //// -func (s *DefaultStore) VoiceState(guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) { - s.mut.Lock() - defer s.mut.Unlock() +func (s *DefaultStore) VoiceState( + guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) { + + s.mut.RLock() + defer s.mut.RUnlock() states, ok := s.voiceStates[guildID] if !ok { @@ -671,8 +650,8 @@ func (s *DefaultStore) VoiceState(guildID discord.GuildID, userID discord.UserID } func (s *DefaultStore) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error) { - s.mut.Lock() - defer s.mut.Unlock() + s.mut.RLock() + defer s.mut.RUnlock() states, ok := s.voiceStates[guildID] if !ok { @@ -682,22 +661,21 @@ func (s *DefaultStore) VoiceStates(guildID discord.GuildID) ([]discord.VoiceStat return append([]discord.VoiceState{}, states...), nil } -func (s *DefaultStore) VoiceStateSet(guildID discord.GuildID, voiceState *discord.VoiceState) error { +func (s *DefaultStore) VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error { s.mut.Lock() defer s.mut.Unlock() - states := s.voiceStates[guildID] + states, _ := s.voiceStates[guildID] for i, vs := range states { if vs.UserID == voiceState.UserID { - states[i] = *voiceState - s.voiceStates[guildID] = states - + // change the backing array + states[i] = voiceState return nil } } - states = append(states, *voiceState) + states = append(states, voiceState) s.voiceStates[guildID] = states return nil } diff --git a/state/store_noop.go b/state/store_noop.go index 9bc436e..0a146b3 100644 --- a/state/store_noop.go +++ b/state/store_noop.go @@ -23,7 +23,7 @@ func (NoopStore) Me() (*discord.User, error) { return nil, ErrNotImplemented } -func (NoopStore) MyselfSet(*discord.User) error { +func (NoopStore) MyselfSet(discord.User) error { return nil } @@ -43,11 +43,11 @@ func (NoopStore) PrivateChannels() ([]discord.Channel, error) { return nil, ErrNotImplemented } -func (NoopStore) ChannelSet(*discord.Channel) error { +func (NoopStore) ChannelSet(discord.Channel) error { return nil } -func (NoopStore) ChannelRemove(*discord.Channel) error { +func (NoopStore) ChannelRemove(discord.Channel) error { return nil } @@ -71,7 +71,7 @@ func (NoopStore) Guilds() ([]discord.Guild, error) { return nil, ErrNotImplemented } -func (NoopStore) GuildSet(*discord.Guild) error { +func (NoopStore) GuildSet(discord.Guild) error { return nil } @@ -87,7 +87,7 @@ func (NoopStore) Members(discord.GuildID) ([]discord.Member, error) { return nil, ErrNotImplemented } -func (NoopStore) MemberSet(discord.GuildID, *discord.Member) error { +func (NoopStore) MemberSet(discord.GuildID, discord.Member) error { return nil } @@ -109,7 +109,7 @@ func (NoopStore) MaxMessages() int { return 100 } -func (NoopStore) MessageSet(*discord.Message) error { +func (NoopStore) MessageSet(discord.Message) error { return nil } @@ -125,7 +125,7 @@ func (NoopStore) Presences(discord.GuildID) ([]discord.Presence, error) { return nil, ErrNotImplemented } -func (NoopStore) PresenceSet(discord.GuildID, *discord.Presence) error { +func (NoopStore) PresenceSet(discord.GuildID, discord.Presence) error { return nil } @@ -141,7 +141,7 @@ func (NoopStore) Roles(discord.GuildID) ([]discord.Role, error) { return nil, ErrNotImplemented } -func (NoopStore) RoleSet(discord.GuildID, *discord.Role) error { +func (NoopStore) RoleSet(discord.GuildID, discord.Role) error { return nil } @@ -157,7 +157,7 @@ func (NoopStore) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) { return nil, ErrNotImplemented } -func (NoopStore) VoiceStateSet(discord.GuildID, *discord.VoiceState) error { +func (NoopStore) VoiceStateSet(discord.GuildID, discord.VoiceState) error { return ErrNotImplemented }