mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-12-01 03:03:48 +00:00
687 lines
15 KiB
Go
687 lines
15 KiB
Go
// Package state provides interfaces for a local or remote state, as well as
|
|
// abstractions around the REST API and Gateway events.
|
|
package state
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
|
|
"github.com/diamondburned/arikawa/v2/discord"
|
|
"github.com/diamondburned/arikawa/v2/gateway"
|
|
"github.com/diamondburned/arikawa/v2/internal/moreatomic"
|
|
"github.com/diamondburned/arikawa/v2/session"
|
|
"github.com/diamondburned/arikawa/v2/utils/handler"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
var (
|
|
MaxFetchMembers uint = 1000
|
|
MaxFetchGuilds uint = 10
|
|
)
|
|
|
|
// State is the cache to store events coming from Discord as well as data from
|
|
// API calls.
|
|
//
|
|
// Store
|
|
//
|
|
// The state basically provides abstractions on top of the API and the state
|
|
// storage (Store). The state storage is effectively a set of interfaces which
|
|
// allow arbitrary backends to be implemented.
|
|
//
|
|
// The default storage backend is a typical in-memory structure consisting of
|
|
// maps and slices. Custom backend implementations could embed this storage
|
|
// backend as an in-memory fallback. A good example of this would be embedding
|
|
// the default store for messages only, while handling everything else in Redis.
|
|
//
|
|
// The package also provides a no-op store (NoopStore) that implementations
|
|
// could embed. This no-op store will always return an error, which makes the
|
|
// state fetch information from the API. The setters are all no-ops, so the
|
|
// fetched data won't be updated.
|
|
//
|
|
// Handler
|
|
//
|
|
// The state uses its own handler over session's to make all handlers run after
|
|
// the state updates itself. A PreHandler is exposed in any case the user needs
|
|
// the handlers to run before the state updates itself. Refer to that field's
|
|
// documentation.
|
|
//
|
|
// The state also provides extra events and overrides to make up for Discord's
|
|
// inconsistencies in data. The following are known instances of such.
|
|
//
|
|
// The Guild Create event is split up to make the state's Guild Available, Guild
|
|
// Ready and Guild Join events. Refer to these events' documentations for more
|
|
// information.
|
|
//
|
|
// The Message Create and Message Update events with the Member field provided
|
|
// will have the User field copied from Author. This is because the User field
|
|
// will be empty, while the Member structure expects it to be there.
|
|
type State struct {
|
|
*session.Session
|
|
Store
|
|
|
|
// *: State doesn't actually keep track of pinned messages.
|
|
|
|
// Ready is not updated by the state.
|
|
Ready gateway.ReadyEvent
|
|
|
|
// StateLog logs all errors that come from the state cache. This includes
|
|
// not found errors. Defaults to a no-op, as state errors aren't that
|
|
// important.
|
|
StateLog func(error)
|
|
|
|
// PreHandler is the manual hook that is executed before the State handler
|
|
// is. This should only be used for low-level operations.
|
|
// It's recommended to set Synchronous to true if you mutate the events.
|
|
PreHandler *handler.Handler // default nil
|
|
|
|
// Command handler with inherited methods. Ran after PreHandler. You should
|
|
// most of the time use this instead of Session's, to avoid race conditions
|
|
// with the State.
|
|
*handler.Handler
|
|
|
|
// List of channels with few messages, so it doesn't bother hitting the API
|
|
// again.
|
|
fewMessages map[discord.ChannelID]struct{}
|
|
fewMutex *sync.Mutex
|
|
|
|
// unavailableGuilds is a set of discord.GuildIDs of guilds that became
|
|
// unavailable when already connected to the gateway, i.e. sent in a
|
|
// GuildUnavailableEvent.
|
|
unavailableGuilds *moreatomic.GuildIDSet
|
|
// unreadyGuilds is a set of discord.GuildIDs of guilds that were
|
|
// unavailable when connecting to the gateway, i.e. they had Unavailable
|
|
// set to true during Ready.
|
|
unreadyGuilds *moreatomic.GuildIDSet
|
|
}
|
|
|
|
// New creates a new state.
|
|
func New(token string) (*State, error) {
|
|
return NewWithStore(token, NewDefaultStore(nil))
|
|
}
|
|
|
|
// NewWithIntents creates a new state with the given gateway intents. For more
|
|
// information, refer to gateway.Intents.
|
|
func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
|
|
s, err := session.NewWithIntents(token, intents...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewFromSession(s, NewDefaultStore(nil))
|
|
}
|
|
|
|
func NewWithStore(token string, store Store) (*State, error) {
|
|
s, err := session.New(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewFromSession(s, store)
|
|
}
|
|
|
|
// NewFromSession never returns an error. This API is kept for backwards
|
|
// compatibility.
|
|
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
|
state := &State{
|
|
Session: s,
|
|
Store: store,
|
|
Handler: handler.New(),
|
|
StateLog: func(err error) {},
|
|
fewMessages: map[discord.ChannelID]struct{}{},
|
|
fewMutex: new(sync.Mutex),
|
|
unavailableGuilds: moreatomic.NewGuildIDSet(),
|
|
unreadyGuilds: moreatomic.NewGuildIDSet(),
|
|
}
|
|
state.hookSession()
|
|
return state, nil
|
|
}
|
|
|
|
// WithContext returns a shallow copy of State with the context replaced in the
|
|
// API client. All methods called on the State will use this given context. This
|
|
// method is thread-safe.
|
|
func (s *State) WithContext(ctx context.Context) *State {
|
|
copied := *s
|
|
copied.Client = copied.Client.WithContext(ctx)
|
|
|
|
return &copied
|
|
}
|
|
|
|
//// Helper methods
|
|
|
|
func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string {
|
|
if !message.GuildID.IsValid() {
|
|
return message.Author.Username
|
|
}
|
|
|
|
if message.Member != nil {
|
|
if message.Member.Nick != "" {
|
|
return message.Member.Nick
|
|
}
|
|
return message.Author.Username
|
|
}
|
|
|
|
n, err := s.MemberDisplayName(message.GuildID, message.Author.ID)
|
|
if err != nil {
|
|
return message.Author.Username
|
|
}
|
|
|
|
return n
|
|
}
|
|
|
|
func (s *State) MemberDisplayName(guildID discord.GuildID, userID discord.UserID) (string, error) {
|
|
member, err := s.Member(guildID, userID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if member.Nick == "" {
|
|
return member.User.Username, nil
|
|
}
|
|
|
|
return member.Nick, nil
|
|
}
|
|
|
|
func (s *State) AuthorColor(message *gateway.MessageCreateEvent) (discord.Color, error) {
|
|
if !message.GuildID.IsValid() { // this is a dm
|
|
return discord.DefaultMemberColor, nil
|
|
}
|
|
|
|
if message.Member != nil {
|
|
guild, err := s.Guild(message.GuildID)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return discord.MemberColor(*guild, *message.Member), nil
|
|
}
|
|
|
|
return s.MemberColor(message.GuildID, message.Author.ID)
|
|
}
|
|
|
|
func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (discord.Color, error) {
|
|
var wg sync.WaitGroup
|
|
|
|
g, gerr := s.Store.Guild(guildID)
|
|
m, merr := s.Store.Member(guildID, userID)
|
|
|
|
switch {
|
|
case gerr != nil && merr != nil:
|
|
wg.Add(1)
|
|
go func() {
|
|
g, gerr = s.fetchGuild(guildID)
|
|
wg.Done()
|
|
}()
|
|
|
|
m, merr = s.fetchMember(guildID, userID)
|
|
case gerr != nil:
|
|
g, gerr = s.fetchGuild(guildID)
|
|
case merr != nil:
|
|
m, merr = s.fetchMember(guildID, userID)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if gerr != nil {
|
|
return 0, errors.Wrap(merr, "failed to get guild")
|
|
}
|
|
if merr != nil {
|
|
return 0, errors.Wrap(merr, "failed to get member")
|
|
}
|
|
|
|
return discord.MemberColor(*g, *m), nil
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Permissions(
|
|
channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) {
|
|
|
|
ch, err := s.Channel(channelID)
|
|
if err != nil {
|
|
return 0, errors.Wrap(err, "failed to get channel")
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
g, gerr := s.Store.Guild(ch.GuildID)
|
|
m, merr := s.Store.Member(ch.GuildID, userID)
|
|
|
|
switch {
|
|
case gerr != nil && merr != nil:
|
|
wg.Add(1)
|
|
go func() {
|
|
g, gerr = s.fetchGuild(ch.GuildID)
|
|
wg.Done()
|
|
}()
|
|
|
|
m, merr = s.fetchMember(ch.GuildID, userID)
|
|
case gerr != nil:
|
|
g, gerr = s.fetchGuild(ch.GuildID)
|
|
case merr != nil:
|
|
m, merr = s.fetchMember(ch.GuildID, userID)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if gerr != nil {
|
|
return 0, errors.Wrap(merr, "failed to get guild")
|
|
}
|
|
if merr != nil {
|
|
return 0, errors.Wrap(merr, "failed to get member")
|
|
}
|
|
|
|
return discord.CalcOverwrites(*g, *ch, *m), nil
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Me() (*discord.User, error) {
|
|
u, err := s.Store.Me()
|
|
if err == nil {
|
|
return u, nil
|
|
}
|
|
|
|
u, err = s.Session.Me()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return u, s.Store.MyselfSet(*u)
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) {
|
|
c, err := s.Store.Channel(id)
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
|
|
c, err = s.Session.Channel(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return c, s.Store.ChannelSet(*c)
|
|
}
|
|
|
|
func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
|
|
c, err := s.Store.Channels(guildID)
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
|
|
c, err = s.Session.Channels(guildID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, ch := range c {
|
|
ch := ch
|
|
|
|
if err := s.Store.ChannelSet(ch); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
|
|
c, err := s.Store.CreatePrivateChannel(recipient)
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
|
|
c, err = s.Session.CreatePrivateChannel(recipient)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return c, s.Store.ChannelSet(*c)
|
|
}
|
|
|
|
func (s *State) PrivateChannels() ([]discord.Channel, error) {
|
|
c, err := s.Store.PrivateChannels()
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
|
|
c, err = s.Session.PrivateChannels()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, ch := range c {
|
|
ch := ch
|
|
|
|
if err := s.Store.ChannelSet(ch); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Emoji(
|
|
guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
|
|
|
|
e, err := s.Store.Emoji(guildID, emojiID)
|
|
if err == nil {
|
|
return e, nil
|
|
}
|
|
|
|
es, err := s.Session.Emojis(guildID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := s.Store.EmojiSet(guildID, es); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, e := range es {
|
|
if e.ID == emojiID {
|
|
return &e, nil
|
|
}
|
|
}
|
|
|
|
return nil, ErrStoreNotFound
|
|
}
|
|
|
|
func (s *State) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
|
|
e, err := s.Store.Emojis(guildID)
|
|
if err == nil {
|
|
return e, nil
|
|
}
|
|
|
|
es, err := s.Session.Emojis(guildID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return es, s.Store.EmojiSet(guildID, es)
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
|
|
c, err := s.Store.Guild(id)
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
|
|
return s.fetchGuild(id)
|
|
}
|
|
|
|
// Guilds will only fill a maximum of 100 guilds from the API.
|
|
func (s *State) Guilds() ([]discord.Guild, error) {
|
|
c, err := s.Store.Guilds()
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
|
|
c, err = s.Session.Guilds(MaxFetchGuilds)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, ch := range c {
|
|
ch := ch
|
|
|
|
if err := s.Store.GuildSet(ch); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
|
|
m, err := s.Store.Member(guildID, userID)
|
|
if err == nil {
|
|
return m, nil
|
|
}
|
|
|
|
return s.fetchMember(guildID, userID)
|
|
}
|
|
|
|
func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) {
|
|
ms, err := s.Store.Members(guildID)
|
|
if err == nil {
|
|
return ms, nil
|
|
}
|
|
|
|
ms, err = s.Session.Members(guildID, MaxFetchMembers)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, m := range ms {
|
|
if err := s.Store.MemberSet(guildID, m); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{
|
|
GuildID: []discord.GuildID{guildID},
|
|
Presences: true,
|
|
})
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Message(
|
|
channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
|
|
|
|
m, err := s.Store.Message(channelID, messageID)
|
|
if err == nil {
|
|
return m, nil
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
c, cerr := s.Store.Channel(channelID)
|
|
if cerr != nil {
|
|
wg.Add(1)
|
|
go func() {
|
|
c, cerr = s.Session.Channel(channelID)
|
|
if cerr == nil {
|
|
cerr = s.Store.ChannelSet(*c)
|
|
}
|
|
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
m, err = s.Session.Message(channelID, messageID)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "unable to fetch message")
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if cerr != nil {
|
|
return nil, errors.Wrap(cerr, "unable to fetch channel")
|
|
}
|
|
|
|
m.ChannelID = c.ID
|
|
m.GuildID = c.GuildID
|
|
|
|
return m, s.Store.MessageSet(*m)
|
|
}
|
|
|
|
// Messages fetches maximum 100 messages from the API, if it has to. There is no
|
|
// limit if it's from the State storage.
|
|
func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
|
|
// TODO: Think of a design that doesn't rely on MaxMessages().
|
|
var maxMsgs = s.MaxMessages()
|
|
|
|
ms, err := s.Store.Messages(channelID)
|
|
if err == nil {
|
|
// If the state already has as many messages as it can, skip the API.
|
|
if maxMsgs <= len(ms) {
|
|
return ms, nil
|
|
}
|
|
|
|
// Is the channel tiny?
|
|
s.fewMutex.Lock()
|
|
if _, ok := s.fewMessages[channelID]; ok {
|
|
s.fewMutex.Unlock()
|
|
return ms, nil
|
|
}
|
|
|
|
// No, fetch from the state.
|
|
s.fewMutex.Unlock()
|
|
}
|
|
|
|
ms, err = s.Session.Messages(channelID, uint(maxMsgs))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// New messages fetched weirdly does not have GuildID filled. We'll try and
|
|
// get it for consistency with incoming message creates.
|
|
var guildID discord.GuildID
|
|
|
|
// A bit too convoluted, but whatever.
|
|
c, err := s.Channel(channelID)
|
|
if err == nil {
|
|
// If it's 0, it's 0 anyway. We don't need a check here.
|
|
guildID = c.GuildID
|
|
}
|
|
|
|
// Iterate in reverse, since the store is expected to prepend the latest
|
|
// messages.
|
|
for i := len(ms) - 1; i >= 0; i-- {
|
|
// Set the guild ID, fine if it's 0 (it's already 0 anyway).
|
|
ms[i].GuildID = guildID
|
|
|
|
if err := s.Store.MessageSet(ms[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if len(ms) < maxMsgs {
|
|
// Tiny channel, store this.
|
|
s.fewMutex.Lock()
|
|
s.fewMessages[channelID] = struct{}{}
|
|
s.fewMutex.Unlock()
|
|
|
|
return ms, nil
|
|
}
|
|
|
|
// Since the latest messages are at the end and we already know the maxMsgs,
|
|
// we could slice this right away.
|
|
return ms[:maxMsgs], nil
|
|
}
|
|
|
|
////
|
|
|
|
// Presence checks the state for user presences. If no guildID is given, it will
|
|
// look for the presence in all guilds.
|
|
func (s *State) Presence(
|
|
guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
|
|
|
|
p, err := s.Store.Presence(guildID, userID)
|
|
if err == nil {
|
|
return p, nil
|
|
}
|
|
|
|
// If there's no guild ID, look in all guilds
|
|
if !guildID.IsValid() {
|
|
g, err := s.Guilds()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, g := range g {
|
|
if p, err := s.Store.Presence(g.ID, userID); err == nil {
|
|
return p, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
////
|
|
|
|
func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) {
|
|
r, err := s.Store.Role(guildID, roleID)
|
|
if err == nil {
|
|
return r, nil
|
|
}
|
|
|
|
rs, err := s.Session.Roles(guildID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var role *discord.Role
|
|
|
|
for _, r := range rs {
|
|
r := r
|
|
|
|
if r.ID == roleID {
|
|
role = &r
|
|
}
|
|
|
|
if err := s.RoleSet(guildID, r); err != nil {
|
|
return role, err
|
|
}
|
|
}
|
|
|
|
if role == nil {
|
|
return nil, ErrStoreNotFound
|
|
}
|
|
|
|
return role, nil
|
|
}
|
|
|
|
func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
|
|
rs, err := s.Store.Roles(guildID)
|
|
if err == nil {
|
|
return rs, nil
|
|
}
|
|
|
|
rs, err = s.Session.Roles(guildID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, r := range rs {
|
|
r := r
|
|
|
|
if err := s.RoleSet(guildID, r); err != nil {
|
|
return rs, err
|
|
}
|
|
}
|
|
|
|
return rs, nil
|
|
}
|
|
|
|
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
|
|
g, err = s.Session.Guild(id)
|
|
if err == nil {
|
|
err = s.Store.GuildSet(*g)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (s *State) fetchMember(
|
|
guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) {
|
|
|
|
m, err = s.Session.Member(guildID, userID)
|
|
if err == nil {
|
|
err = s.Store.MemberSet(guildID, *m)
|
|
}
|
|
|
|
return
|
|
}
|