1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-12-01 08:37:23 +00:00

state, handler: Refactor state storage and sync handlers

This commit refactors a lot of packages.

It refactors the handler package, removing the Synchronous field and
replacing it the AddSyncHandler API, which allows each handler to
control whether or not it should be ran synchronously independent of
other handlers. This is useful for libraries that need to guarantee the
incoming order of events.

It also refactors the store interfaces to accept more interfaces. This
is to make the API more consistent as well as reducing potential useless
copies. The public-facing state API should still be the same, so this
change will mostly concern users with their own store implementations.

Several miscellaneous functions (such as a few in package gateway) were
modified to be more suitable to other packages, but those functions
should rarely ever be used, anyway.

Several tests are also fixed within this commit, namely fixing state's
intents bug.
This commit is contained in:
diamondburned 2021-11-03 15:16:02 -07:00
parent 7f4daccd2d
commit efde3f4ea6
No known key found for this signature in database
GPG key ID: D78C4471CE776659
16 changed files with 362 additions and 274 deletions

View file

@ -233,27 +233,35 @@ type (
} }
) )
// ConvertSupplementalMember converts a SupplementalMember to a regular Member. // ConvertSupplementalMembers converts a SupplementalMember to a regular Member.
func ConvertSupplementalMember(sm SupplementalMember) discord.Member { func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member {
return discord.Member{ members := make([]discord.Member, len(sms))
User: discord.User{ID: sm.UserID}, for i, sm := range sms {
Nick: sm.Nick, members[i] = discord.Member{
RoleIDs: sm.RoleIDs, User: discord.User{ID: sm.UserID},
Joined: sm.Joined, Nick: sm.Nick,
BoostedSince: sm.BoostedSince, RoleIDs: sm.RoleIDs,
Deaf: sm.Deaf, Joined: sm.Joined,
Mute: sm.Mute, BoostedSince: sm.BoostedSince,
IsPending: sm.IsPending, Deaf: sm.Deaf,
Mute: sm.Mute,
IsPending: sm.IsPending,
}
} }
return members
} }
// ConvertSupplementalPresence converts a SupplementalPresence to a regular // ConvertSupplementalPresences converts a SupplementalPresence to a regular
// Presence with an empty GuildID. // Presence with an empty GuildID.
func ConvertSupplementalPresence(sp SupplementalPresence) discord.Presence { func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence {
return discord.Presence{ presences := make([]discord.Presence, len(sps))
User: discord.User{ID: sp.UserID}, for i, sp := range sps {
Status: sp.Status, presences[i] = discord.Presence{
Activities: sp.Activities, User: discord.User{ID: sp.UserID},
ClientStatus: sp.ClientStatus, Status: sp.Status,
Activities: sp.Activities,
ClientStatus: sp.ClientStatus,
}
} }
return presences
} }

View file

@ -27,7 +27,9 @@ var (
// The user should initialize handlers and intents in the opts function. // The user should initialize handlers and intents in the opts function.
func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc { func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc {
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
return NewFromSession(session.NewCustomShard(m, id), defaultstore.New()), nil state := NewFromSession(session.NewCustomShard(m, id), defaultstore.New())
opts(m, state)
return state, nil
} }
} }
@ -363,7 +365,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) {
} }
if s.tracksChannel(c) { if s.tracksChannel(c) {
err = s.Cabinet.ChannelSet(*c, false) err = s.Cabinet.ChannelSet(c, false)
} }
return return
@ -383,8 +385,8 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err
} }
if s.Gateway.HasIntents(gateway.IntentGuilds) { if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, c := range cs { for i := range cs {
if err = s.Cabinet.ChannelSet(c, false); err != nil { if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil {
return return
} }
} }
@ -404,7 +406,7 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel
return nil, err return nil, err
} }
return c, s.Cabinet.ChannelSet(*c, false) return c, s.Cabinet.ChannelSet(c, false)
} }
// PrivateChannels gets the direct messages of the user. // PrivateChannels gets the direct messages of the user.
@ -420,8 +422,8 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) {
return nil, err return nil, err
} }
for _, c := range cs { for i := range cs {
if err := s.Cabinet.ChannelSet(c, false); err != nil { if err := s.Cabinet.ChannelSet(&cs[i], false); err != nil {
return nil, err return nil, err
} }
} }
@ -509,8 +511,8 @@ func (s *State) Guilds() (gs []discord.Guild, err error) {
} }
if s.Gateway.HasIntents(gateway.IntentGuilds) { if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, g := range gs { for i := range gs {
if err = s.Cabinet.GuildSet(g, false); err != nil { if err = s.Cabinet.GuildSet(&gs[i], false); err != nil {
return return
} }
} }
@ -546,8 +548,8 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error
} }
if s.Gateway.HasIntents(gateway.IntentGuildMembers) { if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
for _, m := range ms { for i := range ms {
if err = s.Cabinet.MemberSet(guildID, m, false); err != nil { if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil {
return return
} }
} }
@ -579,7 +581,7 @@ func (s *State) Message(
go func() { go func() {
c, cerr = s.Session.Channel(channelID) c, cerr = s.Session.Channel(channelID)
if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
cerr = s.Cabinet.ChannelSet(*c, false) cerr = s.Cabinet.ChannelSet(c, false)
} }
wg.Done() wg.Done()
@ -688,8 +690,9 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes
i = len(apiMessages) i = len(apiMessages)
} }
for _, m := range apiMessages[:i] { msgs := apiMessages[:i]
if err := s.Cabinet.MessageSet(m, false); err != nil { for i := range msgs {
if err := s.Cabinet.MessageSet(&msgs[i], false); err != nil {
return nil, err return nil, err
} }
} }
@ -745,14 +748,14 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di
return return
} }
for _, r := range rs { for i, r := range rs {
if r.ID == roleID { if r.ID == roleID {
r := r // copy to prevent mem aliasing r := r // copy to prevent mem aliasing
target = &r target = &r
} }
if s.Gateway.HasIntents(gateway.IntentGuilds) { if s.Gateway.HasIntents(gateway.IntentGuilds) {
if err = s.RoleSet(guildID, r, false); err != nil { if err = s.RoleSet(guildID, &rs[i], false); err != nil {
return return
} }
} }
@ -777,8 +780,8 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
} }
if s.Gateway.HasIntents(gateway.IntentGuilds) { if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, r := range rs { for i := range rs {
if err := s.RoleSet(guildID, r, false); err != nil { if err := s.RoleSet(guildID, &rs[i], false); err != nil {
return rs, err return rs, err
} }
} }
@ -790,7 +793,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
g, err = s.Session.Guild(id) g, err = s.Session.Guild(id)
if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
err = s.Cabinet.GuildSet(*g, false) err = s.Cabinet.GuildSet(g, false)
} }
return return
@ -799,7 +802,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) { func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) {
m, err = s.Session.Member(gID, uID) m, err = s.Session.Member(gID, uID)
if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) { if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) {
err = s.Cabinet.MemberSet(gID, *m, false) err = s.Cabinet.MemberSet(gID, m, false)
} }
return return
@ -808,12 +811,14 @@ func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord
// tracksMessage reports whether the state would track the passed message and // tracksMessage reports whether the state would track the passed message and
// messages from the same channel. // messages from the same channel.
func (s *State) tracksMessage(m *discord.Message) bool { func (s *State) tracksMessage(m *discord.Message) bool {
return (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || return s.Gateway.Identifier.Intents == nil ||
(m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) ||
(!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages)) (!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages))
} }
// tracksChannel reports whether the state would track the passed channel. // tracksChannel reports whether the state would track the passed channel.
func (s *State) tracksChannel(c *discord.Channel) bool { func (s *State) tracksChannel(c *discord.Channel) bool {
return (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || return s.Gateway.Identifier.Intents == nil ||
(c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) ||
!c.GuildID.IsValid() !c.GuildID.IsValid()
} }

View file

@ -9,7 +9,7 @@ import (
) )
func (s *State) hookSession() { func (s *State) hookSession() {
s.Session.AddHandler(func(event interface{}) { s.Session.AddSyncHandler(func(event interface{}) {
// Call the pre-handler before the state handler. // Call the pre-handler before the state handler.
if s.PreHandler != nil { if s.PreHandler != nil {
s.PreHandler.Call(event) s.PreHandler.Call(event)
@ -68,15 +68,15 @@ func (s *State) onEvent(iface interface{}) {
} }
// Handle guild presences // Handle guild presences
for _, p := range ev.Presences { for i, presence := range ev.Presences {
if err := s.Cabinet.PresenceSet(p.GuildID, p, false); err != nil { if err := s.Cabinet.PresenceSet(presence.GuildID, &ev.Presences[i], false); err != nil {
s.stateErr(err, "failed to set presence in Ready") s.stateErr(err, "failed to set presence in Ready")
} }
} }
// Handle private channels // Handle private channels
for _, ch := range ev.PrivateChannels { for i := range ev.PrivateChannels {
if err := s.Cabinet.ChannelSet(ch, false); err != nil { if err := s.Cabinet.ChannelSet(&ev.PrivateChannels[i], false); err != nil {
s.stateErr(err, "failed to set channel in Ready") s.stateErr(err, "failed to set channel in Ready")
} }
} }
@ -90,16 +90,17 @@ func (s *State) onEvent(iface interface{}) {
// Handle guilds // Handle guilds
for _, guild := range ev.Guilds { for _, guild := range ev.Guilds {
// Handle guild voice states // Handle guild voice states
for _, v := range guild.VoiceStates { for i := range guild.VoiceStates {
v := &guild.VoiceStates[i]
if err := s.Cabinet.VoiceStateSet(guild.ID, v, false); err != nil { if err := s.Cabinet.VoiceStateSet(guild.ID, v, false); err != nil {
s.stateErr(err, "failed to set guild voice state in Ready Supplemental") s.stateErr(err, "failed to set guild voice state in Ready Supplemental")
} }
} }
} }
for _, friend := range ev.MergedPresences.Friends { friendPresences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Friends)
sPresence := gateway.ConvertSupplementalPresence(friend) for i := range friendPresences {
if err := s.Cabinet.PresenceSet(0, sPresence, false); err != nil { if err := s.Cabinet.PresenceSet(0, &friendPresences[i], false); err != nil {
s.stateErr(err, "failed to set friend presence in Ready Supplemental") s.stateErr(err, "failed to set friend presence in Ready Supplemental")
} }
} }
@ -110,16 +111,16 @@ func (s *State) onEvent(iface interface{}) {
for i := 0; i < len(ready.Guilds) && i < len(ev.MergedMembers); i++ { for i := 0; i < len(ready.Guilds) && i < len(ev.MergedMembers); i++ {
guild := ready.Guilds[i] guild := ready.Guilds[i]
for _, member := range ev.MergedMembers[i] { members := gateway.ConvertSupplementalMembers(ev.MergedMembers[i])
sMember := gateway.ConvertSupplementalMember(member) for i := range members {
if err := s.Cabinet.MemberSet(guild.ID, sMember, false); err != nil { if err := s.Cabinet.MemberSet(guild.ID, &members[i], false); err != nil {
s.stateErr(err, "failed to set friend presence in Ready Supplemental") s.stateErr(err, "failed to set friend presence in Ready Supplemental")
} }
} }
for _, member := range ev.MergedPresences.Guilds[i] { presences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Guilds[i])
sPresence := gateway.ConvertSupplementalPresence(member) for i := range presences {
if err := s.Cabinet.PresenceSet(guild.ID, sPresence, false); err != nil { if err := s.Cabinet.PresenceSet(guild.ID, &presences[i], false); err != nil {
s.stateErr(err, "failed to set member presence in Ready Supplemental") s.stateErr(err, "failed to set member presence in Ready Supplemental")
} }
} }
@ -129,7 +130,7 @@ func (s *State) onEvent(iface interface{}) {
s.batchLog(storeGuildCreate(s.Cabinet, ev)) s.batchLog(storeGuildCreate(s.Cabinet, ev))
case *gateway.GuildUpdateEvent: case *gateway.GuildUpdateEvent:
if err := s.Cabinet.GuildSet(ev.Guild, true); err != nil { if err := s.Cabinet.GuildSet(&ev.Guild, true); err != nil {
s.stateErr(err, "failed to update guild in state") s.stateErr(err, "failed to update guild in state")
} }
@ -139,7 +140,7 @@ func (s *State) onEvent(iface interface{}) {
} }
case *gateway.GuildMemberAddEvent: case *gateway.GuildMemberAddEvent:
if err := s.Cabinet.MemberSet(ev.GuildID, ev.Member, false); err != nil { if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Member, false); err != nil {
s.stateErr(err, "failed to add a member in state") s.stateErr(err, "failed to add a member in state")
} }
@ -153,7 +154,7 @@ func (s *State) onEvent(iface interface{}) {
// Update available fields from ev into m // Update available fields from ev into m
ev.Update(m) ev.Update(m)
if err := s.Cabinet.MemberSet(ev.GuildID, *m, true); err != nil { if err := s.Cabinet.MemberSet(ev.GuildID, m, true); err != nil {
s.stateErr(err, "failed to update a member in state") s.stateErr(err, "failed to update a member in state")
} }
@ -163,25 +164,25 @@ func (s *State) onEvent(iface interface{}) {
} }
case *gateway.GuildMembersChunkEvent: case *gateway.GuildMembersChunkEvent:
for _, m := range ev.Members { for i := range ev.Members {
if err := s.Cabinet.MemberSet(ev.GuildID, m, false); err != nil { if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Members[i], false); err != nil {
s.stateErr(err, "failed to add a member from chunk in state") s.stateErr(err, "failed to add a member from chunk in state")
} }
} }
for _, p := range ev.Presences { for i := range ev.Presences {
if err := s.Cabinet.PresenceSet(ev.GuildID, p, false); err != nil { if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presences[i], false); err != nil {
s.stateErr(err, "failed to add a presence from chunk in state") s.stateErr(err, "failed to add a presence from chunk in state")
} }
} }
case *gateway.GuildRoleCreateEvent: case *gateway.GuildRoleCreateEvent:
if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, false); err != nil { if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, false); err != nil {
s.stateErr(err, "failed to add a role in state") s.stateErr(err, "failed to add a role in state")
} }
case *gateway.GuildRoleUpdateEvent: case *gateway.GuildRoleUpdateEvent:
if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, true); err != nil { if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, true); err != nil {
s.stateErr(err, "failed to update a role in state") s.stateErr(err, "failed to update a role in state")
} }
@ -196,17 +197,17 @@ func (s *State) onEvent(iface interface{}) {
} }
case *gateway.ChannelCreateEvent: case *gateway.ChannelCreateEvent:
if err := s.Cabinet.ChannelSet(ev.Channel, false); err != nil { if err := s.Cabinet.ChannelSet(&ev.Channel, false); err != nil {
s.stateErr(err, "failed to create a channel in state") s.stateErr(err, "failed to create a channel in state")
} }
case *gateway.ChannelUpdateEvent: case *gateway.ChannelUpdateEvent:
if err := s.Cabinet.ChannelSet(ev.Channel, true); err != nil { if err := s.Cabinet.ChannelSet(&ev.Channel, true); err != nil {
s.stateErr(err, "failed to update a channel in state") s.stateErr(err, "failed to update a channel in state")
} }
case *gateway.ChannelDeleteEvent: case *gateway.ChannelDeleteEvent:
if err := s.Cabinet.ChannelRemove(ev.Channel); err != nil { if err := s.Cabinet.ChannelRemove(&ev.Channel); err != nil {
s.stateErr(err, "failed to remove a channel in state") s.stateErr(err, "failed to remove a channel in state")
} }
@ -214,12 +215,12 @@ func (s *State) onEvent(iface interface{}) {
// not tracked. // not tracked.
case *gateway.MessageCreateEvent: case *gateway.MessageCreateEvent:
if err := s.Cabinet.MessageSet(ev.Message, false); err != nil { if err := s.Cabinet.MessageSet(&ev.Message, false); err != nil {
s.stateErr(err, "failed to add a message in state") s.stateErr(err, "failed to add a message in state")
} }
case *gateway.MessageUpdateEvent: case *gateway.MessageUpdateEvent:
if err := s.Cabinet.MessageSet(ev.Message, true); err != nil { if err := s.Cabinet.MessageSet(&ev.Message, true); err != nil {
s.stateErr(err, "failed to update a message in state") s.stateErr(err, "failed to update a message in state")
} }
@ -238,12 +239,17 @@ func (s *State) onEvent(iface interface{}) {
case *gateway.MessageReactionAddEvent: case *gateway.MessageReactionAddEvent:
s.editMessage(ev.ChannelID, ev.MessageID, func(m *discord.Message) bool { s.editMessage(ev.ChannelID, ev.MessageID, func(m *discord.Message) bool {
if i := findReaction(m.Reactions, ev.Emoji); i > -1 { if i := findReaction(m.Reactions, ev.Emoji); i > -1 {
// Copy the reactions slice so it's not racy.
m.Reactions = append([]discord.Reaction(nil), m.Reactions...)
m.Reactions[i].Count++ m.Reactions[i].Count++
} else { } else {
var me bool var me bool
if u, _ := s.Cabinet.Me(); u != nil { if u, _ := s.Cabinet.Me(); u != nil {
me = ev.UserID == u.ID me = ev.UserID == u.ID
} }
old := m.Reactions
m.Reactions = make([]discord.Reaction, 0, len(old)+1)
m.Reactions = append(m.Reactions, old...)
m.Reactions = append(m.Reactions, discord.Reaction{ m.Reactions = append(m.Reactions, discord.Reaction{
Count: 1, Count: 1,
Me: me, Me: me,
@ -261,18 +267,21 @@ func (s *State) onEvent(iface interface{}) {
} }
r := &m.Reactions[i] r := &m.Reactions[i]
r.Count-- newCount := r.Count - 1
switch { switch {
case r.Count < 1: // If the count is 0: case newCount < 1: // If the count is 0:
// Remove the reaction. old := m.Reactions
m.Reactions = append(m.Reactions[:i], m.Reactions[i+1:]...) m.Reactions = make([]discord.Reaction, len(m.Reactions)-1)
copy(m.Reactions[0:], old[:i])
copy(m.Reactions[i:], old[i+1:])
case r.Me: // If reaction removal is the user's case r.Me: // If reaction removal is the user's
u, err := s.Cabinet.Me() u, err := s.Cabinet.Me()
if err == nil && ev.UserID == u.ID { if err == nil && ev.UserID == u.ID {
r.Me = false r.Me = false
} }
r.Count--
} }
return true return true
@ -295,13 +304,13 @@ func (s *State) onEvent(iface interface{}) {
}) })
case *gateway.PresenceUpdateEvent: case *gateway.PresenceUpdateEvent:
if err := s.Cabinet.PresenceSet(ev.GuildID, ev.Presence, true); err != nil { if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presence, true); err != nil {
s.stateErr(err, "failed to update presence in state") s.stateErr(err, "failed to update presence in state")
} }
case *gateway.PresencesReplaceEvent: case *gateway.PresencesReplaceEvent:
for _, p := range *ev { for _, p := range *ev {
if err := s.Cabinet.PresenceSet(p.GuildID, p.Presence, true); err != nil { if err := s.Cabinet.PresenceSet(p.GuildID, &p.Presence, true); err != nil {
s.stateErr(err, "failed to update presence in state") s.stateErr(err, "failed to update presence in state")
} }
} }
@ -332,7 +341,7 @@ func (s *State) onEvent(iface interface{}) {
s.stateErr(err, "failed to remove voice state from state") s.stateErr(err, "failed to remove voice state from state")
} }
} else { } else {
if err := s.Cabinet.VoiceStateSet(vs.GuildID, *vs, true); err != nil { if err := s.Cabinet.VoiceStateSet(vs.GuildID, vs, true); err != nil {
s.stateErr(err, "failed to update voice state in state") s.stateErr(err, "failed to update voice state in state")
} }
} }
@ -342,6 +351,7 @@ func (s *State) onEvent(iface interface{}) {
func (s *State) stateErr(err error, wrap string) { func (s *State) stateErr(err error, wrap string) {
s.StateLog(errors.Wrap(err, wrap)) s.StateLog(errors.Wrap(err, wrap))
} }
func (s *State) batchLog(errors []error) { func (s *State) batchLog(errors []error) {
for _, err := range errors { for _, err := range errors {
s.StateLog(err) s.StateLog(err)
@ -355,10 +365,16 @@ func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func
if err != nil { if err != nil {
return return
} }
// Copy the messages.
cpy := *m
m = &cpy
if !fn(m) { if !fn(m) {
return return
} }
if err := s.Cabinet.MessageSet(*m, true); err != nil {
if err := s.Cabinet.MessageSet(m, true); err != nil {
s.stateErr(err, "failed to save message in reaction add") s.stateErr(err, "failed to save message in reaction add")
} }
} }
@ -379,7 +395,7 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err
stack, errs := newErrorStack() stack, errs := newErrorStack()
if err := cab.GuildSet(guild.Guild, false); err != nil { if err := cab.GuildSet(&guild.Guild, false); err != nil {
errs(err, "failed to set guild in Ready") errs(err, "failed to set guild in Ready")
} }
@ -391,8 +407,8 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err
} }
// Handle guild member // Handle guild member
for _, m := range guild.Members { for i := range guild.Members {
if err := cab.MemberSet(guild.ID, m, false); err != nil { if err := cab.MemberSet(guild.ID, &guild.Members[i], false); err != nil {
errs(err, "failed to set guild member in Ready") errs(err, "failed to set guild member in Ready")
} }
} }
@ -400,36 +416,40 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err
// Handle guild channels // Handle guild channels
for _, ch := range guild.Channels { for _, ch := range guild.Channels {
// I HATE Discord. // I HATE Discord.
ch := ch
ch.GuildID = guild.ID ch.GuildID = guild.ID
if err := cab.ChannelSet(ch, false); err != nil { if err := cab.ChannelSet(&ch, false); err != nil {
errs(err, "failed to set guild channel in Ready") errs(err, "failed to set guild channel in Ready")
} }
} }
// Handle threads. // Handle threads.
for _, ch := range guild.Threads { for _, ch := range guild.Threads {
ch := ch
ch.GuildID = guild.ID ch.GuildID = guild.ID
if err := cab.ChannelSet(ch, false); err != nil { if err := cab.ChannelSet(&ch, false); err != nil {
errs(err, "failed to set guild thread in Ready") errs(err, "failed to set guild thread in Ready")
} }
} }
// Handle guild presences // Handle guild presences
for _, p := range guild.Presences { for _, p := range guild.Presences {
p := p
p.GuildID = guild.ID p.GuildID = guild.ID
if err := cab.PresenceSet(guild.ID, p, false); err != nil { if err := cab.PresenceSet(guild.ID, &p, false); err != nil {
errs(err, "failed to set guild presence in Ready") errs(err, "failed to set guild presence in Ready")
} }
} }
// Handle guild voice states // Handle guild voice states
for _, v := range guild.VoiceStates { for _, v := range guild.VoiceStates {
v := v
v.GuildID = guild.ID v.GuildID = guild.ID
if err := cab.VoiceStateSet(guild.ID, v, false); err != nil { if err := cab.VoiceStateSet(guild.ID, &v, false); err != nil {
errs(err, "failed to set guild voice state in Ready") errs(err, "failed to set guild voice state in Ready")
} }
} }

View file

@ -2,6 +2,7 @@ package defaultstore
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/discord"
@ -13,20 +14,18 @@ type Channel struct {
// Channel references must be protected under the same mutex. // Channel references must be protected under the same mutex.
privates map[discord.UserID]*discord.Channel channels map[discord.ChannelID]discord.Channel
privateChs []*discord.Channel privates map[discord.UserID]discord.ChannelID
guildChs map[discord.GuildID][]discord.ChannelID
channels map[discord.ChannelID]*discord.Channel
guildChs map[discord.GuildID][]*discord.Channel
} }
var _ store.ChannelStore = (*Channel)(nil) var _ store.ChannelStore = (*Channel)(nil)
func NewChannel() *Channel { func NewChannel() *Channel {
return &Channel{ return &Channel{
privates: map[discord.UserID]*discord.Channel{}, channels: map[discord.ChannelID]discord.Channel{},
channels: map[discord.ChannelID]*discord.Channel{}, privates: map[discord.UserID]discord.ChannelID{},
guildChs: map[discord.GuildID][]*discord.Channel{}, guildChs: map[discord.GuildID][]discord.ChannelID{},
} }
} }
@ -34,9 +33,9 @@ func (s *Channel) Reset() error {
s.mut.Lock() s.mut.Lock()
defer s.mut.Unlock() defer s.mut.Unlock()
s.privates = map[discord.UserID]*discord.Channel{} s.channels = map[discord.ChannelID]discord.Channel{}
s.channels = map[discord.ChannelID]*discord.Channel{} s.privates = map[discord.UserID]discord.ChannelID{}
s.guildChs = map[discord.GuildID][]*discord.Channel{} s.guildChs = map[discord.GuildID][]discord.ChannelID{}
return nil return nil
} }
@ -50,20 +49,19 @@ func (s *Channel) Channel(id discord.ChannelID) (*discord.Channel, error) {
return nil, store.ErrNotFound return nil, store.ErrNotFound
} }
cpy := *ch return &ch, nil
return &cpy, nil
} }
func (s *Channel) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) { func (s *Channel) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
s.mut.RLock() s.mut.RLock()
defer s.mut.RUnlock() defer s.mut.RUnlock()
ch, ok := s.privates[recipient] id, ok := s.privates[recipient]
if !ok { if !ok {
return nil, store.ErrNotFound return nil, store.ErrNotFound
} }
cpy := *ch cpy := s.channels[id]
return &cpy, nil return &cpy, nil
} }
@ -72,16 +70,20 @@ func (s *Channel) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
s.mut.RLock() s.mut.RLock()
defer s.mut.RUnlock() defer s.mut.RUnlock()
chRefs, ok := s.guildChs[guildID] chIDs, ok := s.guildChs[guildID]
if !ok { if !ok {
return nil, store.ErrNotFound return nil, store.ErrNotFound
} }
// Reading chRefs is also covered by the global mutex. // Reading chRefs is also covered by the global mutex.
var channels = make([]discord.Channel, len(chRefs)) var channels = make([]discord.Channel, 0, len(chIDs))
for i, chRef := range chRefs { for _, chID := range chIDs {
channels[i] = *chRef ch, ok := s.channels[chID]
if !ok {
continue
}
channels = append(channels, ch)
} }
return channels, nil return channels, nil
@ -92,42 +94,47 @@ func (s *Channel) PrivateChannels() ([]discord.Channel, error) {
s.mut.RLock() s.mut.RLock()
defer s.mut.RUnlock() defer s.mut.RUnlock()
if len(s.privateChs) == 0 { groupDMs := s.guildChs[0]
if len(s.privates) == 0 && len(groupDMs) == 0 {
return nil, store.ErrNotFound return nil, store.ErrNotFound
} }
var channels = make([]discord.Channel, len(s.privateChs)) var channels = make([]discord.Channel, 0, len(s.privates)+len(groupDMs))
for i, ch := range s.privateChs { for _, chID := range s.privates {
channels[i] = *ch if ch, ok := s.channels[chID]; ok {
channels = append(channels, ch)
}
}
for _, chID := range groupDMs {
if ch, ok := s.channels[chID]; ok {
channels = append(channels, ch)
}
} }
return channels, nil return channels, nil
} }
// ChannelSet sets the Direct Message or Guild channel into the state. // ChannelSet sets the Direct Message or Guild channel into the state.
func (s *Channel) ChannelSet(channel discord.Channel, update bool) error { func (s *Channel) ChannelSet(channel *discord.Channel, update bool) error {
cpy := *channel
s.mut.Lock() s.mut.Lock()
defer s.mut.Unlock() defer s.mut.Unlock()
// Update the reference if we can. // Update the reference if we can.
if ch, ok := s.channels[channel.ID]; ok { s.channels[channel.ID] = cpy
if update {
*ch = channel
}
return nil
}
switch channel.Type { switch channel.Type {
case discord.DirectMessage: case discord.DirectMessage:
// Safety bound check. // Safety bound check.
if len(channel.DMRecipients) != 1 { if len(channel.DMRecipients) != 1 {
return errors.New("DirectMessage channel does not have 1 recipient") return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID)
} }
s.privates[channel.DMRecipients[0].ID] = &channel s.privates[channel.DMRecipients[0].ID] = channel.ID
fallthrough return nil
case discord.GroupDM: case discord.GroupDM:
s.privateChs = append(s.privateChs, &channel) s.guildChs[0] = addChannelID(s.guildChs[0], channel.ID)
s.channels[channel.ID] = &channel
return nil return nil
} }
@ -137,16 +144,11 @@ func (s *Channel) ChannelSet(channel discord.Channel, update bool) error {
return errors.New("invalid guildID for guild channel") return errors.New("invalid guildID for guild channel")
} }
s.channels[channel.ID] = &channel s.guildChs[channel.GuildID] = addChannelID(s.guildChs[channel.GuildID], channel.ID)
channels, _ := s.guildChs[channel.GuildID]
channels = append(channels, &channel)
s.guildChs[channel.GuildID] = channels
return nil return nil
} }
func (s *Channel) ChannelRemove(channel discord.Channel) error { func (s *Channel) ChannelRemove(channel *discord.Channel) error {
s.mut.Lock() s.mut.Lock()
defer s.mut.Unlock() defer s.mut.Unlock()
@ -158,49 +160,42 @@ func (s *Channel) ChannelRemove(channel discord.Channel) error {
case discord.DirectMessage: case discord.DirectMessage:
// Safety bound check. // Safety bound check.
if len(channel.DMRecipients) != 1 { if len(channel.DMRecipients) != 1 {
return errors.New("DirectMessage channel does not have 1 recipient") return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID)
} }
delete(s.privates, channel.DMRecipients[0].ID) delete(s.privates, channel.DMRecipients[0].ID)
fallthrough return nil
case discord.GroupDM: case discord.GroupDM:
for i, priv := range s.privateChs { s.guildChs[0] = removeChannelID(s.guildChs[0], channel.ID)
if priv.ID == channel.ID {
s.privateChs = removeChannel(s.privateChs, i)
break
}
}
return nil return nil
} }
// Wipe the channel off the guilds index, if available. s.guildChs[channel.GuildID] = removeChannelID(s.guildChs[channel.GuildID], channel.ID)
channels, ok := s.guildChs[channel.GuildID]
if !ok {
return nil
}
for i, ch := range channels {
if ch.ID == channel.ID {
s.guildChs[channel.GuildID] = removeChannel(channels, i)
break
}
}
return nil return nil
} }
// removeChannel removes the given channel with the index from the given func addChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID {
for _, ch := range channels {
if ch == id {
return channels
}
}
if channels == nil {
channels = make([]discord.ChannelID, 0, 5)
}
return append(channels, id)
}
// removeChannelID removes the given channel with the index from the given
// channels slice in an unordered fashion. // channels slice in an unordered fashion.
func removeChannel(channels []*discord.Channel, i int) []*discord.Channel { func removeChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID {
// Fast unordered delete. Not sure if there's a benefit in doing for i, ch := range channels {
// this over using a map, but I guess the memory usage is less and if ch == id {
// there's no copying. // Move the last channel to the current channel, then slice the last
// channel off.
// Move the last channel to the current channel, set the last channels[i] = channels[len(channels)-1]
// channel there to a nil value to unreference its children, then channels = channels[:len(channels)-1]
// slice the last channel off. break
channels[i] = channels[len(channels)-1] }
channels[len(channels)-1] = nil }
channels = channels[:len(channels)-1]
return channels return channels
} }

View file

@ -33,13 +33,12 @@ func (s *Guild) Guild(id discord.GuildID) (*discord.Guild, error) {
s.mut.RLock() s.mut.RLock()
defer s.mut.RUnlock() defer s.mut.RUnlock()
ch, ok := s.guilds[id] g, ok := s.guilds[id]
if !ok { if ok {
return nil, store.ErrNotFound return &g, nil
} }
// implicit copy return nil, store.ErrNotFound
return &ch, nil
} }
func (s *Guild) Guilds() ([]discord.Guild, error) { func (s *Guild) Guilds() ([]discord.Guild, error) {
@ -58,10 +57,12 @@ func (s *Guild) Guilds() ([]discord.Guild, error) {
return gs, nil return gs, nil
} }
func (s *Guild) GuildSet(guild discord.Guild, update bool) error { func (s *Guild) GuildSet(guild *discord.Guild, update bool) error {
cpy := *guild
s.mut.Lock() s.mut.Lock()
if _, ok := s.guilds[guild.ID]; !ok || update { if _, ok := s.guilds[guild.ID]; !ok || update {
s.guilds[guild.ID] = guild s.guilds[guild.ID] = cpy
} }
s.mut.Unlock() s.mut.Unlock()

View file

@ -13,7 +13,7 @@ type Member struct {
} }
type guildMembers struct { type guildMembers struct {
mut sync.Mutex mut sync.RWMutex
members map[discord.UserID]discord.Member members map[discord.UserID]discord.Member
} }
@ -41,8 +41,8 @@ func (s *Member) Member(guildID discord.GuildID, userID discord.UserID) (*discor
gm := iv.(*guildMembers) gm := iv.(*guildMembers)
gm.mut.Lock() gm.mut.RLock()
defer gm.mut.Unlock() defer gm.mut.RUnlock()
m, ok := gm.members[userID] m, ok := gm.members[userID]
if ok { if ok {
@ -60,8 +60,8 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) {
gm := iv.(*guildMembers) gm := iv.(*guildMembers)
gm.mut.Lock() gm.mut.RLock()
defer gm.mut.Unlock() defer gm.mut.RUnlock()
var members = make([]discord.Member, 0, len(gm.members)) var members = make([]discord.Member, 0, len(gm.members))
for _, m := range gm.members { for _, m := range gm.members {
@ -71,13 +71,13 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) {
return members, nil return members, nil
} }
func (s *Member) MemberSet(guildID discord.GuildID, m discord.Member, update bool) error { func (s *Member) MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID) iv, _ := s.guilds.LoadOrStore(guildID)
gm := iv.(*guildMembers) gm := iv.(*guildMembers)
gm.mut.Lock() gm.mut.Lock()
if _, ok := gm.members[m.User.ID]; !ok || update { if _, ok := gm.members[m.User.ID]; !ok || update {
gm.members[m.User.ID] = m gm.members[m.User.ID] = *m
} }
gm.mut.Unlock() gm.mut.Unlock()

View file

@ -16,7 +16,7 @@ type Message struct {
var _ store.MessageStore = (*Message)(nil) var _ store.MessageStore = (*Message)(nil)
type messages struct { type messages struct {
mut sync.Mutex mut sync.RWMutex
messages []discord.Message messages []discord.Message
} }
@ -43,8 +43,8 @@ func (s *Message) Message(chID discord.ChannelID, mID discord.MessageID) (*disco
msgs := iv.(*messages) msgs := iv.(*messages)
msgs.mut.Lock() msgs.mut.RLock()
defer msgs.mut.Unlock() defer msgs.mut.RUnlock()
for _, m := range msgs.messages { for _, m := range msgs.messages {
if m.ID == mID { if m.ID == mID {
@ -63,8 +63,8 @@ func (s *Message) Messages(channelID discord.ChannelID) ([]discord.Message, erro
msgs := iv.(*messages) msgs := iv.(*messages)
msgs.mut.Lock() msgs.mut.RLock()
defer msgs.mut.Unlock() defer msgs.mut.RUnlock()
return append([]discord.Message(nil), msgs.messages...), nil return append([]discord.Message(nil), msgs.messages...), nil
} }
@ -73,7 +73,7 @@ func (s *Message) MaxMessages() int {
return s.maxMsgs return s.maxMsgs
} }
func (s *Message) MessageSet(message discord.Message, update bool) error { func (s *Message) MessageSet(message *discord.Message, update bool) error {
if s.maxMsgs <= 0 { if s.maxMsgs <= 0 {
return nil return nil
} }
@ -102,19 +102,19 @@ func (s *Message) MessageSet(message discord.Message, update bool) error {
} }
if len(msgs.messages) == 0 { if len(msgs.messages) == 0 {
msgs.messages = []discord.Message{message} msgs.messages = []discord.Message{*message}
} }
if pos := messageInsertPosition(message, msgs.messages); pos < 0 { if pos := messageInsertPosition(message, msgs.messages); pos < 0 {
// Messages are full, drop the oldest messages to make room. // Messages are full, drop the oldest messages to make room.
if len(msgs.messages) == s.maxMsgs { if len(msgs.messages) == s.maxMsgs {
copy(msgs.messages[1:], msgs.messages) copy(msgs.messages[1:], msgs.messages)
msgs.messages[0] = message msgs.messages[0] = *message
} else { } else {
msgs.messages = append([]discord.Message{message}, msgs.messages...) msgs.messages = append([]discord.Message{*message}, msgs.messages...)
} }
} else if pos > 0 && len(msgs.messages) < s.maxMsgs { } else if pos > 0 && len(msgs.messages) < s.maxMsgs {
msgs.messages = append(msgs.messages, message) msgs.messages = append(msgs.messages, *message)
} }
// We already have this message or we can't append any more messages. // We already have this message or we can't append any more messages.
@ -131,7 +131,7 @@ func (s *Message) MessageSet(message discord.Message, update bool) error {
// messageInsertPosition is biased as it will recommend adding the message even // messageInsertPosition is biased as it will recommend adding the message even
// if timestamps just match, even though the true order cannot be determined in // if timestamps just match, even though the true order cannot be determined in
// that case. // that case.
func messageInsertPosition(target discord.Message, messages []discord.Message) int8 { func messageInsertPosition(target *discord.Message, messages []discord.Message) int8 {
var ( var (
targetTime = target.ID.Time() targetTime = target.ID.Time()
firstTime = messages[0].ID.Time() firstTime = messages[0].ID.Time()
@ -183,7 +183,7 @@ func messageInsertPosition(target discord.Message, messages []discord.Message) i
} }
// DiffMessage fills non-empty fields from src to dst. // DiffMessage fills non-empty fields from src to dst.
func DiffMessage(src discord.Message, dst *discord.Message) { func DiffMessage(src, dst *discord.Message) {
// Thanks, Discord. // Thanks, Discord.
if src.Content != "" { if src.Content != "" {
dst.Content = src.Content dst.Content = src.Content

View file

@ -10,27 +10,27 @@ func populate12Store() *Message {
store := NewMessage(10) store := NewMessage(10)
// Insert a regular list of messages. // Insert a regular list of messages.
store.MessageSet(discord.Message{ID: 1 << 29, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 29, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 28, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 28, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 27, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 27, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 26, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 26, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 25, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 25, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 24, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 24, ChannelID: 1}, false)
// Try to insert newer messages after inserting new messages. // Try to insert newer messages after inserting new messages.
store.MessageSet(discord.Message{ID: 1 << 30, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 30, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 31, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 31, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 32, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 32, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 33, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 33, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 34, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 34, ChannelID: 1}, false)
// TThese messages should be discarded, due to age. // TThese messages should be discarded, due to age.
store.MessageSet(discord.Message{ID: 1 << 23, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 23, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 22, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 22, ChannelID: 1}, false)
// These should be prepended. // These should be prepended.
store.MessageSet(discord.Message{ID: 1 << 35, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 35, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 36, ChannelID: 1}, false) store.MessageSet(&discord.Message{ID: 1 << 36, ChannelID: 1}, false)
return store return store
} }
@ -57,9 +57,9 @@ func TestMessageSet(t *testing.T) {
func TestMessagesUpdate(t *testing.T) { func TestMessagesUpdate(t *testing.T) {
store := populate12Store() store := populate12Store()
store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true) store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true)
store.MessageSet(discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true) store.MessageSet(&discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true)
store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true) store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true)
expect := map[discord.MessageID]string{ expect := map[discord.MessageID]string{
5: "edited 3", 5: "edited 3",

View file

@ -13,7 +13,7 @@ type Presence struct {
} }
type presences struct { type presences struct {
mut sync.Mutex mut sync.RWMutex
presences map[discord.UserID]discord.Presence presences map[discord.UserID]discord.Presence
} }
@ -41,8 +41,8 @@ func (s *Presence) Presence(gID discord.GuildID, uID discord.UserID) (*discord.P
ps := iv.(*presences) ps := iv.(*presences)
ps.mut.Lock() ps.mut.RLock()
defer ps.mut.Unlock() defer ps.mut.RUnlock()
p, ok := ps.presences[uID] p, ok := ps.presences[uID]
if ok { if ok {
@ -60,8 +60,8 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error
ps := iv.(*presences) ps := iv.(*presences)
ps.mut.Lock() ps.mut.RLock()
defer ps.mut.Unlock() defer ps.mut.RUnlock()
var presences = make([]discord.Presence, 0, len(ps.presences)) var presences = make([]discord.Presence, 0, len(ps.presences))
for _, p := range ps.presences { for _, p := range ps.presences {
@ -71,7 +71,7 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error
return presences, nil return presences, nil
} }
func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error { func (s *Presence) PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID) iv, _ := s.guilds.LoadOrStore(guildID)
ps := iv.(*presences) ps := iv.(*presences)
@ -85,7 +85,7 @@ func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, upda
} }
if _, ok := ps.presences[p.User.ID]; !ok || update { if _, ok := ps.presences[p.User.ID]; !ok || update {
ps.presences[p.User.ID] = p ps.presences[p.User.ID] = *p
} }
return nil return nil

View file

@ -15,7 +15,7 @@ type Role struct {
var _ store.RoleStore = (*Role)(nil) var _ store.RoleStore = (*Role)(nil)
type roles struct { type roles struct {
mut sync.Mutex mut sync.RWMutex
roles map[discord.RoleID]discord.Role roles map[discord.RoleID]discord.Role
} }
@ -41,8 +41,8 @@ func (s *Role) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Ro
rs := iv.(*roles) rs := iv.(*roles)
rs.mut.Lock() rs.mut.RLock()
defer rs.mut.Unlock() defer rs.mut.RUnlock()
r, ok := rs.roles[roleID] r, ok := rs.roles[roleID]
if ok { if ok {
@ -60,8 +60,8 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) {
rs := iv.(*roles) rs := iv.(*roles)
rs.mut.Lock() rs.mut.RLock()
defer rs.mut.Unlock() defer rs.mut.RUnlock()
var roles = make([]discord.Role, 0, len(rs.roles)) var roles = make([]discord.Role, 0, len(rs.roles))
for _, role := range rs.roles { for _, role := range rs.roles {
@ -71,14 +71,14 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) {
return roles, nil return roles, nil
} }
func (s *Role) RoleSet(guildID discord.GuildID, role discord.Role, update bool) error { func (s *Role) RoleSet(guildID discord.GuildID, role *discord.Role, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID) iv, _ := s.guilds.LoadOrStore(guildID)
rs := iv.(*roles) rs := iv.(*roles)
rs.mut.Lock() rs.mut.Lock()
if _, ok := rs.roles[role.ID]; !ok || update { if _, ok := rs.roles[role.ID]; !ok || update {
rs.roles[role.ID] = role rs.roles[role.ID] = *role
} }
rs.mut.Unlock() rs.mut.Unlock()

View file

@ -15,7 +15,7 @@ type VoiceState struct {
var _ store.VoiceStateStore = (*VoiceState)(nil) var _ store.VoiceStateStore = (*VoiceState)(nil)
type voiceStates struct { type voiceStates struct {
mut sync.Mutex mut sync.RWMutex
voiceStates map[discord.UserID]discord.VoiceState voiceStates map[discord.UserID]discord.VoiceState
} }
@ -43,8 +43,8 @@ func (s *VoiceState) VoiceState(
vs := iv.(*voiceStates) vs := iv.(*voiceStates)
vs.mut.Lock() vs.mut.RLock()
defer vs.mut.Unlock() defer vs.mut.RUnlock()
v, ok := vs.voiceStates[userID] v, ok := vs.voiceStates[userID]
if ok { if ok {
@ -62,8 +62,8 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState,
vs := iv.(*voiceStates) vs := iv.(*voiceStates)
vs.mut.Lock() vs.mut.RLock()
defer vs.mut.Unlock() defer vs.mut.RUnlock()
var states = make([]discord.VoiceState, 0, len(vs.voiceStates)) var states = make([]discord.VoiceState, 0, len(vs.voiceStates))
for _, state := range vs.voiceStates { for _, state := range vs.voiceStates {
@ -74,7 +74,7 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState,
} }
func (s *VoiceState) VoiceStateSet( func (s *VoiceState) VoiceStateSet(
guildID discord.GuildID, voiceState discord.VoiceState, update bool) error { guildID discord.GuildID, voiceState *discord.VoiceState, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID) iv, _ := s.guilds.LoadOrStore(guildID)
@ -82,7 +82,7 @@ func (s *VoiceState) VoiceStateSet(
vs.mut.Lock() vs.mut.Lock()
if _, ok := vs.voiceStates[voiceState.UserID]; !ok || update { if _, ok := vs.voiceStates[voiceState.UserID]; !ok || update {
vs.voiceStates[voiceState.UserID] = voiceState vs.voiceStates[voiceState.UserID] = *voiceState
} }
vs.mut.Unlock() vs.mut.Unlock()

View file

@ -143,6 +143,12 @@ type Resetter interface {
Reset() error Reset() error
} }
type CoreStorer interface {
Resetter
Lock()
Unlock()
}
var _ Resetter = (*noop)(nil) var _ Resetter = (*noop)(nil)
func (noop) Reset() error { return nil } func (noop) Reset() error { return nil }
@ -176,8 +182,8 @@ type ChannelStore interface {
// Both ChannelSet and ChannelRemove should switch on Type to know if it's a // Both ChannelSet and ChannelRemove should switch on Type to know if it's a
// private channel or not. // private channel or not.
ChannelSet(c discord.Channel, update bool) error ChannelSet(c *discord.Channel, update bool) error
ChannelRemove(discord.Channel) error ChannelRemove(*discord.Channel) error
} }
var _ ChannelStore = (*noop)(nil) var _ ChannelStore = (*noop)(nil)
@ -194,10 +200,10 @@ func (noop) Channels(discord.GuildID) ([]discord.Channel, error) {
func (noop) PrivateChannels() ([]discord.Channel, error) { func (noop) PrivateChannels() ([]discord.Channel, error) {
return nil, ErrNotFound return nil, ErrNotFound
} }
func (noop) ChannelSet(discord.Channel, bool) error { func (noop) ChannelSet(*discord.Channel, bool) error {
return nil return nil
} }
func (noop) ChannelRemove(discord.Channel) error { func (noop) ChannelRemove(*discord.Channel) error {
return nil return nil
} }
@ -232,7 +238,7 @@ type GuildStore interface {
Guild(discord.GuildID) (*discord.Guild, error) Guild(discord.GuildID) (*discord.Guild, error)
Guilds() ([]discord.Guild, error) Guilds() ([]discord.Guild, error)
GuildSet(g discord.Guild, update bool) error GuildSet(g *discord.Guild, update bool) error
GuildRemove(id discord.GuildID) error GuildRemove(id discord.GuildID) error
} }
@ -240,7 +246,7 @@ var _ GuildStore = (*noop)(nil)
func (noop) Guild(discord.GuildID) (*discord.Guild, error) { return nil, ErrNotFound } func (noop) Guild(discord.GuildID) (*discord.Guild, error) { return nil, ErrNotFound }
func (noop) Guilds() ([]discord.Guild, error) { return nil, ErrNotFound } func (noop) Guilds() ([]discord.Guild, error) { return nil, ErrNotFound }
func (noop) GuildSet(discord.Guild, bool) error { return nil } func (noop) GuildSet(*discord.Guild, bool) error { return nil }
func (noop) GuildRemove(discord.GuildID) error { return nil } func (noop) GuildRemove(discord.GuildID) error { return nil }
// MemberStore is the store interface for all members. // MemberStore is the store interface for all members.
@ -250,7 +256,7 @@ type MemberStore interface {
Member(discord.GuildID, discord.UserID) (*discord.Member, error) Member(discord.GuildID, discord.UserID) (*discord.Member, error)
Members(discord.GuildID) ([]discord.Member, error) Members(discord.GuildID) ([]discord.Member, error)
MemberSet(guildID discord.GuildID, m discord.Member, update bool) error MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error
MemberRemove(discord.GuildID, discord.UserID) error MemberRemove(discord.GuildID, discord.UserID) error
} }
@ -262,7 +268,7 @@ func (noop) Member(discord.GuildID, discord.UserID) (*discord.Member, error) {
func (noop) Members(discord.GuildID) ([]discord.Member, error) { func (noop) Members(discord.GuildID) ([]discord.Member, error) {
return nil, ErrNotFound return nil, ErrNotFound
} }
func (noop) MemberSet(discord.GuildID, discord.Member, bool) error { func (noop) MemberSet(discord.GuildID, *discord.Member, bool) error {
return nil return nil
} }
func (noop) MemberRemove(discord.GuildID, discord.UserID) error { func (noop) MemberRemove(discord.GuildID, discord.UserID) error {
@ -289,7 +295,7 @@ type MessageStore interface {
// If update is set to true, MessageSet will check if a message with the // If update is set to true, MessageSet will check if a message with the
// id of the passed message is stored, and update it if so. Otherwise, if // id of the passed message is stored, and update it if so. Otherwise, if
// there is no such message, it will be discarded. // there is no such message, it will be discarded.
MessageSet(m discord.Message, update bool) error MessageSet(m *discord.Message, update bool) error
MessageRemove(discord.ChannelID, discord.MessageID) error MessageRemove(discord.ChannelID, discord.MessageID) error
} }
@ -304,7 +310,7 @@ func (noop) Message(discord.ChannelID, discord.MessageID) (*discord.Message, err
func (noop) Messages(discord.ChannelID) ([]discord.Message, error) { func (noop) Messages(discord.ChannelID) ([]discord.Message, error) {
return nil, ErrNotFound return nil, ErrNotFound
} }
func (noop) MessageSet(discord.Message, bool) error { func (noop) MessageSet(*discord.Message, bool) error {
return nil return nil
} }
func (noop) MessageRemove(discord.ChannelID, discord.MessageID) error { func (noop) MessageRemove(discord.ChannelID, discord.MessageID) error {
@ -319,7 +325,7 @@ type PresenceStore interface {
Presence(discord.GuildID, discord.UserID) (*discord.Presence, error) Presence(discord.GuildID, discord.UserID) (*discord.Presence, error)
Presences(discord.GuildID) ([]discord.Presence, error) Presences(discord.GuildID) ([]discord.Presence, error)
PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error
PresenceRemove(discord.GuildID, discord.UserID) error PresenceRemove(discord.GuildID, discord.UserID) error
} }
@ -331,7 +337,7 @@ func (noop) Presence(discord.GuildID, discord.UserID) (*discord.Presence, error)
func (noop) Presences(discord.GuildID) ([]discord.Presence, error) { func (noop) Presences(discord.GuildID) ([]discord.Presence, error) {
return nil, ErrNotFound return nil, ErrNotFound
} }
func (noop) PresenceSet(discord.GuildID, discord.Presence, bool) error { func (noop) PresenceSet(discord.GuildID, *discord.Presence, bool) error {
return nil return nil
} }
func (noop) PresenceRemove(discord.GuildID, discord.UserID) error { func (noop) PresenceRemove(discord.GuildID, discord.UserID) error {
@ -345,7 +351,7 @@ type RoleStore interface {
Role(discord.GuildID, discord.RoleID) (*discord.Role, error) Role(discord.GuildID, discord.RoleID) (*discord.Role, error)
Roles(discord.GuildID) ([]discord.Role, error) Roles(discord.GuildID) ([]discord.Role, error)
RoleSet(guildID discord.GuildID, r discord.Role, update bool) error RoleSet(guildID discord.GuildID, r *discord.Role, update bool) error
RoleRemove(discord.GuildID, discord.RoleID) error RoleRemove(discord.GuildID, discord.RoleID) error
} }
@ -353,7 +359,7 @@ var _ RoleStore = (*noop)(nil)
func (noop) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) { return nil, ErrNotFound } func (noop) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) { return nil, ErrNotFound }
func (noop) Roles(discord.GuildID) ([]discord.Role, error) { return nil, ErrNotFound } func (noop) Roles(discord.GuildID) ([]discord.Role, error) { return nil, ErrNotFound }
func (noop) RoleSet(discord.GuildID, discord.Role, bool) error { return nil } func (noop) RoleSet(discord.GuildID, *discord.Role, bool) error { return nil }
func (noop) RoleRemove(discord.GuildID, discord.RoleID) error { return nil } func (noop) RoleRemove(discord.GuildID, discord.RoleID) error { return nil }
// VoiceStateStore is the store interface for all voice states. // VoiceStateStore is the store interface for all voice states.
@ -363,7 +369,7 @@ type VoiceStateStore interface {
VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error)
VoiceStates(discord.GuildID) ([]discord.VoiceState, error) VoiceStates(discord.GuildID) ([]discord.VoiceState, error)
VoiceStateSet(guildID discord.GuildID, s discord.VoiceState, update bool) error VoiceStateSet(guildID discord.GuildID, s *discord.VoiceState, update bool) error
VoiceStateRemove(discord.GuildID, discord.UserID) error VoiceStateRemove(discord.GuildID, discord.UserID) error
} }
@ -375,7 +381,7 @@ func (noop) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, er
func (noop) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) { func (noop) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) {
return nil, ErrNotFound return nil, ErrNotFound
} }
func (noop) VoiceStateSet(discord.GuildID, discord.VoiceState, bool) error { func (noop) VoiceStateSet(discord.GuildID, *discord.VoiceState, bool) error {
return nil return nil
} }
func (noop) VoiceStateRemove(discord.GuildID, discord.UserID) error { func (noop) VoiceStateRemove(discord.GuildID, discord.UserID) error {

View file

@ -28,12 +28,8 @@ import (
// Handler is a container for command handlers. A zero-value instance is a valid // Handler is a container for command handlers. A zero-value instance is a valid
// instance. // instance.
type Handler struct { type Handler struct {
// Synchronous controls whether to spawn each event handler in its own mutex sync.RWMutex
// goroutine. Default false (meaning goroutines are spawned). events map[reflect.Type]slab // nil type for interfaces
Synchronous bool
mutex sync.RWMutex
slab slab
} }
func New() *Handler { func New() *Handler {
@ -43,22 +39,24 @@ func New() *Handler {
// Call calls all handlers with the given event. This is an internal method; use // Call calls all handlers with the given event. This is an internal method; use
// with care. // with care.
func (h *Handler) Call(ev interface{}) { func (h *Handler) Call(ev interface{}) {
var evV = reflect.ValueOf(ev) v := reflect.ValueOf(ev)
var evT = evV.Type() t := v.Type()
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
for _, entry := range h.slab.Entries { for _, entry := range h.events[t].Entries {
if entry.isInvalid() || entry.not(evT) { if entry.isInvalid() {
continue continue
} }
entry.Call(v)
}
if h.Synchronous { for _, entry := range h.events[nil].Entries {
entry.call(evV) if entry.isInvalid() || entry.not(t) {
} else { continue
go entry.call(evV)
} }
entry.Call(v)
} }
} }
@ -145,7 +143,19 @@ func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, ca
// h.AddHandler(ch) // h.AddHandler(ch)
// //
func (h *Handler) AddHandler(handler interface{}) (rm func()) { func (h *Handler) AddHandler(handler interface{}) (rm func()) {
rm, err := h.addHandler(handler) rm, err := h.addHandler(handler, false)
if err != nil {
panic(err)
}
return rm
}
// AddSyncHandler is a synchronous variant of AddHandler. Handlers added using
// this method will block the Call method, which is helpful if the user needs to
// rely on the order of events arriving. Handlers added using this method should
// not block for very long, as it may clog up other handlers.
func (h *Handler) AddSyncHandler(handler interface{}) (rm func()) {
rm, err := h.addHandler(handler, true)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -167,23 +177,56 @@ func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) {
} }
}() }()
return h.addHandler(handler) return h.addHandler(handler, false)
} }
func (h *Handler) addHandler(fn interface{}) (rm func(), err error) { // AddSyncHandlerCheck is the safe-guarded version of AddSyncHandler. It is
// similar to AddHandlerCheck.
func (h *Handler) AddSyncHandlerCheck(handler interface{}) (rm func(), err error) {
// Reflect would actually panic if anything goes wrong, so this is just in
// case.
defer func() {
if rec := recover(); rec != nil {
if recErr, ok := rec.(error); ok {
err = recErr
} else {
err = fmt.Errorf("%v", rec)
}
}
}()
return h.addHandler(handler, true)
}
func (h *Handler) addHandler(fn interface{}, sync bool) (rm func(), err error) {
// Reflect the handler // Reflect the handler
r, err := newHandler(fn) r, err := newHandler(fn, sync)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "handler reflect failed") return nil, errors.Wrap(err, "handler reflect failed")
} }
var id int
var t reflect.Type
if !r.isIface {
t = r.event
}
h.mutex.Lock() h.mutex.Lock()
id := h.slab.Put(r)
if h.events == nil {
h.events = make(map[reflect.Type]slab, 10)
}
slab := h.events[t]
id = slab.Put(r)
h.events[t] = slab
h.mutex.Unlock() h.mutex.Unlock()
return func() { return func() {
h.mutex.Lock() h.mutex.Lock()
popped := h.slab.Pop(id) slab := h.events[t]
popped := slab.Pop(id)
h.mutex.Unlock() h.mutex.Unlock()
popped.cleanup() popped.cleanup()
@ -193,20 +236,22 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
type handler struct { type handler struct {
event reflect.Type // underlying type; arg0 or chan underlying type event reflect.Type // underlying type; arg0 or chan underlying type
callback reflect.Value callback reflect.Value
isIface bool
chanclose reflect.Value // IsValid() if chan chanclose reflect.Value // IsValid() if chan
isIface bool
isSync bool
} }
// newHandler reflects either a channel or a function into a handler. A function // newHandler reflects either a channel or a function into a handler. A function
// must only have a single argument being the event and no return, and a channel // must only have a single argument being the event and no return, and a channel
// must have the event type as the underlying type. // must have the event type as the underlying type.
func newHandler(unknown interface{}) (handler, error) { func newHandler(unknown interface{}, sync bool) (handler, error) {
fnV := reflect.ValueOf(unknown) fnV := reflect.ValueOf(unknown)
fnT := fnV.Type() fnT := fnV.Type()
// underlying event type // underlying event type
var handler = handler{ handler := handler{
callback: fnV, callback: fnV,
isSync: sync,
} }
switch fnT.Kind() { switch fnT.Kind() {
@ -249,6 +294,14 @@ func (h handler) not(event reflect.Type) bool {
return h.event != event return h.event != event
} }
func (h handler) Call(event reflect.Value) {
if h.isSync {
h.call(event)
} else {
go h.call(event)
}
}
func (h handler) call(event reflect.Value) { func (h handler) call(event reflect.Value) {
if h.chanclose.IsValid() { if h.chanclose.IsValid() {
reflect.Select([]reflect.SelectCase{ reflect.Select([]reflect.SelectCase{

View file

@ -63,7 +63,7 @@ func TestHandler(t *testing.T) {
h, err := newHandler(func(m *gateway.MessageCreateEvent) { h, err := newHandler(func(m *gateway.MessageCreateEvent) {
results <- m.Content results <- m.Content
}) }, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -88,7 +88,7 @@ func TestHandler(t *testing.T) {
func TestHandlerChan(t *testing.T) { func TestHandlerChan(t *testing.T) {
var results = make(chan *gateway.MessageCreateEvent) var results = make(chan *gateway.MessageCreateEvent)
h, err := newHandler(results) h, err := newHandler(results, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -115,7 +115,7 @@ func TestHandlerChanCancel(t *testing.T) {
// unbuffered. // unbuffered.
var results = make(chan *gateway.MessageCreateEvent) var results = make(chan *gateway.MessageCreateEvent)
h, err := newHandler(results) h, err := newHandler(results, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -161,7 +161,7 @@ func TestHandlerInterface(t *testing.T) {
h, err := newHandler(func(m interface{}) { h, err := newHandler(func(m interface{}) {
results <- m results <- m
}) }, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -277,7 +277,7 @@ func TestHandlerChanFor(t *testing.T) {
} }
func BenchmarkReflect(b *testing.B) { func BenchmarkReflect(b *testing.B) {
h, err := newHandler(func(m *gateway.MessageCreateEvent) {}) h, err := newHandler(func(m *gateway.MessageCreateEvent) {}, false)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View file

@ -1,8 +1,8 @@
package handler package handler
type slabEntry struct { type slabEntry struct {
handler
index int index int
handler
} }
func (entry slabEntry) isInvalid() bool { func (entry slabEntry) isInvalid() bool {
@ -18,13 +18,13 @@ type slab struct {
func (s *slab) Put(entry handler) int { func (s *slab) Put(entry handler) int {
if s.free == len(s.Entries) { if s.free == len(s.Entries) {
index := len(s.Entries) index := len(s.Entries)
s.Entries = append(s.Entries, slabEntry{entry, -1}) s.Entries = append(s.Entries, slabEntry{-1, entry})
s.free++ s.free++
return index return index
} }
next := s.Entries[s.free].index next := s.Entries[s.free].index
s.Entries[s.free] = slabEntry{entry, -1} s.Entries[s.free] = slabEntry{-1, entry}
i := s.free i := s.free
s.free = next s.free = next
@ -38,7 +38,7 @@ func (s *slab) Get(i int) handler {
func (s *slab) Pop(i int) handler { func (s *slab) Pop(i int) handler {
popped := s.Entries[i].handler popped := s.Entries[i].handler
s.Entries[i] = slabEntry{handler{}, s.free} s.Entries[i] = slabEntry{s.free, handler{}}
s.free = i s.free = i
return popped return popped
} }

View file

@ -36,7 +36,7 @@ func TestIntegration(t *testing.T) {
AddIntents(s.Gateway) AddIntents(s.Gateway)
func() { func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
if err := s.Open(ctx); err != nil { if err := s.Open(ctx); err != nil {