diff --git a/gateway/ready.go b/gateway/ready.go index daf91e6..4852189 100644 --- a/gateway/ready.go +++ b/gateway/ready.go @@ -148,18 +148,21 @@ type ( // of ReadySupplementalEvent. It has slight differences to discord.Member. SupplementalMember struct { UserID discord.UserID `json:"user_id"` + Nick string `json:"nick,omitempty"` RoleIDs []discord.RoleID `json:"roles"` - JoinedAt discord.Timestamp `json:"joined_at"` - HoistedRole discord.RoleID `json:"hoisted_role"` + GuildID discord.GuildID `json:"guild_id,omitempty"` + IsPending bool `json:"is_pending,omitempty"` + HoistedRole discord.RoleID `json:"hoisted_role"` Mute bool `json:"mute"` Deaf bool `json:"deaf"` - Nick string `json:"nick,omitempty"` - GuildID discord.GuildID `json:"guild_id,omitempty"` - IsPending bool `json:"is_pending,omitempty"` - PremiumSince discord.Timestamp `json:"premium_since,omitempty"` + // Joined specifies when the user joined the guild. + Joined discord.Timestamp `json:"joined_at"` + + // BoostedSince specifies when the user started boosting the guild. + BoostedSince discord.Timestamp `json:"premium_since,omitempty"` } // FriendSourceFlags describes sources that friend requests could be sent @@ -221,7 +224,7 @@ type ( } // MergedPresences is the struct for presences of guilds' members and - // friends. It is undocumented. + // friends. It is undocumented. MergedPresences struct { Guilds [][]SupplementalPresence `json:"guilds"` Friends []SupplementalPresence `json:"friends"` @@ -243,3 +246,27 @@ type ( LastModified discord.UnixMsTimestamp `json:"last_modified,omitempty"` } ) + +// ConvertSupplementalMember converts a SupplementalMember to a regular Member. +func ConvertSupplementalMember(sm SupplementalMember) discord.Member { + return discord.Member{ + User: discord.User{ID: sm.UserID}, + Nick: sm.Nick, + RoleIDs: sm.RoleIDs, + Joined: sm.Joined, + BoostedSince: sm.BoostedSince, + Deaf: sm.Deaf, + Mute: sm.Mute, + } +} + +// ConvertSupplementalPresence converts a SupplementalPresence to a regular +// Presence with an empty GuildID. +func ConvertSupplementalPresence(sp SupplementalPresence) Presence { + return Presence{ + User: discord.User{ID: sp.UserID}, + Status: sp.Status, + Activities: sp.Activities, + ClientStatus: sp.ClientStatus, + } +} diff --git a/internal/moreatomic/syncmap.go b/internal/moreatomic/syncmap.go new file mode 100644 index 0000000..0677f18 --- /dev/null +++ b/internal/moreatomic/syncmap.go @@ -0,0 +1,61 @@ +package moreatomic + +import ( + "sync" + "sync/atomic" +) + +// Map is a thread-safe map that is a wrapper around sync.Map with slight API +// additions. +type Map struct { + smap atomic.Value + ctor func() interface{} +} + +type sentinelType struct{} + +var sentinel = sentinelType{} + +func NewMap(ctor func() interface{}) *Map { + smap := atomic.Value{} + smap.Store(&sync.Map{}) + return &Map{smap, ctor} +} + +// Reset swaps the internal map out with a fresh one, dropping the old map. This +// method never errors. +func (sm *Map) Reset() error { + sm.smap.Store(&sync.Map{}) + return nil +} + +// LoadOrStore loads an existing value or stores a new value created from the +// given constructor then return that value. +func (sm *Map) LoadOrStore(k interface{}) (v interface{}, loaded bool) { + smap := sm.smap.Load().(*sync.Map) + + v, loaded = smap.LoadOrStore(k, sentinel) + if loaded { + v = sm.ctor() + smap.Store(k, v) + } + + return +} + +// Load loads an existing value; it returns ok set to false if there is no +// value with that key. +func (sm *Map) Load(k interface{}) (lv interface{}, ok bool) { + smap := sm.smap.Load().(*sync.Map) + + for { + lv, ok = smap.Load(k) + if !ok { + return nil, false + } + + if lv != sentinel { + return lv, true + } + } +} diff --git a/state/event_dispatcher.go b/state/event_dispatcher.go index b4e7aaf..966da8f 100644 --- a/state/event_dispatcher.go +++ b/state/event_dispatcher.go @@ -19,19 +19,22 @@ func (s *State) handleReady(ev *gateway.ReadyEvent) { } func (s *State) handleGuildCreate(ev *gateway.GuildCreateEvent) { + switch { // this guild was unavailable, but has come back online - if s.unavailableGuilds.Delete(ev.ID) { + case s.unavailableGuilds.Delete(ev.ID): s.Handler.Call(&GuildAvailableEvent{ GuildCreateEvent: ev, }) - // the guild was already unavailable when connecting to the gateway - // we can dispatch a belated GuildReadyEvent - } else if s.unreadyGuilds.Delete(ev.ID) { + // the guild was already unavailable when connecting to the gateway + // we can dispatch a belated GuildReadyEvent + case s.unreadyGuilds.Delete(ev.ID): s.Handler.Call(&GuildReadyEvent{ GuildCreateEvent: ev, }) - } else { // we don't know this guild, hence we just joined it + + // we don't know this guild, hence we just joined it + default: s.Handler.Call(&GuildJoinEvent{ GuildCreateEvent: ev, }) diff --git a/state/state.go b/state/state.go index 6668e34..9c48744 100644 --- a/state/state.go +++ b/state/state.go @@ -10,6 +10,8 @@ import ( "github.com/diamondburned/arikawa/v2/gateway" "github.com/diamondburned/arikawa/v2/internal/moreatomic" "github.com/diamondburned/arikawa/v2/session" + "github.com/diamondburned/arikawa/v2/state/store" + "github.com/diamondburned/arikawa/v2/state/store/defaultstore" "github.com/diamondburned/arikawa/v2/utils/handler" "github.com/pkg/errors" @@ -58,7 +60,7 @@ var ( // will be empty, while the Member structure expects it to be there. type State struct { *session.Session - Store + store.Cabinet // *: State doesn't actually keep track of pinned messages. @@ -97,7 +99,7 @@ type State struct { // New creates a new state. func New(token string) (*State, error) { - return NewWithStore(token, NewDefaultStore(nil)) + return NewWithStore(token, defaultstore.New()) } // NewWithIntents creates a new state with the given gateway intents. For more @@ -108,24 +110,24 @@ func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) { return nil, err } - return NewFromSession(s, NewDefaultStore(nil)) + return NewFromSession(s, defaultstore.New()) } -func NewWithStore(token string, store Store) (*State, error) { +func NewWithStore(token string, cabinet store.Cabinet) (*State, error) { s, err := session.New(token) if err != nil { return nil, err } - return NewFromSession(s, store) + return NewFromSession(s, cabinet) } // NewFromSession never returns an error. This API is kept for backwards // compatibility. -func NewFromSession(s *session.Session, store Store) (*State, error) { +func NewFromSession(s *session.Session, cabinet store.Cabinet) (*State, error) { state := &State{ Session: s, - Store: store, + Cabinet: cabinet, Handler: handler.New(), StateLog: func(err error) {}, readyMu: new(sync.Mutex), @@ -148,18 +150,18 @@ func (s *State) WithContext(ctx context.Context) *State { return &copied } -// Ready takes in a callback to access the Ready event in a thread-safe manner. -// As it acquires a mutex for thread-safety, the callback shouldn't do anything -// blocking to prevent stalling the state updates. It should also not reference -// or copy the Ready instance, as that instance will not be thread-safe. +// Ready returns a copy of the Ready event. Although this function is safe to +// call concurrently, its values should still not be changed, as certain types +// like slices are not concurrent-safe. // -// Note that the Ready that passed in will never be nil; if Ready events are not -// received yet, then the pointer will point to State's zero-value Ready -// instance. -func (s *State) Ready(fn func(*gateway.ReadyEvent)) { +// Note that if Ready events are not received yet, then the returned event will +// be a zero-value Ready instance. +func (s *State) Ready() gateway.ReadyEvent { s.readyMu.Lock() - fn(&s.ready) + r := s.ready s.readyMu.Unlock() + + return r } //// Helper methods @@ -217,19 +219,19 @@ func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (dis var wg sync.WaitGroup var ( - g *discord.Guild - gerr = ErrStoreNotFound + g *discord.Guild + m *discord.Member - m *discord.Member - merr = ErrStoreNotFound + gerr = store.ErrNotFound + merr = store.ErrNotFound ) if s.Gateway.HasIntents(gateway.IntentGuilds) { - g, gerr = s.Store.Guild(guildID) + g, gerr = s.Cabinet.Guild(guildID) } if s.Gateway.HasIntents(gateway.IntentGuildMembers) { - m, merr = s.Store.Member(guildID, userID) + m, merr = s.Cabinet.Member(guildID, userID) } switch { @@ -272,19 +274,19 @@ func (s *State) Permissions( var wg sync.WaitGroup var ( - g *discord.Guild - gerr = ErrStoreNotFound + g *discord.Guild + m *discord.Member - m *discord.Member - merr = ErrStoreNotFound + gerr = store.ErrNotFound + merr = store.ErrNotFound ) if s.Gateway.HasIntents(gateway.IntentGuilds) { - g, gerr = s.Store.Guild(ch.GuildID) + g, gerr = s.Cabinet.Guild(ch.GuildID) } if s.Gateway.HasIntents(gateway.IntentGuildMembers) { - m, merr = s.Store.Member(ch.GuildID, userID) + m, merr = s.Cabinet.Member(ch.GuildID, userID) } switch { @@ -317,7 +319,7 @@ func (s *State) Permissions( //// func (s *State) Me() (*discord.User, error) { - u, err := s.Store.Me() + u, err := s.Cabinet.Me() if err == nil { return u, nil } @@ -327,13 +329,13 @@ func (s *State) Me() (*discord.User, error) { return nil, err } - return u, s.Store.MyselfSet(*u) + return u, s.Cabinet.MyselfSet(*u) } //// func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { - c, err = s.Store.Channel(id) + c, err = s.Cabinet.Channel(id) if err == nil && s.tracksChannel(c) { return } @@ -344,7 +346,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { } if s.tracksChannel(c) { - err = s.Store.ChannelSet(*c) + err = s.Cabinet.ChannelSet(*c) } return @@ -352,7 +354,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err error) { if s.Gateway.HasIntents(gateway.IntentGuilds) { - cs, err = s.Store.Channels(guildID) + cs, err = s.Cabinet.Channels(guildID) if err == nil { return } @@ -365,7 +367,7 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err if s.Gateway.HasIntents(gateway.IntentGuilds) { for _, c := range cs { - if err = s.Store.ChannelSet(c); err != nil { + if err = s.Cabinet.ChannelSet(c); err != nil { return } } @@ -375,7 +377,7 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err } func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { - c, err := s.Store.CreatePrivateChannel(recipient) + c, err := s.Cabinet.CreatePrivateChannel(recipient) if err == nil { return c, nil } @@ -385,13 +387,13 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel return nil, err } - return c, s.Store.ChannelSet(*c) + return c, s.Cabinet.ChannelSet(*c) } // PrivateChannels gets the direct messages of the user. // This is not supported for bots. func (s *State) PrivateChannels() ([]discord.Channel, error) { - cs, err := s.Store.PrivateChannels() + cs, err := s.Cabinet.PrivateChannels() if err == nil { return cs, nil } @@ -402,7 +404,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) { } for _, c := range cs { - if err := s.Store.ChannelSet(c); err != nil { + if err := s.Cabinet.ChannelSet(c); err != nil { return nil, err } } @@ -416,7 +418,7 @@ func (s *State) Emoji( guildID discord.GuildID, emojiID discord.EmojiID) (e *discord.Emoji, err error) { if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { - e, err = s.Store.Emoji(guildID, emojiID) + e, err = s.Cabinet.Emoji(guildID, emojiID) if err == nil { return } @@ -429,7 +431,7 @@ func (s *State) Emoji( return nil, err } - if err = s.Store.EmojiSet(guildID, es); err != nil { + if err = s.Cabinet.EmojiSet(guildID, es); err != nil { return } @@ -439,12 +441,12 @@ func (s *State) Emoji( } } - return nil, ErrStoreNotFound + return nil, store.ErrNotFound } func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) { if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { - es, err = s.Store.Emojis(guildID) + es, err = s.Cabinet.Emojis(guildID) if err == nil { return } @@ -456,7 +458,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) } if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { - err = s.Store.EmojiSet(guildID, es) + err = s.Cabinet.EmojiSet(guildID, es) } return @@ -466,7 +468,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) { if s.Gateway.HasIntents(gateway.IntentGuilds) { - c, err := s.Store.Guild(id) + c, err := s.Cabinet.Guild(id) if err == nil { return c, nil } @@ -478,7 +480,7 @@ func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) { // Guilds will only fill a maximum of 100 guilds from the API. func (s *State) Guilds() (gs []discord.Guild, err error) { if s.Gateway.HasIntents(gateway.IntentGuilds) { - gs, err = s.Store.Guilds() + gs, err = s.Cabinet.Guilds() if err == nil { return } @@ -491,7 +493,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) { if s.Gateway.HasIntents(gateway.IntentGuilds) { for _, g := range gs { - if err = s.Store.GuildSet(g); err != nil { + if err = s.Cabinet.GuildSet(g); err != nil { return } } @@ -504,7 +506,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) { func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { if s.Gateway.HasIntents(gateway.IntentGuildMembers) { - m, err := s.Store.Member(guildID, userID) + m, err := s.Cabinet.Member(guildID, userID) if err == nil { return m, nil } @@ -515,7 +517,7 @@ func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error) { if s.Gateway.HasIntents(gateway.IntentGuildMembers) { - ms, err = s.Store.Members(guildID) + ms, err = s.Cabinet.Members(guildID) if err == nil { return } @@ -528,7 +530,7 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error if s.Gateway.HasIntents(gateway.IntentGuildMembers) { for _, m := range ms { - if err = s.Store.MemberSet(guildID, m); err != nil { + if err = s.Cabinet.MemberSet(guildID, m); err != nil { return } } @@ -542,7 +544,7 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error func (s *State) Message( channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { - m, err := s.Store.Message(channelID, messageID) + m, err := s.Cabinet.Message(channelID, messageID) if err == nil && s.tracksMessage(m) { return m, nil } @@ -551,16 +553,16 @@ func (s *State) Message( wg sync.WaitGroup c *discord.Channel - cerr = ErrStoreNotFound + cerr = store.ErrNotFound ) - c, cerr = s.Store.Channel(channelID) + c, cerr = s.Cabinet.Channel(channelID) if cerr != nil || !s.tracksChannel(c) { wg.Add(1) go func() { c, cerr = s.Session.Channel(channelID) if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { - cerr = s.Store.ChannelSet(*c) + cerr = s.Cabinet.ChannelSet(*c) } wg.Done() @@ -582,7 +584,7 @@ func (s *State) Message( m.GuildID = c.GuildID if s.tracksMessage(m) { - err = s.Store.MessageSet(*m) + err = s.Cabinet.MessageSet(*m) } return m, err @@ -594,7 +596,7 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) // TODO: Think of a design that doesn't rely on MaxMessages(). var maxMsgs = s.MaxMessages() - ms, err := s.Store.Messages(channelID) + ms, err := s.Cabinet.Messages(channelID) if err == nil && (len(ms) == 0 || s.tracksMessage(&ms[0])) { // If the state already has as many messages as it can, skip the API. if maxMsgs <= len(ms) { @@ -635,7 +637,7 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) // Set the guild ID, fine if it's 0 (it's already 0 anyway). ms[i].GuildID = guildID - if err := s.Store.MessageSet(ms[i]); err != nil { + if err := s.Cabinet.MessageSet(ms[i]); err != nil { return nil, err } } @@ -659,41 +661,39 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) // Presence checks the state for user presences. If no guildID is given, it // will look for the presence in all cached guilds. -func (s *State) Presence( - guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { - +func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*gateway.Presence, error) { if !s.Gateway.HasIntents(gateway.IntentGuildPresences) { - return nil, ErrStoreNotFound + return nil, store.ErrNotFound } // If there's no guild ID, look in all guilds - if !guildID.IsValid() { + if !gID.IsValid() { if !s.Gateway.HasIntents(gateway.IntentGuilds) { - return nil, ErrStoreNotFound + return nil, store.ErrNotFound } - g, err := s.Store.Guilds() + g, err := s.Cabinet.Guilds() if err != nil { return nil, err } for _, g := range g { - if p, err := s.Store.Presence(g.ID, userID); err == nil { + if p, err := s.Cabinet.Presence(g.ID, uID); err == nil { return p, nil } } - return nil, ErrStoreNotFound + return nil, store.ErrNotFound } - return s.Store.Presence(guildID, userID) + return s.Cabinet.Presence(gID, uID) } //// func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *discord.Role, err error) { if s.Gateway.HasIntents(gateway.IntentGuilds) { - target, err = s.Store.Role(guildID, roleID) + target, err = s.Cabinet.Role(guildID, roleID) if err == nil { return } @@ -718,14 +718,14 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di } if target == nil { - return nil, ErrStoreNotFound + return nil, store.ErrNotFound } return } func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { - rs, err := s.Store.Roles(guildID) + rs, err := s.Cabinet.Roles(guildID) if err == nil { return rs, nil } @@ -749,18 +749,16 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { g, err = s.Session.Guild(id) if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { - err = s.Store.GuildSet(*g) + err = s.Cabinet.GuildSet(*g) } return } -func (s *State) fetchMember( - guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) { - - m, err = s.Session.Member(guildID, userID) +func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) { + m, err = s.Session.Member(gID, uID) if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) { - err = s.Store.MemberSet(guildID, *m) + err = s.Cabinet.MemberSet(gID, *m) } return diff --git a/state/state_events.go b/state/state_events.go index 8d0f030..c788154 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -5,6 +5,7 @@ import ( "github.com/diamondburned/arikawa/v2/discord" "github.com/diamondburned/arikawa/v2/gateway" + "github.com/diamondburned/arikawa/v2/state/store" ) func (s *State) hookSession() { @@ -56,57 +57,53 @@ func (s *State) onEvent(iface interface{}) { s.ready = *ev // Reset the store before proceeding. - if err := s.Store.Reset(); err != nil { + if err := s.Cabinet.Reset(); err != nil { s.stateErr(err, "failed to reset state on READY") } - // Handle presences - for _, p := range ev.Presences { - if err := s.Store.PresenceSet(0, p); err != nil { - s.stateErr(err, "failed to set global presence") - } - } - // Handle guilds for i := range ev.Guilds { - s.batchLog(storeGuildCreate(s.Store, &ev.Guilds[i])) + s.batchLog(storeGuildCreate(s.Cabinet, &ev.Guilds[i])) } // Handle private channels for _, ch := range ev.PrivateChannels { - if err := s.Store.ChannelSet(ch); err != nil { + if err := s.Cabinet.ChannelSet(ch); err != nil { s.stateErr(err, "failed to set channel in state") } } // Handle user - if err := s.Store.MyselfSet(ev.User); err != nil { + if err := s.Cabinet.MyselfSet(ev.User); err != nil { s.stateErr(err, "failed to set self in state") } // Release the ready mutex only after we're done with everything. s.readyMu.Unlock() + case *gateway.ReadySupplementalEvent: + // TODO + case *gateway.GuildCreateEvent: - s.batchLog(storeGuildCreate(s.Store, ev)) + s.batchLog(storeGuildCreate(s.Cabinet, ev)) case *gateway.GuildUpdateEvent: - if err := s.Store.GuildSet(ev.Guild); err != nil { + if err := s.Cabinet.GuildSet(ev.Guild); err != nil { s.stateErr(err, "failed to update guild in state") } case *gateway.GuildDeleteEvent: - if err := s.Store.GuildRemove(ev.ID); err != nil && !ev.Unavailable { + if err := s.Cabinet.GuildRemove(ev.ID); err != nil && !ev.Unavailable { s.stateErr(err, "failed to delete guild in state") } case *gateway.GuildMemberAddEvent: - if err := s.Store.MemberSet(ev.GuildID, ev.Member); err != nil { + if err := s.Cabinet.MemberSet(ev.GuildID, ev.Member); err != nil { s.stateErr(err, "failed to add a member in state") } case *gateway.GuildMemberUpdateEvent: - m, err := s.Store.Member(ev.GuildID, ev.User.ID) + m, err := s.Cabinet.Member(ev.GuildID, ev.User.ID) if err != nil { // We can't do much here. m = &discord.Member{} @@ -115,60 +112,60 @@ func (s *State) onEvent(iface interface{}) { // Update available fields from ev into m ev.Update(m) - if err := s.Store.MemberSet(ev.GuildID, *m); err != nil { + if err := s.Cabinet.MemberSet(ev.GuildID, *m); err != nil { s.stateErr(err, "failed to update a member in state") } case *gateway.GuildMemberRemoveEvent: - if err := s.Store.MemberRemove(ev.GuildID, ev.User.ID); err != nil { + if err := s.Cabinet.MemberRemove(ev.GuildID, ev.User.ID); err != nil { s.stateErr(err, "failed to remove a member in state") } case *gateway.GuildMembersChunkEvent: for _, m := range ev.Members { - if err := s.Store.MemberSet(ev.GuildID, m); err != nil { + if err := s.Cabinet.MemberSet(ev.GuildID, m); err != nil { s.stateErr(err, "failed to add a member from chunk in state") } } for _, p := range ev.Presences { - if err := s.Store.PresenceSet(ev.GuildID, p); err != nil { + if err := s.Cabinet.PresenceSet(ev.GuildID, p); err != nil { s.stateErr(err, "failed to add a presence from chunk in state") } } case *gateway.GuildRoleCreateEvent: - if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil { + if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role); err != nil { s.stateErr(err, "failed to add a role in state") } case *gateway.GuildRoleUpdateEvent: - if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil { + if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role); err != nil { s.stateErr(err, "failed to update a role in state") } case *gateway.GuildRoleDeleteEvent: - if err := s.Store.RoleRemove(ev.GuildID, ev.RoleID); err != nil { + if err := s.Cabinet.RoleRemove(ev.GuildID, ev.RoleID); err != nil { s.stateErr(err, "failed to remove a role in state") } case *gateway.GuildEmojisUpdateEvent: - if err := s.Store.EmojiSet(ev.GuildID, ev.Emojis); err != nil { + if err := s.Cabinet.EmojiSet(ev.GuildID, ev.Emojis); err != nil { s.stateErr(err, "failed to update emojis in state") } case *gateway.ChannelCreateEvent: - if err := s.Store.ChannelSet(ev.Channel); err != nil { + if err := s.Cabinet.ChannelSet(ev.Channel); err != nil { s.stateErr(err, "failed to create a channel in state") } case *gateway.ChannelUpdateEvent: - if err := s.Store.ChannelSet(ev.Channel); err != nil { + if err := s.Cabinet.ChannelSet(ev.Channel); err != nil { s.stateErr(err, "failed to update a channel in state") } case *gateway.ChannelDeleteEvent: - if err := s.Store.ChannelRemove(ev.Channel); err != nil { + if err := s.Cabinet.ChannelRemove(ev.Channel); err != nil { s.stateErr(err, "failed to remove a channel in state") } @@ -176,23 +173,23 @@ func (s *State) onEvent(iface interface{}) { // not tracked. case *gateway.MessageCreateEvent: - if err := s.Store.MessageSet(ev.Message); err != nil { + if err := s.Cabinet.MessageSet(ev.Message); err != nil { s.stateErr(err, "failed to add a message in state") } case *gateway.MessageUpdateEvent: - if err := s.Store.MessageSet(ev.Message); err != nil { + if err := s.Cabinet.MessageSet(ev.Message); err != nil { s.stateErr(err, "failed to update a message in state") } case *gateway.MessageDeleteEvent: - if err := s.Store.MessageRemove(ev.ChannelID, ev.ID); err != nil { + if err := s.Cabinet.MessageRemove(ev.ChannelID, ev.ID); err != nil { s.stateErr(err, "failed to delete a message in state") } case *gateway.MessageDeleteBulkEvent: for _, id := range ev.IDs { - if err := s.Store.MessageRemove(ev.ChannelID, id); err != nil { + if err := s.Cabinet.MessageRemove(ev.ChannelID, id); err != nil { s.stateErr(err, "failed to delete bulk messages in state") } } @@ -203,7 +200,7 @@ func (s *State) onEvent(iface interface{}) { m.Reactions[i].Count++ } else { var me bool - if u, _ := s.Store.Me(); u != nil { + if u, _ := s.Cabinet.Me(); u != nil { me = ev.UserID == u.ID } m.Reactions = append(m.Reactions, discord.Reaction{ @@ -231,7 +228,7 @@ func (s *State) onEvent(iface interface{}) { m.Reactions = append(m.Reactions[:i], m.Reactions[i+1:]...) case r.Me: // If reaction removal is the user's - u, err := s.Store.Me() + u, err := s.Cabinet.Me() if err == nil && ev.UserID == u.ID { r.Me = false } @@ -257,51 +254,44 @@ func (s *State) onEvent(iface interface{}) { }) case *gateway.PresenceUpdateEvent: - if err := s.Store.PresenceSet(ev.GuildID, ev.Presence); err != nil { + if err := s.Cabinet.PresenceSet(ev.GuildID, ev.Presence); err != nil { s.stateErr(err, "failed to update presence in state") } case *gateway.PresencesReplaceEvent: for _, p := range *ev { - if err := s.Store.PresenceSet(p.GuildID, p); err != nil { + if err := s.Cabinet.PresenceSet(p.GuildID, p.Presence); err != nil { s.stateErr(err, "failed to update presence in state") } } case *gateway.SessionsReplaceEvent: + // TODO case *gateway.UserGuildSettingsUpdateEvent: - s.readyMu.Lock() - for i, ugs := range s.ready.UserGuildSettings { - if ugs.GuildID == ev.GuildID { - s.ready.UserGuildSettings[i] = ev.UserGuildSettings - } - } - s.readyMu.Unlock() + // TODO case *gateway.UserSettingsUpdateEvent: s.readyMu.Lock() - s.ready.Settings = &ev.UserSettings + s.ready.UserSettings = &ev.UserSettings s.readyMu.Unlock() case *gateway.UserNoteUpdateEvent: - s.readyMu.Lock() - s.ready.Notes[ev.ID] = ev.Note - s.readyMu.Unlock() + // TODO case *gateway.UserUpdateEvent: - if err := s.Store.MyselfSet(ev.User); err != nil { + if err := s.Cabinet.MyselfSet(ev.User); err != nil { s.stateErr(err, "failed to update myself from USER_UPDATE") } case *gateway.VoiceStateUpdateEvent: vs := &ev.VoiceState if vs.ChannelID == 0 { - if err := s.Store.VoiceStateRemove(vs.GuildID, vs.UserID); err != nil { + if err := s.Cabinet.VoiceStateRemove(vs.GuildID, vs.UserID); err != nil { s.stateErr(err, "failed to remove voice state from state") } } else { - if err := s.Store.VoiceStateSet(vs.GuildID, *vs); err != nil { + if err := s.Cabinet.VoiceStateSet(vs.GuildID, *vs); err != nil { s.stateErr(err, "failed to update voice state in state") } } @@ -320,14 +310,14 @@ func (s *State) batchLog(errors []error) { // Helper functions func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func(m *discord.Message) bool) { - m, err := s.Store.Message(ch, msg) + m, err := s.Cabinet.Message(ch, msg) if err != nil { return } if !fn(m) { return } - if err := s.Store.MessageSet(*m); err != nil { + if err := s.Cabinet.MessageSet(*m); err != nil { s.stateErr(err, "failed to save message in reaction add") } } @@ -341,27 +331,27 @@ func findReaction(rs []discord.Reaction, emoji discord.Emoji) int { return -1 } -func storeGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error { +func storeGuildCreate(cab store.Cabinet, guild *gateway.GuildCreateEvent) []error { if guild.Unavailable { return nil } stack, errs := newErrorStack() - if err := store.GuildSet(guild.Guild); err != nil { + if err := cab.GuildSet(guild.Guild); err != nil { errs(err, "failed to set guild in Ready") } // Handle guild emojis if guild.Emojis != nil { - if err := store.EmojiSet(guild.ID, guild.Emojis); err != nil { + if err := cab.EmojiSet(guild.ID, guild.Emojis); err != nil { errs(err, "failed to set guild emojis") } } // Handle guild member for _, m := range guild.Members { - if err := store.MemberSet(guild.ID, m); err != nil { + if err := cab.MemberSet(guild.ID, m); err != nil { errs(err, "failed to set guild member in Ready") } } @@ -371,21 +361,21 @@ func storeGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error { // I HATE Discord. ch.GuildID = guild.ID - if err := store.ChannelSet(ch); err != nil { + if err := cab.ChannelSet(ch); err != nil { errs(err, "failed to set guild channel in Ready") } } // Handle guild presences for _, p := range guild.Presences { - if err := store.PresenceSet(guild.ID, p); err != nil { + if err := cab.PresenceSet(guild.ID, p); err != nil { errs(err, "failed to set guild presence in Ready") } } // Handle guild voice states for _, v := range guild.VoiceStates { - if err := store.VoiceStateSet(guild.ID, v); err != nil { + if err := cab.VoiceStateSet(guild.ID, v); err != nil { errs(err, "failed to set guild voice state in Ready") } } diff --git a/state/store.go b/state/store.go deleted file mode 100644 index 5f92c6e..0000000 --- a/state/store.go +++ /dev/null @@ -1,134 +0,0 @@ -package state - -import ( - "errors" - - "github.com/diamondburned/arikawa/v2/discord" -) - -// Store is the state storage. It should handle mutex itself, and it should only -// concern itself with the local state. -type Store interface { - StoreGetter - StoreModifier - StoreResetter -} - -// All methods in StoreGetter will be wrapped by the State. If the State can't -// find anything in the storage, it will call the API itself and automatically -// add what's missing into the storage. -// -// Methods that return with a slice should pay attention to race conditions that -// would mutate the underlying slice (and as a result the returned slice as -// well). The best way to avoid this is to copy the whole slice, like -// DefaultStore does. -// -// These methods should not care about returning slices in order, unless -// explicitly stated against. -type StoreGetter interface { - Me() (*discord.User, error) - - // Channel should check for both DM and guild channels. - Channel(id discord.ChannelID) (*discord.Channel, error) - Channels(guildID discord.GuildID) ([]discord.Channel, error) - - // same API as (*api.Client) - CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) - PrivateChannels() ([]discord.Channel, error) - - Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) - Emojis(guildID discord.GuildID) ([]discord.Emoji, error) - - Guild(id discord.GuildID) (*discord.Guild, error) - Guilds() ([]discord.Guild, error) - - Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) - Members(guildID discord.GuildID) ([]discord.Member, error) - - Message(channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) - // Messages should return messages ordered from latest to earliest. - Messages(channelID discord.ChannelID) ([]discord.Message, error) - MaxMessages() int // used to know if the state is filled or not. - - // These don't get fetched from the API, it's Gateway only. - Presence(guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) - Presences(guildID discord.GuildID) ([]discord.Presence, error) - - Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) - Roles(guildID discord.GuildID) ([]discord.Role, error) - - VoiceState(guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) - VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error) -} - -type StoreModifier interface { - MyselfSet(me discord.User) error - - // ChannelSet should switch on Type to know if it's a private channel or - // not. - ChannelSet(discord.Channel) error - ChannelRemove(discord.Channel) error - - // EmojiSet should delete all old emojis before setting new ones. - EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error - - GuildSet(discord.Guild) error - GuildRemove(id discord.GuildID) error - - MemberSet(guildID discord.GuildID, member discord.Member) error - MemberRemove(guildID discord.GuildID, userID discord.UserID) error - - // MessageSet should prepend messages into the slice, the latest being in - // front. - MessageSet(discord.Message) error - MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error - - PresenceSet(guildID discord.GuildID, presence discord.Presence) error - PresenceRemove(guildID discord.GuildID, userID discord.UserID) error - - RoleSet(guildID discord.GuildID, role discord.Role) error - RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error - - VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error - VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error -} - -// StoreResetter is used by the state to reset the store on every Ready event. -type StoreResetter interface { - // Reset resets the store to a new valid instance. - Reset() error -} - -// ErrStoreNotFound is an error that a store can use to return when something -// isn't in the storage. There is no strict restrictions on what uses this (the -// default one does, though), so be advised. -var ErrStoreNotFound = errors.New("item not found in store") - -// DiffMessage fills non-empty fields from src to dst. -func DiffMessage(src discord.Message, dst *discord.Message) { - // Thanks, Discord. - if src.Content != "" { - dst.Content = src.Content - } - if src.EditedTimestamp.IsValid() { - dst.EditedTimestamp = src.EditedTimestamp - } - if src.Mentions != nil { - dst.Mentions = src.Mentions - } - if src.Embeds != nil { - dst.Embeds = src.Embeds - } - if src.Attachments != nil { - dst.Attachments = src.Attachments - } - if src.Timestamp.IsValid() { - dst.Timestamp = src.Timestamp - } - if src.Author.ID.IsValid() { - dst.Author = src.Author - } - if src.Reactions != nil { - dst.Reactions = src.Reactions - } -} diff --git a/state/store/defaultstore/channel.go b/state/store/defaultstore/channel.go new file mode 100644 index 0000000..ee6e537 --- /dev/null +++ b/state/store/defaultstore/channel.go @@ -0,0 +1,179 @@ +package defaultstore + +import ( + "errors" + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Channel struct { + mut sync.RWMutex + + // Channel references must be protected under the same mutex. + + privates map[discord.UserID]*discord.Channel + channels map[discord.ChannelID]*discord.Channel + guildChs map[discord.GuildID][]*discord.Channel +} + +var _ store.ChannelStore = (*Channel)(nil) + +func NewChannel() *Channel { + return &Channel{ + privates: map[discord.UserID]*discord.Channel{}, + channels: map[discord.ChannelID]*discord.Channel{}, + guildChs: map[discord.GuildID][]*discord.Channel{}, + } +} + +func (s *Channel) Reset() error { + s.mut.Lock() + defer s.mut.Unlock() + + s.privates = map[discord.UserID]*discord.Channel{} + s.channels = map[discord.ChannelID]*discord.Channel{} + s.guildChs = map[discord.GuildID][]*discord.Channel{} + + return nil +} + +func (s *Channel) Channel(id discord.ChannelID) (*discord.Channel, error) { + s.mut.RLock() + defer s.mut.RUnlock() + + ch, ok := s.channels[id] + if !ok { + return nil, store.ErrNotFound + } + + cpy := *ch + return &cpy, nil +} + +func (s *Channel) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { + s.mut.RLock() + defer s.mut.RUnlock() + + ch, ok := s.privates[recipient] + if !ok { + return nil, store.ErrNotFound + } + + cpy := *ch + return &cpy, nil +} + +// Channels returns a list of Guild channels randomly ordered. +func (s *Channel) Channels(guildID discord.GuildID) ([]discord.Channel, error) { + s.mut.RLock() + defer s.mut.RUnlock() + + chRefs, ok := s.guildChs[guildID] + if !ok { + return nil, store.ErrNotFound + } + + // Reading chRefs is also covered by the global mutex. + + var channels = make([]discord.Channel, len(chRefs)) + for i, chRef := range chRefs { + channels[i] = *chRef + } + + return channels, nil +} + +// PrivateChannels returns a list of Direct Message channels randomly ordered. +func (s *Channel) PrivateChannels() ([]discord.Channel, error) { + s.mut.RLock() + defer s.mut.RUnlock() + + if len(s.privates) == 0 { + return nil, store.ErrNotFound + } + + var channels = make([]discord.Channel, 0, len(s.privates)) + for _, ch := range s.privates { + channels = append(channels, *ch) + } + + return channels, nil +} + +// ChannelSet sets the Direct Message or Guild channl into the state. If the +// channel doesn't have 1 (one) DMRecipients, then it must have a valid GuildID, +// otherwise an error will be returned. +func (s *Channel) ChannelSet(channel discord.Channel) error { + s.mut.Lock() + defer s.mut.Unlock() + + // Update the reference if we can. + if ch, ok := s.channels[channel.ID]; ok { + *ch = channel + return nil + } + + if len(channel.DMRecipients) == 1 { + s.privates[channel.DMRecipients[0].ID] = &channel + s.channels[channel.ID] = &channel + return nil + } + + // Invalid channel case, as we need the GuildID to search for this channel. + if !channel.GuildID.IsValid() { + return errors.New("invalid guildID for guild channel") + } + + // Always ensure that if the channel is in the slice, then it will be in the + // map. + + s.channels[channel.ID] = &channel + + channels, _ := s.guildChs[channel.GuildID] + channels = append(channels, &channel) + s.guildChs[channel.GuildID] = channels + + return nil +} + +func (s *Channel) ChannelRemove(channel discord.Channel) error { + s.mut.Lock() + defer s.mut.Unlock() + + delete(s.channels, channel.ID) + + if len(channel.DMRecipients) == 1 { + delete(s.privates, channel.DMRecipients[0].ID) + return nil + } + + channels, ok := s.guildChs[channel.GuildID] + if !ok { + return nil + } + + for i, ch := range channels { + if ch.ID != channel.ID { + continue + } + + // Fast unordered delete. Not sure if there's a benefit in doing + // this over using a map, but I guess the memory usage is less and + // there's no copying. + + // Move the last channel to the current channel, set the last + // channel there to a nil value to unreference its children, then + // slice the last channel off. + channels[i] = channels[len(channels)-1] + channels[len(channels)-1] = nil + channels = channels[:len(channels)-1] + + s.guildChs[channel.GuildID] = channels + + break + } + + return nil +} diff --git a/state/store/defaultstore/defaultstore.go b/state/store/defaultstore/defaultstore.go new file mode 100644 index 0000000..556d79b --- /dev/null +++ b/state/store/defaultstore/defaultstore.go @@ -0,0 +1,21 @@ +// Package defaultstore provides thread-safe store implementations that store +// state values in memory. +package defaultstore + +import "github.com/diamondburned/arikawa/v2/state/store" + +// New creates a new cabinet instance of defaultstore. For Message, it creates a +// Message store with a limit of 100 messages. +func New() store.Cabinet { + return store.Cabinet{ + MeStore: NewMe(), + ChannelStore: NewChannel(), + EmojiStore: NewEmoji(), + GuildStore: NewGuild(), + MemberStore: NewMember(), + MessageStore: NewMessage(100), + PresenceStore: NewPresence(), + RoleStore: NewRole(), + VoiceStateStore: NewVoiceState(), + } +} diff --git a/state/store/defaultstore/emoji.go b/state/store/defaultstore/emoji.go new file mode 100644 index 0000000..fa09e0c --- /dev/null +++ b/state/store/defaultstore/emoji.go @@ -0,0 +1,84 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/internal/moreatomic" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Emoji struct { + guilds moreatomic.Map +} + +type emojis struct { + mut sync.Mutex + emojis []discord.Emoji +} + +var _ store.EmojiStore = (*Emoji)(nil) + +func NewEmoji() *Emoji { + return &Emoji{ + guilds: *moreatomic.NewMap(func() interface{} { + return &emojis{ + emojis: []discord.Emoji{}, + } + }), + } +} + +func (s *Emoji) Reset() error { + s.guilds.Reset() + return nil +} + +func (s *Emoji) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + es := iv.(*emojis) + + es.mut.Lock() + defer es.mut.Unlock() + + for _, emoji := range es.emojis { + if emoji.ID == emojiID { + // Emoji is an implicit copy made by range, so we could do this + // safely. + return &emoji, nil + } + } + + return nil, store.ErrNotFound +} + +func (s *Emoji) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + es := iv.(*emojis) + + es.mut.Lock() + defer es.mut.Unlock() + + // We're never modifying the slice internals ourselves, so this is fine. + return es.emojis, nil +} + +func (s *Emoji) EmojiSet(guildID discord.GuildID, allEmojis []discord.Emoji) error { + iv, _ := s.guilds.LoadOrStore(guildID) + + es := iv.(*emojis) + + es.mut.Lock() + es.emojis = allEmojis + es.mut.Unlock() + + return nil +} diff --git a/state/store/defaultstore/guild.go b/state/store/defaultstore/guild.go new file mode 100644 index 0000000..0046ae4 --- /dev/null +++ b/state/store/defaultstore/guild.go @@ -0,0 +1,73 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Guild struct { + mut sync.RWMutex + guilds map[discord.GuildID]discord.Guild +} + +var _ store.GuildStore = (*Guild)(nil) + +func NewGuild() *Guild { + return &Guild{ + guilds: map[discord.GuildID]discord.Guild{}, + } +} + +func (s *Guild) Reset() error { + s.mut.Lock() + defer s.mut.Unlock() + + s.guilds = map[discord.GuildID]discord.Guild{} + + return nil +} + +func (s *Guild) Guild(id discord.GuildID) (*discord.Guild, error) { + s.mut.RLock() + defer s.mut.RUnlock() + + ch, ok := s.guilds[id] + if !ok { + return nil, store.ErrNotFound + } + + // implicit copy + return &ch, nil +} + +func (s *Guild) Guilds() ([]discord.Guild, error) { + s.mut.RLock() + defer s.mut.RUnlock() + + if len(s.guilds) == 0 { + return nil, store.ErrNotFound + } + + var gs = make([]discord.Guild, 0, len(s.guilds)) + for _, g := range s.guilds { + gs = append(gs, g) + } + + return gs, nil +} + +func (s *Guild) GuildSet(guild discord.Guild) error { + s.mut.Lock() + s.guilds[guild.ID] = guild + s.mut.Unlock() + return nil +} + +func (s *Guild) GuildRemove(id discord.GuildID) error { + s.mut.Lock() + delete(s.guilds, id) + s.mut.Unlock() + return nil +} diff --git a/state/store/defaultstore/me.go b/state/store/defaultstore/me.go new file mode 100644 index 0000000..53e1a12 --- /dev/null +++ b/state/store/defaultstore/me.go @@ -0,0 +1,47 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Me struct { + mut sync.RWMutex + self discord.User +} + +var _ store.MeStore = (*Me)(nil) + +func NewMe() *Me { + return &Me{} +} + +func (m *Me) Reset() error { + m.mut.Lock() + m.self = discord.User{} + m.mut.Unlock() + + return nil +} + +func (m *Me) Me() (*discord.User, error) { + m.mut.RLock() + self := m.self + m.mut.RUnlock() + + if !self.ID.IsValid() { + return nil, store.ErrNotFound + } + + return &self, nil +} + +func (m *Me) MyselfSet(me discord.User) error { + m.mut.Lock() + m.self = me + m.mut.Unlock() + + return nil +} diff --git a/state/store/defaultstore/member.go b/state/store/defaultstore/member.go new file mode 100644 index 0000000..d4b686a --- /dev/null +++ b/state/store/defaultstore/member.go @@ -0,0 +1,98 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/internal/moreatomic" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Member struct { + guilds moreatomic.Map // discord.GuildID -> *guildMembers +} + +type guildMembers struct { + mut sync.Mutex + members map[discord.UserID]discord.Member +} + +var _ store.MemberStore = (*Member)(nil) + +func NewMember() *Member { + return &Member{ + guilds: *moreatomic.NewMap(func() interface{} { + return &guildMembers{ + members: make(map[discord.UserID]discord.Member, 1), + } + }), + } +} + +func (s *Member) Reset() error { + return s.guilds.Reset() +} + +func (s *Member) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + gm := iv.(*guildMembers) + + gm.mut.Lock() + defer gm.mut.Unlock() + + m, ok := gm.members[userID] + if ok { + return &m, nil + } + + return nil, store.ErrNotFound +} + +func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + gm := iv.(*guildMembers) + + gm.mut.Lock() + defer gm.mut.Unlock() + + var members = make([]discord.Member, 0, len(gm.members)) + for _, m := range gm.members { + members = append(members, m) + } + + return members, nil +} + +func (s *Member) MemberSet(guildID discord.GuildID, member discord.Member) error { + iv, _ := s.guilds.LoadOrStore(guildID) + gm := iv.(*guildMembers) + + gm.mut.Lock() + gm.members[member.User.ID] = member + gm.mut.Unlock() + + return nil +} + +func (s *Member) MemberRemove(guildID discord.GuildID, userID discord.UserID) error { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil + } + + gm := iv.(*guildMembers) + + gm.mut.Lock() + delete(gm.members, userID) + gm.mut.Unlock() + + return nil +} diff --git a/state/store/defaultstore/message.go b/state/store/defaultstore/message.go new file mode 100644 index 0000000..49edadf --- /dev/null +++ b/state/store/defaultstore/message.go @@ -0,0 +1,162 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/internal/moreatomic" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Message struct { + channels moreatomic.Map + maxMsgs int +} + +var _ store.MessageStore = (*Message)(nil) + +type messages struct { + mut sync.Mutex + messages []discord.Message +} + +func NewMessage(maxMsgs int) *Message { + return &Message{ + channels: *moreatomic.NewMap(func() interface{} { + return &messages{ + messages: []discord.Message{}, // never use a nil slice + } + }), + } +} + +func (s *Message) Reset() error { + return s.channels.Reset() +} + +func (s *Message) Message(chID discord.ChannelID, mID discord.MessageID) (*discord.Message, error) { + iv, ok := s.channels.Load(chID) + if !ok { + return nil, store.ErrNotFound + } + + msgs := iv.(*messages) + + msgs.mut.Lock() + defer msgs.mut.Unlock() + + for _, m := range msgs.messages { + if m.ID == mID { + return &m, nil + } + } + + return nil, store.ErrNotFound +} + +func (s *Message) Messages(channelID discord.ChannelID) ([]discord.Message, error) { + iv, ok := s.channels.Load(channelID) + if !ok { + return nil, store.ErrNotFound + } + + msgs := iv.(*messages) + + msgs.mut.Lock() + defer msgs.mut.Unlock() + + return append([]discord.Message(nil), msgs.messages...), nil +} + +func (s *Message) MaxMessages() int { + return s.maxMsgs +} + +func (s *Message) MessageSet(message discord.Message) error { + iv, _ := s.channels.LoadOrStore(message.ChannelID) + + msgs := iv.(*messages) + + msgs.mut.Lock() + defer msgs.mut.Unlock() + + // Check if we already have the message. + for i, m := range msgs.messages { + if m.ID == message.ID { + DiffMessage(message, &m) + msgs.messages[i] = m + return nil + } + } + + // Order: latest to earliest, similar to the API. + + var end = len(msgs.messages) + if max := s.MaxMessages(); end >= max { + // If the end (length) is larger than the maximum amount, then cap it. + end = max + } else { + // Else, append an empty message to the end. + msgs.messages = append(msgs.messages, discord.Message{}) + // Increment to update the length. + end++ + } + + // Copy hack to prepend. This copies the 0th-(end-1)th entries to + // 1st-endth. + copy(msgs.messages[1:end], msgs.messages[0:end-1]) + // Then, set the 0th entry. + msgs.messages[0] = message + + return nil +} + +// DiffMessage fills non-empty fields from src to dst. +func DiffMessage(src discord.Message, dst *discord.Message) { + // Thanks, Discord. + if src.Content != "" { + dst.Content = src.Content + } + if src.EditedTimestamp.IsValid() { + dst.EditedTimestamp = src.EditedTimestamp + } + if src.Mentions != nil { + dst.Mentions = src.Mentions + } + if src.Embeds != nil { + dst.Embeds = src.Embeds + } + if src.Attachments != nil { + dst.Attachments = src.Attachments + } + if src.Timestamp.IsValid() { + dst.Timestamp = src.Timestamp + } + if src.Author.ID.IsValid() { + dst.Author = src.Author + } + if src.Reactions != nil { + dst.Reactions = src.Reactions + } +} + +func (s *Message) MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error { + iv, ok := s.channels.Load(channelID) + if !ok { + return nil + } + + msgs := iv.(*messages) + + msgs.mut.Lock() + defer msgs.mut.Unlock() + + for i, m := range msgs.messages { + if m.ID == messageID { + msgs.messages = append(msgs.messages[:i], msgs.messages[i+1:]...) + return nil + } + } + + return nil +} diff --git a/state/store/defaultstore/presence.go b/state/store/defaultstore/presence.go new file mode 100644 index 0000000..fd73a85 --- /dev/null +++ b/state/store/defaultstore/presence.go @@ -0,0 +1,106 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/gateway" + "github.com/diamondburned/arikawa/v2/internal/moreatomic" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Presence struct { + guilds moreatomic.Map +} + +type presences struct { + mut sync.Mutex + presences map[discord.UserID]gateway.Presence +} + +var _ store.PresenceStore = (*Presence)(nil) + +func NewPresence() *Presence { + return &Presence{ + guilds: *moreatomic.NewMap(func() interface{} { + return &presences{ + presences: make(map[discord.UserID]gateway.Presence, 1), + } + }), + } +} + +func (s *Presence) Reset() error { + return s.guilds.Reset() +} + +func (s *Presence) Presence(gID discord.GuildID, uID discord.UserID) (*gateway.Presence, error) { + iv, ok := s.guilds.Load(gID) + if !ok { + return nil, store.ErrNotFound + } + + ps := iv.(*presences) + + ps.mut.Lock() + defer ps.mut.Unlock() + + p, ok := ps.presences[uID] + if ok { + return &p, nil + } + + return nil, store.ErrNotFound +} + +func (s *Presence) Presences(guildID discord.GuildID) ([]gateway.Presence, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + ps := iv.(*presences) + + ps.mut.Lock() + defer ps.mut.Unlock() + + var presences = make([]gateway.Presence, 0, len(ps.presences)) + for _, p := range ps.presences { + presences = append(presences, p) + } + + return presences, nil +} + +func (s *Presence) PresenceSet(guildID discord.GuildID, presence gateway.Presence) error { + iv, _ := s.guilds.LoadOrStore(guildID) + + ps := iv.(*presences) + + ps.mut.Lock() + defer ps.mut.Unlock() + + // Shitty if check is better than a realloc every time. + if ps.presences == nil { + ps.presences = make(map[discord.UserID]gateway.Presence, 1) + } + + ps.presences[presence.User.ID] = presence + + return nil +} + +func (s *Presence) PresenceRemove(guildID discord.GuildID, userID discord.UserID) error { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil + } + + ps := iv.(*presences) + + ps.mut.Lock() + delete(ps.presences, userID) + ps.mut.Unlock() + + return nil +} diff --git a/state/store/defaultstore/role.go b/state/store/defaultstore/role.go new file mode 100644 index 0000000..c62ce59 --- /dev/null +++ b/state/store/defaultstore/role.go @@ -0,0 +1,99 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/internal/moreatomic" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type Role struct { + guilds moreatomic.Map +} + +var _ store.RoleStore = (*Role)(nil) + +type roles struct { + mut sync.Mutex + roles map[discord.RoleID]discord.Role +} + +func NewRole() *Role { + return &Role{ + guilds: *moreatomic.NewMap(func() interface{} { + return &roles{ + roles: make(map[discord.RoleID]discord.Role, 1), + } + }), + } +} + +func (s *Role) Reset() error { + return s.guilds.Reset() +} + +func (s *Role) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + rs := iv.(*roles) + + rs.mut.Lock() + defer rs.mut.Unlock() + + r, ok := rs.roles[roleID] + if ok { + return &r, nil + } + + return nil, store.ErrNotFound +} + +func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + rs := iv.(*roles) + + rs.mut.Lock() + defer rs.mut.Unlock() + + var roles = make([]discord.Role, 0, len(rs.roles)) + for _, role := range rs.roles { + roles = append(roles, role) + } + + return roles, nil +} + +func (s *Role) RoleSet(guildID discord.GuildID, role discord.Role) error { + iv, _ := s.guilds.LoadOrStore(guildID) + + rs := iv.(*roles) + + rs.mut.Lock() + rs.roles[role.ID] = role + rs.mut.Unlock() + + return nil +} + +func (s *Role) RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil + } + + rs := iv.(*roles) + + rs.mut.Lock() + delete(rs.roles, roleID) + rs.mut.Unlock() + + return nil +} diff --git a/state/store/defaultstore/voicestate.go b/state/store/defaultstore/voicestate.go new file mode 100644 index 0000000..d76f084 --- /dev/null +++ b/state/store/defaultstore/voicestate.go @@ -0,0 +1,101 @@ +package defaultstore + +import ( + "sync" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/internal/moreatomic" + "github.com/diamondburned/arikawa/v2/state/store" +) + +type VoiceState struct { + guilds moreatomic.Map +} + +var _ store.VoiceStateStore = (*VoiceState)(nil) + +type voiceStates struct { + mut sync.Mutex + voiceStates map[discord.UserID]discord.VoiceState +} + +func NewVoiceState() *VoiceState { + return &VoiceState{ + guilds: *moreatomic.NewMap(func() interface{} { + return &voiceStates{ + voiceStates: make(map[discord.UserID]discord.VoiceState, 1), + } + }), + } +} + +func (s *VoiceState) Reset() error { + return s.guilds.Reset() +} + +func (s *VoiceState) VoiceState( + guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) { + + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + vs := iv.(*voiceStates) + + vs.mut.Lock() + defer vs.mut.Unlock() + + v, ok := vs.voiceStates[userID] + if ok { + return &v, nil + } + + return nil, store.ErrNotFound +} + +func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error) { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil, store.ErrNotFound + } + + vs := iv.(*voiceStates) + + vs.mut.Lock() + defer vs.mut.Unlock() + + var states = make([]discord.VoiceState, 0, len(vs.voiceStates)) + for _, state := range vs.voiceStates { + states = append(states, state) + } + + return states, nil +} + +func (s *VoiceState) VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error { + iv, _ := s.guilds.LoadOrStore(guildID) + + vs := iv.(*voiceStates) + + vs.mut.Lock() + vs.voiceStates[voiceState.UserID] = voiceState + vs.mut.Unlock() + + return nil +} + +func (s *VoiceState) VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error { + iv, ok := s.guilds.Load(guildID) + if !ok { + return nil + } + + vs := iv.(*voiceStates) + + vs.mut.Lock() + delete(vs.voiceStates, userID) + vs.mut.Unlock() + + return nil +} diff --git a/state/store/store.go b/state/store/store.go new file mode 100644 index 0000000..5d1bd65 --- /dev/null +++ b/state/store/store.go @@ -0,0 +1,375 @@ +// Package store contains interfaces of the state's storage and its +// implementations. +// +// Getter Methods +// +// All getter methods will be wrapped by the State. If the State can't find +// anything in the storage, it will call the API itself and automatically add +// what's missing into the storage. +// +// Methods that return with a slice should pay attention to race conditions that +// would mutate the underlying slice (and as a result the returned slice as +// well). The best way to avoid this is to copy the whole slice, like +// defaultstore implementations do. +// +// Getter methods should not care about returning slices in order, unless +// explicitly stated against. +// +// ErrNotFound Rules +// +// If a getter method cannot find something, it should return ErrNotFound. +// Callers including State may check if the error is ErrNotFound to do something +// else. For example, if Guilds currently stores nothing, then it should return +// an empty slice and a nil error. +// +// In some cases, there may not be a way to know whether or not the store is +// unpopulated or is actually empty. In that case, implementations can return +// ErrNotFound when either happens. This will make State refetch from the API, +// so it is not ideal. +// +// Remove Methods +// +// Remove methods should return a nil error if the item it wants to delete is +// not found. This helps save some additional work in some cases. +package store + +import ( + "errors" + "fmt" + + "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/gateway" +) + +// ErrNotFound is an error that a store can use to return when something isn't +// in the storage. There is no strict restrictions on what uses this (the +// default one does, though), so be advised. +var ErrNotFound = errors.New("item not found in store") + +// Cabinet combines all store interfaces into one but allows swapping individual +// stores out for another. Since the struct only consists of interfaces, it can +// be copied around. +type Cabinet struct { + MeStore + ChannelStore + EmojiStore + GuildStore + MemberStore + MessageStore + PresenceStore + RoleStore + VoiceStateStore +} + +// Reset resets everything inside the container. +func (sc *Cabinet) Reset() error { + errors := []error{ + sc.MeStore.Reset(), + sc.ChannelStore.Reset(), + sc.EmojiStore.Reset(), + sc.GuildStore.Reset(), + sc.MemberStore.Reset(), + sc.MessageStore.Reset(), + sc.PresenceStore.Reset(), + sc.RoleStore.Reset(), + sc.VoiceStateStore.Reset(), + } + + nonNils := errors[:0] + + for _, err := range errors { + if err != nil { + nonNils = append(nonNils, err) + } + } + + if len(nonNils) > 0 { + return ResetErrors(nonNils) + } + + return nil +} + +// ResetErrors represents the multiple errors when StoreContainer is being +// resetted. A ResetErrors value must have at least 1 error. +type ResetErrors []error + +// Error formats ResetErrors, showing the number of errors and the last error. +func (errs ResetErrors) Error() string { + return fmt.Sprintf( + "encountered %d reset errors (last: %v)", + len(errs), errs[len(errs)-1], + ) +} + +// Unwrap returns the last error in the list. +func (errs ResetErrors) Unwrap() error { + return errs[len(errs)-1] +} + +// append adds the error only if it is not nil. +func (errs *ResetErrors) append(err error) { + if err != nil { + *errs = append(*errs, err) + } +} + +// Noop is a no-op implementation of all store interfaces. Its getters will +// always return ErrNotFound, and its setters will never return an error. +var Noop = noop{} + +// NoopCabinet is a store cabinet with all store methods set to the Noop +// implementations. +var NoopCabinet = Cabinet{ + MeStore: Noop, + ChannelStore: Noop, + EmojiStore: Noop, + GuildStore: Noop, + MemberStore: Noop, + MessageStore: Noop, + PresenceStore: Noop, + RoleStore: Noop, + VoiceStateStore: Noop, +} + +// noop is the Noop type that implements methods. +type noop struct{} + +// Resetter is an interface to reset the store on every Ready event. +type Resetter interface { + // Reset resets the store to a new valid instance. + Reset() error +} + +var _ Resetter = (*noop)(nil) + +func (noop) Reset() error { return nil } + +// MeStore is the store interface for the current user. +type MeStore interface { + Resetter + + Me() (*discord.User, error) + MyselfSet(me discord.User) error +} + +func (noop) Me() (*discord.User, error) { return nil, ErrNotFound } +func (noop) MyselfSet(discord.User) error { return nil } + +// ChannelStore is the store interface for all channels. +type ChannelStore interface { + Resetter + + // ChannelStore searches for both DM and guild channels. + Channel(discord.ChannelID) (*discord.Channel, error) + // CreatePrivateChannelStore searches for private channels by the recipient ID. + // It has the same API as *api.Client does. + CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) + + // Channels returns only channels from a guild. + Channels(discord.GuildID) ([]discord.Channel, error) + // PrivateChannels returns all private channels from the state. + PrivateChannels() ([]discord.Channel, error) + + // Both ChannelSet and ChannelRemove should switch on Type to know if it's a + // private channel or not. + + ChannelSet(discord.Channel) error + ChannelRemove(discord.Channel) error +} + +var _ ChannelStore = (*noop)(nil) + +func (noop) Channel(discord.ChannelID) (*discord.Channel, error) { + return nil, ErrNotFound +} +func (noop) CreatePrivateChannel(discord.UserID) (*discord.Channel, error) { + return nil, ErrNotFound +} +func (noop) Channels(discord.GuildID) ([]discord.Channel, error) { + return nil, ErrNotFound +} +func (noop) PrivateChannels() ([]discord.Channel, error) { + return nil, ErrNotFound +} +func (noop) ChannelSet(discord.Channel) error { + return nil +} +func (noop) ChannelRemove(discord.Channel) error { + return nil +} + +// EmojiStore is the store interface for all emojis. +type EmojiStore interface { + Resetter + + Emoji(discord.GuildID, discord.EmojiID) (*discord.Emoji, error) + Emojis(discord.GuildID) ([]discord.Emoji, error) + + // EmojiSet should delete all old emojis before setting new ones. The given + // emojis slice will be a complete list of all emojis. + EmojiSet(discord.GuildID, []discord.Emoji) error +} + +var _ EmojiStore = (*noop)(nil) + +func (noop) Emoji(discord.GuildID, discord.EmojiID) (*discord.Emoji, error) { + return nil, ErrNotFound +} +func (noop) Emojis(discord.GuildID) ([]discord.Emoji, error) { + return nil, ErrNotFound +} +func (noop) EmojiSet(discord.GuildID, []discord.Emoji) error { + return nil +} + +// GuildStore is the store interface for all guilds. +type GuildStore interface { + Resetter + + Guild(discord.GuildID) (*discord.Guild, error) + Guilds() ([]discord.Guild, error) + + GuildSet(discord.Guild) error + GuildRemove(id discord.GuildID) error +} + +var _ GuildStore = (*noop)(nil) + +func (noop) Guild(discord.GuildID) (*discord.Guild, error) { return nil, ErrNotFound } +func (noop) Guilds() ([]discord.Guild, error) { return nil, ErrNotFound } +func (noop) GuildSet(discord.Guild) error { return nil } +func (noop) GuildRemove(discord.GuildID) error { return nil } + +// MemberStore is the store interface for all members. +type MemberStore interface { + Resetter + + Member(discord.GuildID, discord.UserID) (*discord.Member, error) + Members(discord.GuildID) ([]discord.Member, error) + + MemberSet(discord.GuildID, discord.Member) error + MemberRemove(discord.GuildID, discord.UserID) error +} + +var _ MemberStore = (*noop)(nil) + +func (noop) Member(discord.GuildID, discord.UserID) (*discord.Member, error) { + return nil, ErrNotFound +} +func (noop) Members(discord.GuildID) ([]discord.Member, error) { + return nil, ErrNotFound +} +func (noop) MemberSet(discord.GuildID, discord.Member) error { + return nil +} +func (noop) MemberRemove(discord.GuildID, discord.UserID) error { + return nil +} + +// MessageStore is the store interface for all messages. +type MessageStore interface { + Resetter + + // MaxMessages returns the maximum number of messages. It is used to know if + // the state cache is filled or not for one channel + MaxMessages() int + + Message(discord.ChannelID, discord.MessageID) (*discord.Message, error) + // Messages should return messages ordered from latest to earliest. + Messages(discord.ChannelID) ([]discord.Message, error) + + // MessageSet should prepend messages into the slice, the latest being in + // front. + MessageSet(discord.Message) error + MessageRemove(discord.ChannelID, discord.MessageID) error +} + +var _ MessageStore = (*noop)(nil) + +func (noop) MaxMessages() int { + return 0 +} +func (noop) Message(discord.ChannelID, discord.MessageID) (*discord.Message, error) { + return nil, ErrNotFound +} +func (noop) Messages(discord.ChannelID) ([]discord.Message, error) { + return nil, ErrNotFound +} +func (noop) MessageSet(discord.Message) error { + return nil +} +func (noop) MessageRemove(discord.ChannelID, discord.MessageID) error { + return nil +} + +// PresenceStore is the store interface for all user presences. Presences don't get +// fetched from the API; they will only be updated through the Gateway. +type PresenceStore interface { + Resetter + + Presence(discord.GuildID, discord.UserID) (*gateway.Presence, error) + Presences(discord.GuildID) ([]gateway.Presence, error) + + PresenceSet(discord.GuildID, gateway.Presence) error + PresenceRemove(discord.GuildID, discord.UserID) error +} + +var _ PresenceStore = (*noop)(nil) + +func (noop) Presence(discord.GuildID, discord.UserID) (*gateway.Presence, error) { + return nil, ErrNotFound +} +func (noop) Presences(discord.GuildID) ([]gateway.Presence, error) { + return nil, ErrNotFound +} +func (noop) PresenceSet(discord.GuildID, gateway.Presence) error { + return nil +} +func (noop) PresenceRemove(discord.GuildID, discord.UserID) error { + return nil +} + +// RoleStore is the store interface for all member roles. +type RoleStore interface { + Resetter + + Role(discord.GuildID, discord.RoleID) (*discord.Role, error) + Roles(discord.GuildID) ([]discord.Role, error) + + RoleSet(discord.GuildID, discord.Role) error + RoleRemove(discord.GuildID, discord.RoleID) error +} + +var _ RoleStore = (*noop)(nil) + +func (noop) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) { return nil, ErrNotFound } +func (noop) Roles(discord.GuildID) ([]discord.Role, error) { return nil, ErrNotFound } +func (noop) RoleSet(discord.GuildID, discord.Role) error { return nil } +func (noop) RoleRemove(discord.GuildID, discord.RoleID) error { return nil } + +// VoiceStateStore is the store interface for all voice states. +type VoiceStateStore interface { + Resetter + + VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error) + VoiceStates(discord.GuildID) ([]discord.VoiceState, error) + + VoiceStateSet(discord.GuildID, discord.VoiceState) error + VoiceStateRemove(discord.GuildID, discord.UserID) error +} + +var _ VoiceStateStore = (*noop)(nil) + +func (noop) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error) { + return nil, ErrNotFound +} +func (noop) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) { + return nil, ErrNotFound +} +func (noop) VoiceStateSet(discord.GuildID, discord.VoiceState) error { + return nil +} +func (noop) VoiceStateRemove(discord.GuildID, discord.UserID) error { + return nil +} diff --git a/state/store_default.go b/state/store_default.go deleted file mode 100644 index b731a0b..0000000 --- a/state/store_default.go +++ /dev/null @@ -1,702 +0,0 @@ -package state - -import ( - "sync" - - "github.com/diamondburned/arikawa/v2/discord" -) - -// TODO: make an ExpiryStore - -type DefaultStore struct { - DefaultStoreOptions - - self discord.User - - // includes normal and private - privates map[discord.ChannelID]discord.Channel - guilds map[discord.GuildID]discord.Guild - - roles map[discord.GuildID][]discord.Role - emojis map[discord.GuildID][]discord.Emoji - channels map[discord.GuildID][]discord.Channel - presences map[discord.GuildID][]discord.Presence - voiceStates map[discord.GuildID][]discord.VoiceState - messages map[discord.ChannelID][]discord.Message - - // special case; optimize for lots of members - members map[discord.GuildID]map[discord.UserID]discord.Member - - mut sync.RWMutex -} - -type DefaultStoreOptions struct { - MaxMessages uint // default 50 -} - -var _ Store = (*DefaultStore)(nil) - -func NewDefaultStore(opts *DefaultStoreOptions) *DefaultStore { - if opts == nil { - opts = &DefaultStoreOptions{ - MaxMessages: 50, - } - } - - ds := &DefaultStore{DefaultStoreOptions: *opts} - ds.Reset() - - return ds -} - -func (s *DefaultStore) Reset() error { - s.mut.Lock() - defer s.mut.Unlock() - - s.self = discord.User{} - - s.privates = map[discord.ChannelID]discord.Channel{} - s.guilds = map[discord.GuildID]discord.Guild{} - - s.roles = map[discord.GuildID][]discord.Role{} - s.emojis = map[discord.GuildID][]discord.Emoji{} - s.channels = map[discord.GuildID][]discord.Channel{} - s.presences = map[discord.GuildID][]discord.Presence{} - s.voiceStates = map[discord.GuildID][]discord.VoiceState{} - s.messages = map[discord.ChannelID][]discord.Message{} - - s.members = map[discord.GuildID]map[discord.UserID]discord.Member{} - - return nil -} - -//// - -func (s *DefaultStore) Me() (*discord.User, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - if !s.self.ID.IsValid() { - return nil, ErrStoreNotFound - } - - return &s.self, nil -} - -func (s *DefaultStore) MyselfSet(me discord.User) error { - s.mut.Lock() - s.self = me - s.mut.Unlock() - - return nil -} - -//// - -func (s *DefaultStore) Channel(id discord.ChannelID) (*discord.Channel, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - if ch, ok := s.privates[id]; ok { - // implicit copy - return &ch, nil - } - - for _, chs := range s.channels { - for _, ch := range chs { - if ch.ID == id { - return &ch, nil - } - } - } - - return nil, ErrStoreNotFound -} - -func (s *DefaultStore) Channels(guildID discord.GuildID) ([]discord.Channel, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - chs, ok := s.channels[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - return append([]discord.Channel{}, chs...), nil -} - -// CreatePrivateChannel searches in the cache for a private channel. It makes no -// API calls. -func (s *DefaultStore) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - // slow way - for _, ch := range s.privates { - if ch.Type != discord.DirectMessage || len(ch.DMRecipients) == 0 { - continue - } - if ch.DMRecipients[0].ID == recipient { - // Return an implicit copy made by range. - return &ch, nil - } - } - return nil, ErrStoreNotFound -} - -// PrivateChannels returns a list of Direct Message channels randomly ordered. -func (s *DefaultStore) PrivateChannels() ([]discord.Channel, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - var chs = make([]discord.Channel, 0, len(s.privates)) - for i := range s.privates { - chs = append(chs, s.privates[i]) - } - - return chs, nil -} - -func (s *DefaultStore) ChannelSet(channel discord.Channel) error { - s.mut.Lock() - defer s.mut.Unlock() - - if !channel.GuildID.IsValid() { - s.privates[channel.ID] = channel - - } else { - chs := s.channels[channel.GuildID] - - for i, ch := range chs { - if ch.ID == channel.ID { - // Also from discordgo. - if channel.Permissions == nil { - channel.Permissions = ch.Permissions - } - - // Found, just edit - chs[i] = channel - - return nil - } - } - - chs = append(chs, channel) - s.channels[channel.GuildID] = chs - } - - return nil -} - -func (s *DefaultStore) ChannelRemove(channel discord.Channel) error { - s.mut.Lock() - defer s.mut.Unlock() - - chs, ok := s.channels[channel.GuildID] - if !ok { - return ErrStoreNotFound - } - - for i, ch := range chs { - if ch.ID == channel.ID { - // Fast unordered delete. - chs[i] = chs[len(chs)-1] - chs = chs[:len(chs)-1] - - s.channels[channel.GuildID] = chs - return nil - } - } - - return ErrStoreNotFound -} - -//// - -func (s *DefaultStore) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - emojis, ok := s.emojis[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - for _, emoji := range emojis { - if emoji.ID == emojiID { - // Emoji is an implicit copy, so we could do this safely. - return &emoji, nil - } - } - - return nil, ErrStoreNotFound -} - -func (s *DefaultStore) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - emojis, ok := s.emojis[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - return append([]discord.Emoji{}, emojis...), nil -} - -func (s *DefaultStore) EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error { - s.mut.Lock() - defer s.mut.Unlock() - - // A nil slice is acceptable, as we'll make a new slice later on and set it. - s.emojis[guildID] = emojis - - return nil -} - -//// - -func (s *DefaultStore) Guild(id discord.GuildID) (*discord.Guild, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - ch, ok := s.guilds[id] - if !ok { - return nil, ErrStoreNotFound - } - - // implicit copy - return &ch, nil -} - -func (s *DefaultStore) Guilds() ([]discord.Guild, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - if len(s.guilds) == 0 { - return nil, ErrStoreNotFound - } - - var gs = make([]discord.Guild, 0, len(s.guilds)) - for _, g := range s.guilds { - gs = append(gs, g) - } - - return gs, nil -} - -func (s *DefaultStore) GuildSet(guild discord.Guild) error { - s.mut.Lock() - defer s.mut.Unlock() - - s.guilds[guild.ID] = guild - return nil -} - -func (s *DefaultStore) GuildRemove(id discord.GuildID) error { - s.mut.Lock() - defer s.mut.Unlock() - - if _, ok := s.guilds[id]; !ok { - return ErrStoreNotFound - } - - delete(s.guilds, id) - return nil -} - -//// - -func (s *DefaultStore) Member( - guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { - - s.mut.RLock() - defer s.mut.RUnlock() - - ms, ok := s.members[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - m, ok := ms[userID] - if ok { - return &m, nil - } - - return nil, ErrStoreNotFound -} - -func (s *DefaultStore) Members(guildID discord.GuildID) ([]discord.Member, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - ms, ok := s.members[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - var members = make([]discord.Member, 0, len(ms)) - for _, m := range ms { - members = append(members, m) - } - - return members, nil -} - -func (s *DefaultStore) MemberSet(guildID discord.GuildID, member discord.Member) error { - s.mut.Lock() - defer s.mut.Unlock() - - ms, ok := s.members[guildID] - if !ok { - ms = make(map[discord.UserID]discord.Member, 1) - } - - ms[member.User.ID] = member - s.members[guildID] = ms - - return nil -} - -func (s *DefaultStore) MemberRemove(guildID discord.GuildID, userID discord.UserID) error { - s.mut.Lock() - defer s.mut.Unlock() - - ms, ok := s.members[guildID] - if !ok { - return ErrStoreNotFound - } - - if _, ok := ms[userID]; !ok { - return ErrStoreNotFound - } - - delete(ms, userID) - return nil -} - -//// - -func (s *DefaultStore) Message( - channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { - - s.mut.RLock() - defer s.mut.RUnlock() - - ms, ok := s.messages[channelID] - if !ok { - return nil, ErrStoreNotFound - } - - for _, m := range ms { - if m.ID == messageID { - return &m, nil - } - } - - return nil, ErrStoreNotFound -} - -func (s *DefaultStore) Messages(channelID discord.ChannelID) ([]discord.Message, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - ms, ok := s.messages[channelID] - if !ok { - return nil, ErrStoreNotFound - } - - return append([]discord.Message{}, ms...), nil -} - -func (s *DefaultStore) MaxMessages() int { - return int(s.DefaultStoreOptions.MaxMessages) -} - -func (s *DefaultStore) MessageSet(message discord.Message) error { - s.mut.Lock() - defer s.mut.Unlock() - - ms, ok := s.messages[message.ChannelID] - if !ok { - ms = make([]discord.Message, 0, s.MaxMessages()+1) - } - - // Check if we already have the message. - for i, m := range ms { - if m.ID == message.ID { - DiffMessage(message, &m) - ms[i] = m - return nil - } - } - - // Order: latest to earliest, similar to the API. - - var end = len(ms) - if max := s.MaxMessages(); end >= max { - // If the end (length) is larger than the maximum amount, then cap it. - end = max - } else { - // Else, append an empty message to the end. - ms = append(ms, discord.Message{}) - // Increment to update the length. - end++ - } - - // Copy hack to prepend. This copies the 0th-(end-1)th entries to - // 1st-endth. - copy(ms[1:end], ms[0:end-1]) - // Then, set the 0th entry. - ms[0] = message - - s.messages[message.ChannelID] = ms - return nil -} - -func (s *DefaultStore) MessageRemove( - channelID discord.ChannelID, messageID discord.MessageID) error { - - s.mut.Lock() - defer s.mut.Unlock() - - ms, ok := s.messages[channelID] - if !ok { - return ErrStoreNotFound - } - - for i, m := range ms { - if m.ID == messageID { - ms = append(ms[:i], ms[i+1:]...) - s.messages[channelID] = ms - return nil - } - } - - return ErrStoreNotFound -} - -//// - -func (s *DefaultStore) Presence( - guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { - - s.mut.RLock() - defer s.mut.RUnlock() - - ps, ok := s.presences[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - for _, p := range ps { - if p.User.ID == userID { - return &p, nil - } - } - - return nil, ErrStoreNotFound -} - -func (s *DefaultStore) Presences(guildID discord.GuildID) ([]discord.Presence, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - ps, ok := s.presences[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - return append([]discord.Presence{}, ps...), nil -} - -func (s *DefaultStore) PresenceSet(guildID discord.GuildID, presence discord.Presence) error { - s.mut.Lock() - defer s.mut.Unlock() - - ps, _ := s.presences[guildID] - - for i, p := range ps { - if p.User.ID == presence.User.ID { - // Change the backing array. - ps[i] = presence - return nil - } - } - - ps = append(ps, presence) - s.presences[guildID] = ps - return nil -} - -func (s *DefaultStore) PresenceRemove(guildID discord.GuildID, userID discord.UserID) error { - s.mut.Lock() - defer s.mut.Unlock() - - ps, ok := s.presences[guildID] - if !ok { - return ErrStoreNotFound - } - - for i, p := range ps { - if p.User.ID == userID { - ps[i] = ps[len(ps)-1] - ps = ps[:len(ps)-1] - - s.presences[guildID] = ps - return nil - } - } - - return ErrStoreNotFound -} - -//// - -func (s *DefaultStore) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - rs, ok := s.roles[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - for _, r := range rs { - if r.ID == roleID { - return &r, nil - } - } - - return nil, ErrStoreNotFound -} - -func (s *DefaultStore) Roles(guildID discord.GuildID) ([]discord.Role, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - rs, ok := s.roles[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - return append([]discord.Role{}, rs...), nil -} - -func (s *DefaultStore) RoleSet(guildID discord.GuildID, role discord.Role) error { - s.mut.Lock() - defer s.mut.Unlock() - - // A nil slice is fine, since we can just append the role. - rs, _ := s.roles[guildID] - - for i, r := range rs { - if r.ID == role.ID { - // This changes the backing array, so we don't need to reset the - // slice. - rs[i] = role - return nil - } - } - - rs = append(rs, role) - s.roles[guildID] = rs - return nil -} - -func (s *DefaultStore) RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error { - s.mut.Lock() - defer s.mut.Unlock() - - rs, ok := s.roles[guildID] - if !ok { - return ErrStoreNotFound - } - - for i, r := range rs { - if r.ID == roleID { - // Fast delete. - rs[i] = rs[len(rs)-1] - rs = rs[:len(rs)-1] - - s.roles[guildID] = rs - return nil - } - } - - return ErrStoreNotFound -} - -//// - -func (s *DefaultStore) VoiceState( - guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) { - - s.mut.RLock() - defer s.mut.RUnlock() - - states, ok := s.voiceStates[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - for _, vs := range states { - if vs.UserID == userID { - return &vs, nil - } - } - - return nil, ErrStoreNotFound -} - -func (s *DefaultStore) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error) { - s.mut.RLock() - defer s.mut.RUnlock() - - states, ok := s.voiceStates[guildID] - if !ok { - return nil, ErrStoreNotFound - } - - return append([]discord.VoiceState{}, states...), nil -} - -func (s *DefaultStore) VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error { - s.mut.Lock() - defer s.mut.Unlock() - - states, _ := s.voiceStates[guildID] - - for i, vs := range states { - if vs.UserID == voiceState.UserID { - // change the backing array - states[i] = voiceState - return nil - } - } - - states = append(states, voiceState) - s.voiceStates[guildID] = states - return nil -} - -func (s *DefaultStore) VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error { - s.mut.Lock() - defer s.mut.Unlock() - - states, ok := s.voiceStates[guildID] - if !ok { - return ErrStoreNotFound - } - - for i, vs := range states { - if vs.UserID == userID { - states = append(states[:i], states[i+1:]...) - s.voiceStates[guildID] = states - - return nil - } - } - - return ErrStoreNotFound -} diff --git a/state/store_noop.go b/state/store_noop.go deleted file mode 100644 index f020c1b..0000000 --- a/state/store_noop.go +++ /dev/null @@ -1,166 +0,0 @@ -package state - -import ( - "errors" - - "github.com/diamondburned/arikawa/v2/discord" -) - -// NoopStore could be embedded by other structs for partial state -// implementation. All Getters will return ErrNotImplemented, and all Setters -// will return no error. -type NoopStore struct{} - -var _ Store = (*NoopStore)(nil) - -var ErrNotImplemented = errors.New("state is not implemented") - -func (NoopStore) Reset() error { - return nil -} - -func (NoopStore) Me() (*discord.User, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) MyselfSet(discord.User) error { - return nil -} - -func (NoopStore) Channel(discord.ChannelID) (*discord.Channel, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) Channels(discord.GuildID) ([]discord.Channel, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) CreatePrivateChannel(discord.UserID) (*discord.Channel, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) PrivateChannels() ([]discord.Channel, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) ChannelSet(discord.Channel) error { - return nil -} - -func (NoopStore) ChannelRemove(discord.Channel) error { - return nil -} - -func (NoopStore) Emoji(discord.GuildID, discord.EmojiID) (*discord.Emoji, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) Emojis(discord.GuildID) ([]discord.Emoji, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) EmojiSet(discord.GuildID, []discord.Emoji) error { - return nil -} - -func (NoopStore) Guild(discord.GuildID) (*discord.Guild, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) Guilds() ([]discord.Guild, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) GuildSet(discord.Guild) error { - return nil -} - -func (NoopStore) GuildRemove(discord.GuildID) error { - return nil -} - -func (NoopStore) Member(discord.GuildID, discord.UserID) (*discord.Member, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) Members(discord.GuildID) ([]discord.Member, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) MemberSet(discord.GuildID, discord.Member) error { - return nil -} - -func (NoopStore) MemberRemove(discord.GuildID, discord.UserID) error { - return nil -} - -func (NoopStore) Message(discord.ChannelID, discord.MessageID) (*discord.Message, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) Messages(discord.ChannelID) ([]discord.Message, error) { - return nil, ErrNotImplemented -} - -// MaxMessages will always return 100 messages, so the API can fetch that -// many. -func (NoopStore) MaxMessages() int { - return 100 -} - -func (NoopStore) MessageSet(discord.Message) error { - return nil -} - -func (NoopStore) MessageRemove(discord.ChannelID, discord.MessageID) error { - return nil -} - -func (NoopStore) Presence(discord.GuildID, discord.UserID) (*discord.Presence, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) Presences(discord.GuildID) ([]discord.Presence, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) PresenceSet(discord.GuildID, discord.Presence) error { - return nil -} - -func (NoopStore) PresenceRemove(discord.GuildID, discord.UserID) error { - return nil -} - -func (NoopStore) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) Roles(discord.GuildID) ([]discord.Role, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) RoleSet(discord.GuildID, discord.Role) error { - return nil -} - -func (NoopStore) RoleRemove(discord.GuildID, discord.RoleID) error { - return nil -} - -func (NoopStore) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) { - return nil, ErrNotImplemented -} - -func (NoopStore) VoiceStateSet(discord.GuildID, discord.VoiceState) error { - return ErrNotImplemented -} - -func (NoopStore) VoiceStateRemove(discord.GuildID, discord.UserID) error { - return ErrNotImplemented -}