shard: Remake shard manager (#226)

This commit is contained in:
diamondburned 2021-06-10 16:48:32 -07:00
parent af3bedc472
commit 5b328bdab0
30 changed files with 1163 additions and 273 deletions

View File

@ -12,6 +12,7 @@ environment:
GO111MODULE: "on"
CGO_ENABLED: "1"
# Integration test variables.
SHARD_COUNT: "3"
tested: "./api,./gateway,./bot,./discord"
cov_file: "/tmp/cov_results"
dismock: "github.com/mavolin/dismock/v2/pkg/dismock"

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"log"
"os"
@ -34,8 +35,8 @@ func main() {
data := api.InteractionResponse{
Type: api.MessageInteractionWithSource,
Data: &api.InteractionResponseData{
Content: "This is a message with a button!",
Components: []discord.Component{
Content: option.NewNullableString("This is a message with a button!"),
Components: &[]discord.Component{
discord.ActionRowComponent{
Components: []discord.Component{
discord.ButtonComponent{
@ -93,10 +94,10 @@ func main() {
}
})
s.Gateway.AddIntents(gateway.IntentGuilds)
s.Gateway.AddIntents(gateway.IntentGuildMessages)
s.AddIntents(gateway.IntentGuilds)
s.AddIntents(gateway.IntentGuildMessages)
if err := s.Open(); err != nil {
if err := s.Open(context.Background()); err != nil {
log.Fatalln("failed to open:", err)
}
defer s.Close()

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"log"
"os"
@ -8,6 +9,7 @@ import (
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/utils/json/option"
)
// To run, do `APP_ID="APP ID" GUILD_ID="GUILD ID" BOT_TOKEN="TOKEN HERE" go run .`
@ -31,7 +33,7 @@ func main() {
data := api.InteractionResponse{
Type: api.MessageInteractionWithSource,
Data: &api.InteractionResponseData{
Content: "Pong!",
Content: option.NewNullableString("Pong!"),
},
}
@ -40,10 +42,10 @@ func main() {
}
})
s.Gateway.AddIntents(gateway.IntentGuilds)
s.Gateway.AddIntents(gateway.IntentGuildMessages)
s.AddIntents(gateway.IntentGuilds)
s.AddIntents(gateway.IntentGuildMessages)
if err := s.Open(); err != nil {
if err := s.Open(context.Background()); err != nil {
log.Fatalln("failed to open:", err)
}
defer s.Close()

61
_example/sharded/main.go Normal file
View File

@ -0,0 +1,61 @@
// Package main demonstrates a bare simple bot without a state cache. It logs
// all messages it sees into stderr.
package main
import (
"context"
"log"
"os"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/state"
)
// To run, do `BOT_TOKEN="TOKEN HERE" go run .`
func main() {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
log.Fatalln("No $BOT_TOKEN given.")
}
newShard := state.NewShardFunc(func(m *shard.Manager, s *state.State) {
// Add the needed Gateway intents.
s.AddIntents(gateway.IntentGuildMessages)
s.AddIntents(gateway.IntentDirectMessages)
s.AddHandler(func(c *gateway.MessageCreateEvent) {
_, shardIx := m.FromGuildID(c.GuildID)
log.Println(c.Author.Tag(), "sent", c.Content, "on shard", shardIx)
})
})
m, err := shard.NewManager("Bot "+token, newShard)
if err != nil {
log.Fatalln("failed to create shard manager:", err)
}
if err := m.Open(context.Background()); err != nil {
log.Fatalln("failed to connect shards:", err)
}
defer m.Close()
var shardNum int
m.ForEach(func(s shard.Shard) {
state := s.(*state.State)
u, err := state.Me()
if err != nil {
log.Fatalln("failed to get myself:", err)
}
log.Printf("Shard %d/%d started as %s", shardNum, m.NumShards()-1, u.Tag())
shardNum++
})
// Block forever.
select {}
}

View File

@ -3,6 +3,7 @@
package main
import (
"context"
"log"
"os"
@ -28,10 +29,10 @@ func main() {
})
// Add the needed Gateway intents.
s.Gateway.AddIntents(gateway.IntentGuildMessages)
s.Gateway.AddIntents(gateway.IntentDirectMessages)
s.AddIntents(gateway.IntentGuildMessages)
s.AddIntents(gateway.IntentDirectMessages)
if err := s.Open(); err != nil {
if err := s.Open(context.Background()); err != nil {
log.Fatalln("Failed to connect:", err)
}
defer s.Close()

View File

@ -2,6 +2,7 @@
package main
import (
"context"
"log"
"os"
@ -37,10 +38,10 @@ func main() {
})
// Add the needed Gateway intents.
s.Gateway.AddIntents(gateway.IntentGuildMessages)
s.Gateway.AddIntents(gateway.IntentDirectMessages)
s.AddIntents(gateway.IntentGuildMessages)
s.AddIntents(gateway.IntentDirectMessages)
if err := s.Open(); err != nil {
if err := s.Open(context.Background()); err != nil {
log.Fatalln("Failed to connect:", err)
}
defer s.Close()

36
api/bot.go Normal file
View File

@ -0,0 +1,36 @@
package api
import (
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/utils/httputil"
)
// BotData contains the GatewayURL as well as extra metadata on how to
// shard bots.
type BotData struct {
URL string `json:"url"`
Shards int `json:"shards,omitempty"`
StartLimit *SessionStartLimit `json:"session_start_limit"`
}
// SessionStartLimit is the information on the current session start limit. It's
// used in BotData.
type SessionStartLimit struct {
Total int `json:"total"`
Remaining int `json:"remaining"`
ResetAfter discord.Milliseconds `json:"reset_after"`
MaxConcurrency int `json:"max_concurrency"`
}
// BotURL fetches the Gateway URL along with extra metadata. The token
// passed in will NOT be prefixed with Bot.
func (c *Client) BotURL() (*BotData, error) {
var g *BotData
return g, c.RequestJSON(&g, "GET", EndpointGatewayBot)
}
// GatewayURL asks Discord for a Websocket URL to the Gateway.
func GatewayURL() (string, error) {
var g BotData
return g.URL, httputil.NewClient().RequestJSON(&g, "GET", EndpointGateway)
}

View File

@ -1,6 +1,7 @@
package bot
import (
"context"
"fmt"
"log"
"os"
@ -14,7 +15,11 @@ import (
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/bot/extras/shellwords"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/state/store"
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
)
// Prefixer checks a message if it starts with the desired prefix. By default,
@ -40,8 +45,33 @@ type ArgsParser func(content string) ([]string, error)
// DefaultArgsParser implements a parser similar to that of shell's,
// implementing quotes as well as escapes.
func DefaultArgsParser() ArgsParser {
return shellwords.Parse
var DefaultArgsParser = shellwords.Parse
// NewShardFunc creates a shard constructor that shares the same internal store.
// If opts sets its own cabinet, then a new store isn't created.
func NewShardFunc(fn func(*state.State) (*Context, error)) shard.NewShardFunc {
if fn == nil {
panic("bot.NewShardFunc missing fn")
}
var once sync.Once
var cab *store.Cabinet
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
state := state.NewFromSession(session.NewCustomShard(m, id), nil)
bot, err := fn(state)
if err != nil {
return nil, errors.Wrap(err, "failed to create bot instance")
}
if state.Cabinet == nil {
once.Do(func() { cab = defaultstore.New() })
state.Cabinet = cab
}
return bot, nil
}
}
// Context is the bot state for commands and subcommands.
@ -148,6 +178,8 @@ type Context struct {
// Quick access map from event types to pointers. This map will never have
// MessageCreateEvent's type.
typeCache sync.Map // map[reflect.Type][]*CommandContext
stopFunc func() // unbind function, see Start()
}
// Start quickly starts a bot with the given command. It will prepend "Bot"
@ -164,44 +196,55 @@ func Start(
token = "Bot " + token
}
s, err := state.New(token)
if err != nil {
return nil, errors.Wrap(err, "failed to create a dgo session")
}
// fail api request if they (will) take up more than 5 minutes
s.Client.Client.Timeout = 5 * time.Minute
c, err := New(s, cmd)
if err != nil {
return nil, errors.Wrap(err, "failed to create rfrouter")
}
s.Gateway.ErrorLog = func(err error) {
c.ErrorLogger(err)
}
if opts != nil {
if err := opts(c); err != nil {
newShard := NewShardFunc(func(s *state.State) (*Context, error) {
ctx, err := New(s, cmd)
if err != nil {
return nil, err
}
// fail api request if they (will) take up more than 5 minutes
ctx.Client.Client.Timeout = 5 * time.Minute
ctx.Gateway.ErrorLog = func(err error) {
ctx.ErrorLogger(err)
}
if opts != nil {
if err := opts(ctx); err != nil {
return nil, err
}
}
ctx.AddIntents(ctx.DeriveIntents())
ctx.AddIntents(gateway.IntentGuilds) // for channel event caching
return ctx, nil
})
m, err := shard.NewManager(token, newShard)
if err != nil {
return nil, errors.Wrap(err, "failed to create shard manager")
}
c.AddIntents(c.DeriveIntents())
c.AddIntents(gateway.IntentGuilds) // for channel event caching
cancel := c.Start()
if err := s.Open(); err != nil {
return nil, errors.Wrap(err, "failed to connect to Discord")
if err := m.Open(context.Background()); err == nil {
return nil, errors.Wrap(err, "failed to open")
}
return func() error {
Wait()
// remove handler first
cancel()
// then finish closing session
return s.Close()
WaitForInterrupt()
// Close the shards first.
closeErr := m.Close()
// Remove all handlers to clean up.
m.ForEach(func(s shard.Shard) {
ctx := s.(*Context)
stop := ctx.Start()
stop()
})
return closeErr
}, nil
}
@ -221,8 +264,8 @@ func Run(token string, cmd interface{}, opts func(*Context) error) {
}
}
// Wait blocks until SIGINT.
func Wait() {
// WaitForInterrupt blocks until SIGINT.
func WaitForInterrupt() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
<-sigs
@ -251,7 +294,7 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
ctx := &Context{
Subcommand: c,
State: s,
ParseArgs: DefaultArgsParser(),
ParseArgs: DefaultArgsParser,
HasPrefix: NewPrefix("~"),
FormatError: func(err error) string {
// Escape all pings, including @everyone.
@ -374,15 +417,34 @@ func (ctx *Context) RegisterSubcommand(cmd interface{}, names ...string) (*Subco
// emptyMentionTypes is used by Start() to not parse any mentions.
var emptyMentionTypes = []api.AllowedMentionType{}
// Start adds itself into the session handlers. This needs to be run. The
// returned function is a delete function, which removes itself from the
// Session handlers.
// Start adds itself into the session handlers. If Start is called more than
// once, then it does nothing. The caller doesn't have to call Start if they
// call Open.
//
// The returned function is a delete function, which removes itself from the
// Session handlers. The delete function is not safe to use concurrently.
func (ctx *Context) Start() func() {
return ctx.State.AddHandler(func(v interface{}) {
if err := ctx.callCmd(v); err != nil {
ctx.ErrorLogger(errors.Wrap(err, "command error"))
if ctx.stopFunc == nil {
cancel := ctx.State.AddHandler(func(v interface{}) {
if err := ctx.callCmd(v); err != nil {
ctx.ErrorLogger(errors.Wrap(err, "command error"))
}
})
ctx.stopFunc = func() {
cancel()
ctx.stopFunc = nil
}
})
}
return ctx.stopFunc
}
// Open starts the bot context and the gateway connection. It automatically
// binds the needed handlers.
func (ctx *Context) Open(cancelCtx context.Context) error {
ctx.Start()
return ctx.State.Open(cancelCtx)
}
// Call should only be used if you know what you're doing.

75
bot/ctx_shard_test.go Normal file
View File

@ -0,0 +1,75 @@
package bot
import (
"context"
"testing"
"time"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/state"
)
type shardedBot struct {
Ctx *Context
readyCh chan *gateway.ReadyEvent
}
func (bot *shardedBot) OnReady(r *gateway.ReadyEvent) {
bot.readyCh <- r
}
func TestSharding(t *testing.T) {
env := testenv.Must(t)
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
data.Shard = &gateway.Shard{0, env.ShardCount}
readyCh := make(chan *gateway.ReadyEvent)
newShard := NewShardFunc(func(s *state.State) (*Context, error) {
b, err := New(s, &shardedBot{nil, readyCh})
if err != nil {
return nil, err
}
b.AddIntents(gateway.IntentGuilds)
return b, nil
})
m, err := shard.NewIdentifiedManager(data, newShard)
if err != nil {
t.Fatal("failed to make shard manager:", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
go func() {
// Timeout
if err := m.Open(ctx); err != nil {
t.Error("failed to open:", err)
cancel()
}
t.Cleanup(func() {
if err := m.Close(); err != nil {
t.Error("failed to close:", err)
cancel()
}
})
}()
// Expect 4 Ready events.
for i := 0; i < env.ShardCount; i++ {
select {
case ready := <-readyCh:
now := time.Now().Format(time.StampMilli)
t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount)
case <-ctx.Done():
t.Fatal("test expired, got", i, "shards")
}
}
}

View File

@ -97,7 +97,7 @@ func TestContext(t *testing.T) {
Subcommand: sub,
State: s,
ParseArgs: DefaultArgsParser(),
ParseArgs: DefaultArgsParser,
}
t.Run("init commands", func(t *testing.T) {
@ -396,7 +396,7 @@ func BenchmarkCall(b *testing.B) {
Subcommand: sub,
State: s,
HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
ParseArgs: DefaultArgsParser,
}
m := &gateway.MessageCreateEvent{
@ -424,7 +424,7 @@ func BenchmarkHelp(b *testing.B) {
Subcommand: sub,
State: s,
HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
ParseArgs: DefaultArgsParser,
}
b.ResetTimer()

View File

@ -182,13 +182,13 @@ type mockStore struct {
store.NoopStore
}
func mockCabinet() store.Cabinet {
c := store.NoopCabinet
func mockCabinet() *store.Cabinet {
c := *store.NoopCabinet
c.GuildStore = &mockStore{}
c.MemberStore = &mockStore{}
c.ChannelStore = &mockStore{}
return c
return &c
}
func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) {

View File

@ -9,16 +9,13 @@ package gateway
import (
"context"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/utils/httputil"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/gorilla/websocket"
@ -26,9 +23,6 @@ import (
)
var (
EndpointGateway = api.Endpoint + "gateway"
EndpointGatewayBot = api.EndpointGateway + "/bot"
Version = api.Version
Encoding = "json"
)
@ -44,47 +38,26 @@ var (
// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes
const errCodeShardingRequired = 4011
// BotData contains the GatewayURL as well as extra metadata on how to
// shard bots.
type BotData struct {
URL string `json:"url"`
Shards int `json:"shards,omitempty"`
StartLimit *SessionStartLimit `json:"session_start_limit"`
}
// SessionStartLimit is the information on the current session start limit. It's
// used in BotData.
type SessionStartLimit struct {
Total int `json:"total"`
Remaining int `json:"remaining"`
ResetAfter discord.Milliseconds `json:"reset_after"`
MaxConcurrency int `json:"max_concurrency"`
}
// URL asks Discord for a Websocket URL to the Gateway.
func URL() (string, error) {
var g BotData
c := httputil.NewClient()
if err := c.RequestJSON(&g, "GET", EndpointGateway); err != nil {
return "", err
}
return g.URL, nil
return api.GatewayURL()
}
// BotURL fetches the Gateway URL along with extra metadata. The token
// passed in will NOT be prefixed with Bot.
func BotURL(token string) (*BotData, error) {
var g *BotData
func BotURL(token string) (*api.BotData, error) {
return api.NewClient(token).BotURL()
}
return g, httputil.NewClient().RequestJSON(
&g, "GET",
EndpointGatewayBot,
httputil.WithHeaders(http.Header{
"Authorization": {token},
}),
)
// AddGatewayParams appends into the given URL string the gateway URL
// parameters.
func AddGatewayParams(baseURL string) string {
param := url.Values{
"v": {Version},
"encoding": {Encoding},
}
return baseURL + "?" + param.Encode()
}
type Gateway struct {
@ -124,22 +97,16 @@ type Gateway struct {
// Defaults to noop.
FatalErrorCallback func(err error)
// OnScalingRequired is the function called, if Discord closes with error
// code 4011 aka Scaling Required. At the point of calling, the Gateway
// will be closed, and can, after increasing the number of shards, be
// reopened using Open. Reconnect or ReconnectCtx, however, will not be
// available as the session is invalidated.
OnScalingRequired func()
// AfterClose is called after each close or pause. It is used mainly for
// reconnections or any type of connection interruptions.
//
// Constructors will use a no-op function by default.
AfterClose func(err error)
waitGroup sync.WaitGroup
onShardingRequired func()
closed chan struct{}
waitGroup sync.WaitGroup
closed chan struct{}
}
// NewGatewayWithIntents creates a new Gateway with the given intents and the
@ -167,7 +134,7 @@ func NewGateway(token string) (*Gateway, error) {
// shared identifier.
func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
var gatewayURL string
var botData *BotData
var botData *api.BotData
var err error
if strings.HasPrefix(id.Token, "Bot ") {
@ -184,14 +151,7 @@ func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
}
}
// Parameters for the gateway
param := url.Values{
"v": {Version},
"encoding": {Encoding},
}
// Append the form to the URL
gatewayURL += "?" + param.Encode()
gatewayURL = AddGatewayParams(gatewayURL)
gateway := NewCustomIdentifiedGateway(gatewayURL, id)
// Use the supplied connect rate limit, if any.
@ -318,6 +278,18 @@ func (g *Gateway) SessionID() string {
return g.sessionID
}
// OnShardingRequired sets the function to be called if Discord closes with
// error code 4011 aka Sharding Required. When called, the Gateway will already
// be closed, and can (after increasing the number of shards) be reopened using
// Open. Reconnect or ReconnectCtx, however, will not be available as the
// session is invalidated.
//
// The gateway will completely halt what it's doing in the background when this
// callback is called.
func (g *Gateway) OnShardingRequired(fn func()) {
g.onShardingRequired = fn
}
// Reconnect tries to reconnect to the Gateway until the ReconnectAttempts are
// reached.
func (g *Gateway) Reconnect() {
@ -349,7 +321,7 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
wsutil.WSDebug("Trying to dial, attempt", try)
// if we encounter an error, make sure we return it, and not nil
if oerr := g.OpenContext(ctx); oerr != nil {
if oerr := g.Open(ctx); oerr != nil {
err = oerr
g.ErrorLog(oerr)
@ -370,19 +342,13 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
return err
}
// Open connects to the Websocket and authenticate it. You should usually use
// this function over Start().
func (g *Gateway) Open() error {
// Open connects to the Websocket and authenticates it. You should usually use
// this function over Start(). The given context provides cancellation and
// timeout.
func (g *Gateway) Open(ctx context.Context) error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.OpenContext(ctx)
}
// OpenContext connects to the Websocket and authenticates it. You should
// usually use this function over Start(). The given context provides
// cancellation and timeout.
func (g *Gateway) OpenContext(ctx context.Context) error {
// Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "failed to Reconnect")
@ -456,19 +422,18 @@ func (g *Gateway) start(ctx context.Context) error {
g.waitGroup.Done() // mark so Close() can exit.
wsutil.WSDebug("Event loop stopped with error:", err)
// If Discord signals us sharding is required, do not attempt to
// Reconnect. Instead invalidate our session id, as we cannot resume,
// call OnShardingRequired, and exit.
var cerr *websocket.CloseError
if errors.As(err, &cerr) && cerr != nil && cerr.Code == errCodeShardingRequired {
g.ErrorLog(cerr)
g.sessionMu.Lock()
g.sessionID = ""
g.sessionMu.Unlock()
g.OnScalingRequired()
return
if err != nil && g.onShardingRequired != nil {
// If Discord signals us sharding is required, do not attempt to
// Reconnect, unless we don't know what to do. Instead invalidate
// our session ID, as we cannot resume, call OnShardingRequired, and
// exit.
var cerr *websocket.CloseError
if errors.As(err, &cerr) && cerr.Code == errCodeShardingRequired {
g.ErrorLog(cerr)
g.UseSessionID("")
g.onShardingRequired()
return
}
}
// Bail if there is no error or if the error is an explicit close, as

View File

@ -42,7 +42,10 @@ func TestInvalidToken(t *testing.T) {
t.Fatal("failed to make a Gateway:", err)
}
if err = g.Open(); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err = g.Open(ctx); err == nil {
t.Fatal("unexpected success while opening with a bad token.")
}
@ -55,10 +58,6 @@ func TestInvalidToken(t *testing.T) {
func TestIntegration(t *testing.T) {
config := testenv.Must(t)
wsutil.WSError = func(err error) {
t.Error(err)
}
var gateway *Gateway
// NewGateway should call Start for us.
@ -70,11 +69,16 @@ func TestIntegration(t *testing.T) {
g.AfterClose = func(err error) {
t.Log("closed.")
}
g.ErrorLog = func(err error) {
t.Log("gateway error:", err)
}
gateway = g
if err := g.Open(); err != nil {
t.Fatal("failed to authenticate with Discord:", err)
}
gotimeout(t, func(ctx context.Context) {
if err := g.Open(ctx); err != nil {
t.Fatal("failed to authenticate with Discord:", err)
}
})
ev := wait(t, gateway.Events)
ready, ok := ev.(*ReadyEvent)
@ -94,11 +98,7 @@ func TestIntegration(t *testing.T) {
// Sleep past the rate limiter before reconnecting:
time.Sleep(5 * time.Second)
gotimeout(t, func() {
// Try and reconnect for 20 seconds maximum.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
gotimeout(t, func(ctx context.Context) {
g.ErrorLog = func(err error) {
t.Error("unexpected error while reconnecting:", err)
}
@ -108,10 +108,10 @@ func TestIntegration(t *testing.T) {
}
})
g.ErrorLog = func(err error) { log.Println(err) }
g.ErrorLog = func(err error) { t.Log("warning:", err) }
// Wait for the desired event:
gotimeout(t, func() {
gotimeout(t, func(context.Context) {
for ev := range gateway.Events {
switch ev.(type) {
// Accept only a Resumed event.
@ -138,17 +138,21 @@ func wait(t *testing.T, evCh chan interface{}) interface{} {
}
}
func gotimeout(t *testing.T, fn func()) {
func gotimeout(t *testing.T, fn func(context.Context)) {
t.Helper()
// Try and reconnect for 20 seconds maximum.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
var done = make(chan struct{})
go func() {
fn()
fn(ctx)
done <- struct{}{}
}()
select {
case <-time.After(20 * time.Second):
case <-ctx.Done():
t.Fatal("timed out waiting for function.")
case <-done:
return

View File

@ -42,3 +42,49 @@ func (i Intents) IsPrivileged() (presences, member bool) {
// Keep this in sync with PrivilegedIntents.
return i.Has(IntentGuildPresences), i.Has(IntentGuildMembers)
}
// EventIntents maps event types to intents.
var EventIntents = map[string]Intents{
"GUILD_CREATE": IntentGuilds,
"GUILD_UPDATE": IntentGuilds,
"GUILD_DELETE": IntentGuilds,
"GUILD_ROLE_CREATE": IntentGuilds,
"GUILD_ROLE_UPDATE": IntentGuilds,
"GUILD_ROLE_DELETE": IntentGuilds,
"CHANNEL_CREATE": IntentGuilds,
"CHANNEL_UPDATE": IntentGuilds,
"CHANNEL_DELETE": IntentGuilds,
"CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages,
"GUILD_MEMBER_ADD": IntentGuildMembers,
"GUILD_MEMBER_REMOVE": IntentGuildMembers,
"GUILD_MEMBER_UPDATE": IntentGuildMembers,
"GUILD_BAN_ADD": IntentGuildBans,
"GUILD_BAN_REMOVE": IntentGuildBans,
"GUILD_EMOJIS_UPDATE": IntentGuildEmojis,
"GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations,
"WEBHOOKS_UPDATE": IntentGuildWebhooks,
"INVITE_CREATE": IntentGuildInvites,
"INVITE_DELETE": IntentGuildInvites,
"VOICE_STATE_UPDATE": IntentGuildVoiceStates,
"PRESENCE_UPDATE": IntentGuildPresences,
"MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_DELETE_BULK": IntentGuildMessages,
"MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions,
"TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping,
}

View File

@ -1,47 +0,0 @@
package gateway
// EventIntents maps event types to intents.
var EventIntents = map[string]Intents{
"GUILD_CREATE": IntentGuilds,
"GUILD_UPDATE": IntentGuilds,
"GUILD_DELETE": IntentGuilds,
"GUILD_ROLE_CREATE": IntentGuilds,
"GUILD_ROLE_UPDATE": IntentGuilds,
"GUILD_ROLE_DELETE": IntentGuilds,
"CHANNEL_CREATE": IntentGuilds,
"CHANNEL_UPDATE": IntentGuilds,
"CHANNEL_DELETE": IntentGuilds,
"CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages,
"GUILD_MEMBER_ADD": IntentGuildMembers,
"GUILD_MEMBER_REMOVE": IntentGuildMembers,
"GUILD_MEMBER_UPDATE": IntentGuildMembers,
"GUILD_BAN_ADD": IntentGuildBans,
"GUILD_BAN_REMOVE": IntentGuildBans,
"GUILD_EMOJIS_UPDATE": IntentGuildEmojis,
"GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations,
"WEBHOOKS_UPDATE": IntentGuildWebhooks,
"INVITE_CREATE": IntentGuildInvites,
"INVITE_DELETE": IntentGuildInvites,
"VOICE_STATE_UPDATE": IntentGuildVoiceStates,
"PRESENCE_UPDATE": IntentGuildPresences,
"MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages,
"MESSAGE_DELETE_BULK": IntentGuildMessages,
"MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions,
"MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions,
"TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping,
}

312
gateway/shard/manager.go Normal file
View File

@ -0,0 +1,312 @@
package shard
import (
"context"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/internal/backoff"
"github.com/pkg/errors"
)
func updateIdentifier(ctx context.Context, id *gateway.Identifier) (url string, err error) {
botData, err := api.NewClient(id.Token).WithContext(ctx).BotURL()
if err != nil {
return "", err
}
if botData.Shards < 1 {
botData.Shards = 1
}
id.Shard = &gateway.Shard{0, botData.Shards}
// Update the burst to be the current given time and reset it back to
// the default when the given time is reached.
id.IdentifyGlobalLimit.SetBurst(botData.StartLimit.Remaining)
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
id.IdentifyGlobalLimit.SetBurstAt(resetAt, botData.StartLimit.Total)
// Update the maximum number of identify requests allowed per 5s.
id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
return botData.URL, nil
}
// Manager is the manager responsible for handling all sharding on this
// instance. An instance of Manager must never be copied.
type Manager struct {
// shards are the *shards.shards managed by this Manager. They are
// sorted in ascending order by their shard id.
shards []ShardState
gatewayURL string
mutex sync.RWMutex
rescaling *rescalingState // nil unless rescaling
new NewShardFunc
}
type rescalingState struct {
haltRescale context.CancelFunc
rescaleDone sync.WaitGroup
}
// NewManager creates a Manager using as many gateways as recommended by
// Discord.
func NewManager(token string, fn NewShardFunc) (*Manager, error) {
id := gateway.DefaultIdentifier(token)
url, err := updateIdentifier(context.Background(), id)
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway info")
}
return NewIdentifiedManagerWithURL(url, id, fn)
}
// NewIdentifiedManager creates a new Manager using the given
// gateway.Identifier. The total number of shards will be taken from the
// identifier instead of being queried from Discord, but the shard ID will be
// ignored.
//
// This function should rarely be used, since the shard information will be
// queried from Discord if it's required to shard anyway.
func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager, error) {
// Ensure id.Shard is never nil.
if data.Shard == nil {
data.Shard = gateway.DefaultShard
}
id := gateway.NewIdentifier(data)
url, err := updateIdentifier(context.Background(), id)
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway info")
}
id.Shard = data.Shard
return NewIdentifiedManagerWithURL(url, id, fn)
}
// NewIdentifiedManagerWithURL creates a new Manager with the given Identifier
// and gateway URL. It behaves similarly to NewIdentifiedManager.
func NewIdentifiedManagerWithURL(
url string, id *gateway.Identifier, fn NewShardFunc) (*Manager, error) {
m := Manager{
gatewayURL: gateway.AddGatewayParams(url),
shards: make([]ShardState, id.Shard.NumShards()),
new: fn,
}
var err error
for i := range m.shards {
data := id.IdentifyData
data.Shard = &gateway.Shard{i, len(m.shards)}
m.shards[i] = ShardState{
ID: gateway.Identifier{
IdentifyData: data,
IdentifyShortLimit: id.IdentifyShortLimit,
IdentifyGlobalLimit: id.IdentifyGlobalLimit,
},
}
m.shards[i].Shard, err = fn(&m, &m.shards[i].ID)
if err != nil {
return nil, errors.Wrapf(err, "failed to create shard %d/%d", i, len(m.shards)-1)
}
}
return &m, nil
}
// GatewayURL returns the URL to the gateway. The URL will always have the
// needed gateway parameters appended.
func (m *Manager) GatewayURL() string {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.gatewayURL
}
// NumShards returns the total number of shards. It is OK for the caller to rely
// on NumShards while they're inside ForEach.
func (m *Manager) NumShards() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.shards)
}
// Shard gets the shard with the given ID.
func (m *Manager) Shard(ix int) Shard {
m.mutex.RLock()
defer m.mutex.RUnlock()
if ix >= len(m.shards) {
return nil
}
return m.shards[ix]
}
// FromGuildID returns the Shard and the shard ID for the guild with the given
// ID.
func (m *Manager) FromGuildID(guildID discord.GuildID) (shard Shard, ix int) {
m.mutex.RLock()
defer m.mutex.RUnlock()
ix = int(uint64(guildID>>22) % uint64(len(m.shards)))
return m.shards[ix], ix
}
// ForEach calls the given function on each shard from first to last. The caller
// can safely access the number of shards by either asserting Shard to get the
// IdentifyData or call m.NumShards.
func (m *Manager) ForEach(f func(shard Shard)) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, g := range m.shards {
f(g)
}
}
// Open opens all gateways handled by this Manager. If an error occurs, Open
// will attempt to close all previously opened gateways before returning.
func (m *Manager) Open(ctx context.Context) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return OpenShards(ctx, m.shards)
}
// Close closes all gateways handled by this Manager; it will stop rescaling if
// the manager is currently being rescaled. If an error occurs, Close will
// attempt to close all remaining gateways first, before returning.
func (m *Manager) Close() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.rescaling != nil {
m.rescaling.haltRescale()
m.rescaling.rescaleDone.Wait()
m.rescaling = nil
}
return CloseShards(m.shards)
}
// Rescale rescales the manager asynchronously. The caller MUST NOT call Rescale
// in the constructor function; doing so WILL cause the state to be inconsistent
// and eventually crash and burn and destroy us all.
func (m *Manager) Rescale() {
go m.rescale()
}
func (m *Manager) rescale() {
m.mutex.Lock()
// Exit if we're already rescaling.
if m.rescaling != nil {
m.mutex.Unlock()
return
}
// Create a new context to allow the caller to cancel rescaling.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
m.rescaling = &rescalingState{haltRescale: cancel}
m.rescaling.rescaleDone.Add(1)
defer m.rescaling.rescaleDone.Done()
// Take the old list of shards for ourselves.
oldShards := m.shards
m.shards = nil
m.mutex.Unlock()
// Close the shards outside the lock. This should be fairly quickly, but it
// allows the caller to halt rescaling while we're closing or opening the
// shards.
CloseShards(oldShards)
backoffT := backoff.NewTimer(time.Second, 15*time.Minute)
defer backoffT.Stop()
for {
if m.tryRescale(ctx) {
return
}
select {
case <-backoffT.Next():
continue
case <-ctx.Done():
return
}
}
}
// tryRescale attempts once to rescale. It assumes the mutex is unlocked and
// will unlock the mutex itself.
func (m *Manager) tryRescale(ctx context.Context) bool {
m.mutex.Lock()
data := m.shards[0].ID.IdentifyData
newID := gateway.NewIdentifier(data)
url, err := updateIdentifier(ctx, newID)
if err != nil {
m.mutex.Unlock()
return false
}
numShards := newID.Shard.NumShards()
m.gatewayURL = url
// Release the mutex early.
m.mutex.Unlock()
// Create the shards slice to set after we reacquire the mutex.
newShards := make([]ShardState, numShards)
for i := 0; i < numShards; i++ {
data := newID.IdentifyData
data.Shard = &gateway.Shard{i, len(m.shards)}
newShards[i] = ShardState{
ID: gateway.Identifier{
IdentifyData: data,
IdentifyShortLimit: newID.IdentifyShortLimit,
IdentifyGlobalLimit: newID.IdentifyGlobalLimit,
},
}
newShards[i].Shard, err = m.new(m, &newShards[i].ID)
if err != nil {
return false
}
}
if err := OpenShards(ctx, newShards); err != nil {
return false
}
m.mutex.Lock()
m.shards = newShards
m.rescaling = nil
m.mutex.Unlock()
return true
}

85
gateway/shard/shard.go Normal file
View File

@ -0,0 +1,85 @@
package shard
import (
"context"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/pkg/errors"
)
// Shard defines a shard gateway interface that the shard manager can use.
type Shard interface {
Open(context.Context) error
Close() error
}
// NewShardFunc is the constructor to create a new gateway. For examples, see
// package session and state's. The constructor must manually connect the
// Manager's Rescale method appropriately.
//
// A new Gateway must not open any background resources until OpenCtx is called;
// if the gateway has never been opened, its Close method will never be called.
// During callback, the Manager is not locked, so the callback can use Manager's
// methods without deadlocking.
type NewShardFunc func(m *Manager, id *gateway.Identifier) (Shard, error)
// NewGatewayShardFunc wraps around NewGatewayShard to be compatible with
// NewShardFunc.
var NewGatewayShardFunc NewShardFunc = func(m *Manager, id *gateway.Identifier) (Shard, error) {
return NewGatewayShard(m, id), nil
}
// NewGatewayShard creates a new gateway that's plugged into the shard manager.
func NewGatewayShard(m *Manager, id *gateway.Identifier) *gateway.Gateway {
gw := gateway.NewCustomIdentifiedGateway(m.GatewayURL(), id)
gw.OnShardingRequired(m.Rescale)
return gw
}
// ShardState wraps around the Gateway interface to provide additional state.
type ShardState struct {
Shard
// This is a bit wasteful: 2 constant pointers are stored here, and they
// waste GC cycles. This is unavoidable, however, since the API has to take
// in a pointer to Identifier, not IdentifyData. This is to ensure rescales
// are consistent.
ID gateway.Identifier
Opened bool
}
// ShardID returns the shard state's shard ID.
func (state ShardState) ShardID() int {
return state.ID.Shard.ShardID()
}
// OpenShards opens the gateways of the given list of shard states.
func OpenShards(ctx context.Context, shards []ShardState) error {
for i, shard := range shards {
if err := shard.Open(ctx); err != nil {
CloseShards(shards)
return errors.Wrapf(err, "failed to open shard %d/%d", i, len(shards)-1)
}
// Mark as opened so we can close them.
shards[i].Opened = true
}
return nil
}
// CloseShards closes the gateways of the given list of shard states.
func CloseShards(shards []ShardState) error {
var lastError error
for i, gw := range shards {
if gw.Opened {
if err := gw.Close(); err != nil {
lastError = err
}
shards[i].Opened = false
}
}
return lastError
}

117
internal/backoff/backoff.go Normal file
View File

@ -0,0 +1,117 @@
// Package backoff provides an exponential-backoff implementation partially
// taken from jpillora/backoff.
package backoff
import (
"math"
"math/rand"
"sync/atomic"
"time"
)
const (
factor = 2
jitter = true
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// Timer is a backoff timer.
type Timer struct {
backoff Backoff
timer *time.Timer
}
// NewTimer returns a new uninitialized timer.
func NewTimer(min, max time.Duration) Timer {
return Timer{
backoff: NewBackoff(min, max),
}
}
// Next initializes the timer if needed and returns a timer channel that fires
// when the backoff timeout is reached.
func (t *Timer) Next() <-chan time.Time {
if t.timer == nil {
t.timer = time.NewTimer(t.backoff.Next())
} else {
t.timer.Stop() // ensure drained
t.timer.Reset(t.backoff.Next())
}
return t.timer.C
}
// Stop stops the internal timer and frees its resources. It does nothing if the
// timer is uninitialized.
func (t *Timer) Stop() {
if t.timer == nil {
return
}
if !t.timer.Stop() {
<-t.timer.C // drain
}
}
// Backoff is a time.Duration counter, starting at Min. After every call to
// the Duration method the current timing is multiplied by Factor, but it
// never exceeds Max.
type Backoff struct {
min, max float64 // seconds
attempt int32 // negative == max uint32
}
// NewBackoff creates a new backoff time.Duration counter.
func NewBackoff(min, max time.Duration) Backoff {
return Backoff{
min: min.Seconds(),
max: max.Seconds(),
}
}
// Next returns the next backoff duration.
func (b *Backoff) Next() time.Duration {
return b.forAttempt(atomic.AddInt32(&b.attempt, 1) - 1)
}
const maxInt64 = float64(math.MaxInt64 - 512)
// forAttempt returns the duration for a specific attempt. This is useful if
// you have a large number of independent Backoffs, but don't want use
// unnecessary memory storing the Backoff parameters per Backoff. The first
// attempt should be 0.
func (b *Backoff) forAttempt(attempt int32) time.Duration {
if b.min >= b.max {
// short-circuit
return duration(b.max)
}
// Ensure attempt never overflows.
if attempt < 0 {
attempt = math.MaxInt32
}
// Calculate this duration.
dur := b.min * math.Pow(factor, float64(attempt))
if jitter {
dur = rand.Float64()*(dur-b.min) + b.min
}
if dur < b.min {
return duration(b.min)
}
if dur > b.max {
return duration(b.max)
}
return duration(dur)
}
// duration converts a seconds float64 to time.Duration without losing accuracy.
func duration(secs float64) time.Duration {
int, frac := math.Modf(secs)
return (time.Duration(int) * time.Second) + time.Duration(frac*float64(time.Second))
}

View File

@ -14,25 +14,6 @@ func NewCtxMutex() *CtxMutex {
}
}
// func (m *CtxMutex) TryLock() bool {
// select {
// case m.mut <- struct{}{}:
// return true
// default:
// return false
// }
// }
// func (m *CtxMutex) IsBusy() bool {
// select {
// case m.mut <- struct{}{}:
// <-m.mut
// return false
// default:
// return true
// }
// }
func (m *CtxMutex) Lock(ctx context.Context) error {
select {
case m.mut <- struct{}{}:

View File

@ -4,6 +4,7 @@ package testenv
import (
"os"
"strconv"
"sync"
"testing"
"time"
@ -15,9 +16,10 @@ import (
const PerseveranceTime = 50 * time.Minute
type Env struct {
BotToken string
ChannelID discord.ChannelID
VoiceChID discord.ChannelID
BotToken string
ChannelID discord.ChannelID
VoiceChID discord.ChannelID
ShardCount int // default 3
}
var (
@ -40,39 +42,33 @@ func GetEnv() (Env, error) {
}
func getEnv() {
var token = os.Getenv("BOT_TOKEN")
token := os.Getenv("BOT_TOKEN")
if token == "" {
globalErr = errors.New("missing $BOT_TOKEN")
return
}
var cid = os.Getenv("CHANNEL_ID")
if cid == "" {
globalErr = errors.New("missing $CHANNEL_ID")
return
}
id, err := discord.ParseSnowflake(cid)
id, err := discord.ParseSnowflake(os.Getenv("CHANNEL_ID"))
if err != nil {
globalErr = errors.Wrap(err, "invalid $CHANNEL_ID")
return
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
globalErr = errors.New("missing $VOICE_ID")
return
}
vid, err := discord.ParseSnowflake(sid)
vid, err := discord.ParseSnowflake(os.Getenv("VOICE_ID"))
if err != nil {
globalErr = errors.Wrap(err, "invalid $VOICE_ID")
return
}
shardCount := 3
if c, err := strconv.Atoi(os.Getenv("SHARD_COUNT")); err == nil {
shardCount = c
}
globalEnv = Env{
BotToken: token,
ChannelID: discord.ChannelID(id),
VoiceChID: discord.ChannelID(vid),
BotToken: token,
ChannelID: discord.ChannelID(id),
VoiceChID: discord.ChannelID(vid),
ShardCount: shardCount,
}
}

View File

@ -10,6 +10,7 @@ import (
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/handleloop"
"github.com/diamondburned/arikawa/v3/utils/handler"
)
@ -28,11 +29,33 @@ type Closed struct {
Error error
}
// NewShardFunc creates a shard constructor for a session.
// Accessing any shard and adding a handler will add a handler for all shards.
func NewShardFunc(f func(m *shard.Manager, s *Session)) shard.NewShardFunc {
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
s := NewCustomShard(m, id)
if f != nil {
f(m, s)
}
return s, nil
}
}
// NewCustomShard creates a new session from the given shard manager and other
// parameters.
func NewCustomShard(m *shard.Manager, id *gateway.Identifier) *Session {
return NewCustomSession(
shard.NewGatewayShard(m, id),
api.NewClient(id.Token),
handler.New(),
)
}
// Session manages both the API and Gateway. As such, Session inherits all of
// API's methods, as well has the Handler used for Gateway.
type Session struct {
*api.Client
Gateway *gateway.Gateway
*gateway.Gateway
// Command handler with inherited methods.
*handler.Handler
@ -92,20 +115,22 @@ func Login(email, password, mfa string) (*Session, error) {
return New(l.Token)
}
// NewWithGateway creates a new Session with the given Gateway.
func NewWithGateway(gw *gateway.Gateway) *Session {
handler := handler.New()
looper := handleloop.NewLoop(handler)
return NewCustomSession(gw, api.NewClient(gw.Identifier.Token), handler.New())
}
// NewCustomSession constructs a bare Session from the given parameters.
func NewCustomSession(gw *gateway.Gateway, cl *api.Client, h *handler.Handler) *Session {
return &Session{
Gateway: gw,
// Nab off gateway's token
Client: api.NewClient(gw.Identifier.Token),
Handler: handler,
looper: looper,
Client: cl,
Handler: h,
looper: handleloop.NewLoop(h),
}
}
func (s *Session) Open() error {
func (s *Session) Open(ctx context.Context) error {
// Start the handler beforehand so no events are missed.
s.looper.Start(s.Gateway.Events)
@ -116,7 +141,7 @@ func (s *Session) Open() error {
})
}
if err := s.Gateway.Open(); err != nil {
if err := s.Gateway.Open(ctx); err != nil {
return errors.Wrap(err, "failed to start gateway")
}

View File

@ -0,0 +1,66 @@
package session
import (
"context"
"testing"
"time"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/testenv"
)
func TestSharding(t *testing.T) {
env := testenv.Must(t)
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
data.Shard = &gateway.Shard{0, env.ShardCount}
readyCh := make(chan *gateway.ReadyEvent)
m, err := shard.NewIdentifiedManager(data, NewShardFunc(
func(m *shard.Manager, s *Session) {
now := time.Now().Format(time.StampMilli)
t.Log(now, "initializing shard")
s.Gateway.ErrorLog = func(err error) {
t.Error("gateway error:", err)
}
s.AddIntents(gateway.IntentGuilds)
s.AddHandler(readyCh)
},
))
if err != nil {
t.Fatal("failed to make shard manager:", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
go func() {
// Timeout
if err := m.Open(ctx); err != nil {
t.Error("failed to open:", err)
cancel()
}
t.Cleanup(func() {
if err := m.Close(); err != nil {
t.Error("failed to close:", err)
cancel()
}
})
}()
// Expect 4 Ready events.
for i := 0; i < env.ShardCount; i++ {
select {
case ready := <-readyCh:
now := time.Now().Format(time.StampMilli)
t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount)
case <-ctx.Done():
t.Fatal("test expired, got", i, "shards")
}
}
}

View File

@ -8,6 +8,7 @@ import (
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/state/store"
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
@ -21,6 +22,32 @@ var (
MaxFetchGuilds uint = 100
)
// NewShardFunc creates a shard constructor that shares the same handler. The
// given opts function is called everytime the State is created. If it doesn't
// set a cabinet into the state, then a shared default cabinet is set instead.
func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc {
var once sync.Once
var cab *store.Cabinet
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
state := NewFromSession(session.NewCustomShard(m, id), nil)
if opts != nil {
opts(m, state)
}
if state.Cabinet == nil {
// Create the cabinet once; use sync.Once so the constructor can be
// concurrently safe.
once.Do(func() { cab = defaultstore.New() })
state.Cabinet = cab
}
return state, nil
}
}
// State is the cache to store events coming from Discord as well as data from
// API calls.
//
@ -59,7 +86,7 @@ var (
// will be empty, while the Member structure expects it to be there.
type State struct {
*session.Session
store.Cabinet
*store.Cabinet
// *: State doesn't actually keep track of pinned messages.
@ -113,7 +140,8 @@ func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
return NewFromSession(s, defaultstore.New()), nil
}
func NewWithStore(token string, cabinet store.Cabinet) (*State, error) {
// NewWithStore creates a new state with the given store cabinet.
func NewWithStore(token string, cabinet *store.Cabinet) (*State, error) {
s, err := session.New(token)
if err != nil {
return nil, err
@ -123,7 +151,7 @@ func NewWithStore(token string, cabinet store.Cabinet) (*State, error) {
}
// NewFromSession creates a new State from the passed Session and Cabinet.
func NewFromSession(s *session.Session, cabinet store.Cabinet) *State {
func NewFromSession(s *session.Session, cabinet *store.Cabinet) *State {
state := &State{
Session: s,
Cabinet: cabinet,
@ -625,7 +653,7 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes
if len(storeMessages) >= int(limit) && limit > 0 {
return storeMessages[:limit], nil
}
// Decrease the limit, if we aren't fetching all messages.
if limit > 0 {
limit -= uint(len(storeMessages))

View File

@ -372,7 +372,7 @@ func findReaction(rs []discord.Reaction, emoji discord.Emoji) int {
return -1
}
func storeGuildCreate(cab store.Cabinet, guild *gateway.GuildCreateEvent) []error {
func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []error {
if guild.Unavailable {
return nil
}

65
state/state_shard_test.go Normal file
View File

@ -0,0 +1,65 @@
package state
import (
"context"
"testing"
"time"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/testenv"
)
func TestSharding(t *testing.T) {
env := testenv.Must(t)
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
data.Shard = &gateway.Shard{0, env.ShardCount}
readyCh := make(chan *gateway.ReadyEvent)
m, err := shard.NewIdentifiedManager(data, NewShardFunc(
func(m *shard.Manager, s *State) {
now := time.Now().Format(time.StampMilli)
t.Log(now, "initializing shard")
s.Gateway.ErrorLog = func(err error) {
t.Error("gateway error:", err)
}
s.AddIntents(gateway.IntentGuilds)
s.AddHandler(readyCh)
},
))
if err != nil {
t.Fatal("failed to make shard manager:", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
go func() {
// Timeout
if err := m.Open(ctx); err != nil {
t.Error("failed to open:", err)
cancel()
}
t.Cleanup(func() {
if err := m.Close(); err != nil {
t.Error("failed to close:", err)
cancel()
}
})
}()
for i := 0; i < env.ShardCount; i++ {
select {
case ready := <-readyCh:
now := time.Now().Format(time.StampMilli)
t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount)
case <-ctx.Done():
t.Fatal("test expired, got", i, "shards")
}
}
}

View File

@ -6,8 +6,8 @@ import "github.com/diamondburned/arikawa/v3/state/store"
// New creates a new cabinet instance of defaultstore. For Message, it creates a
// Message store with a limit of 100 messages.
func New() store.Cabinet {
return store.Cabinet{
func New() *store.Cabinet {
return &store.Cabinet{
MeStore: NewMe(),
ChannelStore: NewChannel(),
EmojiStore: NewEmoji(),

View File

@ -123,7 +123,7 @@ type NoopStore = noop
// NoopCabinet is a store cabinet with all store methods set to the Noop
// implementations.
var NoopCabinet = Cabinet{
var NoopCabinet = &Cabinet{
MeStore: Noop,
ChannelStore: Noop,
EmojiStore: Noop,

View File

@ -133,7 +133,7 @@ func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
return errors.Wrap(err, "SendLimiter failed")
}
WSDebug("Send is passed the rate limiting. Waiting on mutex.")
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
ws.mutex.Lock()
defer ws.mutex.Unlock()

View File

@ -1,6 +1,7 @@
package voice_test
import (
"context"
"io"
"log"
"testing"
@ -41,7 +42,7 @@ func ExampleSession() {
// This is required for bots.
voice.AddIntents(s.Gateway)
if err := s.Open(); err != nil {
if err := s.Open(context.TODO()); err != nil {
log.Fatalln("failed to open gateway:", err)
}
defer s.Close()

View File

@ -35,9 +35,15 @@ func TestIntegration(t *testing.T) {
}
AddIntents(s.Gateway)
if err := s.Open(); err != nil {
t.Fatal("Failed to connect:", err)
}
func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := s.Open(ctx); err != nil {
t.Fatal("Failed to connect:", err)
}
}()
t.Cleanup(func() { s.Close() })
// Validate the given voice channel.