state, handler: Refactor state storage and sync handlers

This commit refactors a lot of packages.

It refactors the handler package, removing the Synchronous field and
replacing it the AddSyncHandler API, which allows each handler to
control whether or not it should be ran synchronously independent of
other handlers. This is useful for libraries that need to guarantee the
incoming order of events.

It also refactors the store interfaces to accept more interfaces. This
is to make the API more consistent as well as reducing potential useless
copies. The public-facing state API should still be the same, so this
change will mostly concern users with their own store implementations.

Several miscellaneous functions (such as a few in package gateway) were
modified to be more suitable to other packages, but those functions
should rarely ever be used, anyway.

Several tests are also fixed within this commit, namely fixing state's
intents bug.
This commit is contained in:
diamondburned 2021-11-03 15:16:02 -07:00
parent 7f4daccd2d
commit efde3f4ea6
No known key found for this signature in database
GPG Key ID: D78C4471CE776659
16 changed files with 362 additions and 274 deletions

View File

@ -233,27 +233,35 @@ type (
}
)
// ConvertSupplementalMember converts a SupplementalMember to a regular Member.
func ConvertSupplementalMember(sm SupplementalMember) discord.Member {
return discord.Member{
User: discord.User{ID: sm.UserID},
Nick: sm.Nick,
RoleIDs: sm.RoleIDs,
Joined: sm.Joined,
BoostedSince: sm.BoostedSince,
Deaf: sm.Deaf,
Mute: sm.Mute,
IsPending: sm.IsPending,
// ConvertSupplementalMembers converts a SupplementalMember to a regular Member.
func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member {
members := make([]discord.Member, len(sms))
for i, sm := range sms {
members[i] = discord.Member{
User: discord.User{ID: sm.UserID},
Nick: sm.Nick,
RoleIDs: sm.RoleIDs,
Joined: sm.Joined,
BoostedSince: sm.BoostedSince,
Deaf: sm.Deaf,
Mute: sm.Mute,
IsPending: sm.IsPending,
}
}
return members
}
// ConvertSupplementalPresence converts a SupplementalPresence to a regular
// ConvertSupplementalPresences converts a SupplementalPresence to a regular
// Presence with an empty GuildID.
func ConvertSupplementalPresence(sp SupplementalPresence) discord.Presence {
return discord.Presence{
User: discord.User{ID: sp.UserID},
Status: sp.Status,
Activities: sp.Activities,
ClientStatus: sp.ClientStatus,
func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence {
presences := make([]discord.Presence, len(sps))
for i, sp := range sps {
presences[i] = discord.Presence{
User: discord.User{ID: sp.UserID},
Status: sp.Status,
Activities: sp.Activities,
ClientStatus: sp.ClientStatus,
}
}
return presences
}

View File

@ -27,7 +27,9 @@ var (
// The user should initialize handlers and intents in the opts function.
func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc {
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
return NewFromSession(session.NewCustomShard(m, id), defaultstore.New()), nil
state := NewFromSession(session.NewCustomShard(m, id), defaultstore.New())
opts(m, state)
return state, nil
}
}
@ -363,7 +365,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) {
}
if s.tracksChannel(c) {
err = s.Cabinet.ChannelSet(*c, false)
err = s.Cabinet.ChannelSet(c, false)
}
return
@ -383,8 +385,8 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, c := range cs {
if err = s.Cabinet.ChannelSet(c, false); err != nil {
for i := range cs {
if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil {
return
}
}
@ -404,7 +406,7 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel
return nil, err
}
return c, s.Cabinet.ChannelSet(*c, false)
return c, s.Cabinet.ChannelSet(c, false)
}
// PrivateChannels gets the direct messages of the user.
@ -420,8 +422,8 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) {
return nil, err
}
for _, c := range cs {
if err := s.Cabinet.ChannelSet(c, false); err != nil {
for i := range cs {
if err := s.Cabinet.ChannelSet(&cs[i], false); err != nil {
return nil, err
}
}
@ -509,8 +511,8 @@ func (s *State) Guilds() (gs []discord.Guild, err error) {
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, g := range gs {
if err = s.Cabinet.GuildSet(g, false); err != nil {
for i := range gs {
if err = s.Cabinet.GuildSet(&gs[i], false); err != nil {
return
}
}
@ -546,8 +548,8 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error
}
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
for _, m := range ms {
if err = s.Cabinet.MemberSet(guildID, m, false); err != nil {
for i := range ms {
if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil {
return
}
}
@ -579,7 +581,7 @@ func (s *State) Message(
go func() {
c, cerr = s.Session.Channel(channelID)
if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
cerr = s.Cabinet.ChannelSet(*c, false)
cerr = s.Cabinet.ChannelSet(c, false)
}
wg.Done()
@ -688,8 +690,9 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes
i = len(apiMessages)
}
for _, m := range apiMessages[:i] {
if err := s.Cabinet.MessageSet(m, false); err != nil {
msgs := apiMessages[:i]
for i := range msgs {
if err := s.Cabinet.MessageSet(&msgs[i], false); err != nil {
return nil, err
}
}
@ -745,14 +748,14 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di
return
}
for _, r := range rs {
for i, r := range rs {
if r.ID == roleID {
r := r // copy to prevent mem aliasing
target = &r
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if err = s.RoleSet(guildID, r, false); err != nil {
if err = s.RoleSet(guildID, &rs[i], false); err != nil {
return
}
}
@ -777,8 +780,8 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
for _, r := range rs {
if err := s.RoleSet(guildID, r, false); err != nil {
for i := range rs {
if err := s.RoleSet(guildID, &rs[i], false); err != nil {
return rs, err
}
}
@ -790,7 +793,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
g, err = s.Session.Guild(id)
if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
err = s.Cabinet.GuildSet(*g, false)
err = s.Cabinet.GuildSet(g, false)
}
return
@ -799,7 +802,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) {
m, err = s.Session.Member(gID, uID)
if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) {
err = s.Cabinet.MemberSet(gID, *m, false)
err = s.Cabinet.MemberSet(gID, m, false)
}
return
@ -808,12 +811,14 @@ func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord
// tracksMessage reports whether the state would track the passed message and
// messages from the same channel.
func (s *State) tracksMessage(m *discord.Message) bool {
return (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) ||
return s.Gateway.Identifier.Intents == nil ||
(m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) ||
(!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages))
}
// tracksChannel reports whether the state would track the passed channel.
func (s *State) tracksChannel(c *discord.Channel) bool {
return (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) ||
return s.Gateway.Identifier.Intents == nil ||
(c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) ||
!c.GuildID.IsValid()
}

View File

@ -9,7 +9,7 @@ import (
)
func (s *State) hookSession() {
s.Session.AddHandler(func(event interface{}) {
s.Session.AddSyncHandler(func(event interface{}) {
// Call the pre-handler before the state handler.
if s.PreHandler != nil {
s.PreHandler.Call(event)
@ -68,15 +68,15 @@ func (s *State) onEvent(iface interface{}) {
}
// Handle guild presences
for _, p := range ev.Presences {
if err := s.Cabinet.PresenceSet(p.GuildID, p, false); err != nil {
for i, presence := range ev.Presences {
if err := s.Cabinet.PresenceSet(presence.GuildID, &ev.Presences[i], false); err != nil {
s.stateErr(err, "failed to set presence in Ready")
}
}
// Handle private channels
for _, ch := range ev.PrivateChannels {
if err := s.Cabinet.ChannelSet(ch, false); err != nil {
for i := range ev.PrivateChannels {
if err := s.Cabinet.ChannelSet(&ev.PrivateChannels[i], false); err != nil {
s.stateErr(err, "failed to set channel in Ready")
}
}
@ -90,16 +90,17 @@ func (s *State) onEvent(iface interface{}) {
// Handle guilds
for _, guild := range ev.Guilds {
// Handle guild voice states
for _, v := range guild.VoiceStates {
for i := range guild.VoiceStates {
v := &guild.VoiceStates[i]
if err := s.Cabinet.VoiceStateSet(guild.ID, v, false); err != nil {
s.stateErr(err, "failed to set guild voice state in Ready Supplemental")
}
}
}
for _, friend := range ev.MergedPresences.Friends {
sPresence := gateway.ConvertSupplementalPresence(friend)
if err := s.Cabinet.PresenceSet(0, sPresence, false); err != nil {
friendPresences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Friends)
for i := range friendPresences {
if err := s.Cabinet.PresenceSet(0, &friendPresences[i], false); err != nil {
s.stateErr(err, "failed to set friend presence in Ready Supplemental")
}
}
@ -110,16 +111,16 @@ func (s *State) onEvent(iface interface{}) {
for i := 0; i < len(ready.Guilds) && i < len(ev.MergedMembers); i++ {
guild := ready.Guilds[i]
for _, member := range ev.MergedMembers[i] {
sMember := gateway.ConvertSupplementalMember(member)
if err := s.Cabinet.MemberSet(guild.ID, sMember, false); err != nil {
members := gateway.ConvertSupplementalMembers(ev.MergedMembers[i])
for i := range members {
if err := s.Cabinet.MemberSet(guild.ID, &members[i], false); err != nil {
s.stateErr(err, "failed to set friend presence in Ready Supplemental")
}
}
for _, member := range ev.MergedPresences.Guilds[i] {
sPresence := gateway.ConvertSupplementalPresence(member)
if err := s.Cabinet.PresenceSet(guild.ID, sPresence, false); err != nil {
presences := gateway.ConvertSupplementalPresences(ev.MergedPresences.Guilds[i])
for i := range presences {
if err := s.Cabinet.PresenceSet(guild.ID, &presences[i], false); err != nil {
s.stateErr(err, "failed to set member presence in Ready Supplemental")
}
}
@ -129,7 +130,7 @@ func (s *State) onEvent(iface interface{}) {
s.batchLog(storeGuildCreate(s.Cabinet, ev))
case *gateway.GuildUpdateEvent:
if err := s.Cabinet.GuildSet(ev.Guild, true); err != nil {
if err := s.Cabinet.GuildSet(&ev.Guild, true); err != nil {
s.stateErr(err, "failed to update guild in state")
}
@ -139,7 +140,7 @@ func (s *State) onEvent(iface interface{}) {
}
case *gateway.GuildMemberAddEvent:
if err := s.Cabinet.MemberSet(ev.GuildID, ev.Member, false); err != nil {
if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Member, false); err != nil {
s.stateErr(err, "failed to add a member in state")
}
@ -153,7 +154,7 @@ func (s *State) onEvent(iface interface{}) {
// Update available fields from ev into m
ev.Update(m)
if err := s.Cabinet.MemberSet(ev.GuildID, *m, true); err != nil {
if err := s.Cabinet.MemberSet(ev.GuildID, m, true); err != nil {
s.stateErr(err, "failed to update a member in state")
}
@ -163,25 +164,25 @@ func (s *State) onEvent(iface interface{}) {
}
case *gateway.GuildMembersChunkEvent:
for _, m := range ev.Members {
if err := s.Cabinet.MemberSet(ev.GuildID, m, false); err != nil {
for i := range ev.Members {
if err := s.Cabinet.MemberSet(ev.GuildID, &ev.Members[i], false); err != nil {
s.stateErr(err, "failed to add a member from chunk in state")
}
}
for _, p := range ev.Presences {
if err := s.Cabinet.PresenceSet(ev.GuildID, p, false); err != nil {
for i := range ev.Presences {
if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presences[i], false); err != nil {
s.stateErr(err, "failed to add a presence from chunk in state")
}
}
case *gateway.GuildRoleCreateEvent:
if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, false); err != nil {
if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, false); err != nil {
s.stateErr(err, "failed to add a role in state")
}
case *gateway.GuildRoleUpdateEvent:
if err := s.Cabinet.RoleSet(ev.GuildID, ev.Role, true); err != nil {
if err := s.Cabinet.RoleSet(ev.GuildID, &ev.Role, true); err != nil {
s.stateErr(err, "failed to update a role in state")
}
@ -196,17 +197,17 @@ func (s *State) onEvent(iface interface{}) {
}
case *gateway.ChannelCreateEvent:
if err := s.Cabinet.ChannelSet(ev.Channel, false); err != nil {
if err := s.Cabinet.ChannelSet(&ev.Channel, false); err != nil {
s.stateErr(err, "failed to create a channel in state")
}
case *gateway.ChannelUpdateEvent:
if err := s.Cabinet.ChannelSet(ev.Channel, true); err != nil {
if err := s.Cabinet.ChannelSet(&ev.Channel, true); err != nil {
s.stateErr(err, "failed to update a channel in state")
}
case *gateway.ChannelDeleteEvent:
if err := s.Cabinet.ChannelRemove(ev.Channel); err != nil {
if err := s.Cabinet.ChannelRemove(&ev.Channel); err != nil {
s.stateErr(err, "failed to remove a channel in state")
}
@ -214,12 +215,12 @@ func (s *State) onEvent(iface interface{}) {
// not tracked.
case *gateway.MessageCreateEvent:
if err := s.Cabinet.MessageSet(ev.Message, false); err != nil {
if err := s.Cabinet.MessageSet(&ev.Message, false); err != nil {
s.stateErr(err, "failed to add a message in state")
}
case *gateway.MessageUpdateEvent:
if err := s.Cabinet.MessageSet(ev.Message, true); err != nil {
if err := s.Cabinet.MessageSet(&ev.Message, true); err != nil {
s.stateErr(err, "failed to update a message in state")
}
@ -238,12 +239,17 @@ func (s *State) onEvent(iface interface{}) {
case *gateway.MessageReactionAddEvent:
s.editMessage(ev.ChannelID, ev.MessageID, func(m *discord.Message) bool {
if i := findReaction(m.Reactions, ev.Emoji); i > -1 {
// Copy the reactions slice so it's not racy.
m.Reactions = append([]discord.Reaction(nil), m.Reactions...)
m.Reactions[i].Count++
} else {
var me bool
if u, _ := s.Cabinet.Me(); u != nil {
me = ev.UserID == u.ID
}
old := m.Reactions
m.Reactions = make([]discord.Reaction, 0, len(old)+1)
m.Reactions = append(m.Reactions, old...)
m.Reactions = append(m.Reactions, discord.Reaction{
Count: 1,
Me: me,
@ -261,18 +267,21 @@ func (s *State) onEvent(iface interface{}) {
}
r := &m.Reactions[i]
r.Count--
newCount := r.Count - 1
switch {
case r.Count < 1: // If the count is 0:
// Remove the reaction.
m.Reactions = append(m.Reactions[:i], m.Reactions[i+1:]...)
case newCount < 1: // If the count is 0:
old := m.Reactions
m.Reactions = make([]discord.Reaction, len(m.Reactions)-1)
copy(m.Reactions[0:], old[:i])
copy(m.Reactions[i:], old[i+1:])
case r.Me: // If reaction removal is the user's
u, err := s.Cabinet.Me()
if err == nil && ev.UserID == u.ID {
r.Me = false
}
r.Count--
}
return true
@ -295,13 +304,13 @@ func (s *State) onEvent(iface interface{}) {
})
case *gateway.PresenceUpdateEvent:
if err := s.Cabinet.PresenceSet(ev.GuildID, ev.Presence, true); err != nil {
if err := s.Cabinet.PresenceSet(ev.GuildID, &ev.Presence, true); err != nil {
s.stateErr(err, "failed to update presence in state")
}
case *gateway.PresencesReplaceEvent:
for _, p := range *ev {
if err := s.Cabinet.PresenceSet(p.GuildID, p.Presence, true); err != nil {
if err := s.Cabinet.PresenceSet(p.GuildID, &p.Presence, true); err != nil {
s.stateErr(err, "failed to update presence in state")
}
}
@ -332,7 +341,7 @@ func (s *State) onEvent(iface interface{}) {
s.stateErr(err, "failed to remove voice state from state")
}
} else {
if err := s.Cabinet.VoiceStateSet(vs.GuildID, *vs, true); err != nil {
if err := s.Cabinet.VoiceStateSet(vs.GuildID, vs, true); err != nil {
s.stateErr(err, "failed to update voice state in state")
}
}
@ -342,6 +351,7 @@ func (s *State) onEvent(iface interface{}) {
func (s *State) stateErr(err error, wrap string) {
s.StateLog(errors.Wrap(err, wrap))
}
func (s *State) batchLog(errors []error) {
for _, err := range errors {
s.StateLog(err)
@ -355,10 +365,16 @@ func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func
if err != nil {
return
}
// Copy the messages.
cpy := *m
m = &cpy
if !fn(m) {
return
}
if err := s.Cabinet.MessageSet(*m, true); err != nil {
if err := s.Cabinet.MessageSet(m, true); err != nil {
s.stateErr(err, "failed to save message in reaction add")
}
}
@ -379,7 +395,7 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err
stack, errs := newErrorStack()
if err := cab.GuildSet(guild.Guild, false); err != nil {
if err := cab.GuildSet(&guild.Guild, false); err != nil {
errs(err, "failed to set guild in Ready")
}
@ -391,8 +407,8 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err
}
// Handle guild member
for _, m := range guild.Members {
if err := cab.MemberSet(guild.ID, m, false); err != nil {
for i := range guild.Members {
if err := cab.MemberSet(guild.ID, &guild.Members[i], false); err != nil {
errs(err, "failed to set guild member in Ready")
}
}
@ -400,36 +416,40 @@ func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []err
// Handle guild channels
for _, ch := range guild.Channels {
// I HATE Discord.
ch := ch
ch.GuildID = guild.ID
if err := cab.ChannelSet(ch, false); err != nil {
if err := cab.ChannelSet(&ch, false); err != nil {
errs(err, "failed to set guild channel in Ready")
}
}
// Handle threads.
for _, ch := range guild.Threads {
ch := ch
ch.GuildID = guild.ID
if err := cab.ChannelSet(ch, false); err != nil {
if err := cab.ChannelSet(&ch, false); err != nil {
errs(err, "failed to set guild thread in Ready")
}
}
// Handle guild presences
for _, p := range guild.Presences {
p := p
p.GuildID = guild.ID
if err := cab.PresenceSet(guild.ID, p, false); err != nil {
if err := cab.PresenceSet(guild.ID, &p, false); err != nil {
errs(err, "failed to set guild presence in Ready")
}
}
// Handle guild voice states
for _, v := range guild.VoiceStates {
v := v
v.GuildID = guild.ID
if err := cab.VoiceStateSet(guild.ID, v, false); err != nil {
if err := cab.VoiceStateSet(guild.ID, &v, false); err != nil {
errs(err, "failed to set guild voice state in Ready")
}
}

View File

@ -2,6 +2,7 @@ package defaultstore
import (
"errors"
"fmt"
"sync"
"github.com/diamondburned/arikawa/v3/discord"
@ -13,20 +14,18 @@ type Channel struct {
// Channel references must be protected under the same mutex.
privates map[discord.UserID]*discord.Channel
privateChs []*discord.Channel
channels map[discord.ChannelID]*discord.Channel
guildChs map[discord.GuildID][]*discord.Channel
channels map[discord.ChannelID]discord.Channel
privates map[discord.UserID]discord.ChannelID
guildChs map[discord.GuildID][]discord.ChannelID
}
var _ store.ChannelStore = (*Channel)(nil)
func NewChannel() *Channel {
return &Channel{
privates: map[discord.UserID]*discord.Channel{},
channels: map[discord.ChannelID]*discord.Channel{},
guildChs: map[discord.GuildID][]*discord.Channel{},
channels: map[discord.ChannelID]discord.Channel{},
privates: map[discord.UserID]discord.ChannelID{},
guildChs: map[discord.GuildID][]discord.ChannelID{},
}
}
@ -34,9 +33,9 @@ func (s *Channel) Reset() error {
s.mut.Lock()
defer s.mut.Unlock()
s.privates = map[discord.UserID]*discord.Channel{}
s.channels = map[discord.ChannelID]*discord.Channel{}
s.guildChs = map[discord.GuildID][]*discord.Channel{}
s.channels = map[discord.ChannelID]discord.Channel{}
s.privates = map[discord.UserID]discord.ChannelID{}
s.guildChs = map[discord.GuildID][]discord.ChannelID{}
return nil
}
@ -50,20 +49,19 @@ func (s *Channel) Channel(id discord.ChannelID) (*discord.Channel, error) {
return nil, store.ErrNotFound
}
cpy := *ch
return &cpy, nil
return &ch, nil
}
func (s *Channel) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ch, ok := s.privates[recipient]
id, ok := s.privates[recipient]
if !ok {
return nil, store.ErrNotFound
}
cpy := *ch
cpy := s.channels[id]
return &cpy, nil
}
@ -72,16 +70,20 @@ func (s *Channel) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
s.mut.RLock()
defer s.mut.RUnlock()
chRefs, ok := s.guildChs[guildID]
chIDs, ok := s.guildChs[guildID]
if !ok {
return nil, store.ErrNotFound
}
// Reading chRefs is also covered by the global mutex.
var channels = make([]discord.Channel, len(chRefs))
for i, chRef := range chRefs {
channels[i] = *chRef
var channels = make([]discord.Channel, 0, len(chIDs))
for _, chID := range chIDs {
ch, ok := s.channels[chID]
if !ok {
continue
}
channels = append(channels, ch)
}
return channels, nil
@ -92,42 +94,47 @@ func (s *Channel) PrivateChannels() ([]discord.Channel, error) {
s.mut.RLock()
defer s.mut.RUnlock()
if len(s.privateChs) == 0 {
groupDMs := s.guildChs[0]
if len(s.privates) == 0 && len(groupDMs) == 0 {
return nil, store.ErrNotFound
}
var channels = make([]discord.Channel, len(s.privateChs))
for i, ch := range s.privateChs {
channels[i] = *ch
var channels = make([]discord.Channel, 0, len(s.privates)+len(groupDMs))
for _, chID := range s.privates {
if ch, ok := s.channels[chID]; ok {
channels = append(channels, ch)
}
}
for _, chID := range groupDMs {
if ch, ok := s.channels[chID]; ok {
channels = append(channels, ch)
}
}
return channels, nil
}
// ChannelSet sets the Direct Message or Guild channel into the state.
func (s *Channel) ChannelSet(channel discord.Channel, update bool) error {
func (s *Channel) ChannelSet(channel *discord.Channel, update bool) error {
cpy := *channel
s.mut.Lock()
defer s.mut.Unlock()
// Update the reference if we can.
if ch, ok := s.channels[channel.ID]; ok {
if update {
*ch = channel
}
return nil
}
s.channels[channel.ID] = cpy
switch channel.Type {
case discord.DirectMessage:
// Safety bound check.
if len(channel.DMRecipients) != 1 {
return errors.New("DirectMessage channel does not have 1 recipient")
return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID)
}
s.privates[channel.DMRecipients[0].ID] = &channel
fallthrough
s.privates[channel.DMRecipients[0].ID] = channel.ID
return nil
case discord.GroupDM:
s.privateChs = append(s.privateChs, &channel)
s.channels[channel.ID] = &channel
s.guildChs[0] = addChannelID(s.guildChs[0], channel.ID)
return nil
}
@ -137,16 +144,11 @@ func (s *Channel) ChannelSet(channel discord.Channel, update bool) error {
return errors.New("invalid guildID for guild channel")
}
s.channels[channel.ID] = &channel
channels, _ := s.guildChs[channel.GuildID]
channels = append(channels, &channel)
s.guildChs[channel.GuildID] = channels
s.guildChs[channel.GuildID] = addChannelID(s.guildChs[channel.GuildID], channel.ID)
return nil
}
func (s *Channel) ChannelRemove(channel discord.Channel) error {
func (s *Channel) ChannelRemove(channel *discord.Channel) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -158,49 +160,42 @@ func (s *Channel) ChannelRemove(channel discord.Channel) error {
case discord.DirectMessage:
// Safety bound check.
if len(channel.DMRecipients) != 1 {
return errors.New("DirectMessage channel does not have 1 recipient")
return fmt.Errorf("DirectMessage channel %d doesn't have 1 recipient", channel.ID)
}
delete(s.privates, channel.DMRecipients[0].ID)
fallthrough
return nil
case discord.GroupDM:
for i, priv := range s.privateChs {
if priv.ID == channel.ID {
s.privateChs = removeChannel(s.privateChs, i)
break
}
}
s.guildChs[0] = removeChannelID(s.guildChs[0], channel.ID)
return nil
}
// Wipe the channel off the guilds index, if available.
channels, ok := s.guildChs[channel.GuildID]
if !ok {
return nil
}
for i, ch := range channels {
if ch.ID == channel.ID {
s.guildChs[channel.GuildID] = removeChannel(channels, i)
break
}
}
s.guildChs[channel.GuildID] = removeChannelID(s.guildChs[channel.GuildID], channel.ID)
return nil
}
// removeChannel removes the given channel with the index from the given
func addChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID {
for _, ch := range channels {
if ch == id {
return channels
}
}
if channels == nil {
channels = make([]discord.ChannelID, 0, 5)
}
return append(channels, id)
}
// removeChannelID removes the given channel with the index from the given
// channels slice in an unordered fashion.
func removeChannel(channels []*discord.Channel, i int) []*discord.Channel {
// Fast unordered delete. Not sure if there's a benefit in doing
// this over using a map, but I guess the memory usage is less and
// there's no copying.
// Move the last channel to the current channel, set the last
// channel there to a nil value to unreference its children, then
// slice the last channel off.
channels[i] = channels[len(channels)-1]
channels[len(channels)-1] = nil
channels = channels[:len(channels)-1]
func removeChannelID(channels []discord.ChannelID, id discord.ChannelID) []discord.ChannelID {
for i, ch := range channels {
if ch == id {
// Move the last channel to the current channel, then slice the last
// channel off.
channels[i] = channels[len(channels)-1]
channels = channels[:len(channels)-1]
break
}
}
return channels
}

View File

@ -33,13 +33,12 @@ func (s *Guild) Guild(id discord.GuildID) (*discord.Guild, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ch, ok := s.guilds[id]
if !ok {
return nil, store.ErrNotFound
g, ok := s.guilds[id]
if ok {
return &g, nil
}
// implicit copy
return &ch, nil
return nil, store.ErrNotFound
}
func (s *Guild) Guilds() ([]discord.Guild, error) {
@ -58,10 +57,12 @@ func (s *Guild) Guilds() ([]discord.Guild, error) {
return gs, nil
}
func (s *Guild) GuildSet(guild discord.Guild, update bool) error {
func (s *Guild) GuildSet(guild *discord.Guild, update bool) error {
cpy := *guild
s.mut.Lock()
if _, ok := s.guilds[guild.ID]; !ok || update {
s.guilds[guild.ID] = guild
s.guilds[guild.ID] = cpy
}
s.mut.Unlock()

View File

@ -13,7 +13,7 @@ type Member struct {
}
type guildMembers struct {
mut sync.Mutex
mut sync.RWMutex
members map[discord.UserID]discord.Member
}
@ -41,8 +41,8 @@ func (s *Member) Member(guildID discord.GuildID, userID discord.UserID) (*discor
gm := iv.(*guildMembers)
gm.mut.Lock()
defer gm.mut.Unlock()
gm.mut.RLock()
defer gm.mut.RUnlock()
m, ok := gm.members[userID]
if ok {
@ -60,8 +60,8 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) {
gm := iv.(*guildMembers)
gm.mut.Lock()
defer gm.mut.Unlock()
gm.mut.RLock()
defer gm.mut.RUnlock()
var members = make([]discord.Member, 0, len(gm.members))
for _, m := range gm.members {
@ -71,13 +71,13 @@ func (s *Member) Members(guildID discord.GuildID) ([]discord.Member, error) {
return members, nil
}
func (s *Member) MemberSet(guildID discord.GuildID, m discord.Member, update bool) error {
func (s *Member) MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID)
gm := iv.(*guildMembers)
gm.mut.Lock()
if _, ok := gm.members[m.User.ID]; !ok || update {
gm.members[m.User.ID] = m
gm.members[m.User.ID] = *m
}
gm.mut.Unlock()

View File

@ -16,7 +16,7 @@ type Message struct {
var _ store.MessageStore = (*Message)(nil)
type messages struct {
mut sync.Mutex
mut sync.RWMutex
messages []discord.Message
}
@ -43,8 +43,8 @@ func (s *Message) Message(chID discord.ChannelID, mID discord.MessageID) (*disco
msgs := iv.(*messages)
msgs.mut.Lock()
defer msgs.mut.Unlock()
msgs.mut.RLock()
defer msgs.mut.RUnlock()
for _, m := range msgs.messages {
if m.ID == mID {
@ -63,8 +63,8 @@ func (s *Message) Messages(channelID discord.ChannelID) ([]discord.Message, erro
msgs := iv.(*messages)
msgs.mut.Lock()
defer msgs.mut.Unlock()
msgs.mut.RLock()
defer msgs.mut.RUnlock()
return append([]discord.Message(nil), msgs.messages...), nil
}
@ -73,7 +73,7 @@ func (s *Message) MaxMessages() int {
return s.maxMsgs
}
func (s *Message) MessageSet(message discord.Message, update bool) error {
func (s *Message) MessageSet(message *discord.Message, update bool) error {
if s.maxMsgs <= 0 {
return nil
}
@ -102,19 +102,19 @@ func (s *Message) MessageSet(message discord.Message, update bool) error {
}
if len(msgs.messages) == 0 {
msgs.messages = []discord.Message{message}
msgs.messages = []discord.Message{*message}
}
if pos := messageInsertPosition(message, msgs.messages); pos < 0 {
// Messages are full, drop the oldest messages to make room.
if len(msgs.messages) == s.maxMsgs {
copy(msgs.messages[1:], msgs.messages)
msgs.messages[0] = message
msgs.messages[0] = *message
} else {
msgs.messages = append([]discord.Message{message}, msgs.messages...)
msgs.messages = append([]discord.Message{*message}, msgs.messages...)
}
} else if pos > 0 && len(msgs.messages) < s.maxMsgs {
msgs.messages = append(msgs.messages, message)
msgs.messages = append(msgs.messages, *message)
}
// We already have this message or we can't append any more messages.
@ -131,7 +131,7 @@ func (s *Message) MessageSet(message discord.Message, update bool) error {
// messageInsertPosition is biased as it will recommend adding the message even
// if timestamps just match, even though the true order cannot be determined in
// that case.
func messageInsertPosition(target discord.Message, messages []discord.Message) int8 {
func messageInsertPosition(target *discord.Message, messages []discord.Message) int8 {
var (
targetTime = target.ID.Time()
firstTime = messages[0].ID.Time()
@ -183,7 +183,7 @@ func messageInsertPosition(target discord.Message, messages []discord.Message) i
}
// DiffMessage fills non-empty fields from src to dst.
func DiffMessage(src discord.Message, dst *discord.Message) {
func DiffMessage(src, dst *discord.Message) {
// Thanks, Discord.
if src.Content != "" {
dst.Content = src.Content

View File

@ -10,27 +10,27 @@ func populate12Store() *Message {
store := NewMessage(10)
// Insert a regular list of messages.
store.MessageSet(discord.Message{ID: 1 << 29, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 28, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 27, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 26, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 25, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 24, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 29, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 28, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 27, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 26, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 25, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 24, ChannelID: 1}, false)
// Try to insert newer messages after inserting new messages.
store.MessageSet(discord.Message{ID: 1 << 30, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 31, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 32, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 33, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 34, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 30, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 31, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 32, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 33, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 34, ChannelID: 1}, false)
// TThese messages should be discarded, due to age.
store.MessageSet(discord.Message{ID: 1 << 23, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 22, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 23, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 22, ChannelID: 1}, false)
// These should be prepended.
store.MessageSet(discord.Message{ID: 1 << 35, ChannelID: 1}, false)
store.MessageSet(discord.Message{ID: 1 << 36, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 35, ChannelID: 1}, false)
store.MessageSet(&discord.Message{ID: 1 << 36, ChannelID: 1}, false)
return store
}
@ -57,9 +57,9 @@ func TestMessageSet(t *testing.T) {
func TestMessagesUpdate(t *testing.T) {
store := populate12Store()
store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true)
store.MessageSet(discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true)
store.MessageSet(discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true)
store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 1"}, true)
store.MessageSet(&discord.Message{ID: 6, ChannelID: 1, Content: "edited 2"}, true)
store.MessageSet(&discord.Message{ID: 5, ChannelID: 1, Content: "edited 3"}, true)
expect := map[discord.MessageID]string{
5: "edited 3",

View File

@ -13,7 +13,7 @@ type Presence struct {
}
type presences struct {
mut sync.Mutex
mut sync.RWMutex
presences map[discord.UserID]discord.Presence
}
@ -41,8 +41,8 @@ func (s *Presence) Presence(gID discord.GuildID, uID discord.UserID) (*discord.P
ps := iv.(*presences)
ps.mut.Lock()
defer ps.mut.Unlock()
ps.mut.RLock()
defer ps.mut.RUnlock()
p, ok := ps.presences[uID]
if ok {
@ -60,8 +60,8 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error
ps := iv.(*presences)
ps.mut.Lock()
defer ps.mut.Unlock()
ps.mut.RLock()
defer ps.mut.RUnlock()
var presences = make([]discord.Presence, 0, len(ps.presences))
for _, p := range ps.presences {
@ -71,7 +71,7 @@ func (s *Presence) Presences(guildID discord.GuildID) ([]discord.Presence, error
return presences, nil
}
func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error {
func (s *Presence) PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID)
ps := iv.(*presences)
@ -85,7 +85,7 @@ func (s *Presence) PresenceSet(guildID discord.GuildID, p discord.Presence, upda
}
if _, ok := ps.presences[p.User.ID]; !ok || update {
ps.presences[p.User.ID] = p
ps.presences[p.User.ID] = *p
}
return nil

View File

@ -15,7 +15,7 @@ type Role struct {
var _ store.RoleStore = (*Role)(nil)
type roles struct {
mut sync.Mutex
mut sync.RWMutex
roles map[discord.RoleID]discord.Role
}
@ -41,8 +41,8 @@ func (s *Role) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Ro
rs := iv.(*roles)
rs.mut.Lock()
defer rs.mut.Unlock()
rs.mut.RLock()
defer rs.mut.RUnlock()
r, ok := rs.roles[roleID]
if ok {
@ -60,8 +60,8 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) {
rs := iv.(*roles)
rs.mut.Lock()
defer rs.mut.Unlock()
rs.mut.RLock()
defer rs.mut.RUnlock()
var roles = make([]discord.Role, 0, len(rs.roles))
for _, role := range rs.roles {
@ -71,14 +71,14 @@ func (s *Role) Roles(guildID discord.GuildID) ([]discord.Role, error) {
return roles, nil
}
func (s *Role) RoleSet(guildID discord.GuildID, role discord.Role, update bool) error {
func (s *Role) RoleSet(guildID discord.GuildID, role *discord.Role, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID)
rs := iv.(*roles)
rs.mut.Lock()
if _, ok := rs.roles[role.ID]; !ok || update {
rs.roles[role.ID] = role
rs.roles[role.ID] = *role
}
rs.mut.Unlock()

View File

@ -15,7 +15,7 @@ type VoiceState struct {
var _ store.VoiceStateStore = (*VoiceState)(nil)
type voiceStates struct {
mut sync.Mutex
mut sync.RWMutex
voiceStates map[discord.UserID]discord.VoiceState
}
@ -43,8 +43,8 @@ func (s *VoiceState) VoiceState(
vs := iv.(*voiceStates)
vs.mut.Lock()
defer vs.mut.Unlock()
vs.mut.RLock()
defer vs.mut.RUnlock()
v, ok := vs.voiceStates[userID]
if ok {
@ -62,8 +62,8 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState,
vs := iv.(*voiceStates)
vs.mut.Lock()
defer vs.mut.Unlock()
vs.mut.RLock()
defer vs.mut.RUnlock()
var states = make([]discord.VoiceState, 0, len(vs.voiceStates))
for _, state := range vs.voiceStates {
@ -74,7 +74,7 @@ func (s *VoiceState) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState,
}
func (s *VoiceState) VoiceStateSet(
guildID discord.GuildID, voiceState discord.VoiceState, update bool) error {
guildID discord.GuildID, voiceState *discord.VoiceState, update bool) error {
iv, _ := s.guilds.LoadOrStore(guildID)
@ -82,7 +82,7 @@ func (s *VoiceState) VoiceStateSet(
vs.mut.Lock()
if _, ok := vs.voiceStates[voiceState.UserID]; !ok || update {
vs.voiceStates[voiceState.UserID] = voiceState
vs.voiceStates[voiceState.UserID] = *voiceState
}
vs.mut.Unlock()

View File

@ -143,6 +143,12 @@ type Resetter interface {
Reset() error
}
type CoreStorer interface {
Resetter
Lock()
Unlock()
}
var _ Resetter = (*noop)(nil)
func (noop) Reset() error { return nil }
@ -176,8 +182,8 @@ type ChannelStore interface {
// Both ChannelSet and ChannelRemove should switch on Type to know if it's a
// private channel or not.
ChannelSet(c discord.Channel, update bool) error
ChannelRemove(discord.Channel) error
ChannelSet(c *discord.Channel, update bool) error
ChannelRemove(*discord.Channel) error
}
var _ ChannelStore = (*noop)(nil)
@ -194,10 +200,10 @@ func (noop) Channels(discord.GuildID) ([]discord.Channel, error) {
func (noop) PrivateChannels() ([]discord.Channel, error) {
return nil, ErrNotFound
}
func (noop) ChannelSet(discord.Channel, bool) error {
func (noop) ChannelSet(*discord.Channel, bool) error {
return nil
}
func (noop) ChannelRemove(discord.Channel) error {
func (noop) ChannelRemove(*discord.Channel) error {
return nil
}
@ -232,7 +238,7 @@ type GuildStore interface {
Guild(discord.GuildID) (*discord.Guild, error)
Guilds() ([]discord.Guild, error)
GuildSet(g discord.Guild, update bool) error
GuildSet(g *discord.Guild, update bool) error
GuildRemove(id discord.GuildID) error
}
@ -240,7 +246,7 @@ var _ GuildStore = (*noop)(nil)
func (noop) Guild(discord.GuildID) (*discord.Guild, error) { return nil, ErrNotFound }
func (noop) Guilds() ([]discord.Guild, error) { return nil, ErrNotFound }
func (noop) GuildSet(discord.Guild, bool) error { return nil }
func (noop) GuildSet(*discord.Guild, bool) error { return nil }
func (noop) GuildRemove(discord.GuildID) error { return nil }
// MemberStore is the store interface for all members.
@ -250,7 +256,7 @@ type MemberStore interface {
Member(discord.GuildID, discord.UserID) (*discord.Member, error)
Members(discord.GuildID) ([]discord.Member, error)
MemberSet(guildID discord.GuildID, m discord.Member, update bool) error
MemberSet(guildID discord.GuildID, m *discord.Member, update bool) error
MemberRemove(discord.GuildID, discord.UserID) error
}
@ -262,7 +268,7 @@ func (noop) Member(discord.GuildID, discord.UserID) (*discord.Member, error) {
func (noop) Members(discord.GuildID) ([]discord.Member, error) {
return nil, ErrNotFound
}
func (noop) MemberSet(discord.GuildID, discord.Member, bool) error {
func (noop) MemberSet(discord.GuildID, *discord.Member, bool) error {
return nil
}
func (noop) MemberRemove(discord.GuildID, discord.UserID) error {
@ -289,7 +295,7 @@ type MessageStore interface {
// If update is set to true, MessageSet will check if a message with the
// id of the passed message is stored, and update it if so. Otherwise, if
// there is no such message, it will be discarded.
MessageSet(m discord.Message, update bool) error
MessageSet(m *discord.Message, update bool) error
MessageRemove(discord.ChannelID, discord.MessageID) error
}
@ -304,7 +310,7 @@ func (noop) Message(discord.ChannelID, discord.MessageID) (*discord.Message, err
func (noop) Messages(discord.ChannelID) ([]discord.Message, error) {
return nil, ErrNotFound
}
func (noop) MessageSet(discord.Message, bool) error {
func (noop) MessageSet(*discord.Message, bool) error {
return nil
}
func (noop) MessageRemove(discord.ChannelID, discord.MessageID) error {
@ -319,7 +325,7 @@ type PresenceStore interface {
Presence(discord.GuildID, discord.UserID) (*discord.Presence, error)
Presences(discord.GuildID) ([]discord.Presence, error)
PresenceSet(guildID discord.GuildID, p discord.Presence, update bool) error
PresenceSet(guildID discord.GuildID, p *discord.Presence, update bool) error
PresenceRemove(discord.GuildID, discord.UserID) error
}
@ -331,7 +337,7 @@ func (noop) Presence(discord.GuildID, discord.UserID) (*discord.Presence, error)
func (noop) Presences(discord.GuildID) ([]discord.Presence, error) {
return nil, ErrNotFound
}
func (noop) PresenceSet(discord.GuildID, discord.Presence, bool) error {
func (noop) PresenceSet(discord.GuildID, *discord.Presence, bool) error {
return nil
}
func (noop) PresenceRemove(discord.GuildID, discord.UserID) error {
@ -345,7 +351,7 @@ type RoleStore interface {
Role(discord.GuildID, discord.RoleID) (*discord.Role, error)
Roles(discord.GuildID) ([]discord.Role, error)
RoleSet(guildID discord.GuildID, r discord.Role, update bool) error
RoleSet(guildID discord.GuildID, r *discord.Role, update bool) error
RoleRemove(discord.GuildID, discord.RoleID) error
}
@ -353,7 +359,7 @@ var _ RoleStore = (*noop)(nil)
func (noop) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) { return nil, ErrNotFound }
func (noop) Roles(discord.GuildID) ([]discord.Role, error) { return nil, ErrNotFound }
func (noop) RoleSet(discord.GuildID, discord.Role, bool) error { return nil }
func (noop) RoleSet(discord.GuildID, *discord.Role, bool) error { return nil }
func (noop) RoleRemove(discord.GuildID, discord.RoleID) error { return nil }
// VoiceStateStore is the store interface for all voice states.
@ -363,7 +369,7 @@ type VoiceStateStore interface {
VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error)
VoiceStates(discord.GuildID) ([]discord.VoiceState, error)
VoiceStateSet(guildID discord.GuildID, s discord.VoiceState, update bool) error
VoiceStateSet(guildID discord.GuildID, s *discord.VoiceState, update bool) error
VoiceStateRemove(discord.GuildID, discord.UserID) error
}
@ -375,7 +381,7 @@ func (noop) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, er
func (noop) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) {
return nil, ErrNotFound
}
func (noop) VoiceStateSet(discord.GuildID, discord.VoiceState, bool) error {
func (noop) VoiceStateSet(discord.GuildID, *discord.VoiceState, bool) error {
return nil
}
func (noop) VoiceStateRemove(discord.GuildID, discord.UserID) error {

View File

@ -28,12 +28,8 @@ import (
// Handler is a container for command handlers. A zero-value instance is a valid
// instance.
type Handler struct {
// Synchronous controls whether to spawn each event handler in its own
// goroutine. Default false (meaning goroutines are spawned).
Synchronous bool
mutex sync.RWMutex
slab slab
mutex sync.RWMutex
events map[reflect.Type]slab // nil type for interfaces
}
func New() *Handler {
@ -43,22 +39,24 @@ func New() *Handler {
// Call calls all handlers with the given event. This is an internal method; use
// with care.
func (h *Handler) Call(ev interface{}) {
var evV = reflect.ValueOf(ev)
var evT = evV.Type()
v := reflect.ValueOf(ev)
t := v.Type()
h.mutex.RLock()
defer h.mutex.RUnlock()
for _, entry := range h.slab.Entries {
if entry.isInvalid() || entry.not(evT) {
for _, entry := range h.events[t].Entries {
if entry.isInvalid() {
continue
}
entry.Call(v)
}
if h.Synchronous {
entry.call(evV)
} else {
go entry.call(evV)
for _, entry := range h.events[nil].Entries {
if entry.isInvalid() || entry.not(t) {
continue
}
entry.Call(v)
}
}
@ -145,7 +143,19 @@ func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, ca
// h.AddHandler(ch)
//
func (h *Handler) AddHandler(handler interface{}) (rm func()) {
rm, err := h.addHandler(handler)
rm, err := h.addHandler(handler, false)
if err != nil {
panic(err)
}
return rm
}
// AddSyncHandler is a synchronous variant of AddHandler. Handlers added using
// this method will block the Call method, which is helpful if the user needs to
// rely on the order of events arriving. Handlers added using this method should
// not block for very long, as it may clog up other handlers.
func (h *Handler) AddSyncHandler(handler interface{}) (rm func()) {
rm, err := h.addHandler(handler, true)
if err != nil {
panic(err)
}
@ -167,23 +177,56 @@ func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) {
}
}()
return h.addHandler(handler)
return h.addHandler(handler, false)
}
func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
// AddSyncHandlerCheck is the safe-guarded version of AddSyncHandler. It is
// similar to AddHandlerCheck.
func (h *Handler) AddSyncHandlerCheck(handler interface{}) (rm func(), err error) {
// Reflect would actually panic if anything goes wrong, so this is just in
// case.
defer func() {
if rec := recover(); rec != nil {
if recErr, ok := rec.(error); ok {
err = recErr
} else {
err = fmt.Errorf("%v", rec)
}
}
}()
return h.addHandler(handler, true)
}
func (h *Handler) addHandler(fn interface{}, sync bool) (rm func(), err error) {
// Reflect the handler
r, err := newHandler(fn)
r, err := newHandler(fn, sync)
if err != nil {
return nil, errors.Wrap(err, "handler reflect failed")
}
var id int
var t reflect.Type
if !r.isIface {
t = r.event
}
h.mutex.Lock()
id := h.slab.Put(r)
if h.events == nil {
h.events = make(map[reflect.Type]slab, 10)
}
slab := h.events[t]
id = slab.Put(r)
h.events[t] = slab
h.mutex.Unlock()
return func() {
h.mutex.Lock()
popped := h.slab.Pop(id)
slab := h.events[t]
popped := slab.Pop(id)
h.mutex.Unlock()
popped.cleanup()
@ -193,20 +236,22 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
type handler struct {
event reflect.Type // underlying type; arg0 or chan underlying type
callback reflect.Value
isIface bool
chanclose reflect.Value // IsValid() if chan
isIface bool
isSync bool
}
// newHandler reflects either a channel or a function into a handler. A function
// must only have a single argument being the event and no return, and a channel
// must have the event type as the underlying type.
func newHandler(unknown interface{}) (handler, error) {
func newHandler(unknown interface{}, sync bool) (handler, error) {
fnV := reflect.ValueOf(unknown)
fnT := fnV.Type()
// underlying event type
var handler = handler{
handler := handler{
callback: fnV,
isSync: sync,
}
switch fnT.Kind() {
@ -249,6 +294,14 @@ func (h handler) not(event reflect.Type) bool {
return h.event != event
}
func (h handler) Call(event reflect.Value) {
if h.isSync {
h.call(event)
} else {
go h.call(event)
}
}
func (h handler) call(event reflect.Value) {
if h.chanclose.IsValid() {
reflect.Select([]reflect.SelectCase{

View File

@ -63,7 +63,7 @@ func TestHandler(t *testing.T) {
h, err := newHandler(func(m *gateway.MessageCreateEvent) {
results <- m.Content
})
}, false)
if err != nil {
t.Fatal(err)
}
@ -88,7 +88,7 @@ func TestHandler(t *testing.T) {
func TestHandlerChan(t *testing.T) {
var results = make(chan *gateway.MessageCreateEvent)
h, err := newHandler(results)
h, err := newHandler(results, false)
if err != nil {
t.Fatal(err)
}
@ -115,7 +115,7 @@ func TestHandlerChanCancel(t *testing.T) {
// unbuffered.
var results = make(chan *gateway.MessageCreateEvent)
h, err := newHandler(results)
h, err := newHandler(results, false)
if err != nil {
t.Fatal(err)
}
@ -161,7 +161,7 @@ func TestHandlerInterface(t *testing.T) {
h, err := newHandler(func(m interface{}) {
results <- m
})
}, false)
if err != nil {
t.Fatal(err)
}
@ -277,7 +277,7 @@ func TestHandlerChanFor(t *testing.T) {
}
func BenchmarkReflect(b *testing.B) {
h, err := newHandler(func(m *gateway.MessageCreateEvent) {})
h, err := newHandler(func(m *gateway.MessageCreateEvent) {}, false)
if err != nil {
b.Fatal(err)
}

View File

@ -1,8 +1,8 @@
package handler
type slabEntry struct {
handler
index int
handler
}
func (entry slabEntry) isInvalid() bool {
@ -18,13 +18,13 @@ type slab struct {
func (s *slab) Put(entry handler) int {
if s.free == len(s.Entries) {
index := len(s.Entries)
s.Entries = append(s.Entries, slabEntry{entry, -1})
s.Entries = append(s.Entries, slabEntry{-1, entry})
s.free++
return index
}
next := s.Entries[s.free].index
s.Entries[s.free] = slabEntry{entry, -1}
s.Entries[s.free] = slabEntry{-1, entry}
i := s.free
s.free = next
@ -38,7 +38,7 @@ func (s *slab) Get(i int) handler {
func (s *slab) Pop(i int) handler {
popped := s.Entries[i].handler
s.Entries[i] = slabEntry{handler{}, s.free}
s.Entries[i] = slabEntry{s.free, handler{}}
s.free = i
return popped
}

View File

@ -36,7 +36,7 @@ func TestIntegration(t *testing.T) {
AddIntents(s.Gateway)
func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if err := s.Open(ctx); err != nil {