Gateway: State refactored into smaller components

This commit is contained in:
diamondburned (Forefront) 2020-02-29 18:50:50 -08:00
parent f0102d765f
commit 11465c62bd
5 changed files with 411 additions and 188 deletions

97
:w Normal file
View File

@ -0,0 +1,97 @@
package state
import (
"errors"
"github.com/diamondburned/arikawa/discord"
)
// Store is the state storage. It should handle mutex itself, and it should only
// concern itself with the local state.
type Store interface {
StoreGetter
StoreModifier
}
type StoreMe interface {
Me() (*discord.User, error)
MyselfSet(me *discord.User) error
}
type StoreChannel interface {
Channel(id discord.Snowflake) (*discord.Channel, error)
Channels(guildID discord.Snowflake) ([]discord.Channel, error)
PrivateChannels() ([]discord.Channel, error)
// ChannelSet should switch on Type to know if it's a private channel or
// not.
ChannelSet(*discord.Channel) error
ChannelRemove(*discord.Channel) error
}
type StoreEmoji interface {
Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error)
Emojis(guildID discord.Snowflake) ([]discord.Emoji, error)
EmojiSet(guildID discord.Snowflake, emojis []discord.Emoji) error
}
type StoreGuild interface {
Guild(id discord.Snowflake) (*discord.Guild, error)
Guilds() ([]discord.Guild, error)
GuildSet(*discord.Guild) error
GuildRemove(id discord.Snowflake) error
}
type StoreMember interface {
Member(guildID, userID discord.Snowflake) (*discord.Member, error)
Members(guildID discord.Snowflake) ([]discord.Member, error)
MemberSet(guildID discord.Snowflake, member *discord.Member) error
MemberRemove(guildID, userID discord.Snowflake) error
}
type StoreMessage interface {
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
Messages(channelID discord.Snowflake) ([]discord.Message, error)
MaxMessages() int // used to know if the state is filled or not.
MessageSet(*discord.Message) error
MessageRemove(channelID, messageID discord.Snowflake) error
}
type StorePresence interface {
// These don't get fetched from the API, it's Gateway only.
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
Presences(guildID discord.Snowflake) ([]discord.Presence, error)
PresenceSet(guildID discord.Snowflake, presence *discord.Presence) error
PresenceRemove(guildID, userID discord.Snowflake) error
}
// All methods in StoreGetter will be wrapped by the State. If the State can't
// find anything in the storage, it will call the API itself and automatically
// add what's missing into the storage.
//
// Methods that return with a slice should pay attention to race conditions that
// would mutate the underlying slice (and as a result the returned slice as
// well). The best way to avoid this is to copy the whole slice, like
// DefaultStore does.
type StoreGetter interface {
Role(guildID, roleID discord.Snowflake) (*discord.Role, error)
Roles(guildID discord.Snowflake) ([]discord.Role, error)
}
type StoreModifier interface {
RoleSet(guildID discord.Snowflake, role *discord.Role) error
RoleRemove(guildID, roleID discord.Snowflake) error
// This should reset all the state to zero/null.
Reset() error
}
// ErrStoreNotFound is an error that a store can use to return when something
// isn't in the storage. There is no strict restrictions on what uses this (the
// default one does, though), so be advised.
var ErrStoreNotFound = errors.New("item not found in store")

View File

@ -103,7 +103,7 @@ func TestContext(t *testing.T) {
t.Fatal("given's Context field is nil")
}
if given.Ctx.State.Store == nil {
if given.Ctx.State.StoreMessage == nil {
t.Fatal("given's State is nil")
}
})

View File

@ -49,6 +49,20 @@ type State struct {
fewMutex sync.Mutex
}
// Store serves as an option that NewFromSession uses to add in stores. All
// fields are optional.
type Store struct {
StoreMe
StoreChannel
StoreEmoji
StoreGuild
StoreMember
StoreMessage
StorePresence
StoreRole
Resetter
}
func NewFromSession(s *session.Session, store Store) (*State, error) {
state := &State{
Session: s,
@ -148,9 +162,7 @@ func (s *State) MemberColor(guildID, userID discord.Snowflake) discord.Color {
////
func (s *State) Permissions(
channelID, userID discord.Snowflake) (discord.Permissions, error) {
func (s *State) Permissions(channelID, userID discord.Snowflake) (discord.Permissions, error) {
ch, err := s.Channel(channelID)
if err != nil {
return 0, errors.Wrap(err, "Failed to get channel")
@ -172,51 +184,67 @@ func (s *State) Permissions(
////
func (s *State) Me() (*discord.User, error) {
u, err := s.Store.Me()
if err == nil {
return u, nil
if s.StoreMe != nil {
u, err := s.StoreMe.Me()
if err == nil {
return u, nil
}
}
u, err = s.Session.Me()
u, err := s.Session.Me()
if err != nil {
return nil, err
}
return u, s.Store.MyselfSet(u)
if s.StoreMe != nil {
return u, s.StoreMe.MyselfSet(u)
}
return u, nil
}
////
func (s *State) Channel(id discord.Snowflake) (*discord.Channel, error) {
c, err := s.Store.Channel(id)
if err == nil {
return c, nil
if s.StoreChannel != nil {
c, err := s.StoreChannel.Channel(id)
if err == nil {
return c, nil
}
}
c, err = s.Session.Channel(id)
c, err := s.Session.Channel(id)
if err != nil {
return nil, err
}
return c, s.Store.ChannelSet(c)
if s.StoreChannel != nil {
return c, s.StoreChannel.ChannelSet(c)
}
return c, nil
}
func (s *State) Channels(guildID discord.Snowflake) ([]discord.Channel, error) {
c, err := s.Store.Channels(guildID)
if err == nil {
return c, nil
if s.StoreChannel != nil {
c, err := s.StoreChannel.Channels(guildID)
if err == nil {
return c, nil
}
}
c, err = s.Session.Channels(guildID)
c, err := s.Session.Channels(guildID)
if err != nil {
return nil, err
}
for _, ch := range c {
ch := ch
if s.StoreChannel != nil {
for _, ch := range c {
ch := ch
if err := s.Store.ChannelSet(&ch); err != nil {
return nil, err
if err := s.StoreChannel.ChannelSet(&ch); err != nil {
return nil, err
}
}
}
@ -225,12 +253,12 @@ func (s *State) Channels(guildID discord.Snowflake) ([]discord.Channel, error) {
////
func (s *State) Emoji(
guildID, emojiID discord.Snowflake) (*discord.Emoji, error) {
e, err := s.Store.Emoji(guildID, emojiID)
if err == nil {
return e, nil
func (s *State) Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error) {
if s.StoreEmoji != nil {
e, err := s.StoreEmoji.Emoji(guildID, emojiID)
if err == nil {
return e, nil
}
}
es, err := s.Session.Emojis(guildID)
@ -238,8 +266,10 @@ func (s *State) Emoji(
return nil, err
}
if err := s.Store.EmojiSet(guildID, es); err != nil {
return nil, err
if s.StoreEmoji != nil {
if err := s.StoreEmoji.EmojiSet(guildID, es); err != nil {
return nil, err
}
}
for _, e := range es {
@ -252,9 +282,11 @@ func (s *State) Emoji(
}
func (s *State) Emojis(guildID discord.Snowflake) ([]discord.Emoji, error) {
e, err := s.Store.Emojis(guildID)
if err == nil {
return e, nil
if s.StoreEmoji != nil {
e, err := s.StoreEmoji.Emojis(guildID)
if err == nil {
return e, nil
}
}
es, err := s.Session.Emojis(guildID)
@ -262,42 +294,54 @@ func (s *State) Emojis(guildID discord.Snowflake) ([]discord.Emoji, error) {
return nil, err
}
return es, s.Store.EmojiSet(guildID, es)
if s.StoreEmoji != nil {
return es, s.StoreEmoji.EmojiSet(guildID, es)
}
return es, nil
}
////
func (s *State) Guild(id discord.Snowflake) (*discord.Guild, error) {
c, err := s.Store.Guild(id)
if err == nil {
return c, nil
if s.StoreGuild != nil {
c, err := s.StoreGuild.Guild(id)
if err == nil {
return c, nil
}
}
c, err = s.Session.Guild(id)
c, err := s.Session.Guild(id)
if err != nil {
return nil, err
}
return c, s.Store.GuildSet(c)
if s.StoreGuild != nil {
return c, s.StoreGuild.GuildSet(c)
}
return c, nil
}
// Guilds will only fill a maximum of 100 guilds from the API.
func (s *State) Guilds() ([]discord.Guild, error) {
c, err := s.Store.Guilds()
if err == nil {
return c, nil
if s.StoreGuild != nil {
c, err := s.StoreGuild.Guilds()
if err == nil {
return c, nil
}
}
c, err = s.Session.Guilds(MaxFetchGuilds)
c, err := s.Session.Guilds(MaxFetchGuilds)
if err != nil {
return nil, err
}
for _, ch := range c {
ch := ch
if err := s.Store.GuildSet(&ch); err != nil {
return nil, err
if s.StoreGuild != nil {
for i := range c {
if err := s.StoreGuild.GuildSet(&c[i]); err != nil {
return nil, err
}
}
}
@ -306,56 +350,68 @@ func (s *State) Guilds() ([]discord.Guild, error) {
////
func (s *State) Member(
guildID, userID discord.Snowflake) (*discord.Member, error) {
m, err := s.Store.Member(guildID, userID)
if err == nil {
return m, nil
}
m, err = s.Session.Member(guildID, userID)
if err != nil {
return nil, err
}
return m, s.Store.MemberSet(guildID, m)
}
func (s *State) Members(guildID discord.Snowflake) ([]discord.Member, error) {
ms, err := s.Store.Members(guildID)
if err == nil {
return ms, nil
}
ms, err = s.Session.Members(guildID, MaxFetchMembers)
if err != nil {
return nil, err
}
for _, m := range ms {
if err := s.Store.MemberSet(guildID, &m); err != nil {
return nil, err
func (s *State) Member(guildID, userID discord.Snowflake) (*discord.Member, error) {
if s.StoreMember != nil {
m, err := s.StoreMember.Member(guildID, userID)
if err == nil {
return m, nil
}
}
return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{
GuildID: []discord.Snowflake{guildID},
Presences: true,
})
m, err := s.Session.Member(guildID, userID)
if err != nil {
return nil, err
}
if s.StoreMember != nil {
return m, s.StoreMember.MemberSet(guildID, m)
}
return m, nil
}
// Members when called for its first time may not return a lot.
func (s *State) Members(guildID discord.Snowflake) ([]discord.Member, error) {
if s.StoreMember != nil {
ms, err := s.StoreMember.Members(guildID)
if err == nil {
return ms, nil
}
}
ms, err := s.Session.Members(guildID, MaxFetchMembers)
if err != nil {
return nil, err
}
if s.StoreMember != nil {
for _, m := range ms {
if err := s.StoreMember.MemberSet(guildID, &m); err != nil {
return nil, err
}
}
// idk why I wrote this
return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{
GuildID: []discord.Snowflake{guildID},
Presences: true,
})
}
return ms, nil
}
////
func (s *State) Message(
channelID, messageID discord.Snowflake) (*discord.Message, error) {
m, err := s.Store.Message(channelID, messageID)
if err == nil {
return m, nil
func (s *State) Message(channelID, messageID discord.Snowflake) (*discord.Message, error) {
if s.StoreMessage != nil {
m, err := s.StoreMessage.Message(channelID, messageID)
if err == nil {
return m, nil
}
}
m, err = s.Session.Message(channelID, messageID)
m, err := s.Session.Message(channelID, messageID)
if err != nil {
return nil, err
}
@ -367,34 +423,42 @@ func (s *State) Message(
m.GuildID = c.GuildID
}
return m, s.Store.MessageSet(m)
}
// Messages fetches maximum 100 messages from the API, if it has to. There is no
// limit if it's from the State storage.
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
// TODO: Think of a design that doesn't rely on MaxMessages().
var maxMsgs = s.MaxMessages()
ms, err := s.Store.Messages(channelID)
if err == nil {
// If the state already has as many messages as it can, skip the API.
if maxMsgs <= len(ms) {
return ms, nil
}
// Is the channel tiny?
s.fewMutex.Lock()
if _, ok := s.fewMessages[channelID]; ok {
s.fewMutex.Unlock()
return ms, nil
}
// No, fetch from the state.
s.fewMutex.Unlock()
if s.StoreMessage != nil {
return m, s.StoreMessage.MessageSet(m)
}
ms, err = s.Session.Messages(channelID, uint(maxMsgs))
return m, nil
}
// Messages fetches maximum 100 messages from the API, if it has to, or it will
// use the limit from the Message State.
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
var maxMsgs = 100
if s.StoreMessage != nil {
// TODO: Think of a design that doesn't rely on MaxMessages().
maxMsgs = s.StoreMessage.MaxMessages()
ms, err := s.StoreMessage.Messages(channelID)
if err == nil {
// If the state already has as many messages as it can, skip the API.
if maxMsgs <= len(ms) {
return ms, nil
}
// Is the channel tiny?
s.fewMutex.Lock()
if _, ok := s.fewMessages[channelID]; ok {
s.fewMutex.Unlock()
return ms, nil
}
// No, fetch from the state.
s.fewMutex.Unlock()
}
}
ms, err := s.Session.Messages(channelID, uint(maxMsgs))
if err != nil {
return nil, err
}
@ -410,11 +474,15 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
guildID = c.GuildID
}
if s.StoreMessage == nil {
return ms, nil
}
for i := range ms {
// Set the guild ID, fine if it's 0 (it's already 0 anyway).
ms[i].GuildID = guildID
if err := s.Store.MessageSet(&ms[i]); err != nil {
if err := s.StoreMessage.MessageSet(&ms[i]); err != nil {
return nil, err
}
}
@ -436,9 +504,14 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
////
// Presence checks the state for user presences. If no guildID is given, it will
// look for the presence in all guilds.
// look for the presence in all guilds. The function will error out if
// StorePresences is nil.
func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence, error) {
p, err := s.Store.Presence(guildID, userID)
if s.StorePresence == nil {
return nil, ErrStoreNotFound
}
p, err := s.StorePresence.Presence(guildID, userID)
if err == nil {
return p, nil
}
@ -451,7 +524,7 @@ func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence,
}
for _, g := range g {
if p, err := s.Store.Presence(g.ID, userID); err == nil {
if p, err := s.StorePresence.Presence(g.ID, userID); err == nil {
return p, nil
}
}
@ -460,17 +533,22 @@ func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence,
return nil, err
}
// Presences only returns presences if StorePresences is not nil.
func (s *State) Presences(guildID discord.Snowflake) ([]discord.Presence, error) {
return s.Store.Presences(guildID)
if s.StorePresence == nil {
return nil, ErrStoreNotFound
}
return s.StorePresence.Presences(guildID)
}
////
func (s *State) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) {
r, err := s.Store.Role(guildID, roleID)
if err == nil {
return r, nil
if s.StoreRole != nil {
r, err := s.StoreRole.Role(guildID, roleID)
if err == nil {
return r, nil
}
}
rs, err := s.Session.Roles(guildID)
@ -487,30 +565,40 @@ func (s *State) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) {
role = &r
}
if err := s.RoleSet(guildID, &r); err != nil {
return role, err
if s.StoreRole != nil {
if err := s.StoreRole.RoleSet(guildID, &r); err != nil {
return nil, err
}
}
}
if role == nil {
return nil, ErrStoreNotFound
}
return role, nil
}
func (s *State) Roles(guildID discord.Snowflake) ([]discord.Role, error) {
rs, err := s.Store.Roles(guildID)
if err == nil {
return rs, nil
if s.StoreRole != nil {
rs, err := s.StoreRole.Roles(guildID)
if err == nil {
return rs, nil
}
}
rs, err = s.Session.Roles(guildID)
rs, err := s.Session.Roles(guildID)
if err != nil {
return nil, err
}
for _, r := range rs {
r := r
if s.StoreRole != nil {
for _, r := range rs {
r := r
if err := s.RoleSet(guildID, &r); err != nil {
return rs, err
if err := s.StoreRole.RoleSet(guildID, &r); err != nil {
return rs, err
}
}
}

View File

@ -8,9 +8,23 @@ import (
// Store is the state storage. It should handle mutex itself, and it should only
// concern itself with the local state.
type Store interface {
StoreGetter
StoreModifier
// type Store interface {
// StoreMe
// StoreChannel
// StoreEmoji
// StoreGuild
// StoreMember
// StoreMessage
// StorePresence
// StoreRole
// // This should reset all the state to zero/null.
// Reset() error
// }
// Resetter is an optional state reset function that stores could implement.
type Resetter interface {
Reset() error
}
// All methods in StoreGetter will be wrapped by the State. If the State can't
@ -21,62 +35,66 @@ type Store interface {
// would mutate the underlying slice (and as a result the returned slice as
// well). The best way to avoid this is to copy the whole slice, like
// DefaultStore does.
type StoreGetter interface {
Me() (*discord.User, error)
type (
StoreMe interface {
Me() (*discord.User, error)
MyselfSet(me *discord.User) error
}
Channel(id discord.Snowflake) (*discord.Channel, error)
Channels(guildID discord.Snowflake) ([]discord.Channel, error)
PrivateChannels() ([]discord.Channel, error)
StoreChannel interface {
Channel(id discord.Snowflake) (*discord.Channel, error)
Channels(guildID discord.Snowflake) ([]discord.Channel, error)
PrivateChannels() ([]discord.Channel, error)
Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error)
Emojis(guildID discord.Snowflake) ([]discord.Emoji, error)
// ChannelSet should switch on Type to know if it's a private channel or
// not.
ChannelSet(*discord.Channel) error
ChannelRemove(*discord.Channel) error
}
Guild(id discord.Snowflake) (*discord.Guild, error)
Guilds() ([]discord.Guild, error)
StoreEmoji interface {
Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error)
Emojis(guildID discord.Snowflake) ([]discord.Emoji, error)
EmojiSet(guildID discord.Snowflake, emojis []discord.Emoji) error
}
Member(guildID, userID discord.Snowflake) (*discord.Member, error)
Members(guildID discord.Snowflake) ([]discord.Member, error)
StoreGuild interface {
Guild(id discord.Snowflake) (*discord.Guild, error)
Guilds() ([]discord.Guild, error)
GuildSet(*discord.Guild) error
GuildRemove(id discord.Snowflake) error
}
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
Messages(channelID discord.Snowflake) ([]discord.Message, error)
MaxMessages() int // used to know if the state is filled or not.
StoreMember interface {
Member(guildID, userID discord.Snowflake) (*discord.Member, error)
Members(guildID discord.Snowflake) ([]discord.Member, error)
MemberSet(guildID discord.Snowflake, member *discord.Member) error
MemberRemove(guildID, userID discord.Snowflake) error
}
// These don't get fetched from the API, it's Gateway only.
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
Presences(guildID discord.Snowflake) ([]discord.Presence, error)
StoreMessage interface {
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
Messages(channelID discord.Snowflake) ([]discord.Message, error)
MaxMessages() int // used to know if the state is filled or not.
MessageSet(*discord.Message) error
MessageRemove(channelID, messageID discord.Snowflake) error
}
Role(guildID, roleID discord.Snowflake) (*discord.Role, error)
Roles(guildID discord.Snowflake) ([]discord.Role, error)
}
StorePresence interface {
// These don't get fetched from the API, it's Gateway only.
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
Presences(guildID discord.Snowflake) ([]discord.Presence, error)
PresenceSet(guildID discord.Snowflake, presence *discord.Presence) error
PresenceRemove(guildID, userID discord.Snowflake) error
}
type StoreModifier interface {
MyselfSet(me *discord.User) error
// ChannelSet should switch on Type to know if it's a private channel or
// not.
ChannelSet(*discord.Channel) error
ChannelRemove(*discord.Channel) error
EmojiSet(guildID discord.Snowflake, emojis []discord.Emoji) error
GuildSet(*discord.Guild) error
GuildRemove(id discord.Snowflake) error
MemberSet(guildID discord.Snowflake, member *discord.Member) error
MemberRemove(guildID, userID discord.Snowflake) error
MessageSet(*discord.Message) error
MessageRemove(channelID, messageID discord.Snowflake) error
PresenceSet(guildID discord.Snowflake, presence *discord.Presence) error
PresenceRemove(guildID, userID discord.Snowflake) error
RoleSet(guildID discord.Snowflake, role *discord.Role) error
RoleRemove(guildID, roleID discord.Snowflake) error
// This should reset all the state to zero/null.
Reset() error
}
StoreRole interface {
Role(guildID, roleID discord.Snowflake) (*discord.Role, error)
Roles(guildID discord.Snowflake) ([]discord.Role, error)
RoleSet(guildID discord.Snowflake, role *discord.Role) error
RoleRemove(guildID, roleID discord.Snowflake) error
}
)
// ErrStoreNotFound is an error that a store can use to return when something
// isn't in the storage. There is no strict restrictions on what uses this (the

View File

@ -30,9 +30,19 @@ type DefaultStoreOptions struct {
MaxMessages uint // default 50
}
var _ Store = (*DefaultStore)(nil)
var (
_ StoreMe = (*DefaultStore)(nil)
_ StoreChannel = (*DefaultStore)(nil)
_ StoreEmoji = (*DefaultStore)(nil)
_ StoreGuild = (*DefaultStore)(nil)
_ StoreMember = (*DefaultStore)(nil)
_ StoreMessage = (*DefaultStore)(nil)
_ StorePresence = (*DefaultStore)(nil)
_ StoreRole = (*DefaultStore)(nil)
_ Resetter = (*DefaultStore)(nil)
)
func NewDefaultStore(opts *DefaultStoreOptions) *DefaultStore {
func NewDefaultStore(opts *DefaultStoreOptions) Store {
if opts == nil {
opts = &DefaultStoreOptions{
MaxMessages: 50,
@ -44,7 +54,17 @@ func NewDefaultStore(opts *DefaultStoreOptions) *DefaultStore {
}
ds.Reset()
return ds
return Store{
StoreMe: ds,
StoreChannel: ds,
StoreEmoji: ds,
StoreGuild: ds,
StoreMember: ds,
StoreMessage: ds,
StorePresence: ds,
StoreRole: ds,
Resetter: ds,
}
}
func (s *DefaultStore) Reset() error {