diff --git a/internal/testenv/testenv.go b/internal/testenv/testenv.go index 78f3881..84aee1b 100644 --- a/internal/testenv/testenv.go +++ b/internal/testenv/testenv.go @@ -60,7 +60,7 @@ func getEnv() { return } - shardCount := 3 + shardCount := 2 if c, err := strconv.Atoi(os.Getenv("SHARD_COUNT")); err == nil { shardCount = c } diff --git a/session/shard/shard_test.go b/session/shard/shard_test.go index a5219b8..af0a719 100644 --- a/session/shard/shard_test.go +++ b/session/shard/shard_test.go @@ -34,7 +34,7 @@ func TestSharding(t *testing.T) { t.Fatal("failed to make shard manager:", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() go func() { diff --git a/state/state.go b/state/state.go index 7362bd5..771aa14 100644 --- a/state/state.go +++ b/state/state.go @@ -348,7 +348,9 @@ func (s *State) Me() (*discord.User, error) { return nil, err } - return u, s.Cabinet.MyselfSet(*u, false) + s.Cabinet.MyselfSet(*u, false) + + return u, nil } //// @@ -365,7 +367,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { } if s.tracksChannel(c) { - err = s.Cabinet.ChannelSet(c, false) + s.Cabinet.ChannelSet(c, false) } return @@ -386,9 +388,7 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err if s.HasIntents(gateway.IntentGuilds) { for i := range cs { - if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil { - return - } + s.Cabinet.ChannelSet(&cs[i], false) } } @@ -406,7 +406,9 @@ func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel return nil, err } - return c, s.Cabinet.ChannelSet(c, false) + s.Cabinet.ChannelSet(c, false) + + return c, nil } // PrivateChannels gets the direct messages of the user. @@ -423,9 +425,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) { } for i := range cs { - if err := s.Cabinet.ChannelSet(&cs[i], false); err != nil { - return nil, err - } + s.Cabinet.ChannelSet(&cs[i], false) } return cs, nil @@ -450,9 +450,7 @@ func (s *State) Emoji( return nil, err } - if err = s.Cabinet.EmojiSet(guildID, es, false); err != nil { - return - } + s.Cabinet.EmojiSet(guildID, es, false) for _, e := range es { if e.ID == emojiID { @@ -477,7 +475,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) } if s.HasIntents(gateway.IntentGuildEmojis) { - err = s.Cabinet.EmojiSet(guildID, es, false) + s.Cabinet.EmojiSet(guildID, es, false) } return @@ -512,9 +510,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) { if s.HasIntents(gateway.IntentGuilds) { for i := range gs { - if err = s.Cabinet.GuildSet(&gs[i], false); err != nil { - return - } + s.Cabinet.GuildSet(&gs[i], false) } } @@ -549,9 +545,7 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error if s.HasIntents(gateway.IntentGuildMembers) { for i := range ms { - if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil { - return - } + s.Cabinet.MemberSet(guildID, &ms[i], false) } } @@ -581,7 +575,7 @@ func (s *State) Message( go func() { c, cerr = s.Session.Channel(channelID) if cerr == nil && s.HasIntents(gateway.IntentGuilds) { - cerr = s.Cabinet.ChannelSet(c, false) + s.Cabinet.ChannelSet(c, false) } wg.Done() @@ -692,9 +686,7 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes msgs := apiMessages[:i] for i := range msgs { - if err := s.Cabinet.MessageSet(&msgs[i], false); err != nil { - return nil, err - } + s.Cabinet.MessageSet(&msgs[i], false) } } @@ -755,9 +747,7 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di } if s.HasIntents(gateway.IntentGuilds) { - if err = s.RoleSet(guildID, &rs[i], false); err != nil { - return - } + s.RoleSet(guildID, &rs[i], false) } } @@ -781,9 +771,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { if s.HasIntents(gateway.IntentGuilds) { for i := range rs { - if err := s.RoleSet(guildID, &rs[i], false); err != nil { - return rs, err - } + s.RoleSet(guildID, &rs[i], false) } } @@ -793,7 +781,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { g, err = s.Session.Guild(id) if err == nil && s.HasIntents(gateway.IntentGuilds) { - err = s.Cabinet.GuildSet(g, false) + s.Cabinet.GuildSet(g, false) } return @@ -802,7 +790,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) { m, err = s.Session.Member(gID, uID) if err == nil && s.HasIntents(gateway.IntentGuildMembers) { - err = s.Cabinet.MemberSet(gID, m, false) + s.Cabinet.MemberSet(gID, m, false) } return diff --git a/utils/bot/ctx_shard_test.go b/utils/bot/ctx_shard_test.go index 4d81246..4a12a95 100644 --- a/utils/bot/ctx_shard_test.go +++ b/utils/bot/ctx_shard_test.go @@ -44,7 +44,7 @@ func TestSharding(t *testing.T) { t.Fatal("failed to make shard manager:", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() go func() { diff --git a/utils/ws/gateway.go b/utils/ws/gateway.go index d86c02a..d361c62 100644 --- a/utils/ws/gateway.go +++ b/utils/ws/gateway.go @@ -321,6 +321,9 @@ func (g *Gateway) spin(ctx context.Context, h Handler) { return } + // Everything went well. Invalidate the error. + g.lastError = nil + case <-g.heart.C: h.SendHeartbeat(ctx) @@ -337,12 +340,21 @@ func (g *Gateway) spin(ctx context.Context, h Handler) { // Keep track of the last error for notifying. var err error + retryLoop: for try := 0; g.opts.ReconnectAttempt == 0 || try < g.opts.ReconnectAttempt; try++ { g.srcOp, err = g.ws.Dial(ctx) if err == nil { break } + // Exit if the context expired. + select { + case <-ctx.Done(): + err = ctx.Err() + break retryLoop + default: + } + // Signal an error before retrying. g.SendError(ConnectionError{err})