From 3230916c4588ef2f869a43137d52bbaddfcea23c Mon Sep 17 00:00:00 2001 From: Maximilian von Lindern <48887425+mavolin@users.noreply.github.com> Date: Thu, 19 Nov 2020 19:43:31 +0100 Subject: [PATCH] State: don't check store if resource is not tracked through intents (#163) Partially reviewed; good for the most part. --- bot/extras/middlewares/middlewares_test.go | 19 ++ gateway/gateway.go | 5 + state/state.go | 335 +++++++++++++-------- state/state_events.go | 6 +- state/store.go | 1 + 5 files changed, 234 insertions(+), 132 deletions(-) diff --git a/bot/extras/middlewares/middlewares_test.go b/bot/extras/middlewares/middlewares_test.go index 5393222..f351034 100644 --- a/bot/extras/middlewares/middlewares_test.go +++ b/bot/extras/middlewares/middlewares_test.go @@ -7,12 +7,22 @@ import ( "github.com/diamondburned/arikawa/v2/bot" "github.com/diamondburned/arikawa/v2/discord" "github.com/diamondburned/arikawa/v2/gateway" + "github.com/diamondburned/arikawa/v2/session" "github.com/diamondburned/arikawa/v2/state" ) func TestAdminOnly(t *testing.T) { var ctx = &bot.Context{ State: &state.State{ + Session: &session.Session{ + Gateway: &gateway.Gateway{ + Identifier: &gateway.Identifier{ + IdentifyData: gateway.IdentifyData{ + Intents: gateway.IntentGuilds | gateway.IntentGuildMembers, + }, + }, + }, + }, Store: &mockStore{}, }, } @@ -50,6 +60,15 @@ func TestAdminOnly(t *testing.T) { func TestGuildOnly(t *testing.T) { var ctx = &bot.Context{ State: &state.State{ + Session: &session.Session{ + Gateway: &gateway.Gateway{ + Identifier: &gateway.Identifier{ + IdentifyData: gateway.IdentifyData{ + Intents: gateway.IntentGuilds, + }, + }, + }, + }, Store: &mockStore{}, }, } diff --git a/gateway/gateway.go b/gateway/gateway.go index 8d07ebd..d0b2ab1 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -175,6 +175,11 @@ func (g *Gateway) AddIntents(i Intents) { g.Identifier.Intents |= i } +// HasIntents reports if the Gateway has the passed Intents. +func (g *Gateway) HasIntents(intents Intents) bool { + return g.Identifier.Intents.Has(intents) +} + // Close closes the underlying Websocket connection. func (g *Gateway) Close() error { wsutil.WSDebug("Trying to close. Pacemaker check skipped.") diff --git a/state/state.go b/state/state.go index e9bd4df..6668e34 100644 --- a/state/state.go +++ b/state/state.go @@ -17,7 +17,7 @@ import ( var ( MaxFetchMembers uint = 1000 - MaxFetchGuilds uint = 10 + MaxFetchGuilds uint = 100 ) // State is the cache to store events coming from Discord as well as data from @@ -216,8 +216,21 @@ func (s *State) AuthorColor(message *gateway.MessageCreateEvent) (discord.Color, func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (discord.Color, error) { var wg sync.WaitGroup - g, gerr := s.Store.Guild(guildID) - m, merr := s.Store.Member(guildID, userID) + var ( + g *discord.Guild + gerr = ErrStoreNotFound + + m *discord.Member + merr = ErrStoreNotFound + ) + + if s.Gateway.HasIntents(gateway.IntentGuilds) { + g, gerr = s.Store.Guild(guildID) + } + + if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + m, merr = s.Store.Member(guildID, userID) + } switch { case gerr != nil && merr != nil: @@ -258,8 +271,21 @@ func (s *State) Permissions( var wg sync.WaitGroup - g, gerr := s.Store.Guild(ch.GuildID) - m, merr := s.Store.Member(ch.GuildID, userID) + var ( + g *discord.Guild + gerr = ErrStoreNotFound + + m *discord.Member + merr = ErrStoreNotFound + ) + + if s.Gateway.HasIntents(gateway.IntentGuilds) { + g, gerr = s.Store.Guild(ch.GuildID) + } + + if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + m, merr = s.Store.Member(ch.GuildID, userID) + } switch { case gerr != nil && merr != nil: @@ -306,40 +332,46 @@ func (s *State) Me() (*discord.User, error) { //// -func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) { - c, err := s.Store.Channel(id) - if err == nil { - return c, nil +func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { + c, err = s.Store.Channel(id) + if err == nil && s.tracksChannel(c) { + return } c, err = s.Session.Channel(id) if err != nil { - return nil, err + return } - return c, s.Store.ChannelSet(*c) + if s.tracksChannel(c) { + err = s.Store.ChannelSet(*c) + } + + return } -func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) { - c, err := s.Store.Channels(guildID) - if err == nil { - return c, nil - } - - c, err = s.Session.Channels(guildID) - if err != nil { - return nil, err - } - - for _, ch := range c { - ch := ch - - if err := s.Store.ChannelSet(ch); err != nil { - return nil, err +func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err error) { + if s.Gateway.HasIntents(gateway.IntentGuilds) { + cs, err = s.Store.Channels(guildID) + if err == nil { + return } } - return c, nil + cs, err = s.Session.Channels(guildID) + if err != nil { + return + } + + if s.Gateway.HasIntents(gateway.IntentGuilds) { + for _, c := range cs { + if err = s.Store.ChannelSet(c); err != nil { + return + } + } + } + + return } func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { @@ -356,36 +388,40 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel return c, s.Store.ChannelSet(*c) } +// PrivateChannels gets the direct messages of the user. +// This is not supported for bots. func (s *State) PrivateChannels() ([]discord.Channel, error) { - c, err := s.Store.PrivateChannels() + cs, err := s.Store.PrivateChannels() if err == nil { - return c, nil + return cs, nil } - c, err = s.Session.PrivateChannels() + cs, err = s.Session.PrivateChannels() if err != nil { return nil, err } - for _, ch := range c { - ch := ch - - if err := s.Store.ChannelSet(ch); err != nil { + for _, c := range cs { + if err := s.Store.ChannelSet(c); err != nil { return nil, err } } - return c, nil + return cs, nil } //// func (s *State) Emoji( - guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) { + guildID discord.GuildID, emojiID discord.EmojiID) (e *discord.Emoji, err error) { - e, err := s.Store.Emoji(guildID, emojiID) - if err == nil { - return e, nil + if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { + e, err = s.Store.Emoji(guildID, emojiID) + if err == nil { + return + } + } else { // Fast path + return s.Session.Emoji(guildID, emojiID) } es, err := s.Session.Emojis(guildID) @@ -393,8 +429,8 @@ func (s *State) Emoji( return nil, err } - if err := s.Store.EmojiSet(guildID, es); err != nil { - return nil, err + if err = s.Store.EmojiSet(guildID, es); err != nil { + return } for _, e := range es { @@ -406,86 +442,99 @@ func (s *State) Emoji( return nil, ErrStoreNotFound } -func (s *State) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) { - e, err := s.Store.Emojis(guildID) - if err == nil { - return e, nil +func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) { + if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { + es, err = s.Store.Emojis(guildID) + if err == nil { + return + } } - es, err := s.Session.Emojis(guildID) + es, err = s.Session.Emojis(guildID) if err != nil { - return nil, err + return } - return es, s.Store.EmojiSet(guildID, es) + if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { + err = s.Store.EmojiSet(guildID, es) + } + + return } //// func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) { - c, err := s.Store.Guild(id) - if err == nil { - return c, nil + if s.Gateway.HasIntents(gateway.IntentGuilds) { + c, err := s.Store.Guild(id) + if err == nil { + return c, nil + } } return s.fetchGuild(id) } // Guilds will only fill a maximum of 100 guilds from the API. -func (s *State) Guilds() ([]discord.Guild, error) { - c, err := s.Store.Guilds() - if err == nil { - return c, nil - } - - c, err = s.Session.Guilds(MaxFetchGuilds) - if err != nil { - return nil, err - } - - for _, ch := range c { - ch := ch - - if err := s.Store.GuildSet(ch); err != nil { - return nil, err +func (s *State) Guilds() (gs []discord.Guild, err error) { + if s.Gateway.HasIntents(gateway.IntentGuilds) { + gs, err = s.Store.Guilds() + if err == nil { + return } } - return c, nil + gs, err = s.Session.Guilds(MaxFetchGuilds) + if err != nil { + return + } + + if s.Gateway.HasIntents(gateway.IntentGuilds) { + for _, g := range gs { + if err = s.Store.GuildSet(g); err != nil { + return + } + } + } + + return } //// func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { - m, err := s.Store.Member(guildID, userID) - if err == nil { - return m, nil + if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + m, err := s.Store.Member(guildID, userID) + if err == nil { + return m, nil + } } return s.fetchMember(guildID, userID) } -func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) { - ms, err := s.Store.Members(guildID) - if err == nil { - return ms, nil +func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error) { + if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + ms, err = s.Store.Members(guildID) + if err == nil { + return + } } ms, err = s.Session.Members(guildID, MaxFetchMembers) if err != nil { - return nil, err + return } - for _, m := range ms { - if err := s.Store.MemberSet(guildID, m); err != nil { - return nil, err + if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + for _, m := range ms { + if err = s.Store.MemberSet(guildID, m); err != nil { + return + } } } - return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{ - GuildID: []discord.GuildID{guildID}, - Presences: true, - }) + return } //// @@ -494,18 +543,23 @@ func (s *State) Message( channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) { m, err := s.Store.Message(channelID, messageID) - if err == nil { + if err == nil && s.tracksMessage(m) { return m, nil } - var wg sync.WaitGroup + var ( + wg sync.WaitGroup - c, cerr := s.Store.Channel(channelID) - if cerr != nil { + c *discord.Channel + cerr = ErrStoreNotFound + ) + + c, cerr = s.Store.Channel(channelID) + if cerr != nil || !s.tracksChannel(c) { wg.Add(1) go func() { c, cerr = s.Session.Channel(channelID) - if cerr == nil { + if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { cerr = s.Store.ChannelSet(*c) } @@ -527,17 +581,21 @@ func (s *State) Message( m.ChannelID = c.ID m.GuildID = c.GuildID - return m, s.Store.MessageSet(*m) + if s.tracksMessage(m) { + err = s.Store.MessageSet(*m) + } + + return m, err } -// Messages fetches maximum 100 messages from the API, if it has to. There is no -// limit if it's from the State storage. +// Messages fetches maximum 100 messages from the API, if it has to. There is +// no limit if it's from the State storage. func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) { // TODO: Think of a design that doesn't rely on MaxMessages(). var maxMsgs = s.MaxMessages() ms, err := s.Store.Messages(channelID) - if err == nil { + if err == nil && (len(ms) == 0 || s.tracksMessage(&ms[0])) { // If the state already has as many messages as it can, skip the API. if maxMsgs <= len(ms) { return ms, nil @@ -570,14 +628,16 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) guildID = c.GuildID } - // Iterate in reverse, since the store is expected to prepend the latest - // messages. - for i := len(ms) - 1; i >= 0; i-- { - // Set the guild ID, fine if it's 0 (it's already 0 anyway). - ms[i].GuildID = guildID + if len(ms) > 0 && s.tracksMessage(&ms[0]) { + // Iterate in reverse, since the store is expected to prepend the latest + // messages. + for i := len(ms) - 1; i >= 0; i-- { + // Set the guild ID, fine if it's 0 (it's already 0 anyway). + ms[i].GuildID = guildID - if err := s.Store.MessageSet(ms[i]); err != nil { - return nil, err + if err := s.Store.MessageSet(ms[i]); err != nil { + return nil, err + } } } @@ -597,19 +657,22 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) //// -// Presence checks the state for user presences. If no guildID is given, it will -// look for the presence in all guilds. +// Presence checks the state for user presences. If no guildID is given, it +// will look for the presence in all cached guilds. func (s *State) Presence( guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) { - p, err := s.Store.Presence(guildID, userID) - if err == nil { - return p, nil + if !s.Gateway.HasIntents(gateway.IntentGuildPresences) { + return nil, ErrStoreNotFound } // If there's no guild ID, look in all guilds if !guildID.IsValid() { - g, err := s.Guilds() + if !s.Gateway.HasIntents(gateway.IntentGuilds) { + return nil, ErrStoreNotFound + } + + g, err := s.Store.Guilds() if err != nil { return nil, err } @@ -619,43 +682,46 @@ func (s *State) Presence( return p, nil } } + + return nil, ErrStoreNotFound } - return nil, err + return s.Store.Presence(guildID, userID) } //// -func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) { - r, err := s.Store.Role(guildID, roleID) - if err == nil { - return r, nil +func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *discord.Role, err error) { + if s.Gateway.HasIntents(gateway.IntentGuilds) { + target, err = s.Store.Role(guildID, roleID) + if err == nil { + return + } } rs, err := s.Session.Roles(guildID) if err != nil { - return nil, err + return } - var role *discord.Role - for _, r := range rs { - r := r - if r.ID == roleID { - role = &r + r := r // copy to prevent mem aliasing + target = &r } - if err := s.RoleSet(guildID, r); err != nil { - return role, err + if s.Gateway.HasIntents(gateway.IntentGuilds) { + if err = s.RoleSet(guildID, r); err != nil { + return + } } } - if role == nil { + if target == nil { return nil, ErrStoreNotFound } - return role, nil + return } func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { @@ -669,11 +735,11 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { return nil, err } - for _, r := range rs { - r := r - - if err := s.RoleSet(guildID, r); err != nil { - return rs, err + if s.Gateway.HasIntents(gateway.IntentGuilds) { + for _, r := range rs { + if err := s.RoleSet(guildID, r); err != nil { + return rs, err + } } } @@ -682,7 +748,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { g, err = s.Session.Guild(id) - if err == nil { + if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { err = s.Store.GuildSet(*g) } @@ -693,9 +759,22 @@ func (s *State) fetchMember( guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) { m, err = s.Session.Member(guildID, userID) - if err == nil { + if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) { err = s.Store.MemberSet(guildID, *m) } return } + +// tracksMessage reports whether the state would track the passed message and +// messages from the same channel. +func (s *State) tracksMessage(m *discord.Message) bool { + return (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || + (!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages)) +} + +// tracksChannel reports whether the state would track the passed channel. +func (s *State) tracksChannel(c *discord.Channel) bool { + return (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || + !c.GuildID.IsValid() +} diff --git a/state/state_events.go b/state/state_events.go index 2e4afcc..8d0f030 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -56,10 +56,8 @@ func (s *State) onEvent(iface interface{}) { s.ready = *ev // Reset the store before proceeding. - if resetter, ok := s.Store.(StoreResetter); ok { - if err := resetter.Reset(); err != nil { - s.stateErr(err, "Failed to reset state on READY") - } + if err := s.Store.Reset(); err != nil { + s.stateErr(err, "failed to reset state on READY") } // Handle presences diff --git a/state/store.go b/state/store.go index 3b4ba76..5f92c6e 100644 --- a/state/store.go +++ b/state/store.go @@ -11,6 +11,7 @@ import ( type Store interface { StoreGetter StoreModifier + StoreResetter } // All methods in StoreGetter will be wrapped by the State. If the State can't