1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-03-21 01:19:20 +00:00

State: don't check store if resource is not tracked through intents (#163)

Partially reviewed; good for the most part.
This commit is contained in:
Maximilian von Lindern 2020-11-19 19:43:31 +01:00 committed by GitHub
parent 8356a8a3f6
commit 3230916c45
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 234 additions and 132 deletions

View file

@ -7,12 +7,22 @@ import (
"github.com/diamondburned/arikawa/v2/bot"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/gateway"
"github.com/diamondburned/arikawa/v2/session"
"github.com/diamondburned/arikawa/v2/state"
)
func TestAdminOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Session: &session.Session{
Gateway: &gateway.Gateway{
Identifier: &gateway.Identifier{
IdentifyData: gateway.IdentifyData{
Intents: gateway.IntentGuilds | gateway.IntentGuildMembers,
},
},
},
},
Store: &mockStore{},
},
}
@ -50,6 +60,15 @@ func TestAdminOnly(t *testing.T) {
func TestGuildOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Session: &session.Session{
Gateway: &gateway.Gateway{
Identifier: &gateway.Identifier{
IdentifyData: gateway.IdentifyData{
Intents: gateway.IntentGuilds,
},
},
},
},
Store: &mockStore{},
},
}

View file

@ -175,6 +175,11 @@ func (g *Gateway) AddIntents(i Intents) {
g.Identifier.Intents |= i
}
// HasIntents reports if the Gateway has the passed Intents.
func (g *Gateway) HasIntents(intents Intents) bool {
return g.Identifier.Intents.Has(intents)
}
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
wsutil.WSDebug("Trying to close. Pacemaker check skipped.")

View file

@ -17,7 +17,7 @@ import (
var (
MaxFetchMembers uint = 1000
MaxFetchGuilds uint = 10
MaxFetchGuilds uint = 100
)
// State is the cache to store events coming from Discord as well as data from
@ -216,8 +216,21 @@ func (s *State) AuthorColor(message *gateway.MessageCreateEvent) (discord.Color,
func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (discord.Color, error) {
var wg sync.WaitGroup
g, gerr := s.Store.Guild(guildID)
m, merr := s.Store.Member(guildID, userID)
var (
g *discord.Guild
gerr = ErrStoreNotFound
m *discord.Member
merr = ErrStoreNotFound
)
if s.Gateway.HasIntents(gateway.IntentGuilds) {
g, gerr = s.Store.Guild(guildID)
}
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
m, merr = s.Store.Member(guildID, userID)
}
switch {
case gerr != nil && merr != nil:
@ -258,8 +271,21 @@ func (s *State) Permissions(
var wg sync.WaitGroup
g, gerr := s.Store.Guild(ch.GuildID)
m, merr := s.Store.Member(ch.GuildID, userID)
var (
g *discord.Guild
gerr = ErrStoreNotFound
m *discord.Member
merr = ErrStoreNotFound
)
if s.Gateway.HasIntents(gateway.IntentGuilds) {
g, gerr = s.Store.Guild(ch.GuildID)
}
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
m, merr = s.Store.Member(ch.GuildID, userID)
}
switch {
case gerr != nil && merr != nil:
@ -306,40 +332,46 @@ func (s *State) Me() (*discord.User, error) {
////
func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) {
c, err := s.Store.Channel(id)
if err == nil {
return c, nil
func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) {
c, err = s.Store.Channel(id)
if err == nil && s.tracksChannel(c) {
return
}
c, err = s.Session.Channel(id)
if err != nil {
return nil, err
return
}
return c, s.Store.ChannelSet(*c)
if s.tracksChannel(c) {
err = s.Store.ChannelSet(*c)
}
return
}
func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
c, err := s.Store.Channels(guildID)
if err == nil {
return c, nil
}
c, err = s.Session.Channels(guildID)
if err != nil {
return nil, err
}
for _, ch := range c {
ch := ch
if err := s.Store.ChannelSet(ch); err != nil {
return nil, err
func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err error) {
if s.Gateway.HasIntents(gateway.IntentGuilds) {
cs, err = s.Store.Channels(guildID)
if err == nil {
return
}
}
return c, nil
cs, err = s.Session.Channels(guildID)
if err != nil {
return
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, c := range cs {
if err = s.Store.ChannelSet(c); err != nil {
return
}
}
}
return
}
func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
@ -356,36 +388,40 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel
return c, s.Store.ChannelSet(*c)
}
// PrivateChannels gets the direct messages of the user.
// This is not supported for bots.
func (s *State) PrivateChannels() ([]discord.Channel, error) {
c, err := s.Store.PrivateChannels()
cs, err := s.Store.PrivateChannels()
if err == nil {
return c, nil
return cs, nil
}
c, err = s.Session.PrivateChannels()
cs, err = s.Session.PrivateChannels()
if err != nil {
return nil, err
}
for _, ch := range c {
ch := ch
if err := s.Store.ChannelSet(ch); err != nil {
for _, c := range cs {
if err := s.Store.ChannelSet(c); err != nil {
return nil, err
}
}
return c, nil
return cs, nil
}
////
func (s *State) Emoji(
guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
guildID discord.GuildID, emojiID discord.EmojiID) (e *discord.Emoji, err error) {
e, err := s.Store.Emoji(guildID, emojiID)
if err == nil {
return e, nil
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
e, err = s.Store.Emoji(guildID, emojiID)
if err == nil {
return
}
} else { // Fast path
return s.Session.Emoji(guildID, emojiID)
}
es, err := s.Session.Emojis(guildID)
@ -393,8 +429,8 @@ func (s *State) Emoji(
return nil, err
}
if err := s.Store.EmojiSet(guildID, es); err != nil {
return nil, err
if err = s.Store.EmojiSet(guildID, es); err != nil {
return
}
for _, e := range es {
@ -406,86 +442,99 @@ func (s *State) Emoji(
return nil, ErrStoreNotFound
}
func (s *State) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
e, err := s.Store.Emojis(guildID)
if err == nil {
return e, nil
func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) {
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
es, err = s.Store.Emojis(guildID)
if err == nil {
return
}
}
es, err := s.Session.Emojis(guildID)
es, err = s.Session.Emojis(guildID)
if err != nil {
return nil, err
return
}
return es, s.Store.EmojiSet(guildID, es)
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
err = s.Store.EmojiSet(guildID, es)
}
return
}
////
func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
c, err := s.Store.Guild(id)
if err == nil {
return c, nil
if s.Gateway.HasIntents(gateway.IntentGuilds) {
c, err := s.Store.Guild(id)
if err == nil {
return c, nil
}
}
return s.fetchGuild(id)
}
// Guilds will only fill a maximum of 100 guilds from the API.
func (s *State) Guilds() ([]discord.Guild, error) {
c, err := s.Store.Guilds()
if err == nil {
return c, nil
}
c, err = s.Session.Guilds(MaxFetchGuilds)
if err != nil {
return nil, err
}
for _, ch := range c {
ch := ch
if err := s.Store.GuildSet(ch); err != nil {
return nil, err
func (s *State) Guilds() (gs []discord.Guild, err error) {
if s.Gateway.HasIntents(gateway.IntentGuilds) {
gs, err = s.Store.Guilds()
if err == nil {
return
}
}
return c, nil
gs, err = s.Session.Guilds(MaxFetchGuilds)
if err != nil {
return
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, g := range gs {
if err = s.Store.GuildSet(g); err != nil {
return
}
}
}
return
}
////
func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
m, err := s.Store.Member(guildID, userID)
if err == nil {
return m, nil
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
m, err := s.Store.Member(guildID, userID)
if err == nil {
return m, nil
}
}
return s.fetchMember(guildID, userID)
}
func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) {
ms, err := s.Store.Members(guildID)
if err == nil {
return ms, nil
func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error) {
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
ms, err = s.Store.Members(guildID)
if err == nil {
return
}
}
ms, err = s.Session.Members(guildID, MaxFetchMembers)
if err != nil {
return nil, err
return
}
for _, m := range ms {
if err := s.Store.MemberSet(guildID, m); err != nil {
return nil, err
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
for _, m := range ms {
if err = s.Store.MemberSet(guildID, m); err != nil {
return
}
}
}
return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{
GuildID: []discord.GuildID{guildID},
Presences: true,
})
return
}
////
@ -494,18 +543,23 @@ func (s *State) Message(
channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
m, err := s.Store.Message(channelID, messageID)
if err == nil {
if err == nil && s.tracksMessage(m) {
return m, nil
}
var wg sync.WaitGroup
var (
wg sync.WaitGroup
c, cerr := s.Store.Channel(channelID)
if cerr != nil {
c *discord.Channel
cerr = ErrStoreNotFound
)
c, cerr = s.Store.Channel(channelID)
if cerr != nil || !s.tracksChannel(c) {
wg.Add(1)
go func() {
c, cerr = s.Session.Channel(channelID)
if cerr == nil {
if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
cerr = s.Store.ChannelSet(*c)
}
@ -527,17 +581,21 @@ func (s *State) Message(
m.ChannelID = c.ID
m.GuildID = c.GuildID
return m, s.Store.MessageSet(*m)
if s.tracksMessage(m) {
err = s.Store.MessageSet(*m)
}
return m, err
}
// Messages fetches maximum 100 messages from the API, if it has to. There is no
// limit if it's from the State storage.
// Messages fetches maximum 100 messages from the API, if it has to. There is
// no limit if it's from the State storage.
func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
// TODO: Think of a design that doesn't rely on MaxMessages().
var maxMsgs = s.MaxMessages()
ms, err := s.Store.Messages(channelID)
if err == nil {
if err == nil && (len(ms) == 0 || s.tracksMessage(&ms[0])) {
// If the state already has as many messages as it can, skip the API.
if maxMsgs <= len(ms) {
return ms, nil
@ -570,14 +628,16 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error)
guildID = c.GuildID
}
// Iterate in reverse, since the store is expected to prepend the latest
// messages.
for i := len(ms) - 1; i >= 0; i-- {
// Set the guild ID, fine if it's 0 (it's already 0 anyway).
ms[i].GuildID = guildID
if len(ms) > 0 && s.tracksMessage(&ms[0]) {
// Iterate in reverse, since the store is expected to prepend the latest
// messages.
for i := len(ms) - 1; i >= 0; i-- {
// Set the guild ID, fine if it's 0 (it's already 0 anyway).
ms[i].GuildID = guildID
if err := s.Store.MessageSet(ms[i]); err != nil {
return nil, err
if err := s.Store.MessageSet(ms[i]); err != nil {
return nil, err
}
}
}
@ -597,19 +657,22 @@ func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error)
////
// Presence checks the state for user presences. If no guildID is given, it will
// look for the presence in all guilds.
// Presence checks the state for user presences. If no guildID is given, it
// will look for the presence in all cached guilds.
func (s *State) Presence(
guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
p, err := s.Store.Presence(guildID, userID)
if err == nil {
return p, nil
if !s.Gateway.HasIntents(gateway.IntentGuildPresences) {
return nil, ErrStoreNotFound
}
// If there's no guild ID, look in all guilds
if !guildID.IsValid() {
g, err := s.Guilds()
if !s.Gateway.HasIntents(gateway.IntentGuilds) {
return nil, ErrStoreNotFound
}
g, err := s.Store.Guilds()
if err != nil {
return nil, err
}
@ -619,43 +682,46 @@ func (s *State) Presence(
return p, nil
}
}
return nil, ErrStoreNotFound
}
return nil, err
return s.Store.Presence(guildID, userID)
}
////
func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) {
r, err := s.Store.Role(guildID, roleID)
if err == nil {
return r, nil
func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *discord.Role, err error) {
if s.Gateway.HasIntents(gateway.IntentGuilds) {
target, err = s.Store.Role(guildID, roleID)
if err == nil {
return
}
}
rs, err := s.Session.Roles(guildID)
if err != nil {
return nil, err
return
}
var role *discord.Role
for _, r := range rs {
r := r
if r.ID == roleID {
role = &r
r := r // copy to prevent mem aliasing
target = &r
}
if err := s.RoleSet(guildID, r); err != nil {
return role, err
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if err = s.RoleSet(guildID, r); err != nil {
return
}
}
}
if role == nil {
if target == nil {
return nil, ErrStoreNotFound
}
return role, nil
return
}
func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
@ -669,11 +735,11 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
return nil, err
}
for _, r := range rs {
r := r
if err := s.RoleSet(guildID, r); err != nil {
return rs, err
if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, r := range rs {
if err := s.RoleSet(guildID, r); err != nil {
return rs, err
}
}
}
@ -682,7 +748,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
g, err = s.Session.Guild(id)
if err == nil {
if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
err = s.Store.GuildSet(*g)
}
@ -693,9 +759,22 @@ func (s *State) fetchMember(
guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) {
m, err = s.Session.Member(guildID, userID)
if err == nil {
if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) {
err = s.Store.MemberSet(guildID, *m)
}
return
}
// tracksMessage reports whether the state would track the passed message and
// messages from the same channel.
func (s *State) tracksMessage(m *discord.Message) bool {
return (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) ||
(!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages))
}
// tracksChannel reports whether the state would track the passed channel.
func (s *State) tracksChannel(c *discord.Channel) bool {
return (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) ||
!c.GuildID.IsValid()
}

View file

@ -56,10 +56,8 @@ func (s *State) onEvent(iface interface{}) {
s.ready = *ev
// Reset the store before proceeding.
if resetter, ok := s.Store.(StoreResetter); ok {
if err := resetter.Reset(); err != nil {
s.stateErr(err, "Failed to reset state on READY")
}
if err := s.Store.Reset(); err != nil {
s.stateErr(err, "failed to reset state on READY")
}
// Handle presences

View file

@ -11,6 +11,7 @@ import (
type Store interface {
StoreGetter
StoreModifier
StoreResetter
}
// All methods in StoreGetter will be wrapped by the State. If the State can't