From adf029bd7c4ea687dddadb6c3f1fcdb4956f012b Mon Sep 17 00:00:00 2001 From: diamondburned Date: Thu, 10 Jun 2021 16:48:32 -0700 Subject: [PATCH] shard: Remake shard manager --- .build.yml | 1 + _example/buttons/main.go | 11 +- _example/commands/main.go | 10 +- _example/sharded/main.go | 61 ++++ _example/simple/main.go | 7 +- _example/undeleter/main.go | 7 +- api/bot.go | 36 +++ bot/ctx.go | 148 +++++++--- bot/ctx_shard_test.go | 75 +++++ bot/ctx_test.go | 6 +- bot/extras/middlewares/middlewares_test.go | 6 +- gateway/gateway.go | 129 ++++----- gateway/integration_test.go | 40 +-- gateway/intents.go | 46 +++ gateway/intents_map.go | 47 ---- gateway/shard/manager.go | 312 +++++++++++++++++++++ gateway/shard/shard.go | 85 ++++++ internal/backoff/backoff.go | 117 ++++++++ internal/moreatomic/mutex.go | 19 -- internal/testenv/testenv.go | 38 ++- session/session.go | 43 ++- session/session_shard_test.go | 66 +++++ state/state.go | 36 ++- state/state_events.go | 2 +- state/state_shard_test.go | 65 +++++ state/store/defaultstore/defaultstore.go | 4 +- state/store/store.go | 2 +- utils/wsutil/ws.go | 2 +- voice/session_example_test.go | 3 +- voice/session_test.go | 12 +- 30 files changed, 1163 insertions(+), 273 deletions(-) create mode 100644 _example/sharded/main.go create mode 100644 api/bot.go create mode 100644 bot/ctx_shard_test.go delete mode 100644 gateway/intents_map.go create mode 100644 gateway/shard/manager.go create mode 100644 gateway/shard/shard.go create mode 100644 internal/backoff/backoff.go create mode 100644 session/session_shard_test.go create mode 100644 state/state_shard_test.go diff --git a/.build.yml b/.build.yml index 8aaaca5..44b8f40 100644 --- a/.build.yml +++ b/.build.yml @@ -12,6 +12,7 @@ environment: GO111MODULE: "on" CGO_ENABLED: "1" # Integration test variables. + SHARD_COUNT: "3" tested: "./api,./gateway,./bot,./discord" cov_file: "/tmp/cov_results" dismock: "github.com/mavolin/dismock/v2/pkg/dismock" diff --git a/_example/buttons/main.go b/_example/buttons/main.go index 0da9e93..aba4c9c 100644 --- a/_example/buttons/main.go +++ b/_example/buttons/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "os" @@ -34,8 +35,8 @@ func main() { data := api.InteractionResponse{ Type: api.MessageInteractionWithSource, Data: &api.InteractionResponseData{ - Content: "This is a message with a button!", - Components: []discord.Component{ + Content: option.NewNullableString("This is a message with a button!"), + Components: &[]discord.Component{ discord.ActionRowComponent{ Components: []discord.Component{ discord.ButtonComponent{ @@ -93,10 +94,10 @@ func main() { } }) - s.Gateway.AddIntents(gateway.IntentGuilds) - s.Gateway.AddIntents(gateway.IntentGuildMessages) + s.AddIntents(gateway.IntentGuilds) + s.AddIntents(gateway.IntentGuildMessages) - if err := s.Open(); err != nil { + if err := s.Open(context.Background()); err != nil { log.Fatalln("failed to open:", err) } defer s.Close() diff --git a/_example/commands/main.go b/_example/commands/main.go index 226bbfa..caacbf0 100644 --- a/_example/commands/main.go +++ b/_example/commands/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "os" @@ -8,6 +9,7 @@ import ( "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" "github.com/diamondburned/arikawa/v3/session" + "github.com/diamondburned/arikawa/v3/utils/json/option" ) // To run, do `APP_ID="APP ID" GUILD_ID="GUILD ID" BOT_TOKEN="TOKEN HERE" go run .` @@ -31,7 +33,7 @@ func main() { data := api.InteractionResponse{ Type: api.MessageInteractionWithSource, Data: &api.InteractionResponseData{ - Content: "Pong!", + Content: option.NewNullableString("Pong!"), }, } @@ -40,10 +42,10 @@ func main() { } }) - s.Gateway.AddIntents(gateway.IntentGuilds) - s.Gateway.AddIntents(gateway.IntentGuildMessages) + s.AddIntents(gateway.IntentGuilds) + s.AddIntents(gateway.IntentGuildMessages) - if err := s.Open(); err != nil { + if err := s.Open(context.Background()); err != nil { log.Fatalln("failed to open:", err) } defer s.Close() diff --git a/_example/sharded/main.go b/_example/sharded/main.go new file mode 100644 index 0000000..bb32155 --- /dev/null +++ b/_example/sharded/main.go @@ -0,0 +1,61 @@ +// Package main demonstrates a bare simple bot without a state cache. It logs +// all messages it sees into stderr. +package main + +import ( + "context" + "log" + "os" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/gateway/shard" + "github.com/diamondburned/arikawa/v3/state" +) + +// To run, do `BOT_TOKEN="TOKEN HERE" go run .` + +func main() { + var token = os.Getenv("BOT_TOKEN") + if token == "" { + log.Fatalln("No $BOT_TOKEN given.") + } + + newShard := state.NewShardFunc(func(m *shard.Manager, s *state.State) { + // Add the needed Gateway intents. + s.AddIntents(gateway.IntentGuildMessages) + s.AddIntents(gateway.IntentDirectMessages) + + s.AddHandler(func(c *gateway.MessageCreateEvent) { + _, shardIx := m.FromGuildID(c.GuildID) + log.Println(c.Author.Tag(), "sent", c.Content, "on shard", shardIx) + }) + }) + + m, err := shard.NewManager("Bot "+token, newShard) + if err != nil { + log.Fatalln("failed to create shard manager:", err) + } + + if err := m.Open(context.Background()); err != nil { + log.Fatalln("failed to connect shards:", err) + } + defer m.Close() + + var shardNum int + + m.ForEach(func(s shard.Shard) { + state := s.(*state.State) + + u, err := state.Me() + if err != nil { + log.Fatalln("failed to get myself:", err) + } + + log.Printf("Shard %d/%d started as %s", shardNum, m.NumShards()-1, u.Tag()) + + shardNum++ + }) + + // Block forever. + select {} +} diff --git a/_example/simple/main.go b/_example/simple/main.go index bd7c0ab..32890d3 100644 --- a/_example/simple/main.go +++ b/_example/simple/main.go @@ -3,6 +3,7 @@ package main import ( + "context" "log" "os" @@ -28,10 +29,10 @@ func main() { }) // Add the needed Gateway intents. - s.Gateway.AddIntents(gateway.IntentGuildMessages) - s.Gateway.AddIntents(gateway.IntentDirectMessages) + s.AddIntents(gateway.IntentGuildMessages) + s.AddIntents(gateway.IntentDirectMessages) - if err := s.Open(); err != nil { + if err := s.Open(context.Background()); err != nil { log.Fatalln("Failed to connect:", err) } defer s.Close() diff --git a/_example/undeleter/main.go b/_example/undeleter/main.go index 3c1b0a8..6e1aa21 100644 --- a/_example/undeleter/main.go +++ b/_example/undeleter/main.go @@ -2,6 +2,7 @@ package main import ( + "context" "log" "os" @@ -37,10 +38,10 @@ func main() { }) // Add the needed Gateway intents. - s.Gateway.AddIntents(gateway.IntentGuildMessages) - s.Gateway.AddIntents(gateway.IntentDirectMessages) + s.AddIntents(gateway.IntentGuildMessages) + s.AddIntents(gateway.IntentDirectMessages) - if err := s.Open(); err != nil { + if err := s.Open(context.Background()); err != nil { log.Fatalln("Failed to connect:", err) } defer s.Close() diff --git a/api/bot.go b/api/bot.go new file mode 100644 index 0000000..bf7eeda --- /dev/null +++ b/api/bot.go @@ -0,0 +1,36 @@ +package api + +import ( + "github.com/diamondburned/arikawa/v3/discord" + "github.com/diamondburned/arikawa/v3/utils/httputil" +) + +// BotData contains the GatewayURL as well as extra metadata on how to +// shard bots. +type BotData struct { + URL string `json:"url"` + Shards int `json:"shards,omitempty"` + StartLimit *SessionStartLimit `json:"session_start_limit"` +} + +// SessionStartLimit is the information on the current session start limit. It's +// used in BotData. +type SessionStartLimit struct { + Total int `json:"total"` + Remaining int `json:"remaining"` + ResetAfter discord.Milliseconds `json:"reset_after"` + MaxConcurrency int `json:"max_concurrency"` +} + +// BotURL fetches the Gateway URL along with extra metadata. The token +// passed in will NOT be prefixed with Bot. +func (c *Client) BotURL() (*BotData, error) { + var g *BotData + return g, c.RequestJSON(&g, "GET", EndpointGatewayBot) +} + +// GatewayURL asks Discord for a Websocket URL to the Gateway. +func GatewayURL() (string, error) { + var g BotData + return g.URL, httputil.NewClient().RequestJSON(&g, "GET", EndpointGateway) +} diff --git a/bot/ctx.go b/bot/ctx.go index 680c7da..379b06f 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -1,6 +1,7 @@ package bot import ( + "context" "fmt" "log" "os" @@ -14,7 +15,11 @@ import ( "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/bot/extras/shellwords" "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/gateway/shard" + "github.com/diamondburned/arikawa/v3/session" "github.com/diamondburned/arikawa/v3/state" + "github.com/diamondburned/arikawa/v3/state/store" + "github.com/diamondburned/arikawa/v3/state/store/defaultstore" ) // Prefixer checks a message if it starts with the desired prefix. By default, @@ -40,8 +45,33 @@ type ArgsParser func(content string) ([]string, error) // DefaultArgsParser implements a parser similar to that of shell's, // implementing quotes as well as escapes. -func DefaultArgsParser() ArgsParser { - return shellwords.Parse +var DefaultArgsParser = shellwords.Parse + +// NewShardFunc creates a shard constructor that shares the same internal store. +// If opts sets its own cabinet, then a new store isn't created. +func NewShardFunc(fn func(*state.State) (*Context, error)) shard.NewShardFunc { + if fn == nil { + panic("bot.NewShardFunc missing fn") + } + + var once sync.Once + var cab *store.Cabinet + + return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { + state := state.NewFromSession(session.NewCustomShard(m, id), nil) + + bot, err := fn(state) + if err != nil { + return nil, errors.Wrap(err, "failed to create bot instance") + } + + if state.Cabinet == nil { + once.Do(func() { cab = defaultstore.New() }) + state.Cabinet = cab + } + + return bot, nil + } } // Context is the bot state for commands and subcommands. @@ -148,6 +178,8 @@ type Context struct { // Quick access map from event types to pointers. This map will never have // MessageCreateEvent's type. typeCache sync.Map // map[reflect.Type][]*CommandContext + + stopFunc func() // unbind function, see Start() } // Start quickly starts a bot with the given command. It will prepend "Bot" @@ -164,44 +196,55 @@ func Start( token = "Bot " + token } - s, err := state.New(token) - if err != nil { - return nil, errors.Wrap(err, "failed to create a dgo session") - } - - // fail api request if they (will) take up more than 5 minutes - s.Client.Client.Timeout = 5 * time.Minute - - c, err := New(s, cmd) - if err != nil { - return nil, errors.Wrap(err, "failed to create rfrouter") - } - - s.Gateway.ErrorLog = func(err error) { - c.ErrorLogger(err) - } - - if opts != nil { - if err := opts(c); err != nil { + newShard := NewShardFunc(func(s *state.State) (*Context, error) { + ctx, err := New(s, cmd) + if err != nil { return nil, err } + + // fail api request if they (will) take up more than 5 minutes + ctx.Client.Client.Timeout = 5 * time.Minute + + ctx.Gateway.ErrorLog = func(err error) { + ctx.ErrorLogger(err) + } + + if opts != nil { + if err := opts(ctx); err != nil { + return nil, err + } + } + + ctx.AddIntents(ctx.DeriveIntents()) + ctx.AddIntents(gateway.IntentGuilds) // for channel event caching + + return ctx, nil + }) + + m, err := shard.NewManager(token, newShard) + if err != nil { + return nil, errors.Wrap(err, "failed to create shard manager") } - c.AddIntents(c.DeriveIntents()) - c.AddIntents(gateway.IntentGuilds) // for channel event caching - - cancel := c.Start() - - if err := s.Open(); err != nil { - return nil, errors.Wrap(err, "failed to connect to Discord") + if err := m.Open(context.Background()); err == nil { + return nil, errors.Wrap(err, "failed to open") } return func() error { - Wait() - // remove handler first - cancel() - // then finish closing session - return s.Close() + WaitForInterrupt() + + // Close the shards first. + closeErr := m.Close() + + // Remove all handlers to clean up. + m.ForEach(func(s shard.Shard) { + ctx := s.(*Context) + + stop := ctx.Start() + stop() + }) + + return closeErr }, nil } @@ -221,8 +264,8 @@ func Run(token string, cmd interface{}, opts func(*Context) error) { } } -// Wait blocks until SIGINT. -func Wait() { +// WaitForInterrupt blocks until SIGINT. +func WaitForInterrupt() { sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) <-sigs @@ -251,7 +294,7 @@ func New(s *state.State, cmd interface{}) (*Context, error) { ctx := &Context{ Subcommand: c, State: s, - ParseArgs: DefaultArgsParser(), + ParseArgs: DefaultArgsParser, HasPrefix: NewPrefix("~"), FormatError: func(err error) string { // Escape all pings, including @everyone. @@ -374,15 +417,34 @@ func (ctx *Context) RegisterSubcommand(cmd interface{}, names ...string) (*Subco // emptyMentionTypes is used by Start() to not parse any mentions. var emptyMentionTypes = []api.AllowedMentionType{} -// Start adds itself into the session handlers. This needs to be run. The -// returned function is a delete function, which removes itself from the -// Session handlers. +// Start adds itself into the session handlers. If Start is called more than +// once, then it does nothing. The caller doesn't have to call Start if they +// call Open. +// +// The returned function is a delete function, which removes itself from the +// Session handlers. The delete function is not safe to use concurrently. func (ctx *Context) Start() func() { - return ctx.State.AddHandler(func(v interface{}) { - if err := ctx.callCmd(v); err != nil { - ctx.ErrorLogger(errors.Wrap(err, "command error")) + if ctx.stopFunc == nil { + cancel := ctx.State.AddHandler(func(v interface{}) { + if err := ctx.callCmd(v); err != nil { + ctx.ErrorLogger(errors.Wrap(err, "command error")) + } + }) + + ctx.stopFunc = func() { + cancel() + ctx.stopFunc = nil } - }) + } + + return ctx.stopFunc +} + +// Open starts the bot context and the gateway connection. It automatically +// binds the needed handlers. +func (ctx *Context) Open(cancelCtx context.Context) error { + ctx.Start() + return ctx.State.Open(cancelCtx) } // Call should only be used if you know what you're doing. diff --git a/bot/ctx_shard_test.go b/bot/ctx_shard_test.go new file mode 100644 index 0000000..6e0cb84 --- /dev/null +++ b/bot/ctx_shard_test.go @@ -0,0 +1,75 @@ +package bot + +import ( + "context" + "testing" + "time" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/gateway/shard" + "github.com/diamondburned/arikawa/v3/internal/testenv" + "github.com/diamondburned/arikawa/v3/state" +) + +type shardedBot struct { + Ctx *Context + + readyCh chan *gateway.ReadyEvent +} + +func (bot *shardedBot) OnReady(r *gateway.ReadyEvent) { + bot.readyCh <- r +} + +func TestSharding(t *testing.T) { + env := testenv.Must(t) + + data := gateway.DefaultIdentifyData("Bot " + env.BotToken) + data.Shard = &gateway.Shard{0, env.ShardCount} + + readyCh := make(chan *gateway.ReadyEvent) + + newShard := NewShardFunc(func(s *state.State) (*Context, error) { + b, err := New(s, &shardedBot{nil, readyCh}) + if err != nil { + return nil, err + } + + b.AddIntents(gateway.IntentGuilds) + return b, nil + }) + + m, err := shard.NewIdentifiedManager(data, newShard) + if err != nil { + t.Fatal("failed to make shard manager:", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + go func() { + // Timeout + if err := m.Open(ctx); err != nil { + t.Error("failed to open:", err) + cancel() + } + + t.Cleanup(func() { + if err := m.Close(); err != nil { + t.Error("failed to close:", err) + cancel() + } + }) + }() + + // Expect 4 Ready events. + for i := 0; i < env.ShardCount; i++ { + select { + case ready := <-readyCh: + now := time.Now().Format(time.StampMilli) + t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount) + case <-ctx.Done(): + t.Fatal("test expired, got", i, "shards") + } + } +} diff --git a/bot/ctx_test.go b/bot/ctx_test.go index 9e79bdd..2aae158 100644 --- a/bot/ctx_test.go +++ b/bot/ctx_test.go @@ -97,7 +97,7 @@ func TestContext(t *testing.T) { Subcommand: sub, State: s, - ParseArgs: DefaultArgsParser(), + ParseArgs: DefaultArgsParser, } t.Run("init commands", func(t *testing.T) { @@ -396,7 +396,7 @@ func BenchmarkCall(b *testing.B) { Subcommand: sub, State: s, HasPrefix: NewPrefix("~"), - ParseArgs: DefaultArgsParser(), + ParseArgs: DefaultArgsParser, } m := &gateway.MessageCreateEvent{ @@ -424,7 +424,7 @@ func BenchmarkHelp(b *testing.B) { Subcommand: sub, State: s, HasPrefix: NewPrefix("~"), - ParseArgs: DefaultArgsParser(), + ParseArgs: DefaultArgsParser, } b.ResetTimer() diff --git a/bot/extras/middlewares/middlewares_test.go b/bot/extras/middlewares/middlewares_test.go index 04e06d4..046b6ad 100644 --- a/bot/extras/middlewares/middlewares_test.go +++ b/bot/extras/middlewares/middlewares_test.go @@ -182,13 +182,13 @@ type mockStore struct { store.NoopStore } -func mockCabinet() store.Cabinet { - c := store.NoopCabinet +func mockCabinet() *store.Cabinet { + c := *store.NoopCabinet c.GuildStore = &mockStore{} c.MemberStore = &mockStore{} c.ChannelStore = &mockStore{} - return c + return &c } func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) { diff --git a/gateway/gateway.go b/gateway/gateway.go index a48ea26..da512fd 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -9,16 +9,13 @@ package gateway import ( "context" - "net/http" "net/url" "strings" "sync" "time" "github.com/diamondburned/arikawa/v3/api" - "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/internal/moreatomic" - "github.com/diamondburned/arikawa/v3/utils/httputil" "github.com/diamondburned/arikawa/v3/utils/json" "github.com/diamondburned/arikawa/v3/utils/wsutil" "github.com/gorilla/websocket" @@ -26,9 +23,6 @@ import ( ) var ( - EndpointGateway = api.Endpoint + "gateway" - EndpointGatewayBot = api.EndpointGateway + "/bot" - Version = api.Version Encoding = "json" ) @@ -44,47 +38,26 @@ var ( // https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes const errCodeShardingRequired = 4011 -// BotData contains the GatewayURL as well as extra metadata on how to -// shard bots. -type BotData struct { - URL string `json:"url"` - Shards int `json:"shards,omitempty"` - StartLimit *SessionStartLimit `json:"session_start_limit"` -} - -// SessionStartLimit is the information on the current session start limit. It's -// used in BotData. -type SessionStartLimit struct { - Total int `json:"total"` - Remaining int `json:"remaining"` - ResetAfter discord.Milliseconds `json:"reset_after"` - MaxConcurrency int `json:"max_concurrency"` -} - // URL asks Discord for a Websocket URL to the Gateway. func URL() (string, error) { - var g BotData - - c := httputil.NewClient() - if err := c.RequestJSON(&g, "GET", EndpointGateway); err != nil { - return "", err - } - - return g.URL, nil + return api.GatewayURL() } // BotURL fetches the Gateway URL along with extra metadata. The token // passed in will NOT be prefixed with Bot. -func BotURL(token string) (*BotData, error) { - var g *BotData +func BotURL(token string) (*api.BotData, error) { + return api.NewClient(token).BotURL() +} - return g, httputil.NewClient().RequestJSON( - &g, "GET", - EndpointGatewayBot, - httputil.WithHeaders(http.Header{ - "Authorization": {token}, - }), - ) +// AddGatewayParams appends into the given URL string the gateway URL +// parameters. +func AddGatewayParams(baseURL string) string { + param := url.Values{ + "v": {Version}, + "encoding": {Encoding}, + } + + return baseURL + "?" + param.Encode() } type Gateway struct { @@ -124,22 +97,16 @@ type Gateway struct { // Defaults to noop. FatalErrorCallback func(err error) - // OnScalingRequired is the function called, if Discord closes with error - // code 4011 aka Scaling Required. At the point of calling, the Gateway - // will be closed, and can, after increasing the number of shards, be - // reopened using Open. Reconnect or ReconnectCtx, however, will not be - // available as the session is invalidated. - OnScalingRequired func() - // AfterClose is called after each close or pause. It is used mainly for // reconnections or any type of connection interruptions. // // Constructors will use a no-op function by default. AfterClose func(err error) - waitGroup sync.WaitGroup + onShardingRequired func() - closed chan struct{} + waitGroup sync.WaitGroup + closed chan struct{} } // NewGatewayWithIntents creates a new Gateway with the given intents and the @@ -167,7 +134,7 @@ func NewGateway(token string) (*Gateway, error) { // shared identifier. func NewIdentifiedGateway(id *Identifier) (*Gateway, error) { var gatewayURL string - var botData *BotData + var botData *api.BotData var err error if strings.HasPrefix(id.Token, "Bot ") { @@ -184,14 +151,7 @@ func NewIdentifiedGateway(id *Identifier) (*Gateway, error) { } } - // Parameters for the gateway - param := url.Values{ - "v": {Version}, - "encoding": {Encoding}, - } - - // Append the form to the URL - gatewayURL += "?" + param.Encode() + gatewayURL = AddGatewayParams(gatewayURL) gateway := NewCustomIdentifiedGateway(gatewayURL, id) // Use the supplied connect rate limit, if any. @@ -318,6 +278,18 @@ func (g *Gateway) SessionID() string { return g.sessionID } +// OnShardingRequired sets the function to be called if Discord closes with +// error code 4011 aka Sharding Required. When called, the Gateway will already +// be closed, and can (after increasing the number of shards) be reopened using +// Open. Reconnect or ReconnectCtx, however, will not be available as the +// session is invalidated. +// +// The gateway will completely halt what it's doing in the background when this +// callback is called. +func (g *Gateway) OnShardingRequired(fn func()) { + g.onShardingRequired = fn +} + // Reconnect tries to reconnect to the Gateway until the ReconnectAttempts are // reached. func (g *Gateway) Reconnect() { @@ -349,7 +321,7 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) { wsutil.WSDebug("Trying to dial, attempt", try) // if we encounter an error, make sure we return it, and not nil - if oerr := g.OpenContext(ctx); oerr != nil { + if oerr := g.Open(ctx); oerr != nil { err = oerr g.ErrorLog(oerr) @@ -370,19 +342,13 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) { return err } -// Open connects to the Websocket and authenticate it. You should usually use -// this function over Start(). -func (g *Gateway) Open() error { +// Open connects to the Websocket and authenticates it. You should usually use +// this function over Start(). The given context provides cancellation and +// timeout. +func (g *Gateway) Open(ctx context.Context) error { ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) defer cancel() - return g.OpenContext(ctx) -} - -// OpenContext connects to the Websocket and authenticates it. You should -// usually use this function over Start(). The given context provides -// cancellation and timeout. -func (g *Gateway) OpenContext(ctx context.Context) error { // Reconnect to the Gateway if err := g.WS.Dial(ctx); err != nil { return errors.Wrap(err, "failed to Reconnect") @@ -456,19 +422,18 @@ func (g *Gateway) start(ctx context.Context) error { g.waitGroup.Done() // mark so Close() can exit. wsutil.WSDebug("Event loop stopped with error:", err) - // If Discord signals us sharding is required, do not attempt to - // Reconnect. Instead invalidate our session id, as we cannot resume, - // call OnShardingRequired, and exit. - var cerr *websocket.CloseError - if errors.As(err, &cerr) && cerr != nil && cerr.Code == errCodeShardingRequired { - g.ErrorLog(cerr) - - g.sessionMu.Lock() - g.sessionID = "" - g.sessionMu.Unlock() - - g.OnScalingRequired() - return + if err != nil && g.onShardingRequired != nil { + // If Discord signals us sharding is required, do not attempt to + // Reconnect, unless we don't know what to do. Instead invalidate + // our session ID, as we cannot resume, call OnShardingRequired, and + // exit. + var cerr *websocket.CloseError + if errors.As(err, &cerr) && cerr.Code == errCodeShardingRequired { + g.ErrorLog(cerr) + g.UseSessionID("") + g.onShardingRequired() + return + } } // Bail if there is no error or if the error is an explicit close, as diff --git a/gateway/integration_test.go b/gateway/integration_test.go index c99310d..bf596a8 100644 --- a/gateway/integration_test.go +++ b/gateway/integration_test.go @@ -42,7 +42,10 @@ func TestInvalidToken(t *testing.T) { t.Fatal("failed to make a Gateway:", err) } - if err = g.Open(); err == nil { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + if err = g.Open(ctx); err == nil { t.Fatal("unexpected success while opening with a bad token.") } @@ -55,10 +58,6 @@ func TestInvalidToken(t *testing.T) { func TestIntegration(t *testing.T) { config := testenv.Must(t) - wsutil.WSError = func(err error) { - t.Error(err) - } - var gateway *Gateway // NewGateway should call Start for us. @@ -70,11 +69,16 @@ func TestIntegration(t *testing.T) { g.AfterClose = func(err error) { t.Log("closed.") } + g.ErrorLog = func(err error) { + t.Log("gateway error:", err) + } gateway = g - if err := g.Open(); err != nil { - t.Fatal("failed to authenticate with Discord:", err) - } + gotimeout(t, func(ctx context.Context) { + if err := g.Open(ctx); err != nil { + t.Fatal("failed to authenticate with Discord:", err) + } + }) ev := wait(t, gateway.Events) ready, ok := ev.(*ReadyEvent) @@ -94,11 +98,7 @@ func TestIntegration(t *testing.T) { // Sleep past the rate limiter before reconnecting: time.Sleep(5 * time.Second) - gotimeout(t, func() { - // Try and reconnect for 20 seconds maximum. - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - + gotimeout(t, func(ctx context.Context) { g.ErrorLog = func(err error) { t.Error("unexpected error while reconnecting:", err) } @@ -108,10 +108,10 @@ func TestIntegration(t *testing.T) { } }) - g.ErrorLog = func(err error) { log.Println(err) } + g.ErrorLog = func(err error) { t.Log("warning:", err) } // Wait for the desired event: - gotimeout(t, func() { + gotimeout(t, func(context.Context) { for ev := range gateway.Events { switch ev.(type) { // Accept only a Resumed event. @@ -138,17 +138,21 @@ func wait(t *testing.T, evCh chan interface{}) interface{} { } } -func gotimeout(t *testing.T, fn func()) { +func gotimeout(t *testing.T, fn func(context.Context)) { t.Helper() + // Try and reconnect for 20 seconds maximum. + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + var done = make(chan struct{}) go func() { - fn() + fn(ctx) done <- struct{}{} }() select { - case <-time.After(20 * time.Second): + case <-ctx.Done(): t.Fatal("timed out waiting for function.") case <-done: return diff --git a/gateway/intents.go b/gateway/intents.go index 6414ba9..234caf7 100644 --- a/gateway/intents.go +++ b/gateway/intents.go @@ -42,3 +42,49 @@ func (i Intents) IsPrivileged() (presences, member bool) { // Keep this in sync with PrivilegedIntents. return i.Has(IntentGuildPresences), i.Has(IntentGuildMembers) } + +// EventIntents maps event types to intents. +var EventIntents = map[string]Intents{ + "GUILD_CREATE": IntentGuilds, + "GUILD_UPDATE": IntentGuilds, + "GUILD_DELETE": IntentGuilds, + "GUILD_ROLE_CREATE": IntentGuilds, + "GUILD_ROLE_UPDATE": IntentGuilds, + "GUILD_ROLE_DELETE": IntentGuilds, + "CHANNEL_CREATE": IntentGuilds, + "CHANNEL_UPDATE": IntentGuilds, + "CHANNEL_DELETE": IntentGuilds, + "CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages, + + "GUILD_MEMBER_ADD": IntentGuildMembers, + "GUILD_MEMBER_REMOVE": IntentGuildMembers, + "GUILD_MEMBER_UPDATE": IntentGuildMembers, + + "GUILD_BAN_ADD": IntentGuildBans, + "GUILD_BAN_REMOVE": IntentGuildBans, + + "GUILD_EMOJIS_UPDATE": IntentGuildEmojis, + + "GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations, + + "WEBHOOKS_UPDATE": IntentGuildWebhooks, + + "INVITE_CREATE": IntentGuildInvites, + "INVITE_DELETE": IntentGuildInvites, + + "VOICE_STATE_UPDATE": IntentGuildVoiceStates, + + "PRESENCE_UPDATE": IntentGuildPresences, + + "MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages, + "MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages, + "MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages, + "MESSAGE_DELETE_BULK": IntentGuildMessages, + + "MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions, + "MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions, + "MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions, + "MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions, + + "TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping, +} diff --git a/gateway/intents_map.go b/gateway/intents_map.go deleted file mode 100644 index 9aa1bc4..0000000 --- a/gateway/intents_map.go +++ /dev/null @@ -1,47 +0,0 @@ -package gateway - -// EventIntents maps event types to intents. -var EventIntents = map[string]Intents{ - "GUILD_CREATE": IntentGuilds, - "GUILD_UPDATE": IntentGuilds, - "GUILD_DELETE": IntentGuilds, - "GUILD_ROLE_CREATE": IntentGuilds, - "GUILD_ROLE_UPDATE": IntentGuilds, - "GUILD_ROLE_DELETE": IntentGuilds, - "CHANNEL_CREATE": IntentGuilds, - "CHANNEL_UPDATE": IntentGuilds, - "CHANNEL_DELETE": IntentGuilds, - "CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages, - - "GUILD_MEMBER_ADD": IntentGuildMembers, - "GUILD_MEMBER_REMOVE": IntentGuildMembers, - "GUILD_MEMBER_UPDATE": IntentGuildMembers, - - "GUILD_BAN_ADD": IntentGuildBans, - "GUILD_BAN_REMOVE": IntentGuildBans, - - "GUILD_EMOJIS_UPDATE": IntentGuildEmojis, - - "GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations, - - "WEBHOOKS_UPDATE": IntentGuildWebhooks, - - "INVITE_CREATE": IntentGuildInvites, - "INVITE_DELETE": IntentGuildInvites, - - "VOICE_STATE_UPDATE": IntentGuildVoiceStates, - - "PRESENCE_UPDATE": IntentGuildPresences, - - "MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages, - "MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages, - "MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages, - "MESSAGE_DELETE_BULK": IntentGuildMessages, - - "MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions, - "MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions, - "MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions, - "MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions, - - "TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping, -} diff --git a/gateway/shard/manager.go b/gateway/shard/manager.go new file mode 100644 index 0000000..5d0c3f7 --- /dev/null +++ b/gateway/shard/manager.go @@ -0,0 +1,312 @@ +package shard + +import ( + "context" + "sync" + "time" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/discord" + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/internal/backoff" + "github.com/pkg/errors" +) + +func updateIdentifier(ctx context.Context, id *gateway.Identifier) (url string, err error) { + botData, err := api.NewClient(id.Token).WithContext(ctx).BotURL() + if err != nil { + return "", err + } + + if botData.Shards < 1 { + botData.Shards = 1 + } + + id.Shard = &gateway.Shard{0, botData.Shards} + + // Update the burst to be the current given time and reset it back to + // the default when the given time is reached. + id.IdentifyGlobalLimit.SetBurst(botData.StartLimit.Remaining) + resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration()) + id.IdentifyGlobalLimit.SetBurstAt(resetAt, botData.StartLimit.Total) + + // Update the maximum number of identify requests allowed per 5s. + id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency) + + return botData.URL, nil +} + +// Manager is the manager responsible for handling all sharding on this +// instance. An instance of Manager must never be copied. +type Manager struct { + // shards are the *shards.shards managed by this Manager. They are + // sorted in ascending order by their shard id. + shards []ShardState + gatewayURL string + + mutex sync.RWMutex + + rescaling *rescalingState // nil unless rescaling + + new NewShardFunc +} + +type rescalingState struct { + haltRescale context.CancelFunc + rescaleDone sync.WaitGroup +} + +// NewManager creates a Manager using as many gateways as recommended by +// Discord. +func NewManager(token string, fn NewShardFunc) (*Manager, error) { + id := gateway.DefaultIdentifier(token) + + url, err := updateIdentifier(context.Background(), id) + if err != nil { + return nil, errors.Wrap(err, "failed to get gateway info") + } + + return NewIdentifiedManagerWithURL(url, id, fn) +} + +// NewIdentifiedManager creates a new Manager using the given +// gateway.Identifier. The total number of shards will be taken from the +// identifier instead of being queried from Discord, but the shard ID will be +// ignored. +// +// This function should rarely be used, since the shard information will be +// queried from Discord if it's required to shard anyway. +func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager, error) { + // Ensure id.Shard is never nil. + if data.Shard == nil { + data.Shard = gateway.DefaultShard + } + + id := gateway.NewIdentifier(data) + + url, err := updateIdentifier(context.Background(), id) + if err != nil { + return nil, errors.Wrap(err, "failed to get gateway info") + } + + id.Shard = data.Shard + + return NewIdentifiedManagerWithURL(url, id, fn) +} + +// NewIdentifiedManagerWithURL creates a new Manager with the given Identifier +// and gateway URL. It behaves similarly to NewIdentifiedManager. +func NewIdentifiedManagerWithURL( + url string, id *gateway.Identifier, fn NewShardFunc) (*Manager, error) { + + m := Manager{ + gatewayURL: gateway.AddGatewayParams(url), + shards: make([]ShardState, id.Shard.NumShards()), + new: fn, + } + + var err error + + for i := range m.shards { + data := id.IdentifyData + data.Shard = &gateway.Shard{i, len(m.shards)} + + m.shards[i] = ShardState{ + ID: gateway.Identifier{ + IdentifyData: data, + IdentifyShortLimit: id.IdentifyShortLimit, + IdentifyGlobalLimit: id.IdentifyGlobalLimit, + }, + } + + m.shards[i].Shard, err = fn(&m, &m.shards[i].ID) + if err != nil { + return nil, errors.Wrapf(err, "failed to create shard %d/%d", i, len(m.shards)-1) + } + } + + return &m, nil +} + +// GatewayURL returns the URL to the gateway. The URL will always have the +// needed gateway parameters appended. +func (m *Manager) GatewayURL() string { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return m.gatewayURL +} + +// NumShards returns the total number of shards. It is OK for the caller to rely +// on NumShards while they're inside ForEach. +func (m *Manager) NumShards() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return len(m.shards) +} + +// Shard gets the shard with the given ID. +func (m *Manager) Shard(ix int) Shard { + m.mutex.RLock() + defer m.mutex.RUnlock() + + if ix >= len(m.shards) { + return nil + } + + return m.shards[ix] +} + +// FromGuildID returns the Shard and the shard ID for the guild with the given +// ID. +func (m *Manager) FromGuildID(guildID discord.GuildID) (shard Shard, ix int) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + ix = int(uint64(guildID>>22) % uint64(len(m.shards))) + return m.shards[ix], ix +} + +// ForEach calls the given function on each shard from first to last. The caller +// can safely access the number of shards by either asserting Shard to get the +// IdentifyData or call m.NumShards. +func (m *Manager) ForEach(f func(shard Shard)) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + for _, g := range m.shards { + f(g) + } +} + +// Open opens all gateways handled by this Manager. If an error occurs, Open +// will attempt to close all previously opened gateways before returning. +func (m *Manager) Open(ctx context.Context) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return OpenShards(ctx, m.shards) +} + +// Close closes all gateways handled by this Manager; it will stop rescaling if +// the manager is currently being rescaled. If an error occurs, Close will +// attempt to close all remaining gateways first, before returning. +func (m *Manager) Close() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.rescaling != nil { + m.rescaling.haltRescale() + m.rescaling.rescaleDone.Wait() + + m.rescaling = nil + } + + return CloseShards(m.shards) +} + +// Rescale rescales the manager asynchronously. The caller MUST NOT call Rescale +// in the constructor function; doing so WILL cause the state to be inconsistent +// and eventually crash and burn and destroy us all. +func (m *Manager) Rescale() { + go m.rescale() +} + +func (m *Manager) rescale() { + m.mutex.Lock() + + // Exit if we're already rescaling. + if m.rescaling != nil { + m.mutex.Unlock() + return + } + + // Create a new context to allow the caller to cancel rescaling. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m.rescaling = &rescalingState{haltRescale: cancel} + m.rescaling.rescaleDone.Add(1) + defer m.rescaling.rescaleDone.Done() + + // Take the old list of shards for ourselves. + oldShards := m.shards + m.shards = nil + + m.mutex.Unlock() + + // Close the shards outside the lock. This should be fairly quickly, but it + // allows the caller to halt rescaling while we're closing or opening the + // shards. + CloseShards(oldShards) + + backoffT := backoff.NewTimer(time.Second, 15*time.Minute) + defer backoffT.Stop() + + for { + if m.tryRescale(ctx) { + return + } + + select { + case <-backoffT.Next(): + continue + case <-ctx.Done(): + return + } + } +} + +// tryRescale attempts once to rescale. It assumes the mutex is unlocked and +// will unlock the mutex itself. +func (m *Manager) tryRescale(ctx context.Context) bool { + m.mutex.Lock() + + data := m.shards[0].ID.IdentifyData + newID := gateway.NewIdentifier(data) + + url, err := updateIdentifier(ctx, newID) + if err != nil { + m.mutex.Unlock() + return false + } + + numShards := newID.Shard.NumShards() + m.gatewayURL = url + + // Release the mutex early. + m.mutex.Unlock() + + // Create the shards slice to set after we reacquire the mutex. + newShards := make([]ShardState, numShards) + + for i := 0; i < numShards; i++ { + data := newID.IdentifyData + data.Shard = &gateway.Shard{i, len(m.shards)} + + newShards[i] = ShardState{ + ID: gateway.Identifier{ + IdentifyData: data, + IdentifyShortLimit: newID.IdentifyShortLimit, + IdentifyGlobalLimit: newID.IdentifyGlobalLimit, + }, + } + + newShards[i].Shard, err = m.new(m, &newShards[i].ID) + if err != nil { + return false + } + } + + if err := OpenShards(ctx, newShards); err != nil { + return false + } + + m.mutex.Lock() + m.shards = newShards + m.rescaling = nil + m.mutex.Unlock() + + return true +} diff --git a/gateway/shard/shard.go b/gateway/shard/shard.go new file mode 100644 index 0000000..5c08df5 --- /dev/null +++ b/gateway/shard/shard.go @@ -0,0 +1,85 @@ +package shard + +import ( + "context" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/pkg/errors" +) + +// Shard defines a shard gateway interface that the shard manager can use. +type Shard interface { + Open(context.Context) error + Close() error +} + +// NewShardFunc is the constructor to create a new gateway. For examples, see +// package session and state's. The constructor must manually connect the +// Manager's Rescale method appropriately. +// +// A new Gateway must not open any background resources until OpenCtx is called; +// if the gateway has never been opened, its Close method will never be called. +// During callback, the Manager is not locked, so the callback can use Manager's +// methods without deadlocking. +type NewShardFunc func(m *Manager, id *gateway.Identifier) (Shard, error) + +// NewGatewayShardFunc wraps around NewGatewayShard to be compatible with +// NewShardFunc. +var NewGatewayShardFunc NewShardFunc = func(m *Manager, id *gateway.Identifier) (Shard, error) { + return NewGatewayShard(m, id), nil +} + +// NewGatewayShard creates a new gateway that's plugged into the shard manager. +func NewGatewayShard(m *Manager, id *gateway.Identifier) *gateway.Gateway { + gw := gateway.NewCustomIdentifiedGateway(m.GatewayURL(), id) + gw.OnShardingRequired(m.Rescale) + return gw +} + +// ShardState wraps around the Gateway interface to provide additional state. +type ShardState struct { + Shard + // This is a bit wasteful: 2 constant pointers are stored here, and they + // waste GC cycles. This is unavoidable, however, since the API has to take + // in a pointer to Identifier, not IdentifyData. This is to ensure rescales + // are consistent. + ID gateway.Identifier + Opened bool +} + +// ShardID returns the shard state's shard ID. +func (state ShardState) ShardID() int { + return state.ID.Shard.ShardID() +} + +// OpenShards opens the gateways of the given list of shard states. +func OpenShards(ctx context.Context, shards []ShardState) error { + for i, shard := range shards { + if err := shard.Open(ctx); err != nil { + CloseShards(shards) + return errors.Wrapf(err, "failed to open shard %d/%d", i, len(shards)-1) + } + + // Mark as opened so we can close them. + shards[i].Opened = true + } + + return nil +} + +// CloseShards closes the gateways of the given list of shard states. +func CloseShards(shards []ShardState) error { + var lastError error + + for i, gw := range shards { + if gw.Opened { + if err := gw.Close(); err != nil { + lastError = err + } + + shards[i].Opened = false + } + } + + return lastError +} diff --git a/internal/backoff/backoff.go b/internal/backoff/backoff.go new file mode 100644 index 0000000..5417d3a --- /dev/null +++ b/internal/backoff/backoff.go @@ -0,0 +1,117 @@ +// Package backoff provides an exponential-backoff implementation partially +// taken from jpillora/backoff. +package backoff + +import ( + "math" + "math/rand" + "sync/atomic" + "time" +) + +const ( + factor = 2 + jitter = true +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +// Timer is a backoff timer. +type Timer struct { + backoff Backoff + timer *time.Timer +} + +// NewTimer returns a new uninitialized timer. +func NewTimer(min, max time.Duration) Timer { + return Timer{ + backoff: NewBackoff(min, max), + } +} + +// Next initializes the timer if needed and returns a timer channel that fires +// when the backoff timeout is reached. +func (t *Timer) Next() <-chan time.Time { + if t.timer == nil { + t.timer = time.NewTimer(t.backoff.Next()) + } else { + t.timer.Stop() // ensure drained + t.timer.Reset(t.backoff.Next()) + } + + return t.timer.C +} + +// Stop stops the internal timer and frees its resources. It does nothing if the +// timer is uninitialized. +func (t *Timer) Stop() { + if t.timer == nil { + return + } + + if !t.timer.Stop() { + <-t.timer.C // drain + } +} + +// Backoff is a time.Duration counter, starting at Min. After every call to +// the Duration method the current timing is multiplied by Factor, but it +// never exceeds Max. +type Backoff struct { + min, max float64 // seconds + attempt int32 // negative == max uint32 +} + +// NewBackoff creates a new backoff time.Duration counter. +func NewBackoff(min, max time.Duration) Backoff { + return Backoff{ + min: min.Seconds(), + max: max.Seconds(), + } +} + +// Next returns the next backoff duration. +func (b *Backoff) Next() time.Duration { + return b.forAttempt(atomic.AddInt32(&b.attempt, 1) - 1) +} + +const maxInt64 = float64(math.MaxInt64 - 512) + +// forAttempt returns the duration for a specific attempt. This is useful if +// you have a large number of independent Backoffs, but don't want use +// unnecessary memory storing the Backoff parameters per Backoff. The first +// attempt should be 0. +func (b *Backoff) forAttempt(attempt int32) time.Duration { + if b.min >= b.max { + // short-circuit + return duration(b.max) + } + + // Ensure attempt never overflows. + if attempt < 0 { + attempt = math.MaxInt32 + } + + // Calculate this duration. + dur := b.min * math.Pow(factor, float64(attempt)) + if jitter { + dur = rand.Float64()*(dur-b.min) + b.min + } + + if dur < b.min { + return duration(b.min) + } + if dur > b.max { + return duration(b.max) + } + + return duration(dur) +} + +// duration converts a seconds float64 to time.Duration without losing accuracy. +func duration(secs float64) time.Duration { + int, frac := math.Modf(secs) + return (time.Duration(int) * time.Second) + time.Duration(frac*float64(time.Second)) +} diff --git a/internal/moreatomic/mutex.go b/internal/moreatomic/mutex.go index 0031141..87fa7f4 100644 --- a/internal/moreatomic/mutex.go +++ b/internal/moreatomic/mutex.go @@ -14,25 +14,6 @@ func NewCtxMutex() *CtxMutex { } } -// func (m *CtxMutex) TryLock() bool { -// select { -// case m.mut <- struct{}{}: -// return true -// default: -// return false -// } -// } - -// func (m *CtxMutex) IsBusy() bool { -// select { -// case m.mut <- struct{}{}: -// <-m.mut -// return false -// default: -// return true -// } -// } - func (m *CtxMutex) Lock(ctx context.Context) error { select { case m.mut <- struct{}{}: diff --git a/internal/testenv/testenv.go b/internal/testenv/testenv.go index 7c2da27..78f3881 100644 --- a/internal/testenv/testenv.go +++ b/internal/testenv/testenv.go @@ -4,6 +4,7 @@ package testenv import ( "os" + "strconv" "sync" "testing" "time" @@ -15,9 +16,10 @@ import ( const PerseveranceTime = 50 * time.Minute type Env struct { - BotToken string - ChannelID discord.ChannelID - VoiceChID discord.ChannelID + BotToken string + ChannelID discord.ChannelID + VoiceChID discord.ChannelID + ShardCount int // default 3 } var ( @@ -40,39 +42,33 @@ func GetEnv() (Env, error) { } func getEnv() { - var token = os.Getenv("BOT_TOKEN") + token := os.Getenv("BOT_TOKEN") if token == "" { globalErr = errors.New("missing $BOT_TOKEN") return } - var cid = os.Getenv("CHANNEL_ID") - if cid == "" { - globalErr = errors.New("missing $CHANNEL_ID") - return - } - - id, err := discord.ParseSnowflake(cid) + id, err := discord.ParseSnowflake(os.Getenv("CHANNEL_ID")) if err != nil { globalErr = errors.Wrap(err, "invalid $CHANNEL_ID") return } - var sid = os.Getenv("VOICE_ID") - if sid == "" { - globalErr = errors.New("missing $VOICE_ID") - return - } - - vid, err := discord.ParseSnowflake(sid) + vid, err := discord.ParseSnowflake(os.Getenv("VOICE_ID")) if err != nil { globalErr = errors.Wrap(err, "invalid $VOICE_ID") return } + shardCount := 3 + if c, err := strconv.Atoi(os.Getenv("SHARD_COUNT")); err == nil { + shardCount = c + } + globalEnv = Env{ - BotToken: token, - ChannelID: discord.ChannelID(id), - VoiceChID: discord.ChannelID(vid), + BotToken: token, + ChannelID: discord.ChannelID(id), + VoiceChID: discord.ChannelID(vid), + ShardCount: shardCount, } } diff --git a/session/session.go b/session/session.go index b6f320a..f6ba5f3 100644 --- a/session/session.go +++ b/session/session.go @@ -10,6 +10,7 @@ import ( "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/gateway/shard" "github.com/diamondburned/arikawa/v3/internal/handleloop" "github.com/diamondburned/arikawa/v3/utils/handler" ) @@ -28,11 +29,33 @@ type Closed struct { Error error } +// NewShardFunc creates a shard constructor for a session. +// Accessing any shard and adding a handler will add a handler for all shards. +func NewShardFunc(f func(m *shard.Manager, s *Session)) shard.NewShardFunc { + return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { + s := NewCustomShard(m, id) + if f != nil { + f(m, s) + } + return s, nil + } +} + +// NewCustomShard creates a new session from the given shard manager and other +// parameters. +func NewCustomShard(m *shard.Manager, id *gateway.Identifier) *Session { + return NewCustomSession( + shard.NewGatewayShard(m, id), + api.NewClient(id.Token), + handler.New(), + ) +} + // Session manages both the API and Gateway. As such, Session inherits all of // API's methods, as well has the Handler used for Gateway. type Session struct { *api.Client - Gateway *gateway.Gateway + *gateway.Gateway // Command handler with inherited methods. *handler.Handler @@ -92,20 +115,22 @@ func Login(email, password, mfa string) (*Session, error) { return New(l.Token) } +// NewWithGateway creates a new Session with the given Gateway. func NewWithGateway(gw *gateway.Gateway) *Session { - handler := handler.New() - looper := handleloop.NewLoop(handler) + return NewCustomSession(gw, api.NewClient(gw.Identifier.Token), handler.New()) +} +// NewCustomSession constructs a bare Session from the given parameters. +func NewCustomSession(gw *gateway.Gateway, cl *api.Client, h *handler.Handler) *Session { return &Session{ Gateway: gw, - // Nab off gateway's token - Client: api.NewClient(gw.Identifier.Token), - Handler: handler, - looper: looper, + Client: cl, + Handler: h, + looper: handleloop.NewLoop(h), } } -func (s *Session) Open() error { +func (s *Session) Open(ctx context.Context) error { // Start the handler beforehand so no events are missed. s.looper.Start(s.Gateway.Events) @@ -116,7 +141,7 @@ func (s *Session) Open() error { }) } - if err := s.Gateway.Open(); err != nil { + if err := s.Gateway.Open(ctx); err != nil { return errors.Wrap(err, "failed to start gateway") } diff --git a/session/session_shard_test.go b/session/session_shard_test.go new file mode 100644 index 0000000..1b74856 --- /dev/null +++ b/session/session_shard_test.go @@ -0,0 +1,66 @@ +package session + +import ( + "context" + "testing" + "time" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/gateway/shard" + "github.com/diamondburned/arikawa/v3/internal/testenv" +) + +func TestSharding(t *testing.T) { + env := testenv.Must(t) + + data := gateway.DefaultIdentifyData("Bot " + env.BotToken) + data.Shard = &gateway.Shard{0, env.ShardCount} + + readyCh := make(chan *gateway.ReadyEvent) + + m, err := shard.NewIdentifiedManager(data, NewShardFunc( + func(m *shard.Manager, s *Session) { + now := time.Now().Format(time.StampMilli) + t.Log(now, "initializing shard") + + s.Gateway.ErrorLog = func(err error) { + t.Error("gateway error:", err) + } + + s.AddIntents(gateway.IntentGuilds) + s.AddHandler(readyCh) + }, + )) + if err != nil { + t.Fatal("failed to make shard manager:", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + go func() { + // Timeout + if err := m.Open(ctx); err != nil { + t.Error("failed to open:", err) + cancel() + } + + t.Cleanup(func() { + if err := m.Close(); err != nil { + t.Error("failed to close:", err) + cancel() + } + }) + }() + + // Expect 4 Ready events. + for i := 0; i < env.ShardCount; i++ { + select { + case ready := <-readyCh: + now := time.Now().Format(time.StampMilli) + t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount) + case <-ctx.Done(): + t.Fatal("test expired, got", i, "shards") + } + } +} diff --git a/state/state.go b/state/state.go index e42ec12..37c8ed4 100644 --- a/state/state.go +++ b/state/state.go @@ -8,6 +8,7 @@ import ( "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/gateway/shard" "github.com/diamondburned/arikawa/v3/session" "github.com/diamondburned/arikawa/v3/state/store" "github.com/diamondburned/arikawa/v3/state/store/defaultstore" @@ -21,6 +22,32 @@ var ( MaxFetchGuilds uint = 100 ) +// NewShardFunc creates a shard constructor that shares the same handler. The +// given opts function is called everytime the State is created. If it doesn't +// set a cabinet into the state, then a shared default cabinet is set instead. +func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc { + var once sync.Once + var cab *store.Cabinet + + return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { + state := NewFromSession(session.NewCustomShard(m, id), nil) + + if opts != nil { + opts(m, state) + } + + if state.Cabinet == nil { + // Create the cabinet once; use sync.Once so the constructor can be + // concurrently safe. + once.Do(func() { cab = defaultstore.New() }) + + state.Cabinet = cab + } + + return state, nil + } +} + // State is the cache to store events coming from Discord as well as data from // API calls. // @@ -59,7 +86,7 @@ var ( // will be empty, while the Member structure expects it to be there. type State struct { *session.Session - store.Cabinet + *store.Cabinet // *: State doesn't actually keep track of pinned messages. @@ -113,7 +140,8 @@ func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) { return NewFromSession(s, defaultstore.New()), nil } -func NewWithStore(token string, cabinet store.Cabinet) (*State, error) { +// NewWithStore creates a new state with the given store cabinet. +func NewWithStore(token string, cabinet *store.Cabinet) (*State, error) { s, err := session.New(token) if err != nil { return nil, err @@ -123,7 +151,7 @@ func NewWithStore(token string, cabinet store.Cabinet) (*State, error) { } // NewFromSession creates a new State from the passed Session and Cabinet. -func NewFromSession(s *session.Session, cabinet store.Cabinet) *State { +func NewFromSession(s *session.Session, cabinet *store.Cabinet) *State { state := &State{ Session: s, Cabinet: cabinet, @@ -625,7 +653,7 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes if len(storeMessages) >= int(limit) && limit > 0 { return storeMessages[:limit], nil } - + // Decrease the limit, if we aren't fetching all messages. if limit > 0 { limit -= uint(len(storeMessages)) diff --git a/state/state_events.go b/state/state_events.go index 4b2cb87..b14581f 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -372,7 +372,7 @@ func findReaction(rs []discord.Reaction, emoji discord.Emoji) int { return -1 } -func storeGuildCreate(cab store.Cabinet, guild *gateway.GuildCreateEvent) []error { +func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []error { if guild.Unavailable { return nil } diff --git a/state/state_shard_test.go b/state/state_shard_test.go new file mode 100644 index 0000000..08cd46d --- /dev/null +++ b/state/state_shard_test.go @@ -0,0 +1,65 @@ +package state + +import ( + "context" + "testing" + "time" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/gateway/shard" + "github.com/diamondburned/arikawa/v3/internal/testenv" +) + +func TestSharding(t *testing.T) { + env := testenv.Must(t) + + data := gateway.DefaultIdentifyData("Bot " + env.BotToken) + data.Shard = &gateway.Shard{0, env.ShardCount} + + readyCh := make(chan *gateway.ReadyEvent) + + m, err := shard.NewIdentifiedManager(data, NewShardFunc( + func(m *shard.Manager, s *State) { + now := time.Now().Format(time.StampMilli) + t.Log(now, "initializing shard") + + s.Gateway.ErrorLog = func(err error) { + t.Error("gateway error:", err) + } + + s.AddIntents(gateway.IntentGuilds) + s.AddHandler(readyCh) + }, + )) + if err != nil { + t.Fatal("failed to make shard manager:", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + go func() { + // Timeout + if err := m.Open(ctx); err != nil { + t.Error("failed to open:", err) + cancel() + } + + t.Cleanup(func() { + if err := m.Close(); err != nil { + t.Error("failed to close:", err) + cancel() + } + }) + }() + + for i := 0; i < env.ShardCount; i++ { + select { + case ready := <-readyCh: + now := time.Now().Format(time.StampMilli) + t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount) + case <-ctx.Done(): + t.Fatal("test expired, got", i, "shards") + } + } +} diff --git a/state/store/defaultstore/defaultstore.go b/state/store/defaultstore/defaultstore.go index 3638fa0..92b7daf 100644 --- a/state/store/defaultstore/defaultstore.go +++ b/state/store/defaultstore/defaultstore.go @@ -6,8 +6,8 @@ import "github.com/diamondburned/arikawa/v3/state/store" // New creates a new cabinet instance of defaultstore. For Message, it creates a // Message store with a limit of 100 messages. -func New() store.Cabinet { - return store.Cabinet{ +func New() *store.Cabinet { + return &store.Cabinet{ MeStore: NewMe(), ChannelStore: NewChannel(), EmojiStore: NewEmoji(), diff --git a/state/store/store.go b/state/store/store.go index d160bd8..d68e4ef 100644 --- a/state/store/store.go +++ b/state/store/store.go @@ -123,7 +123,7 @@ type NoopStore = noop // NoopCabinet is a store cabinet with all store methods set to the Noop // implementations. -var NoopCabinet = Cabinet{ +var NoopCabinet = &Cabinet{ MeStore: Noop, ChannelStore: Noop, EmojiStore: Noop, diff --git a/utils/wsutil/ws.go b/utils/wsutil/ws.go index 99ae60d..b4b5369 100644 --- a/utils/wsutil/ws.go +++ b/utils/wsutil/ws.go @@ -133,7 +133,7 @@ func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error { return errors.Wrap(err, "SendLimiter failed") } - WSDebug("Send is passed the rate limiting. Waiting on mutex.") + WSDebug("Send has passed the rate limiting. Waiting on mutex.") ws.mutex.Lock() defer ws.mutex.Unlock() diff --git a/voice/session_example_test.go b/voice/session_example_test.go index 5f6a066..207aa82 100644 --- a/voice/session_example_test.go +++ b/voice/session_example_test.go @@ -1,6 +1,7 @@ package voice_test import ( + "context" "io" "log" "testing" @@ -41,7 +42,7 @@ func ExampleSession() { // This is required for bots. voice.AddIntents(s.Gateway) - if err := s.Open(); err != nil { + if err := s.Open(context.TODO()); err != nil { log.Fatalln("failed to open gateway:", err) } defer s.Close() diff --git a/voice/session_test.go b/voice/session_test.go index 3ccfa42..e6e6b6c 100644 --- a/voice/session_test.go +++ b/voice/session_test.go @@ -35,9 +35,15 @@ func TestIntegration(t *testing.T) { } AddIntents(s.Gateway) - if err := s.Open(); err != nil { - t.Fatal("Failed to connect:", err) - } + func() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + if err := s.Open(ctx); err != nil { + t.Fatal("Failed to connect:", err) + } + }() + t.Cleanup(func() { s.Close() }) // Validate the given voice channel.