mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-03-21 17:39:25 +00:00
shard: Remake shard manager
This commit is contained in:
parent
af3bedc472
commit
adf029bd7c
|
@ -12,6 +12,7 @@ environment:
|
|||
GO111MODULE: "on"
|
||||
CGO_ENABLED: "1"
|
||||
# Integration test variables.
|
||||
SHARD_COUNT: "3"
|
||||
tested: "./api,./gateway,./bot,./discord"
|
||||
cov_file: "/tmp/cov_results"
|
||||
dismock: "github.com/mavolin/dismock/v2/pkg/dismock"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
|
@ -34,8 +35,8 @@ func main() {
|
|||
data := api.InteractionResponse{
|
||||
Type: api.MessageInteractionWithSource,
|
||||
Data: &api.InteractionResponseData{
|
||||
Content: "This is a message with a button!",
|
||||
Components: []discord.Component{
|
||||
Content: option.NewNullableString("This is a message with a button!"),
|
||||
Components: &[]discord.Component{
|
||||
discord.ActionRowComponent{
|
||||
Components: []discord.Component{
|
||||
discord.ButtonComponent{
|
||||
|
@ -93,10 +94,10 @@ func main() {
|
|||
}
|
||||
})
|
||||
|
||||
s.Gateway.AddIntents(gateway.IntentGuilds)
|
||||
s.Gateway.AddIntents(gateway.IntentGuildMessages)
|
||||
s.AddIntents(gateway.IntentGuilds)
|
||||
s.AddIntents(gateway.IntentGuildMessages)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
if err := s.Open(context.Background()); err != nil {
|
||||
log.Fatalln("failed to open:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
|
@ -8,6 +9,7 @@ import (
|
|||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json/option"
|
||||
)
|
||||
|
||||
// To run, do `APP_ID="APP ID" GUILD_ID="GUILD ID" BOT_TOKEN="TOKEN HERE" go run .`
|
||||
|
@ -31,7 +33,7 @@ func main() {
|
|||
data := api.InteractionResponse{
|
||||
Type: api.MessageInteractionWithSource,
|
||||
Data: &api.InteractionResponseData{
|
||||
Content: "Pong!",
|
||||
Content: option.NewNullableString("Pong!"),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -40,10 +42,10 @@ func main() {
|
|||
}
|
||||
})
|
||||
|
||||
s.Gateway.AddIntents(gateway.IntentGuilds)
|
||||
s.Gateway.AddIntents(gateway.IntentGuildMessages)
|
||||
s.AddIntents(gateway.IntentGuilds)
|
||||
s.AddIntents(gateway.IntentGuildMessages)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
if err := s.Open(context.Background()); err != nil {
|
||||
log.Fatalln("failed to open:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
|
61
_example/sharded/main.go
Normal file
61
_example/sharded/main.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
// Package main demonstrates a bare simple bot without a state cache. It logs
|
||||
// all messages it sees into stderr.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
)
|
||||
|
||||
// To run, do `BOT_TOKEN="TOKEN HERE" go run .`
|
||||
|
||||
func main() {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
log.Fatalln("No $BOT_TOKEN given.")
|
||||
}
|
||||
|
||||
newShard := state.NewShardFunc(func(m *shard.Manager, s *state.State) {
|
||||
// Add the needed Gateway intents.
|
||||
s.AddIntents(gateway.IntentGuildMessages)
|
||||
s.AddIntents(gateway.IntentDirectMessages)
|
||||
|
||||
s.AddHandler(func(c *gateway.MessageCreateEvent) {
|
||||
_, shardIx := m.FromGuildID(c.GuildID)
|
||||
log.Println(c.Author.Tag(), "sent", c.Content, "on shard", shardIx)
|
||||
})
|
||||
})
|
||||
|
||||
m, err := shard.NewManager("Bot "+token, newShard)
|
||||
if err != nil {
|
||||
log.Fatalln("failed to create shard manager:", err)
|
||||
}
|
||||
|
||||
if err := m.Open(context.Background()); err != nil {
|
||||
log.Fatalln("failed to connect shards:", err)
|
||||
}
|
||||
defer m.Close()
|
||||
|
||||
var shardNum int
|
||||
|
||||
m.ForEach(func(s shard.Shard) {
|
||||
state := s.(*state.State)
|
||||
|
||||
u, err := state.Me()
|
||||
if err != nil {
|
||||
log.Fatalln("failed to get myself:", err)
|
||||
}
|
||||
|
||||
log.Printf("Shard %d/%d started as %s", shardNum, m.NumShards()-1, u.Tag())
|
||||
|
||||
shardNum++
|
||||
})
|
||||
|
||||
// Block forever.
|
||||
select {}
|
||||
}
|
|
@ -3,6 +3,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
|
@ -28,10 +29,10 @@ func main() {
|
|||
})
|
||||
|
||||
// Add the needed Gateway intents.
|
||||
s.Gateway.AddIntents(gateway.IntentGuildMessages)
|
||||
s.Gateway.AddIntents(gateway.IntentDirectMessages)
|
||||
s.AddIntents(gateway.IntentGuildMessages)
|
||||
s.AddIntents(gateway.IntentDirectMessages)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
if err := s.Open(context.Background()); err != nil {
|
||||
log.Fatalln("Failed to connect:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
|
@ -37,10 +38,10 @@ func main() {
|
|||
})
|
||||
|
||||
// Add the needed Gateway intents.
|
||||
s.Gateway.AddIntents(gateway.IntentGuildMessages)
|
||||
s.Gateway.AddIntents(gateway.IntentDirectMessages)
|
||||
s.AddIntents(gateway.IntentGuildMessages)
|
||||
s.AddIntents(gateway.IntentDirectMessages)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
if err := s.Open(context.Background()); err != nil {
|
||||
log.Fatalln("Failed to connect:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
|
36
api/bot.go
Normal file
36
api/bot.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/utils/httputil"
|
||||
)
|
||||
|
||||
// BotData contains the GatewayURL as well as extra metadata on how to
|
||||
// shard bots.
|
||||
type BotData struct {
|
||||
URL string `json:"url"`
|
||||
Shards int `json:"shards,omitempty"`
|
||||
StartLimit *SessionStartLimit `json:"session_start_limit"`
|
||||
}
|
||||
|
||||
// SessionStartLimit is the information on the current session start limit. It's
|
||||
// used in BotData.
|
||||
type SessionStartLimit struct {
|
||||
Total int `json:"total"`
|
||||
Remaining int `json:"remaining"`
|
||||
ResetAfter discord.Milliseconds `json:"reset_after"`
|
||||
MaxConcurrency int `json:"max_concurrency"`
|
||||
}
|
||||
|
||||
// BotURL fetches the Gateway URL along with extra metadata. The token
|
||||
// passed in will NOT be prefixed with Bot.
|
||||
func (c *Client) BotURL() (*BotData, error) {
|
||||
var g *BotData
|
||||
return g, c.RequestJSON(&g, "GET", EndpointGatewayBot)
|
||||
}
|
||||
|
||||
// GatewayURL asks Discord for a Websocket URL to the Gateway.
|
||||
func GatewayURL() (string, error) {
|
||||
var g BotData
|
||||
return g.URL, httputil.NewClient().RequestJSON(&g, "GET", EndpointGateway)
|
||||
}
|
148
bot/ctx.go
148
bot/ctx.go
|
@ -1,6 +1,7 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
@ -14,7 +15,11 @@ import (
|
|||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/bot/extras/shellwords"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
"github.com/diamondburned/arikawa/v3/state/store"
|
||||
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
|
||||
)
|
||||
|
||||
// Prefixer checks a message if it starts with the desired prefix. By default,
|
||||
|
@ -40,8 +45,33 @@ type ArgsParser func(content string) ([]string, error)
|
|||
|
||||
// DefaultArgsParser implements a parser similar to that of shell's,
|
||||
// implementing quotes as well as escapes.
|
||||
func DefaultArgsParser() ArgsParser {
|
||||
return shellwords.Parse
|
||||
var DefaultArgsParser = shellwords.Parse
|
||||
|
||||
// NewShardFunc creates a shard constructor that shares the same internal store.
|
||||
// If opts sets its own cabinet, then a new store isn't created.
|
||||
func NewShardFunc(fn func(*state.State) (*Context, error)) shard.NewShardFunc {
|
||||
if fn == nil {
|
||||
panic("bot.NewShardFunc missing fn")
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
var cab *store.Cabinet
|
||||
|
||||
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
|
||||
state := state.NewFromSession(session.NewCustomShard(m, id), nil)
|
||||
|
||||
bot, err := fn(state)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create bot instance")
|
||||
}
|
||||
|
||||
if state.Cabinet == nil {
|
||||
once.Do(func() { cab = defaultstore.New() })
|
||||
state.Cabinet = cab
|
||||
}
|
||||
|
||||
return bot, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Context is the bot state for commands and subcommands.
|
||||
|
@ -148,6 +178,8 @@ type Context struct {
|
|||
// Quick access map from event types to pointers. This map will never have
|
||||
// MessageCreateEvent's type.
|
||||
typeCache sync.Map // map[reflect.Type][]*CommandContext
|
||||
|
||||
stopFunc func() // unbind function, see Start()
|
||||
}
|
||||
|
||||
// Start quickly starts a bot with the given command. It will prepend "Bot"
|
||||
|
@ -164,44 +196,55 @@ func Start(
|
|||
token = "Bot " + token
|
||||
}
|
||||
|
||||
s, err := state.New(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create a dgo session")
|
||||
}
|
||||
|
||||
// fail api request if they (will) take up more than 5 minutes
|
||||
s.Client.Client.Timeout = 5 * time.Minute
|
||||
|
||||
c, err := New(s, cmd)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create rfrouter")
|
||||
}
|
||||
|
||||
s.Gateway.ErrorLog = func(err error) {
|
||||
c.ErrorLogger(err)
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
if err := opts(c); err != nil {
|
||||
newShard := NewShardFunc(func(s *state.State) (*Context, error) {
|
||||
ctx, err := New(s, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// fail api request if they (will) take up more than 5 minutes
|
||||
ctx.Client.Client.Timeout = 5 * time.Minute
|
||||
|
||||
ctx.Gateway.ErrorLog = func(err error) {
|
||||
ctx.ErrorLogger(err)
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
if err := opts(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.AddIntents(ctx.DeriveIntents())
|
||||
ctx.AddIntents(gateway.IntentGuilds) // for channel event caching
|
||||
|
||||
return ctx, nil
|
||||
})
|
||||
|
||||
m, err := shard.NewManager(token, newShard)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create shard manager")
|
||||
}
|
||||
|
||||
c.AddIntents(c.DeriveIntents())
|
||||
c.AddIntents(gateway.IntentGuilds) // for channel event caching
|
||||
|
||||
cancel := c.Start()
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to connect to Discord")
|
||||
if err := m.Open(context.Background()); err == nil {
|
||||
return nil, errors.Wrap(err, "failed to open")
|
||||
}
|
||||
|
||||
return func() error {
|
||||
Wait()
|
||||
// remove handler first
|
||||
cancel()
|
||||
// then finish closing session
|
||||
return s.Close()
|
||||
WaitForInterrupt()
|
||||
|
||||
// Close the shards first.
|
||||
closeErr := m.Close()
|
||||
|
||||
// Remove all handlers to clean up.
|
||||
m.ForEach(func(s shard.Shard) {
|
||||
ctx := s.(*Context)
|
||||
|
||||
stop := ctx.Start()
|
||||
stop()
|
||||
})
|
||||
|
||||
return closeErr
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -221,8 +264,8 @@ func Run(token string, cmd interface{}, opts func(*Context) error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Wait blocks until SIGINT.
|
||||
func Wait() {
|
||||
// WaitForInterrupt blocks until SIGINT.
|
||||
func WaitForInterrupt() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, os.Interrupt)
|
||||
<-sigs
|
||||
|
@ -251,7 +294,7 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
|
|||
ctx := &Context{
|
||||
Subcommand: c,
|
||||
State: s,
|
||||
ParseArgs: DefaultArgsParser(),
|
||||
ParseArgs: DefaultArgsParser,
|
||||
HasPrefix: NewPrefix("~"),
|
||||
FormatError: func(err error) string {
|
||||
// Escape all pings, including @everyone.
|
||||
|
@ -374,15 +417,34 @@ func (ctx *Context) RegisterSubcommand(cmd interface{}, names ...string) (*Subco
|
|||
// emptyMentionTypes is used by Start() to not parse any mentions.
|
||||
var emptyMentionTypes = []api.AllowedMentionType{}
|
||||
|
||||
// Start adds itself into the session handlers. This needs to be run. The
|
||||
// returned function is a delete function, which removes itself from the
|
||||
// Session handlers.
|
||||
// Start adds itself into the session handlers. If Start is called more than
|
||||
// once, then it does nothing. The caller doesn't have to call Start if they
|
||||
// call Open.
|
||||
//
|
||||
// The returned function is a delete function, which removes itself from the
|
||||
// Session handlers. The delete function is not safe to use concurrently.
|
||||
func (ctx *Context) Start() func() {
|
||||
return ctx.State.AddHandler(func(v interface{}) {
|
||||
if err := ctx.callCmd(v); err != nil {
|
||||
ctx.ErrorLogger(errors.Wrap(err, "command error"))
|
||||
if ctx.stopFunc == nil {
|
||||
cancel := ctx.State.AddHandler(func(v interface{}) {
|
||||
if err := ctx.callCmd(v); err != nil {
|
||||
ctx.ErrorLogger(errors.Wrap(err, "command error"))
|
||||
}
|
||||
})
|
||||
|
||||
ctx.stopFunc = func() {
|
||||
cancel()
|
||||
ctx.stopFunc = nil
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return ctx.stopFunc
|
||||
}
|
||||
|
||||
// Open starts the bot context and the gateway connection. It automatically
|
||||
// binds the needed handlers.
|
||||
func (ctx *Context) Open(cancelCtx context.Context) error {
|
||||
ctx.Start()
|
||||
return ctx.State.Open(cancelCtx)
|
||||
}
|
||||
|
||||
// Call should only be used if you know what you're doing.
|
||||
|
|
75
bot/ctx_shard_test.go
Normal file
75
bot/ctx_shard_test.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
)
|
||||
|
||||
type shardedBot struct {
|
||||
Ctx *Context
|
||||
|
||||
readyCh chan *gateway.ReadyEvent
|
||||
}
|
||||
|
||||
func (bot *shardedBot) OnReady(r *gateway.ReadyEvent) {
|
||||
bot.readyCh <- r
|
||||
}
|
||||
|
||||
func TestSharding(t *testing.T) {
|
||||
env := testenv.Must(t)
|
||||
|
||||
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
|
||||
data.Shard = &gateway.Shard{0, env.ShardCount}
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent)
|
||||
|
||||
newShard := NewShardFunc(func(s *state.State) (*Context, error) {
|
||||
b, err := New(s, &shardedBot{nil, readyCh})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.AddIntents(gateway.IntentGuilds)
|
||||
return b, nil
|
||||
})
|
||||
|
||||
m, err := shard.NewIdentifiedManager(data, newShard)
|
||||
if err != nil {
|
||||
t.Fatal("failed to make shard manager:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
// Timeout
|
||||
if err := m.Open(ctx); err != nil {
|
||||
t.Error("failed to open:", err)
|
||||
cancel()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := m.Close(); err != nil {
|
||||
t.Error("failed to close:", err)
|
||||
cancel()
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
// Expect 4 Ready events.
|
||||
for i := 0; i < env.ShardCount; i++ {
|
||||
select {
|
||||
case ready := <-readyCh:
|
||||
now := time.Now().Format(time.StampMilli)
|
||||
t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("test expired, got", i, "shards")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -97,7 +97,7 @@ func TestContext(t *testing.T) {
|
|||
|
||||
Subcommand: sub,
|
||||
State: s,
|
||||
ParseArgs: DefaultArgsParser(),
|
||||
ParseArgs: DefaultArgsParser,
|
||||
}
|
||||
|
||||
t.Run("init commands", func(t *testing.T) {
|
||||
|
@ -396,7 +396,7 @@ func BenchmarkCall(b *testing.B) {
|
|||
Subcommand: sub,
|
||||
State: s,
|
||||
HasPrefix: NewPrefix("~"),
|
||||
ParseArgs: DefaultArgsParser(),
|
||||
ParseArgs: DefaultArgsParser,
|
||||
}
|
||||
|
||||
m := &gateway.MessageCreateEvent{
|
||||
|
@ -424,7 +424,7 @@ func BenchmarkHelp(b *testing.B) {
|
|||
Subcommand: sub,
|
||||
State: s,
|
||||
HasPrefix: NewPrefix("~"),
|
||||
ParseArgs: DefaultArgsParser(),
|
||||
ParseArgs: DefaultArgsParser,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
|
|
@ -182,13 +182,13 @@ type mockStore struct {
|
|||
store.NoopStore
|
||||
}
|
||||
|
||||
func mockCabinet() store.Cabinet {
|
||||
c := store.NoopCabinet
|
||||
func mockCabinet() *store.Cabinet {
|
||||
c := *store.NoopCabinet
|
||||
c.GuildStore = &mockStore{}
|
||||
c.MemberStore = &mockStore{}
|
||||
c.ChannelStore = &mockStore{}
|
||||
|
||||
return c
|
||||
return &c
|
||||
}
|
||||
|
||||
func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) {
|
||||
|
|
|
@ -9,16 +9,13 @@ package gateway
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
|
||||
"github.com/diamondburned/arikawa/v3/utils/httputil"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -26,9 +23,6 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
EndpointGateway = api.Endpoint + "gateway"
|
||||
EndpointGatewayBot = api.EndpointGateway + "/bot"
|
||||
|
||||
Version = api.Version
|
||||
Encoding = "json"
|
||||
)
|
||||
|
@ -44,47 +38,26 @@ var (
|
|||
// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes
|
||||
const errCodeShardingRequired = 4011
|
||||
|
||||
// BotData contains the GatewayURL as well as extra metadata on how to
|
||||
// shard bots.
|
||||
type BotData struct {
|
||||
URL string `json:"url"`
|
||||
Shards int `json:"shards,omitempty"`
|
||||
StartLimit *SessionStartLimit `json:"session_start_limit"`
|
||||
}
|
||||
|
||||
// SessionStartLimit is the information on the current session start limit. It's
|
||||
// used in BotData.
|
||||
type SessionStartLimit struct {
|
||||
Total int `json:"total"`
|
||||
Remaining int `json:"remaining"`
|
||||
ResetAfter discord.Milliseconds `json:"reset_after"`
|
||||
MaxConcurrency int `json:"max_concurrency"`
|
||||
}
|
||||
|
||||
// URL asks Discord for a Websocket URL to the Gateway.
|
||||
func URL() (string, error) {
|
||||
var g BotData
|
||||
|
||||
c := httputil.NewClient()
|
||||
if err := c.RequestJSON(&g, "GET", EndpointGateway); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return g.URL, nil
|
||||
return api.GatewayURL()
|
||||
}
|
||||
|
||||
// BotURL fetches the Gateway URL along with extra metadata. The token
|
||||
// passed in will NOT be prefixed with Bot.
|
||||
func BotURL(token string) (*BotData, error) {
|
||||
var g *BotData
|
||||
func BotURL(token string) (*api.BotData, error) {
|
||||
return api.NewClient(token).BotURL()
|
||||
}
|
||||
|
||||
return g, httputil.NewClient().RequestJSON(
|
||||
&g, "GET",
|
||||
EndpointGatewayBot,
|
||||
httputil.WithHeaders(http.Header{
|
||||
"Authorization": {token},
|
||||
}),
|
||||
)
|
||||
// AddGatewayParams appends into the given URL string the gateway URL
|
||||
// parameters.
|
||||
func AddGatewayParams(baseURL string) string {
|
||||
param := url.Values{
|
||||
"v": {Version},
|
||||
"encoding": {Encoding},
|
||||
}
|
||||
|
||||
return baseURL + "?" + param.Encode()
|
||||
}
|
||||
|
||||
type Gateway struct {
|
||||
|
@ -124,22 +97,16 @@ type Gateway struct {
|
|||
// Defaults to noop.
|
||||
FatalErrorCallback func(err error)
|
||||
|
||||
// OnScalingRequired is the function called, if Discord closes with error
|
||||
// code 4011 aka Scaling Required. At the point of calling, the Gateway
|
||||
// will be closed, and can, after increasing the number of shards, be
|
||||
// reopened using Open. Reconnect or ReconnectCtx, however, will not be
|
||||
// available as the session is invalidated.
|
||||
OnScalingRequired func()
|
||||
|
||||
// AfterClose is called after each close or pause. It is used mainly for
|
||||
// reconnections or any type of connection interruptions.
|
||||
//
|
||||
// Constructors will use a no-op function by default.
|
||||
AfterClose func(err error)
|
||||
|
||||
waitGroup sync.WaitGroup
|
||||
onShardingRequired func()
|
||||
|
||||
closed chan struct{}
|
||||
waitGroup sync.WaitGroup
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// NewGatewayWithIntents creates a new Gateway with the given intents and the
|
||||
|
@ -167,7 +134,7 @@ func NewGateway(token string) (*Gateway, error) {
|
|||
// shared identifier.
|
||||
func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
|
||||
var gatewayURL string
|
||||
var botData *BotData
|
||||
var botData *api.BotData
|
||||
var err error
|
||||
|
||||
if strings.HasPrefix(id.Token, "Bot ") {
|
||||
|
@ -184,14 +151,7 @@ func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Parameters for the gateway
|
||||
param := url.Values{
|
||||
"v": {Version},
|
||||
"encoding": {Encoding},
|
||||
}
|
||||
|
||||
// Append the form to the URL
|
||||
gatewayURL += "?" + param.Encode()
|
||||
gatewayURL = AddGatewayParams(gatewayURL)
|
||||
gateway := NewCustomIdentifiedGateway(gatewayURL, id)
|
||||
|
||||
// Use the supplied connect rate limit, if any.
|
||||
|
@ -318,6 +278,18 @@ func (g *Gateway) SessionID() string {
|
|||
return g.sessionID
|
||||
}
|
||||
|
||||
// OnShardingRequired sets the function to be called if Discord closes with
|
||||
// error code 4011 aka Sharding Required. When called, the Gateway will already
|
||||
// be closed, and can (after increasing the number of shards) be reopened using
|
||||
// Open. Reconnect or ReconnectCtx, however, will not be available as the
|
||||
// session is invalidated.
|
||||
//
|
||||
// The gateway will completely halt what it's doing in the background when this
|
||||
// callback is called.
|
||||
func (g *Gateway) OnShardingRequired(fn func()) {
|
||||
g.onShardingRequired = fn
|
||||
}
|
||||
|
||||
// Reconnect tries to reconnect to the Gateway until the ReconnectAttempts are
|
||||
// reached.
|
||||
func (g *Gateway) Reconnect() {
|
||||
|
@ -349,7 +321,7 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
|
|||
wsutil.WSDebug("Trying to dial, attempt", try)
|
||||
|
||||
// if we encounter an error, make sure we return it, and not nil
|
||||
if oerr := g.OpenContext(ctx); oerr != nil {
|
||||
if oerr := g.Open(ctx); oerr != nil {
|
||||
err = oerr
|
||||
g.ErrorLog(oerr)
|
||||
|
||||
|
@ -370,19 +342,13 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
// Open connects to the Websocket and authenticate it. You should usually use
|
||||
// this function over Start().
|
||||
func (g *Gateway) Open() error {
|
||||
// Open connects to the Websocket and authenticates it. You should usually use
|
||||
// this function over Start(). The given context provides cancellation and
|
||||
// timeout.
|
||||
func (g *Gateway) Open(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.OpenContext(ctx)
|
||||
}
|
||||
|
||||
// OpenContext connects to the Websocket and authenticates it. You should
|
||||
// usually use this function over Start(). The given context provides
|
||||
// cancellation and timeout.
|
||||
func (g *Gateway) OpenContext(ctx context.Context) error {
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to Reconnect")
|
||||
|
@ -456,19 +422,18 @@ func (g *Gateway) start(ctx context.Context) error {
|
|||
g.waitGroup.Done() // mark so Close() can exit.
|
||||
wsutil.WSDebug("Event loop stopped with error:", err)
|
||||
|
||||
// If Discord signals us sharding is required, do not attempt to
|
||||
// Reconnect. Instead invalidate our session id, as we cannot resume,
|
||||
// call OnShardingRequired, and exit.
|
||||
var cerr *websocket.CloseError
|
||||
if errors.As(err, &cerr) && cerr != nil && cerr.Code == errCodeShardingRequired {
|
||||
g.ErrorLog(cerr)
|
||||
|
||||
g.sessionMu.Lock()
|
||||
g.sessionID = ""
|
||||
g.sessionMu.Unlock()
|
||||
|
||||
g.OnScalingRequired()
|
||||
return
|
||||
if err != nil && g.onShardingRequired != nil {
|
||||
// If Discord signals us sharding is required, do not attempt to
|
||||
// Reconnect, unless we don't know what to do. Instead invalidate
|
||||
// our session ID, as we cannot resume, call OnShardingRequired, and
|
||||
// exit.
|
||||
var cerr *websocket.CloseError
|
||||
if errors.As(err, &cerr) && cerr.Code == errCodeShardingRequired {
|
||||
g.ErrorLog(cerr)
|
||||
g.UseSessionID("")
|
||||
g.onShardingRequired()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Bail if there is no error or if the error is an explicit close, as
|
||||
|
|
|
@ -42,7 +42,10 @@ func TestInvalidToken(t *testing.T) {
|
|||
t.Fatal("failed to make a Gateway:", err)
|
||||
}
|
||||
|
||||
if err = g.Open(); err == nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err = g.Open(ctx); err == nil {
|
||||
t.Fatal("unexpected success while opening with a bad token.")
|
||||
}
|
||||
|
||||
|
@ -55,10 +58,6 @@ func TestInvalidToken(t *testing.T) {
|
|||
func TestIntegration(t *testing.T) {
|
||||
config := testenv.Must(t)
|
||||
|
||||
wsutil.WSError = func(err error) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var gateway *Gateway
|
||||
|
||||
// NewGateway should call Start for us.
|
||||
|
@ -70,11 +69,16 @@ func TestIntegration(t *testing.T) {
|
|||
g.AfterClose = func(err error) {
|
||||
t.Log("closed.")
|
||||
}
|
||||
g.ErrorLog = func(err error) {
|
||||
t.Log("gateway error:", err)
|
||||
}
|
||||
gateway = g
|
||||
|
||||
if err := g.Open(); err != nil {
|
||||
t.Fatal("failed to authenticate with Discord:", err)
|
||||
}
|
||||
gotimeout(t, func(ctx context.Context) {
|
||||
if err := g.Open(ctx); err != nil {
|
||||
t.Fatal("failed to authenticate with Discord:", err)
|
||||
}
|
||||
})
|
||||
|
||||
ev := wait(t, gateway.Events)
|
||||
ready, ok := ev.(*ReadyEvent)
|
||||
|
@ -94,11 +98,7 @@ func TestIntegration(t *testing.T) {
|
|||
// Sleep past the rate limiter before reconnecting:
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
gotimeout(t, func() {
|
||||
// Try and reconnect for 20 seconds maximum.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
gotimeout(t, func(ctx context.Context) {
|
||||
g.ErrorLog = func(err error) {
|
||||
t.Error("unexpected error while reconnecting:", err)
|
||||
}
|
||||
|
@ -108,10 +108,10 @@ func TestIntegration(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
g.ErrorLog = func(err error) { log.Println(err) }
|
||||
g.ErrorLog = func(err error) { t.Log("warning:", err) }
|
||||
|
||||
// Wait for the desired event:
|
||||
gotimeout(t, func() {
|
||||
gotimeout(t, func(context.Context) {
|
||||
for ev := range gateway.Events {
|
||||
switch ev.(type) {
|
||||
// Accept only a Resumed event.
|
||||
|
@ -138,17 +138,21 @@ func wait(t *testing.T, evCh chan interface{}) interface{} {
|
|||
}
|
||||
}
|
||||
|
||||
func gotimeout(t *testing.T, fn func()) {
|
||||
func gotimeout(t *testing.T, fn func(context.Context)) {
|
||||
t.Helper()
|
||||
|
||||
// Try and reconnect for 20 seconds maximum.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var done = make(chan struct{})
|
||||
go func() {
|
||||
fn()
|
||||
fn(ctx)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(20 * time.Second):
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for function.")
|
||||
case <-done:
|
||||
return
|
||||
|
|
|
@ -42,3 +42,49 @@ func (i Intents) IsPrivileged() (presences, member bool) {
|
|||
// Keep this in sync with PrivilegedIntents.
|
||||
return i.Has(IntentGuildPresences), i.Has(IntentGuildMembers)
|
||||
}
|
||||
|
||||
// EventIntents maps event types to intents.
|
||||
var EventIntents = map[string]Intents{
|
||||
"GUILD_CREATE": IntentGuilds,
|
||||
"GUILD_UPDATE": IntentGuilds,
|
||||
"GUILD_DELETE": IntentGuilds,
|
||||
"GUILD_ROLE_CREATE": IntentGuilds,
|
||||
"GUILD_ROLE_UPDATE": IntentGuilds,
|
||||
"GUILD_ROLE_DELETE": IntentGuilds,
|
||||
"CHANNEL_CREATE": IntentGuilds,
|
||||
"CHANNEL_UPDATE": IntentGuilds,
|
||||
"CHANNEL_DELETE": IntentGuilds,
|
||||
"CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages,
|
||||
|
||||
"GUILD_MEMBER_ADD": IntentGuildMembers,
|
||||
"GUILD_MEMBER_REMOVE": IntentGuildMembers,
|
||||
"GUILD_MEMBER_UPDATE": IntentGuildMembers,
|
||||
|
||||
"GUILD_BAN_ADD": IntentGuildBans,
|
||||
"GUILD_BAN_REMOVE": IntentGuildBans,
|
||||
|
||||
"GUILD_EMOJIS_UPDATE": IntentGuildEmojis,
|
||||
|
||||
"GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations,
|
||||
|
||||
"WEBHOOKS_UPDATE": IntentGuildWebhooks,
|
||||
|
||||
"INVITE_CREATE": IntentGuildInvites,
|
||||
"INVITE_DELETE": IntentGuildInvites,
|
||||
|
||||
"VOICE_STATE_UPDATE": IntentGuildVoiceStates,
|
||||
|
||||
"PRESENCE_UPDATE": IntentGuildPresences,
|
||||
|
||||
"MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_DELETE_BULK": IntentGuildMessages,
|
||||
|
||||
"MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
|
||||
"TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping,
|
||||
}
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
package gateway
|
||||
|
||||
// EventIntents maps event types to intents.
|
||||
var EventIntents = map[string]Intents{
|
||||
"GUILD_CREATE": IntentGuilds,
|
||||
"GUILD_UPDATE": IntentGuilds,
|
||||
"GUILD_DELETE": IntentGuilds,
|
||||
"GUILD_ROLE_CREATE": IntentGuilds,
|
||||
"GUILD_ROLE_UPDATE": IntentGuilds,
|
||||
"GUILD_ROLE_DELETE": IntentGuilds,
|
||||
"CHANNEL_CREATE": IntentGuilds,
|
||||
"CHANNEL_UPDATE": IntentGuilds,
|
||||
"CHANNEL_DELETE": IntentGuilds,
|
||||
"CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages,
|
||||
|
||||
"GUILD_MEMBER_ADD": IntentGuildMembers,
|
||||
"GUILD_MEMBER_REMOVE": IntentGuildMembers,
|
||||
"GUILD_MEMBER_UPDATE": IntentGuildMembers,
|
||||
|
||||
"GUILD_BAN_ADD": IntentGuildBans,
|
||||
"GUILD_BAN_REMOVE": IntentGuildBans,
|
||||
|
||||
"GUILD_EMOJIS_UPDATE": IntentGuildEmojis,
|
||||
|
||||
"GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations,
|
||||
|
||||
"WEBHOOKS_UPDATE": IntentGuildWebhooks,
|
||||
|
||||
"INVITE_CREATE": IntentGuildInvites,
|
||||
"INVITE_DELETE": IntentGuildInvites,
|
||||
|
||||
"VOICE_STATE_UPDATE": IntentGuildVoiceStates,
|
||||
|
||||
"PRESENCE_UPDATE": IntentGuildPresences,
|
||||
|
||||
"MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_DELETE_BULK": IntentGuildMessages,
|
||||
|
||||
"MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
|
||||
"TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping,
|
||||
}
|
312
gateway/shard/manager.go
Normal file
312
gateway/shard/manager.go
Normal file
|
@ -0,0 +1,312 @@
|
|||
package shard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/internal/backoff"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func updateIdentifier(ctx context.Context, id *gateway.Identifier) (url string, err error) {
|
||||
botData, err := api.NewClient(id.Token).WithContext(ctx).BotURL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if botData.Shards < 1 {
|
||||
botData.Shards = 1
|
||||
}
|
||||
|
||||
id.Shard = &gateway.Shard{0, botData.Shards}
|
||||
|
||||
// Update the burst to be the current given time and reset it back to
|
||||
// the default when the given time is reached.
|
||||
id.IdentifyGlobalLimit.SetBurst(botData.StartLimit.Remaining)
|
||||
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
|
||||
id.IdentifyGlobalLimit.SetBurstAt(resetAt, botData.StartLimit.Total)
|
||||
|
||||
// Update the maximum number of identify requests allowed per 5s.
|
||||
id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
|
||||
|
||||
return botData.URL, nil
|
||||
}
|
||||
|
||||
// Manager is the manager responsible for handling all sharding on this
|
||||
// instance. An instance of Manager must never be copied.
|
||||
type Manager struct {
|
||||
// shards are the *shards.shards managed by this Manager. They are
|
||||
// sorted in ascending order by their shard id.
|
||||
shards []ShardState
|
||||
gatewayURL string
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
rescaling *rescalingState // nil unless rescaling
|
||||
|
||||
new NewShardFunc
|
||||
}
|
||||
|
||||
type rescalingState struct {
|
||||
haltRescale context.CancelFunc
|
||||
rescaleDone sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewManager creates a Manager using as many gateways as recommended by
|
||||
// Discord.
|
||||
func NewManager(token string, fn NewShardFunc) (*Manager, error) {
|
||||
id := gateway.DefaultIdentifier(token)
|
||||
|
||||
url, err := updateIdentifier(context.Background(), id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get gateway info")
|
||||
}
|
||||
|
||||
return NewIdentifiedManagerWithURL(url, id, fn)
|
||||
}
|
||||
|
||||
// NewIdentifiedManager creates a new Manager using the given
|
||||
// gateway.Identifier. The total number of shards will be taken from the
|
||||
// identifier instead of being queried from Discord, but the shard ID will be
|
||||
// ignored.
|
||||
//
|
||||
// This function should rarely be used, since the shard information will be
|
||||
// queried from Discord if it's required to shard anyway.
|
||||
func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager, error) {
|
||||
// Ensure id.Shard is never nil.
|
||||
if data.Shard == nil {
|
||||
data.Shard = gateway.DefaultShard
|
||||
}
|
||||
|
||||
id := gateway.NewIdentifier(data)
|
||||
|
||||
url, err := updateIdentifier(context.Background(), id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get gateway info")
|
||||
}
|
||||
|
||||
id.Shard = data.Shard
|
||||
|
||||
return NewIdentifiedManagerWithURL(url, id, fn)
|
||||
}
|
||||
|
||||
// NewIdentifiedManagerWithURL creates a new Manager with the given Identifier
|
||||
// and gateway URL. It behaves similarly to NewIdentifiedManager.
|
||||
func NewIdentifiedManagerWithURL(
|
||||
url string, id *gateway.Identifier, fn NewShardFunc) (*Manager, error) {
|
||||
|
||||
m := Manager{
|
||||
gatewayURL: gateway.AddGatewayParams(url),
|
||||
shards: make([]ShardState, id.Shard.NumShards()),
|
||||
new: fn,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
for i := range m.shards {
|
||||
data := id.IdentifyData
|
||||
data.Shard = &gateway.Shard{i, len(m.shards)}
|
||||
|
||||
m.shards[i] = ShardState{
|
||||
ID: gateway.Identifier{
|
||||
IdentifyData: data,
|
||||
IdentifyShortLimit: id.IdentifyShortLimit,
|
||||
IdentifyGlobalLimit: id.IdentifyGlobalLimit,
|
||||
},
|
||||
}
|
||||
|
||||
m.shards[i].Shard, err = fn(&m, &m.shards[i].ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to create shard %d/%d", i, len(m.shards)-1)
|
||||
}
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// GatewayURL returns the URL to the gateway. The URL will always have the
|
||||
// needed gateway parameters appended.
|
||||
func (m *Manager) GatewayURL() string {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return m.gatewayURL
|
||||
}
|
||||
|
||||
// NumShards returns the total number of shards. It is OK for the caller to rely
|
||||
// on NumShards while they're inside ForEach.
|
||||
func (m *Manager) NumShards() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return len(m.shards)
|
||||
}
|
||||
|
||||
// Shard gets the shard with the given ID.
|
||||
func (m *Manager) Shard(ix int) Shard {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if ix >= len(m.shards) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.shards[ix]
|
||||
}
|
||||
|
||||
// FromGuildID returns the Shard and the shard ID for the guild with the given
|
||||
// ID.
|
||||
func (m *Manager) FromGuildID(guildID discord.GuildID) (shard Shard, ix int) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
ix = int(uint64(guildID>>22) % uint64(len(m.shards)))
|
||||
return m.shards[ix], ix
|
||||
}
|
||||
|
||||
// ForEach calls the given function on each shard from first to last. The caller
|
||||
// can safely access the number of shards by either asserting Shard to get the
|
||||
// IdentifyData or call m.NumShards.
|
||||
func (m *Manager) ForEach(f func(shard Shard)) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, g := range m.shards {
|
||||
f(g)
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens all gateways handled by this Manager. If an error occurs, Open
|
||||
// will attempt to close all previously opened gateways before returning.
|
||||
func (m *Manager) Open(ctx context.Context) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return OpenShards(ctx, m.shards)
|
||||
}
|
||||
|
||||
// Close closes all gateways handled by this Manager; it will stop rescaling if
|
||||
// the manager is currently being rescaled. If an error occurs, Close will
|
||||
// attempt to close all remaining gateways first, before returning.
|
||||
func (m *Manager) Close() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.rescaling != nil {
|
||||
m.rescaling.haltRescale()
|
||||
m.rescaling.rescaleDone.Wait()
|
||||
|
||||
m.rescaling = nil
|
||||
}
|
||||
|
||||
return CloseShards(m.shards)
|
||||
}
|
||||
|
||||
// Rescale rescales the manager asynchronously. The caller MUST NOT call Rescale
|
||||
// in the constructor function; doing so WILL cause the state to be inconsistent
|
||||
// and eventually crash and burn and destroy us all.
|
||||
func (m *Manager) Rescale() {
|
||||
go m.rescale()
|
||||
}
|
||||
|
||||
func (m *Manager) rescale() {
|
||||
m.mutex.Lock()
|
||||
|
||||
// Exit if we're already rescaling.
|
||||
if m.rescaling != nil {
|
||||
m.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new context to allow the caller to cancel rescaling.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
m.rescaling = &rescalingState{haltRescale: cancel}
|
||||
m.rescaling.rescaleDone.Add(1)
|
||||
defer m.rescaling.rescaleDone.Done()
|
||||
|
||||
// Take the old list of shards for ourselves.
|
||||
oldShards := m.shards
|
||||
m.shards = nil
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
// Close the shards outside the lock. This should be fairly quickly, but it
|
||||
// allows the caller to halt rescaling while we're closing or opening the
|
||||
// shards.
|
||||
CloseShards(oldShards)
|
||||
|
||||
backoffT := backoff.NewTimer(time.Second, 15*time.Minute)
|
||||
defer backoffT.Stop()
|
||||
|
||||
for {
|
||||
if m.tryRescale(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-backoffT.Next():
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryRescale attempts once to rescale. It assumes the mutex is unlocked and
|
||||
// will unlock the mutex itself.
|
||||
func (m *Manager) tryRescale(ctx context.Context) bool {
|
||||
m.mutex.Lock()
|
||||
|
||||
data := m.shards[0].ID.IdentifyData
|
||||
newID := gateway.NewIdentifier(data)
|
||||
|
||||
url, err := updateIdentifier(ctx, newID)
|
||||
if err != nil {
|
||||
m.mutex.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
numShards := newID.Shard.NumShards()
|
||||
m.gatewayURL = url
|
||||
|
||||
// Release the mutex early.
|
||||
m.mutex.Unlock()
|
||||
|
||||
// Create the shards slice to set after we reacquire the mutex.
|
||||
newShards := make([]ShardState, numShards)
|
||||
|
||||
for i := 0; i < numShards; i++ {
|
||||
data := newID.IdentifyData
|
||||
data.Shard = &gateway.Shard{i, len(m.shards)}
|
||||
|
||||
newShards[i] = ShardState{
|
||||
ID: gateway.Identifier{
|
||||
IdentifyData: data,
|
||||
IdentifyShortLimit: newID.IdentifyShortLimit,
|
||||
IdentifyGlobalLimit: newID.IdentifyGlobalLimit,
|
||||
},
|
||||
}
|
||||
|
||||
newShards[i].Shard, err = m.new(m, &newShards[i].ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if err := OpenShards(ctx, newShards); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.shards = newShards
|
||||
m.rescaling = nil
|
||||
m.mutex.Unlock()
|
||||
|
||||
return true
|
||||
}
|
85
gateway/shard/shard.go
Normal file
85
gateway/shard/shard.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package shard
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Shard defines a shard gateway interface that the shard manager can use.
|
||||
type Shard interface {
|
||||
Open(context.Context) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// NewShardFunc is the constructor to create a new gateway. For examples, see
|
||||
// package session and state's. The constructor must manually connect the
|
||||
// Manager's Rescale method appropriately.
|
||||
//
|
||||
// A new Gateway must not open any background resources until OpenCtx is called;
|
||||
// if the gateway has never been opened, its Close method will never be called.
|
||||
// During callback, the Manager is not locked, so the callback can use Manager's
|
||||
// methods without deadlocking.
|
||||
type NewShardFunc func(m *Manager, id *gateway.Identifier) (Shard, error)
|
||||
|
||||
// NewGatewayShardFunc wraps around NewGatewayShard to be compatible with
|
||||
// NewShardFunc.
|
||||
var NewGatewayShardFunc NewShardFunc = func(m *Manager, id *gateway.Identifier) (Shard, error) {
|
||||
return NewGatewayShard(m, id), nil
|
||||
}
|
||||
|
||||
// NewGatewayShard creates a new gateway that's plugged into the shard manager.
|
||||
func NewGatewayShard(m *Manager, id *gateway.Identifier) *gateway.Gateway {
|
||||
gw := gateway.NewCustomIdentifiedGateway(m.GatewayURL(), id)
|
||||
gw.OnShardingRequired(m.Rescale)
|
||||
return gw
|
||||
}
|
||||
|
||||
// ShardState wraps around the Gateway interface to provide additional state.
|
||||
type ShardState struct {
|
||||
Shard
|
||||
// This is a bit wasteful: 2 constant pointers are stored here, and they
|
||||
// waste GC cycles. This is unavoidable, however, since the API has to take
|
||||
// in a pointer to Identifier, not IdentifyData. This is to ensure rescales
|
||||
// are consistent.
|
||||
ID gateway.Identifier
|
||||
Opened bool
|
||||
}
|
||||
|
||||
// ShardID returns the shard state's shard ID.
|
||||
func (state ShardState) ShardID() int {
|
||||
return state.ID.Shard.ShardID()
|
||||
}
|
||||
|
||||
// OpenShards opens the gateways of the given list of shard states.
|
||||
func OpenShards(ctx context.Context, shards []ShardState) error {
|
||||
for i, shard := range shards {
|
||||
if err := shard.Open(ctx); err != nil {
|
||||
CloseShards(shards)
|
||||
return errors.Wrapf(err, "failed to open shard %d/%d", i, len(shards)-1)
|
||||
}
|
||||
|
||||
// Mark as opened so we can close them.
|
||||
shards[i].Opened = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseShards closes the gateways of the given list of shard states.
|
||||
func CloseShards(shards []ShardState) error {
|
||||
var lastError error
|
||||
|
||||
for i, gw := range shards {
|
||||
if gw.Opened {
|
||||
if err := gw.Close(); err != nil {
|
||||
lastError = err
|
||||
}
|
||||
|
||||
shards[i].Opened = false
|
||||
}
|
||||
}
|
||||
|
||||
return lastError
|
||||
}
|
117
internal/backoff/backoff.go
Normal file
117
internal/backoff/backoff.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
// Package backoff provides an exponential-backoff implementation partially
|
||||
// taken from jpillora/backoff.
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
factor = 2
|
||||
jitter = true
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// Timer is a backoff timer.
|
||||
type Timer struct {
|
||||
backoff Backoff
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// NewTimer returns a new uninitialized timer.
|
||||
func NewTimer(min, max time.Duration) Timer {
|
||||
return Timer{
|
||||
backoff: NewBackoff(min, max),
|
||||
}
|
||||
}
|
||||
|
||||
// Next initializes the timer if needed and returns a timer channel that fires
|
||||
// when the backoff timeout is reached.
|
||||
func (t *Timer) Next() <-chan time.Time {
|
||||
if t.timer == nil {
|
||||
t.timer = time.NewTimer(t.backoff.Next())
|
||||
} else {
|
||||
t.timer.Stop() // ensure drained
|
||||
t.timer.Reset(t.backoff.Next())
|
||||
}
|
||||
|
||||
return t.timer.C
|
||||
}
|
||||
|
||||
// Stop stops the internal timer and frees its resources. It does nothing if the
|
||||
// timer is uninitialized.
|
||||
func (t *Timer) Stop() {
|
||||
if t.timer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.timer.Stop() {
|
||||
<-t.timer.C // drain
|
||||
}
|
||||
}
|
||||
|
||||
// Backoff is a time.Duration counter, starting at Min. After every call to
|
||||
// the Duration method the current timing is multiplied by Factor, but it
|
||||
// never exceeds Max.
|
||||
type Backoff struct {
|
||||
min, max float64 // seconds
|
||||
attempt int32 // negative == max uint32
|
||||
}
|
||||
|
||||
// NewBackoff creates a new backoff time.Duration counter.
|
||||
func NewBackoff(min, max time.Duration) Backoff {
|
||||
return Backoff{
|
||||
min: min.Seconds(),
|
||||
max: max.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// Next returns the next backoff duration.
|
||||
func (b *Backoff) Next() time.Duration {
|
||||
return b.forAttempt(atomic.AddInt32(&b.attempt, 1) - 1)
|
||||
}
|
||||
|
||||
const maxInt64 = float64(math.MaxInt64 - 512)
|
||||
|
||||
// forAttempt returns the duration for a specific attempt. This is useful if
|
||||
// you have a large number of independent Backoffs, but don't want use
|
||||
// unnecessary memory storing the Backoff parameters per Backoff. The first
|
||||
// attempt should be 0.
|
||||
func (b *Backoff) forAttempt(attempt int32) time.Duration {
|
||||
if b.min >= b.max {
|
||||
// short-circuit
|
||||
return duration(b.max)
|
||||
}
|
||||
|
||||
// Ensure attempt never overflows.
|
||||
if attempt < 0 {
|
||||
attempt = math.MaxInt32
|
||||
}
|
||||
|
||||
// Calculate this duration.
|
||||
dur := b.min * math.Pow(factor, float64(attempt))
|
||||
if jitter {
|
||||
dur = rand.Float64()*(dur-b.min) + b.min
|
||||
}
|
||||
|
||||
if dur < b.min {
|
||||
return duration(b.min)
|
||||
}
|
||||
if dur > b.max {
|
||||
return duration(b.max)
|
||||
}
|
||||
|
||||
return duration(dur)
|
||||
}
|
||||
|
||||
// duration converts a seconds float64 to time.Duration without losing accuracy.
|
||||
func duration(secs float64) time.Duration {
|
||||
int, frac := math.Modf(secs)
|
||||
return (time.Duration(int) * time.Second) + time.Duration(frac*float64(time.Second))
|
||||
}
|
|
@ -14,25 +14,6 @@ func NewCtxMutex() *CtxMutex {
|
|||
}
|
||||
}
|
||||
|
||||
// func (m *CtxMutex) TryLock() bool {
|
||||
// select {
|
||||
// case m.mut <- struct{}{}:
|
||||
// return true
|
||||
// default:
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (m *CtxMutex) IsBusy() bool {
|
||||
// select {
|
||||
// case m.mut <- struct{}{}:
|
||||
// <-m.mut
|
||||
// return false
|
||||
// default:
|
||||
// return true
|
||||
// }
|
||||
// }
|
||||
|
||||
func (m *CtxMutex) Lock(ctx context.Context) error {
|
||||
select {
|
||||
case m.mut <- struct{}{}:
|
||||
|
|
|
@ -4,6 +4,7 @@ package testenv
|
|||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -15,9 +16,10 @@ import (
|
|||
const PerseveranceTime = 50 * time.Minute
|
||||
|
||||
type Env struct {
|
||||
BotToken string
|
||||
ChannelID discord.ChannelID
|
||||
VoiceChID discord.ChannelID
|
||||
BotToken string
|
||||
ChannelID discord.ChannelID
|
||||
VoiceChID discord.ChannelID
|
||||
ShardCount int // default 3
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -40,39 +42,33 @@ func GetEnv() (Env, error) {
|
|||
}
|
||||
|
||||
func getEnv() {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
token := os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
globalErr = errors.New("missing $BOT_TOKEN")
|
||||
return
|
||||
}
|
||||
|
||||
var cid = os.Getenv("CHANNEL_ID")
|
||||
if cid == "" {
|
||||
globalErr = errors.New("missing $CHANNEL_ID")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := discord.ParseSnowflake(cid)
|
||||
id, err := discord.ParseSnowflake(os.Getenv("CHANNEL_ID"))
|
||||
if err != nil {
|
||||
globalErr = errors.Wrap(err, "invalid $CHANNEL_ID")
|
||||
return
|
||||
}
|
||||
|
||||
var sid = os.Getenv("VOICE_ID")
|
||||
if sid == "" {
|
||||
globalErr = errors.New("missing $VOICE_ID")
|
||||
return
|
||||
}
|
||||
|
||||
vid, err := discord.ParseSnowflake(sid)
|
||||
vid, err := discord.ParseSnowflake(os.Getenv("VOICE_ID"))
|
||||
if err != nil {
|
||||
globalErr = errors.Wrap(err, "invalid $VOICE_ID")
|
||||
return
|
||||
}
|
||||
|
||||
shardCount := 3
|
||||
if c, err := strconv.Atoi(os.Getenv("SHARD_COUNT")); err == nil {
|
||||
shardCount = c
|
||||
}
|
||||
|
||||
globalEnv = Env{
|
||||
BotToken: token,
|
||||
ChannelID: discord.ChannelID(id),
|
||||
VoiceChID: discord.ChannelID(vid),
|
||||
BotToken: token,
|
||||
ChannelID: discord.ChannelID(id),
|
||||
VoiceChID: discord.ChannelID(vid),
|
||||
ShardCount: shardCount,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/handleloop"
|
||||
"github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
)
|
||||
|
@ -28,11 +29,33 @@ type Closed struct {
|
|||
Error error
|
||||
}
|
||||
|
||||
// NewShardFunc creates a shard constructor for a session.
|
||||
// Accessing any shard and adding a handler will add a handler for all shards.
|
||||
func NewShardFunc(f func(m *shard.Manager, s *Session)) shard.NewShardFunc {
|
||||
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
|
||||
s := NewCustomShard(m, id)
|
||||
if f != nil {
|
||||
f(m, s)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewCustomShard creates a new session from the given shard manager and other
|
||||
// parameters.
|
||||
func NewCustomShard(m *shard.Manager, id *gateway.Identifier) *Session {
|
||||
return NewCustomSession(
|
||||
shard.NewGatewayShard(m, id),
|
||||
api.NewClient(id.Token),
|
||||
handler.New(),
|
||||
)
|
||||
}
|
||||
|
||||
// Session manages both the API and Gateway. As such, Session inherits all of
|
||||
// API's methods, as well has the Handler used for Gateway.
|
||||
type Session struct {
|
||||
*api.Client
|
||||
Gateway *gateway.Gateway
|
||||
*gateway.Gateway
|
||||
|
||||
// Command handler with inherited methods.
|
||||
*handler.Handler
|
||||
|
@ -92,20 +115,22 @@ func Login(email, password, mfa string) (*Session, error) {
|
|||
return New(l.Token)
|
||||
}
|
||||
|
||||
// NewWithGateway creates a new Session with the given Gateway.
|
||||
func NewWithGateway(gw *gateway.Gateway) *Session {
|
||||
handler := handler.New()
|
||||
looper := handleloop.NewLoop(handler)
|
||||
return NewCustomSession(gw, api.NewClient(gw.Identifier.Token), handler.New())
|
||||
}
|
||||
|
||||
// NewCustomSession constructs a bare Session from the given parameters.
|
||||
func NewCustomSession(gw *gateway.Gateway, cl *api.Client, h *handler.Handler) *Session {
|
||||
return &Session{
|
||||
Gateway: gw,
|
||||
// Nab off gateway's token
|
||||
Client: api.NewClient(gw.Identifier.Token),
|
||||
Handler: handler,
|
||||
looper: looper,
|
||||
Client: cl,
|
||||
Handler: h,
|
||||
looper: handleloop.NewLoop(h),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) Open() error {
|
||||
func (s *Session) Open(ctx context.Context) error {
|
||||
// Start the handler beforehand so no events are missed.
|
||||
s.looper.Start(s.Gateway.Events)
|
||||
|
||||
|
@ -116,7 +141,7 @@ func (s *Session) Open() error {
|
|||
})
|
||||
}
|
||||
|
||||
if err := s.Gateway.Open(); err != nil {
|
||||
if err := s.Gateway.Open(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to start gateway")
|
||||
}
|
||||
|
||||
|
|
66
session/session_shard_test.go
Normal file
66
session/session_shard_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
)
|
||||
|
||||
func TestSharding(t *testing.T) {
|
||||
env := testenv.Must(t)
|
||||
|
||||
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
|
||||
data.Shard = &gateway.Shard{0, env.ShardCount}
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent)
|
||||
|
||||
m, err := shard.NewIdentifiedManager(data, NewShardFunc(
|
||||
func(m *shard.Manager, s *Session) {
|
||||
now := time.Now().Format(time.StampMilli)
|
||||
t.Log(now, "initializing shard")
|
||||
|
||||
s.Gateway.ErrorLog = func(err error) {
|
||||
t.Error("gateway error:", err)
|
||||
}
|
||||
|
||||
s.AddIntents(gateway.IntentGuilds)
|
||||
s.AddHandler(readyCh)
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal("failed to make shard manager:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
// Timeout
|
||||
if err := m.Open(ctx); err != nil {
|
||||
t.Error("failed to open:", err)
|
||||
cancel()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := m.Close(); err != nil {
|
||||
t.Error("failed to close:", err)
|
||||
cancel()
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
// Expect 4 Ready events.
|
||||
for i := 0; i < env.ShardCount; i++ {
|
||||
select {
|
||||
case ready := <-readyCh:
|
||||
now := time.Now().Format(time.StampMilli)
|
||||
t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("test expired, got", i, "shards")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/state/store"
|
||||
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
|
||||
|
@ -21,6 +22,32 @@ var (
|
|||
MaxFetchGuilds uint = 100
|
||||
)
|
||||
|
||||
// NewShardFunc creates a shard constructor that shares the same handler. The
|
||||
// given opts function is called everytime the State is created. If it doesn't
|
||||
// set a cabinet into the state, then a shared default cabinet is set instead.
|
||||
func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc {
|
||||
var once sync.Once
|
||||
var cab *store.Cabinet
|
||||
|
||||
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
|
||||
state := NewFromSession(session.NewCustomShard(m, id), nil)
|
||||
|
||||
if opts != nil {
|
||||
opts(m, state)
|
||||
}
|
||||
|
||||
if state.Cabinet == nil {
|
||||
// Create the cabinet once; use sync.Once so the constructor can be
|
||||
// concurrently safe.
|
||||
once.Do(func() { cab = defaultstore.New() })
|
||||
|
||||
state.Cabinet = cab
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
|
||||
// State is the cache to store events coming from Discord as well as data from
|
||||
// API calls.
|
||||
//
|
||||
|
@ -59,7 +86,7 @@ var (
|
|||
// will be empty, while the Member structure expects it to be there.
|
||||
type State struct {
|
||||
*session.Session
|
||||
store.Cabinet
|
||||
*store.Cabinet
|
||||
|
||||
// *: State doesn't actually keep track of pinned messages.
|
||||
|
||||
|
@ -113,7 +140,8 @@ func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
|
|||
return NewFromSession(s, defaultstore.New()), nil
|
||||
}
|
||||
|
||||
func NewWithStore(token string, cabinet store.Cabinet) (*State, error) {
|
||||
// NewWithStore creates a new state with the given store cabinet.
|
||||
func NewWithStore(token string, cabinet *store.Cabinet) (*State, error) {
|
||||
s, err := session.New(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -123,7 +151,7 @@ func NewWithStore(token string, cabinet store.Cabinet) (*State, error) {
|
|||
}
|
||||
|
||||
// NewFromSession creates a new State from the passed Session and Cabinet.
|
||||
func NewFromSession(s *session.Session, cabinet store.Cabinet) *State {
|
||||
func NewFromSession(s *session.Session, cabinet *store.Cabinet) *State {
|
||||
state := &State{
|
||||
Session: s,
|
||||
Cabinet: cabinet,
|
||||
|
@ -625,7 +653,7 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes
|
|||
if len(storeMessages) >= int(limit) && limit > 0 {
|
||||
return storeMessages[:limit], nil
|
||||
}
|
||||
|
||||
|
||||
// Decrease the limit, if we aren't fetching all messages.
|
||||
if limit > 0 {
|
||||
limit -= uint(len(storeMessages))
|
||||
|
|
|
@ -372,7 +372,7 @@ func findReaction(rs []discord.Reaction, emoji discord.Emoji) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func storeGuildCreate(cab store.Cabinet, guild *gateway.GuildCreateEvent) []error {
|
||||
func storeGuildCreate(cab *store.Cabinet, guild *gateway.GuildCreateEvent) []error {
|
||||
if guild.Unavailable {
|
||||
return nil
|
||||
}
|
||||
|
|
65
state/state_shard_test.go
Normal file
65
state/state_shard_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
)
|
||||
|
||||
func TestSharding(t *testing.T) {
|
||||
env := testenv.Must(t)
|
||||
|
||||
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
|
||||
data.Shard = &gateway.Shard{0, env.ShardCount}
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent)
|
||||
|
||||
m, err := shard.NewIdentifiedManager(data, NewShardFunc(
|
||||
func(m *shard.Manager, s *State) {
|
||||
now := time.Now().Format(time.StampMilli)
|
||||
t.Log(now, "initializing shard")
|
||||
|
||||
s.Gateway.ErrorLog = func(err error) {
|
||||
t.Error("gateway error:", err)
|
||||
}
|
||||
|
||||
s.AddIntents(gateway.IntentGuilds)
|
||||
s.AddHandler(readyCh)
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal("failed to make shard manager:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
// Timeout
|
||||
if err := m.Open(ctx); err != nil {
|
||||
t.Error("failed to open:", err)
|
||||
cancel()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := m.Close(); err != nil {
|
||||
t.Error("failed to close:", err)
|
||||
cancel()
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
for i := 0; i < env.ShardCount; i++ {
|
||||
select {
|
||||
case ready := <-readyCh:
|
||||
now := time.Now().Format(time.StampMilli)
|
||||
t.Log(now, "shard", ready.Shard.ShardID(), "is ready out of", env.ShardCount)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("test expired, got", i, "shards")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,8 +6,8 @@ import "github.com/diamondburned/arikawa/v3/state/store"
|
|||
|
||||
// New creates a new cabinet instance of defaultstore. For Message, it creates a
|
||||
// Message store with a limit of 100 messages.
|
||||
func New() store.Cabinet {
|
||||
return store.Cabinet{
|
||||
func New() *store.Cabinet {
|
||||
return &store.Cabinet{
|
||||
MeStore: NewMe(),
|
||||
ChannelStore: NewChannel(),
|
||||
EmojiStore: NewEmoji(),
|
||||
|
|
|
@ -123,7 +123,7 @@ type NoopStore = noop
|
|||
|
||||
// NoopCabinet is a store cabinet with all store methods set to the Noop
|
||||
// implementations.
|
||||
var NoopCabinet = Cabinet{
|
||||
var NoopCabinet = &Cabinet{
|
||||
MeStore: Noop,
|
||||
ChannelStore: Noop,
|
||||
EmojiStore: Noop,
|
||||
|
|
|
@ -133,7 +133,7 @@ func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
|
|||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
WSDebug("Send is passed the rate limiting. Waiting on mutex.")
|
||||
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package voice_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"testing"
|
||||
|
@ -41,7 +42,7 @@ func ExampleSession() {
|
|||
// This is required for bots.
|
||||
voice.AddIntents(s.Gateway)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
if err := s.Open(context.TODO()); err != nil {
|
||||
log.Fatalln("failed to open gateway:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
|
|
@ -35,9 +35,15 @@ func TestIntegration(t *testing.T) {
|
|||
}
|
||||
AddIntents(s.Gateway)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
t.Fatal("Failed to connect:", err)
|
||||
}
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Open(ctx); err != nil {
|
||||
t.Fatal("Failed to connect:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Cleanup(func() { s.Close() })
|
||||
|
||||
// Validate the given voice channel.
|
||||
|
|
Loading…
Reference in a new issue