State: Breaking API to fix race conditions in store

This commit is contained in:
diamondburned 2020-07-28 12:00:01 -07:00 committed by diamondburned
parent b8f6fbbda9
commit e79132f2c5
5 changed files with 243 additions and 265 deletions

View File

@ -17,7 +17,7 @@ import (
var (
MaxFetchMembers uint = 1000
MaxFetchGuilds uint = 100
MaxFetchGuilds uint = 10
)
// State is the cache to store events coming from Discord as well as data from
@ -80,12 +80,10 @@ type State struct {
// with the State.
*handler.Handler
unhooker func()
// List of channels with few messages, so it doesn't bother hitting the API
// again.
fewMessages map[discord.ChannelID]struct{}
fewMutex *sync.Mutex
fewMutex sync.Mutex
// unavailableGuilds is a set of discord.GuildIDs of guilds that became
// unavailable when already connected to the gateway, i.e. sent in a
@ -131,7 +129,7 @@ func NewFromSession(s *session.Session, store Store) (*State, error) {
Handler: handler.New(),
StateLog: func(err error) {},
fewMessages: map[discord.ChannelID]struct{}{},
fewMutex: new(sync.Mutex),
fewMutex: sync.Mutex{},
unavailableGuilds: moreatomic.NewGuildIDSet(),
unreadyGuilds: moreatomic.NewGuildIDSet(),
}
@ -235,7 +233,9 @@ func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (dis
////
func (s *State) Permissions(channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) {
func (s *State) Permissions(
channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) {
ch, err := s.Channel(channelID)
if err != nil {
return 0, errors.Wrap(err, "failed to get channel")
@ -286,7 +286,7 @@ func (s *State) Me() (*discord.User, error) {
return nil, err
}
return u, s.Store.MyselfSet(u)
return u, s.Store.MyselfSet(*u)
}
////
@ -302,7 +302,7 @@ func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) {
return nil, err
}
return c, s.Store.ChannelSet(c)
return c, s.Store.ChannelSet(*c)
}
func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
@ -319,7 +319,7 @@ func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
for _, ch := range c {
ch := ch
if err := s.Store.ChannelSet(&ch); err != nil {
if err := s.Store.ChannelSet(ch); err != nil {
return nil, err
}
}
@ -338,7 +338,7 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel
return nil, err
}
return c, s.Store.ChannelSet(c)
return c, s.Store.ChannelSet(*c)
}
func (s *State) PrivateChannels() ([]discord.Channel, error) {
@ -355,7 +355,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) {
for _, ch := range c {
ch := ch
if err := s.Store.ChannelSet(&ch); err != nil {
if err := s.Store.ChannelSet(ch); err != nil {
return nil, err
}
}
@ -431,7 +431,7 @@ func (s *State) Guilds() ([]discord.Guild, error) {
for _, ch := range c {
ch := ch
if err := s.Store.GuildSet(&ch); err != nil {
if err := s.Store.GuildSet(ch); err != nil {
return nil, err
}
}
@ -462,7 +462,7 @@ func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) {
}
for _, m := range ms {
if err := s.Store.MemberSet(guildID, &m); err != nil {
if err := s.Store.MemberSet(guildID, m); err != nil {
return nil, err
}
}
@ -475,7 +475,9 @@ func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) {
////
func (s *State) Message(channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
func (s *State) Message(
channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
m, err := s.Store.Message(channelID, messageID)
if err == nil {
return m, nil
@ -489,7 +491,7 @@ func (s *State) Message(channelID discord.ChannelID, messageID discord.MessageID
go func() {
c, cerr = s.Session.Channel(channelID)
if cerr == nil {
cerr = s.Store.ChannelSet(c)
cerr = s.Store.ChannelSet(*c)
}
wg.Done()
@ -510,7 +512,7 @@ func (s *State) Message(channelID discord.ChannelID, messageID discord.MessageID
m.ChannelID = c.ID
m.GuildID = c.GuildID
return m, s.Store.MessageSet(m)
return m, s.Store.MessageSet(*m)
}
// Messages fetches maximum 100 messages from the API, if it has to. There is no
@ -559,7 +561,7 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error)
// Set the guild ID, fine if it's 0 (it's already 0 anyway).
ms[i].GuildID = guildID
if err := s.Store.MessageSet(&ms[i]); err != nil {
if err := s.Store.MessageSet(ms[i]); err != nil {
return nil, err
}
}
@ -582,7 +584,9 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error)
// Presence checks the state for user presences. If no guildID is given, it will
// look for the presence in all guilds.
func (s *State) Presence(guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
func (s *State) Presence(
guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
p, err := s.Store.Presence(guildID, userID)
if err == nil {
return p, nil
@ -627,7 +631,7 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.R
role = &r
}
if err := s.RoleSet(guildID, &r); err != nil {
if err := s.RoleSet(guildID, r); err != nil {
return role, err
}
}
@ -649,7 +653,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
for _, r := range rs {
r := r
if err := s.RoleSet(guildID, &r); err != nil {
if err := s.RoleSet(guildID, r); err != nil {
return rs, err
}
}
@ -660,16 +664,18 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
g, err = s.Session.Guild(id)
if err == nil {
err = s.Store.GuildSet(g)
err = s.Store.GuildSet(*g)
}
return
}
func (s *State) fetchMember(guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) {
func (s *State) fetchMember(
guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) {
m, err = s.Session.Member(guildID, userID)
if err == nil {
err = s.Store.MemberSet(guildID, m)
err = s.Store.MemberSet(guildID, *m)
}
return

View File

@ -8,7 +8,7 @@ import (
)
func (s *State) hookSession() {
s.unhooker = s.Session.AddHandler(func(event interface{}) {
s.Session.AddHandler(func(event interface{}) {
// Call the pre-handler before the state handler.
if s.PreHandler != nil {
s.PreHandler.Call(event)
@ -55,9 +55,7 @@ func (s *State) onEvent(iface interface{}) {
// Handle presences
for _, p := range ev.Presences {
p := p
if err := s.Store.PresenceSet(0, &p); err != nil {
if err := s.Store.PresenceSet(0, p); err != nil {
s.stateErr(err, "failed to set global presence")
}
}
@ -68,19 +66,19 @@ func (s *State) onEvent(iface interface{}) {
}
// Handle private channels
for i := range ev.PrivateChannels {
if err := s.Store.ChannelSet(&ev.PrivateChannels[i]); err != nil {
for _, ch := range ev.PrivateChannels {
if err := s.Store.ChannelSet(ch); err != nil {
s.stateErr(err, "failed to set channel in state")
}
}
// Handle user
if err := s.Store.MyselfSet(&ev.User); err != nil {
if err := s.Store.MyselfSet(ev.User); err != nil {
s.stateErr(err, "failed to set self in state")
}
case *gateway.GuildUpdateEvent:
if err := s.Store.GuildSet(&ev.Guild); err != nil {
if err := s.Store.GuildSet(ev.Guild); err != nil {
s.stateErr(err, "failed to update guild in state")
}
@ -90,7 +88,7 @@ func (s *State) onEvent(iface interface{}) {
}
case *gateway.GuildMemberAddEvent:
if err := s.Store.MemberSet(ev.GuildID, &ev.Member); err != nil {
if err := s.Store.MemberSet(ev.GuildID, ev.Member); err != nil {
s.stateErr(err, "failed to add a member in state")
}
@ -104,7 +102,7 @@ func (s *State) onEvent(iface interface{}) {
// Update available fields from ev into m
ev.Update(m)
if err := s.Store.MemberSet(ev.GuildID, m); err != nil {
if err := s.Store.MemberSet(ev.GuildID, *m); err != nil {
s.stateErr(err, "failed to update a member in state")
}
@ -115,28 +113,24 @@ func (s *State) onEvent(iface interface{}) {
case *gateway.GuildMembersChunkEvent:
for _, m := range ev.Members {
m := m
if err := s.Store.MemberSet(ev.GuildID, &m); err != nil {
if err := s.Store.MemberSet(ev.GuildID, m); err != nil {
s.stateErr(err, "failed to add a member from chunk in state")
}
}
for _, p := range ev.Presences {
p := p
if err := s.Store.PresenceSet(ev.GuildID, &p); err != nil {
if err := s.Store.PresenceSet(ev.GuildID, p); err != nil {
s.stateErr(err, "failed to add a presence from chunk in state")
}
}
case *gateway.GuildRoleCreateEvent:
if err := s.Store.RoleSet(ev.GuildID, &ev.Role); err != nil {
if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil {
s.stateErr(err, "failed to add a role in state")
}
case *gateway.GuildRoleUpdateEvent:
if err := s.Store.RoleSet(ev.GuildID, &ev.Role); err != nil {
if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil {
s.stateErr(err, "failed to update a role in state")
}
@ -151,17 +145,17 @@ func (s *State) onEvent(iface interface{}) {
}
case *gateway.ChannelCreateEvent:
if err := s.Store.ChannelSet(&ev.Channel); err != nil {
if err := s.Store.ChannelSet(ev.Channel); err != nil {
s.stateErr(err, "failed to create a channel in state")
}
case *gateway.ChannelUpdateEvent:
if err := s.Store.ChannelSet(&ev.Channel); err != nil {
if err := s.Store.ChannelSet(ev.Channel); err != nil {
s.stateErr(err, "failed to update a channel in state")
}
case *gateway.ChannelDeleteEvent:
if err := s.Store.ChannelRemove(&ev.Channel); err != nil {
if err := s.Store.ChannelRemove(ev.Channel); err != nil {
s.stateErr(err, "failed to remove a channel in state")
}
@ -169,12 +163,12 @@ func (s *State) onEvent(iface interface{}) {
// not tracked.
case *gateway.MessageCreateEvent:
if err := s.Store.MessageSet(&ev.Message); err != nil {
if err := s.Store.MessageSet(ev.Message); err != nil {
s.stateErr(err, "failed to add a message in state")
}
case *gateway.MessageUpdateEvent:
if err := s.Store.MessageSet(&ev.Message); err != nil {
if err := s.Store.MessageSet(ev.Message); err != nil {
s.stateErr(err, "failed to update a message in state")
}
@ -250,15 +244,13 @@ func (s *State) onEvent(iface interface{}) {
})
case *gateway.PresenceUpdateEvent:
if err := s.Store.PresenceSet(ev.GuildID, &ev.Presence); err != nil {
if err := s.Store.PresenceSet(ev.GuildID, ev.Presence); err != nil {
s.stateErr(err, "failed to update presence in state")
}
case *gateway.PresencesReplaceEvent:
for i := range *ev {
p := (*ev)[i]
if err := s.Store.PresenceSet(p.GuildID, &p); err != nil {
for _, p := range *ev {
if err := s.Store.PresenceSet(p.GuildID, p); err != nil {
s.stateErr(err, "failed to update presence in state")
}
}
@ -279,7 +271,7 @@ func (s *State) onEvent(iface interface{}) {
s.Ready.Notes[ev.ID] = ev.Note
case *gateway.UserUpdateEvent:
if err := s.Store.MyselfSet(&ev.User); err != nil {
if err := s.Store.MyselfSet(ev.User); err != nil {
s.stateErr(err, "failed to update myself from USER_UPDATE")
}
@ -290,7 +282,7 @@ func (s *State) onEvent(iface interface{}) {
s.stateErr(err, "failed to remove voice state from state")
}
} else {
if err := s.Store.VoiceStateSet(vs.GuildID, vs); err != nil {
if err := s.Store.VoiceStateSet(vs.GuildID, *vs); err != nil {
s.stateErr(err, "failed to update voice state in state")
}
}
@ -316,7 +308,7 @@ func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func
if !fn(m) {
return
}
if err := s.Store.MessageSet(m); err != nil {
if err := s.Store.MessageSet(*m); err != nil {
s.stateErr(err, "failed to save message in reaction add")
}
}
@ -337,7 +329,7 @@ func storeGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error {
stack, errs := newErrorStack()
if err := store.GuildSet(&guild.Guild); err != nil {
if err := store.GuildSet(guild.Guild); err != nil {
errs(err, "failed to set guild in Ready")
}
@ -349,33 +341,32 @@ func storeGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error {
}
// Handle guild member
for i := range guild.Members {
if err := store.MemberSet(guild.ID, &guild.Members[i]); err != nil {
for _, m := range guild.Members {
if err := store.MemberSet(guild.ID, m); err != nil {
errs(err, "failed to set guild member in Ready")
}
}
// Handle guild channels
for i := range guild.Channels {
for _, ch := range guild.Channels {
// I HATE Discord.
ch := guild.Channels[i]
ch.GuildID = guild.ID
if err := store.ChannelSet(&ch); err != nil {
if err := store.ChannelSet(ch); err != nil {
errs(err, "failed to set guild channel in Ready")
}
}
// Handle guild presences
for i := range guild.Presences {
if err := store.PresenceSet(guild.ID, &guild.Presences[i]); err != nil {
for _, p := range guild.Presences {
if err := store.PresenceSet(guild.ID, p); err != nil {
errs(err, "failed to set guild presence in Ready")
}
}
// Handle guild voice states
for i := range guild.VoiceStates {
if err := store.VoiceStateSet(guild.ID, &guild.VoiceStates[i]); err != nil {
for _, v := range guild.VoiceStates {
if err := store.VoiceStateSet(guild.ID, v); err != nil {
errs(err, "failed to set guild voice state in Ready")
}
}

View File

@ -21,6 +21,9 @@ type Store interface {
// would mutate the underlying slice (and as a result the returned slice as
// well). The best way to avoid this is to copy the whole slice, like
// DefaultStore does.
//
// These methods should not care about returning slices in order, unless
// explicitly stated against.
type StoreGetter interface {
Me() (*discord.User, error)
@ -58,34 +61,34 @@ type StoreGetter interface {
}
type StoreModifier interface {
MyselfSet(me *discord.User) error
MyselfSet(me discord.User) error
// ChannelSet should switch on Type to know if it's a private channel or
// not.
ChannelSet(*discord.Channel) error
ChannelRemove(*discord.Channel) error
ChannelSet(discord.Channel) error
ChannelRemove(discord.Channel) error
// EmojiSet should delete all old emojis before setting new ones.
EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error
GuildSet(*discord.Guild) error
GuildSet(discord.Guild) error
GuildRemove(id discord.GuildID) error
MemberSet(guildID discord.GuildID, member *discord.Member) error
MemberSet(guildID discord.GuildID, member discord.Member) error
MemberRemove(guildID discord.GuildID, userID discord.UserID) error
// MessageSet should prepend messages into the slice, the latest being in
// front.
MessageSet(*discord.Message) error
MessageSet(discord.Message) error
MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error
PresenceSet(guildID discord.GuildID, presence *discord.Presence) error
PresenceSet(guildID discord.GuildID, presence discord.Presence) error
PresenceRemove(guildID discord.GuildID, userID discord.UserID) error
RoleSet(guildID discord.GuildID, role *discord.Role) error
RoleSet(guildID discord.GuildID, role discord.Role) error
RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error
VoiceStateSet(guildID discord.GuildID, voiceState *discord.VoiceState) error
VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error
VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error
}

View File

@ -1,7 +1,6 @@
package state
import (
"sort"
"sync"
"github.com/diamondburned/arikawa/discord"
@ -10,21 +9,25 @@ import (
// TODO: make an ExpiryStore
type DefaultStore struct {
*DefaultStoreOptions
DefaultStoreOptions
self discord.User
// includes normal and private
privates map[discord.ChannelID]*discord.Channel
guilds map[discord.GuildID]*discord.Guild
privates map[discord.ChannelID]discord.Channel
guilds map[discord.GuildID]discord.Guild
roles map[discord.GuildID][]discord.Role
emojis map[discord.GuildID][]discord.Emoji
channels map[discord.GuildID][]discord.Channel
members map[discord.GuildID][]discord.Member
presences map[discord.GuildID][]discord.Presence
messages map[discord.ChannelID][]discord.Message
voiceStates map[discord.GuildID][]discord.VoiceState
messages map[discord.ChannelID][]discord.Message
mut sync.Mutex
// special case; optimize for lots of members
members map[discord.GuildID]map[discord.UserID]discord.Member
mut sync.RWMutex
}
type DefaultStoreOptions struct {
@ -40,9 +43,7 @@ func NewDefaultStore(opts *DefaultStoreOptions) *DefaultStore {
}
}
ds := &DefaultStore{
DefaultStoreOptions: opts,
}
ds := &DefaultStore{DefaultStoreOptions: *opts}
ds.Reset()
return ds
@ -54,14 +55,17 @@ func (s *DefaultStore) Reset() error {
s.self = discord.User{}
s.privates = map[discord.ChannelID]*discord.Channel{}
s.guilds = map[discord.GuildID]*discord.Guild{}
s.privates = map[discord.ChannelID]discord.Channel{}
s.guilds = map[discord.GuildID]discord.Guild{}
s.roles = map[discord.GuildID][]discord.Role{}
s.emojis = map[discord.GuildID][]discord.Emoji{}
s.channels = map[discord.GuildID][]discord.Channel{}
s.members = map[discord.GuildID][]discord.Member{}
s.presences = map[discord.GuildID][]discord.Presence{}
s.messages = map[discord.ChannelID][]discord.Message{}
s.voiceStates = map[discord.GuildID][]discord.VoiceState{}
s.messages = map[discord.ChannelID][]discord.Message{}
s.members = map[discord.GuildID]map[discord.UserID]discord.Member{}
return nil
}
@ -69,8 +73,8 @@ func (s *DefaultStore) Reset() error {
////
func (s *DefaultStore) Me() (*discord.User, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
if !s.self.ID.Valid() {
return nil, ErrStoreNotFound
@ -79,9 +83,9 @@ func (s *DefaultStore) Me() (*discord.User, error) {
return &s.self, nil
}
func (s *DefaultStore) MyselfSet(me *discord.User) error {
func (s *DefaultStore) MyselfSet(me discord.User) error {
s.mut.Lock()
s.self = *me
s.self = me
s.mut.Unlock()
return nil
@ -90,11 +94,12 @@ func (s *DefaultStore) MyselfSet(me *discord.User) error {
////
func (s *DefaultStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
if ch, ok := s.privates[id]; ok {
return ch, nil
// implicit copy
return &ch, nil
}
for _, chs := range s.channels {
@ -109,8 +114,8 @@ func (s *DefaultStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
}
func (s *DefaultStore) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
chs, ok := s.channels[guildID]
if !ok {
@ -123,16 +128,17 @@ func (s *DefaultStore) Channels(guildID discord.GuildID) ([]discord.Channel, err
// CreatePrivateChannel searches in the cache for a private channel. It makes no
// API calls.
func (s *DefaultStore) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
// slow way
for _, ch := range s.privates {
if ch.Type != discord.DirectMessage || len(ch.DMRecipients) < 1 {
if ch.Type != discord.DirectMessage || len(ch.DMRecipients) == 0 {
continue
}
if ch.DMRecipients[0].ID == recipient {
return &(*ch), nil
// Return an implicit copy made by range.
return &ch, nil
}
}
return nil, ErrStoreNotFound
@ -140,18 +146,18 @@ func (s *DefaultStore) CreatePrivateChannel(recipient discord.UserID) (*discord.
// PrivateChannels returns a list of Direct Message channels randomly ordered.
func (s *DefaultStore) PrivateChannels() ([]discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
var chs = make([]discord.Channel, 0, len(s.privates))
for _, ch := range s.privates {
chs = append(chs, *ch)
for i := range s.privates {
chs = append(chs, s.privates[i])
}
return chs, nil
}
func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {
func (s *DefaultStore) ChannelSet(channel discord.Channel) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -169,20 +175,20 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {
}
// Found, just edit
chs[i] = *channel
chs[i] = channel
return nil
}
}
chs = append(chs, *channel)
chs = append(chs, channel)
s.channels[channel.GuildID] = chs
}
return nil
}
func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error {
func (s *DefaultStore) ChannelRemove(channel discord.Channel) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -193,9 +199,11 @@ func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error {
for i, ch := range chs {
if ch.ID == channel.ID {
chs = append(chs[:i], chs[i+1:]...)
s.channels[channel.GuildID] = chs
// Fast unordered delete.
chs[i] = chs[len(chs)-1]
chs = chs[:len(chs)-1]
s.channels[channel.GuildID] = chs
return nil
}
}
@ -206,16 +214,17 @@ func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error {
////
func (s *DefaultStore) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
emojis, ok := s.emojis[guildID]
if !ok {
return nil, ErrStoreNotFound
}
for _, emoji := range gd.Emojis {
for _, emoji := range emojis {
if emoji.ID == emojiID {
// Emoji is an implicit copy, so we could do this safely.
return &emoji, nil
}
}
@ -224,162 +233,126 @@ func (s *DefaultStore) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (
}
func (s *DefaultStore) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
emojis, ok := s.emojis[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return append([]discord.Emoji{}, gd.Emojis...), nil
return append([]discord.Emoji{}, emojis...), nil
}
func (s *DefaultStore) EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error {
s.mut.Lock()
defer s.mut.Unlock()
gd, ok := s.guilds[guildID]
if !ok {
return ErrStoreNotFound
}
// A nil slice is acceptable, as we'll make a new slice later on and set it.
s.emojis[guildID] = emojis
filtered := emojis[:0]
Main:
for _, enew := range emojis {
// Try and see if this emoji is already in the slice
for i, emoji := range gd.Emojis {
if emoji.ID == enew.ID {
// If it is, we simply replace it
gd.Emojis[i] = enew
continue Main
}
}
// If not, we add it to the slice that's to be appended.
filtered = append(filtered, enew)
}
// Append the new emojis
gd.Emojis = append(gd.Emojis, filtered...)
return nil
}
////
func (s *DefaultStore) Guild(id discord.GuildID) (*discord.Guild, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
ch, ok := s.guilds[id]
if !ok {
return nil, ErrStoreNotFound
}
return ch, nil
// implicit copy
return &ch, nil
}
func (s *DefaultStore) Guilds() ([]discord.Guild, error) {
s.mut.Lock()
s.mut.RLock()
defer s.mut.RUnlock()
if len(s.guilds) == 0 {
s.mut.Unlock()
return nil, ErrStoreNotFound
}
var gs = make([]discord.Guild, 0, len(s.guilds))
for _, g := range s.guilds {
gs = append(gs, *g)
gs = append(gs, g)
}
s.mut.Unlock()
sort.Slice(gs, func(i, j int) bool {
return gs[i].ID > gs[j].ID
})
return gs, nil
}
func (s *DefaultStore) GuildSet(guild *discord.Guild) error {
func (s *DefaultStore) GuildSet(guild discord.Guild) error {
s.mut.Lock()
defer s.mut.Unlock()
if g, ok := s.guilds[guild.ID]; ok {
// preserve state stuff
if guild.Roles == nil {
guild.Roles = g.Roles
}
if guild.Emojis == nil {
guild.Emojis = g.Emojis
}
}
s.guilds[guild.ID] = guild
return nil
}
func (s *DefaultStore) GuildRemove(id discord.GuildID) error {
s.mut.Lock()
delete(s.guilds, id)
s.mut.Unlock()
defer s.mut.Unlock()
if _, ok := s.guilds[id]; !ok {
return ErrStoreNotFound
}
delete(s.guilds, id)
return nil
}
////
func (s *DefaultStore) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Member(
guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.members[guildID]
if !ok {
return nil, ErrStoreNotFound
}
for _, m := range ms {
if m.User.ID == userID {
return &m, nil
}
m, ok := ms[userID]
if ok {
return &m, nil
}
return nil, ErrStoreNotFound
}
func (s *DefaultStore) Members(guildID discord.GuildID) ([]discord.Member, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.members[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return append([]discord.Member{}, ms...), nil
var members = make([]discord.Member, 0, len(ms))
for _, m := range ms {
members = append(members, m)
}
return members, nil
}
func (s *DefaultStore) MemberSet(guildID discord.GuildID, member *discord.Member) error {
func (s *DefaultStore) MemberSet(guildID discord.GuildID, member discord.Member) error {
s.mut.Lock()
defer s.mut.Unlock()
ms := s.members[guildID]
// Try and see if this member is already in the slice
for i, m := range ms {
if m.User.ID == member.User.ID {
// If it is, we simply replace it
ms[i] = *member
s.members[guildID] = ms
return nil
}
ms, ok := s.members[guildID]
if !ok {
ms = make(map[discord.UserID]discord.Member, 1)
}
// Append the new member
ms = append(ms, *member)
ms[member.User.ID] = member
s.members[guildID] = ms
return nil
@ -394,24 +367,21 @@ func (s *DefaultStore) MemberRemove(guildID discord.GuildID, userID discord.User
return ErrStoreNotFound
}
// Try and see if this member is already in the slice
for i, m := range ms {
if m.User.ID == userID {
ms = append(ms, ms[i+1:]...)
s.members[guildID] = ms
return nil
}
if _, ok := ms[userID]; !ok {
return ErrStoreNotFound
}
return ErrStoreNotFound
delete(ms, userID)
return nil
}
////
func (s *DefaultStore) Message(channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Message(
channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.messages[channelID]
if !ok {
@ -428,24 +398,22 @@ func (s *DefaultStore) Message(channelID discord.ChannelID, messageID discord.Me
}
func (s *DefaultStore) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.messages[channelID]
if !ok {
return nil, ErrStoreNotFound
}
cp := make([]discord.Message, len(ms))
copy(cp, ms)
return cp, nil
return append([]discord.Message{}, ms...), nil
}
func (s *DefaultStore) MaxMessages() int {
return int(s.DefaultStoreOptions.MaxMessages)
}
func (s *DefaultStore) MessageSet(message *discord.Message) error {
func (s *DefaultStore) MessageSet(message discord.Message) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -457,7 +425,7 @@ func (s *DefaultStore) MessageSet(message *discord.Message) error {
// Check if we already have the message.
for i, m := range ms {
if m.ID == message.ID {
DiffMessage(*message, &m)
DiffMessage(message, &m)
ms[i] = m
return nil
}
@ -480,13 +448,15 @@ func (s *DefaultStore) MessageSet(message *discord.Message) error {
// 1st-endth.
copy(ms[1:end], ms[0:end-1])
// Then, set the 0th entry.
ms[0] = *message
ms[0] = message
s.messages[message.ChannelID] = ms
return nil
}
func (s *DefaultStore) MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error {
func (s *DefaultStore) MessageRemove(
channelID discord.ChannelID, messageID discord.MessageID) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -508,9 +478,11 @@ func (s *DefaultStore) MessageRemove(channelID discord.ChannelID, messageID disc
////
func (s *DefaultStore) Presence(guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Presence(
guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ps, ok := s.presences[guildID]
if !ok {
@ -527,33 +499,32 @@ func (s *DefaultStore) Presence(guildID discord.GuildID, userID discord.UserID)
}
func (s *DefaultStore) Presences(guildID discord.GuildID) ([]discord.Presence, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
ps, ok := s.presences[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return ps, nil
return append([]discord.Presence{}, ps...), nil
}
func (s *DefaultStore) PresenceSet(guildID discord.GuildID, presence *discord.Presence) error {
func (s *DefaultStore) PresenceSet(guildID discord.GuildID, presence discord.Presence) error {
s.mut.Lock()
defer s.mut.Unlock()
ps := s.presences[guildID]
ps, _ := s.presences[guildID]
for i, p := range ps {
if p.User.ID == presence.User.ID {
ps[i] = *presence
s.presences[guildID] = ps
// Change the backing array.
ps[i] = presence
return nil
}
}
ps = append(ps, *presence)
ps = append(ps, presence)
s.presences[guildID] = ps
return nil
}
@ -569,9 +540,10 @@ func (s *DefaultStore) PresenceRemove(guildID discord.GuildID, userID discord.Us
for i, p := range ps {
if p.User.ID == userID {
ps = append(ps[:i], ps[i+1:]...)
s.presences[guildID] = ps
ps[i] = ps[len(ps)-1]
ps = ps[:len(ps)-1]
s.presences[guildID] = ps
return nil
}
}
@ -582,15 +554,15 @@ func (s *DefaultStore) PresenceRemove(guildID discord.GuildID, userID discord.Us
////
func (s *DefaultStore) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
rs, ok := s.roles[guildID]
if !ok {
return nil, ErrStoreNotFound
}
for _, r := range gd.Roles {
for _, r := range rs {
if r.ID == roleID {
return &r, nil
}
@ -600,34 +572,35 @@ func (s *DefaultStore) Role(guildID discord.GuildID, roleID discord.RoleID) (*di
}
func (s *DefaultStore) Roles(guildID discord.GuildID) ([]discord.Role, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
rs, ok := s.roles[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return append([]discord.Role{}, gd.Roles...), nil
return append([]discord.Role{}, rs...), nil
}
func (s *DefaultStore) RoleSet(guildID discord.GuildID, role *discord.Role) error {
func (s *DefaultStore) RoleSet(guildID discord.GuildID, role discord.Role) error {
s.mut.Lock()
defer s.mut.Unlock()
gd, ok := s.guilds[guildID]
if !ok {
return ErrStoreNotFound
}
// A nil slice is fine, since we can just append the role.
rs, _ := s.roles[guildID]
for i, r := range gd.Roles {
for i, r := range rs {
if r.ID == role.ID {
gd.Roles[i] = *role
// This changes the backing array, so we don't need to reset the
// slice.
rs[i] = role
return nil
}
}
gd.Roles = append(gd.Roles, *role)
rs = append(rs, role)
s.roles[guildID] = rs
return nil
}
@ -635,14 +608,18 @@ func (s *DefaultStore) RoleRemove(guildID discord.GuildID, roleID discord.RoleID
s.mut.Lock()
defer s.mut.Unlock()
gd, ok := s.guilds[guildID]
rs, ok := s.roles[guildID]
if !ok {
return ErrStoreNotFound
}
for i, r := range gd.Roles {
for i, r := range rs {
if r.ID == roleID {
gd.Roles = append(gd.Roles[:i], gd.Roles[i+1:]...)
// Fast delete.
rs[i] = rs[len(rs)-1]
rs = rs[:len(rs)-1]
s.roles[guildID] = rs
return nil
}
}
@ -652,9 +629,11 @@ func (s *DefaultStore) RoleRemove(guildID discord.GuildID, roleID discord.RoleID
////
func (s *DefaultStore) VoiceState(guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) VoiceState(
guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) {
s.mut.RLock()
defer s.mut.RUnlock()
states, ok := s.voiceStates[guildID]
if !ok {
@ -671,8 +650,8 @@ func (s *DefaultStore) VoiceState(guildID discord.GuildID, userID discord.UserID
}
func (s *DefaultStore) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
states, ok := s.voiceStates[guildID]
if !ok {
@ -682,22 +661,21 @@ func (s *DefaultStore) VoiceStates(guildID discord.GuildID) ([]discord.VoiceStat
return append([]discord.VoiceState{}, states...), nil
}
func (s *DefaultStore) VoiceStateSet(guildID discord.GuildID, voiceState *discord.VoiceState) error {
func (s *DefaultStore) VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error {
s.mut.Lock()
defer s.mut.Unlock()
states := s.voiceStates[guildID]
states, _ := s.voiceStates[guildID]
for i, vs := range states {
if vs.UserID == voiceState.UserID {
states[i] = *voiceState
s.voiceStates[guildID] = states
// change the backing array
states[i] = voiceState
return nil
}
}
states = append(states, *voiceState)
states = append(states, voiceState)
s.voiceStates[guildID] = states
return nil
}

View File

@ -23,7 +23,7 @@ func (NoopStore) Me() (*discord.User, error) {
return nil, ErrNotImplemented
}
func (NoopStore) MyselfSet(*discord.User) error {
func (NoopStore) MyselfSet(discord.User) error {
return nil
}
@ -43,11 +43,11 @@ func (NoopStore) PrivateChannels() ([]discord.Channel, error) {
return nil, ErrNotImplemented
}
func (NoopStore) ChannelSet(*discord.Channel) error {
func (NoopStore) ChannelSet(discord.Channel) error {
return nil
}
func (NoopStore) ChannelRemove(*discord.Channel) error {
func (NoopStore) ChannelRemove(discord.Channel) error {
return nil
}
@ -71,7 +71,7 @@ func (NoopStore) Guilds() ([]discord.Guild, error) {
return nil, ErrNotImplemented
}
func (NoopStore) GuildSet(*discord.Guild) error {
func (NoopStore) GuildSet(discord.Guild) error {
return nil
}
@ -87,7 +87,7 @@ func (NoopStore) Members(discord.GuildID) ([]discord.Member, error) {
return nil, ErrNotImplemented
}
func (NoopStore) MemberSet(discord.GuildID, *discord.Member) error {
func (NoopStore) MemberSet(discord.GuildID, discord.Member) error {
return nil
}
@ -109,7 +109,7 @@ func (NoopStore) MaxMessages() int {
return 100
}
func (NoopStore) MessageSet(*discord.Message) error {
func (NoopStore) MessageSet(discord.Message) error {
return nil
}
@ -125,7 +125,7 @@ func (NoopStore) Presences(discord.GuildID) ([]discord.Presence, error) {
return nil, ErrNotImplemented
}
func (NoopStore) PresenceSet(discord.GuildID, *discord.Presence) error {
func (NoopStore) PresenceSet(discord.GuildID, discord.Presence) error {
return nil
}
@ -141,7 +141,7 @@ func (NoopStore) Roles(discord.GuildID) ([]discord.Role, error) {
return nil, ErrNotImplemented
}
func (NoopStore) RoleSet(discord.GuildID, *discord.Role) error {
func (NoopStore) RoleSet(discord.GuildID, discord.Role) error {
return nil
}
@ -157,7 +157,7 @@ func (NoopStore) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) {
return nil, ErrNotImplemented
}
func (NoopStore) VoiceStateSet(discord.GuildID, *discord.VoiceState) error {
func (NoopStore) VoiceStateSet(discord.GuildID, discord.VoiceState) error {
return ErrNotImplemented
}