1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-07 12:38:05 +00:00
arikawa/session/shard/manager.go
diamondburned 54cadd2f45 gateway: Refactor for a better concurrent API
This commit refactors the whole package gateway as well as utils/ws
(formerly utils/wsutil) and voice/voicegateway. The new refactor
utilizes a design pattern involving a concurrent loop and an arriving
event channel.

An additional change was made to the way gateway events are typed.
Before, pretty much any type will satisfy a gateway event type, since
the actual type was just interface{}. The new refactor defines a
concrete interface that events can implement:

    type Event interface {
        Op() OpCode
        EventType() EventType
    }

Using this interface, the user can easily add custom gateway events
independently of the library without relying on string maps. This adds a
lot of type safety into the library and makes type-switching on Event
types much more reasonable.

Gateway error callbacks are also almost entirely removed in favor of
custom gateway events. A catch-all can easily be added like this:

    s.AddHandler(func(err error) {
        log.Println("gateway error:, err")
    })
2021-12-14 13:49:34 -08:00

313 lines
7.8 KiB
Go

package shard
import (
"context"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/internal/backoff"
"github.com/pkg/errors"
)
func updateIdentifier(ctx context.Context, id *gateway.Identifier) (url string, err error) {
botData, err := api.NewClient(id.Token).WithContext(ctx).BotURL()
if err != nil {
return "", err
}
if botData.Shards < 1 {
botData.Shards = 1
}
id.Shard = &gateway.Shard{0, botData.Shards}
// Update the burst to be the current given time and reset it back to
// the default when the given time is reached.
id.IdentifyGlobalLimit.SetBurst(botData.StartLimit.Remaining)
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
id.IdentifyGlobalLimit.SetBurstAt(resetAt, botData.StartLimit.Total)
// Update the maximum number of identify requests allowed per 5s.
id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
return botData.URL, nil
}
// Manager is the manager responsible for handling all sharding on this
// instance. An instance of Manager must never be copied.
type Manager struct {
// shards are the *shards.shards managed by this Manager. They are
// sorted in ascending order by their shard id.
shards []ShardState
gatewayURL string
mutex sync.RWMutex
rescaling *rescalingState // nil unless rescaling
new NewShardFunc
}
type rescalingState struct {
haltRescale context.CancelFunc
rescaleDone sync.WaitGroup
}
// NewManager creates a Manager using as many gateways as recommended by
// Discord.
func NewManager(token string, fn NewShardFunc) (*Manager, error) {
id := gateway.DefaultIdentifier(token)
url, err := updateIdentifier(context.Background(), &id)
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway info")
}
return NewIdentifiedManagerWithURL(url, id, fn)
}
// NewIdentifiedManager creates a new Manager using the given
// gateway.Identifier. The total number of shards will be taken from the
// identifier instead of being queried from Discord, but the shard ID will be
// ignored.
//
// This function should rarely be used, since the shard information will be
// queried from Discord if it's required to shard anyway.
func NewIdentifiedManager(idData gateway.IdentifyCommand, fn NewShardFunc) (*Manager, error) {
// Ensure id.Shard is never nil.
if idData.Shard == nil {
idData.Shard = gateway.DefaultShard
}
id := gateway.NewIdentifier(idData)
url, err := updateIdentifier(context.Background(), &id)
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway info")
}
id.Shard = idData.Shard
return NewIdentifiedManagerWithURL(url, id, fn)
}
// NewIdentifiedManagerWithURL creates a new Manager with the given Identifier
// and gateway URL. It behaves similarly to NewIdentifiedManager.
func NewIdentifiedManagerWithURL(
url string, id gateway.Identifier, fn NewShardFunc) (*Manager, error) {
m := Manager{
gatewayURL: gateway.AddGatewayParams(url),
shards: make([]ShardState, id.Shard.NumShards()),
new: fn,
}
var err error
for i := range m.shards {
data := id.IdentifyCommand
data.Shard = &gateway.Shard{i, len(m.shards)}
m.shards[i] = ShardState{
ID: gateway.Identifier{
IdentifyCommand: data,
IdentifyShortLimit: id.IdentifyShortLimit,
IdentifyGlobalLimit: id.IdentifyGlobalLimit,
},
}
m.shards[i].Shard, err = fn(&m, &m.shards[i].ID)
if err != nil {
return nil, errors.Wrapf(err, "failed to create shard %d/%d", i, len(m.shards)-1)
}
}
return &m, nil
}
// GatewayURL returns the URL to the gateway. The URL will always have the
// needed gateway parameters appended.
func (m *Manager) GatewayURL() string {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.gatewayURL
}
// NumShards returns the total number of shards. It is OK for the caller to rely
// on NumShards while they're inside ForEach.
func (m *Manager) NumShards() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.shards)
}
// Shard gets the shard with the given ID.
func (m *Manager) Shard(ix int) Shard {
m.mutex.RLock()
defer m.mutex.RUnlock()
if ix >= len(m.shards) {
return nil
}
return m.shards[ix].Shard
}
// FromGuildID returns the Shard and the shard ID for the guild with the given
// ID.
func (m *Manager) FromGuildID(guildID discord.GuildID) (shard Shard, ix int) {
m.mutex.RLock()
defer m.mutex.RUnlock()
ix = int(uint64(guildID>>22) % uint64(len(m.shards)))
return m.shards[ix].Shard, ix
}
// ForEach calls the given function on each shard from first to last. The caller
// can safely access the number of shards by either asserting Shard to get the
// IdentifyData or call m.NumShards.
func (m *Manager) ForEach(f func(shard Shard)) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, g := range m.shards {
f(g.Shard)
}
}
// Open opens all gateways handled by this Manager. If an error occurs, Open
// will attempt to close all previously opened gateways before returning.
func (m *Manager) Open(ctx context.Context) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return OpenShards(ctx, m.shards)
}
// Close closes all gateways handled by this Manager; it will stop rescaling if
// the manager is currently being rescaled. If an error occurs, Close will
// attempt to close all remaining gateways first, before returning.
func (m *Manager) Close() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.rescaling != nil {
m.rescaling.haltRescale()
m.rescaling.rescaleDone.Wait()
m.rescaling = nil
}
return CloseShards(m.shards)
}
// Rescale rescales the manager asynchronously. The caller MUST NOT call Rescale
// in the constructor function; doing so WILL cause the state to be inconsistent
// and eventually crash and burn and destroy us all.
func (m *Manager) Rescale() {
go m.rescale()
}
func (m *Manager) rescale() {
m.mutex.Lock()
// Exit if we're already rescaling.
if m.rescaling != nil {
m.mutex.Unlock()
return
}
// Create a new context to allow the caller to cancel rescaling.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
m.rescaling = &rescalingState{haltRescale: cancel}
m.rescaling.rescaleDone.Add(1)
defer m.rescaling.rescaleDone.Done()
// Take the old list of shards for ourselves.
oldShards := m.shards
m.shards = nil
m.mutex.Unlock()
// Close the shards outside the lock. This should be fairly quickly, but it
// allows the caller to halt rescaling while we're closing or opening the
// shards.
CloseShards(oldShards)
backoffT := backoff.NewTimer(time.Second, 15*time.Minute)
defer backoffT.Stop()
for {
if m.tryRescale(ctx) {
return
}
select {
case <-backoffT.Next():
continue
case <-ctx.Done():
return
}
}
}
// tryRescale attempts once to rescale. It assumes the mutex is unlocked and
// will unlock the mutex itself.
func (m *Manager) tryRescale(ctx context.Context) bool {
m.mutex.Lock()
data := m.shards[0].ID.IdentifyCommand
newID := gateway.NewIdentifier(data)
url, err := updateIdentifier(ctx, &newID)
if err != nil {
m.mutex.Unlock()
return false
}
numShards := newID.Shard.NumShards()
m.gatewayURL = url
// Release the mutex early.
m.mutex.Unlock()
// Create the shards slice to set after we reacquire the mutex.
newShards := make([]ShardState, numShards)
for i := 0; i < numShards; i++ {
data := newID.IdentifyCommand
data.Shard = &gateway.Shard{i, len(m.shards)}
newShards[i] = ShardState{
ID: gateway.Identifier{
IdentifyCommand: data,
IdentifyShortLimit: newID.IdentifyShortLimit,
IdentifyGlobalLimit: newID.IdentifyGlobalLimit,
},
}
newShards[i].Shard, err = m.new(m, &newShards[i].ID)
if err != nil {
return false
}
}
if err := OpenShards(ctx, newShards); err != nil {
return false
}
m.mutex.Lock()
m.shards = newShards
m.rescaling = nil
m.mutex.Unlock()
return true
}