mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-10 22:16:59 +00:00
d36955acea
Signed-off-by: Cléo Rebert <cleo.rebert@gmail.com>
313 lines
7.8 KiB
Go
313 lines
7.8 KiB
Go
package shard
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/diamondburned/arikawa/v3/api"
|
|
"github.com/diamondburned/arikawa/v3/discord"
|
|
"github.com/diamondburned/arikawa/v3/gateway"
|
|
"github.com/diamondburned/arikawa/v3/internal/backoff"
|
|
)
|
|
|
|
func updateIdentifier(ctx context.Context, id *gateway.Identifier) (url string, err error) {
|
|
botData, err := api.NewClient(id.Token).WithContext(ctx).BotURL()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if botData.Shards < 1 {
|
|
botData.Shards = 1
|
|
}
|
|
|
|
id.Shard = &gateway.Shard{0, botData.Shards}
|
|
|
|
// Update the burst to be the current given time and reset it back to
|
|
// the default when the given time is reached.
|
|
id.IdentifyGlobalLimit.SetBurst(botData.StartLimit.Remaining)
|
|
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
|
|
id.IdentifyGlobalLimit.SetBurstAt(resetAt, botData.StartLimit.Total)
|
|
|
|
// Update the maximum number of identify requests allowed per 5s.
|
|
id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
|
|
|
|
return botData.URL, nil
|
|
}
|
|
|
|
// Manager is the manager responsible for handling all sharding on this
|
|
// instance. An instance of Manager must never be copied.
|
|
type Manager struct {
|
|
// shards are the *shards.shards managed by this Manager. They are
|
|
// sorted in ascending order by their shard id.
|
|
shards []ShardState
|
|
gatewayURL string
|
|
|
|
mutex sync.RWMutex
|
|
|
|
rescaling *rescalingState // nil unless rescaling
|
|
|
|
new NewShardFunc
|
|
}
|
|
|
|
type rescalingState struct {
|
|
haltRescale context.CancelFunc
|
|
rescaleDone sync.WaitGroup
|
|
}
|
|
|
|
// NewManager creates a Manager using as many gateways as recommended by
|
|
// Discord.
|
|
func NewManager(token string, fn NewShardFunc) (*Manager, error) {
|
|
id := gateway.DefaultIdentifier(token)
|
|
|
|
url, err := updateIdentifier(context.Background(), &id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get gateway info: %w", err)
|
|
}
|
|
|
|
return NewIdentifiedManagerWithURL(url, id, fn)
|
|
}
|
|
|
|
// NewIdentifiedManager creates a new Manager using the given
|
|
// gateway.Identifier. The total number of shards will be taken from the
|
|
// identifier instead of being queried from Discord, but the shard ID will be
|
|
// ignored.
|
|
//
|
|
// This function should rarely be used, since the shard information will be
|
|
// queried from Discord if it's required to shard anyway.
|
|
func NewIdentifiedManager(idData gateway.IdentifyCommand, fn NewShardFunc) (*Manager, error) {
|
|
// Ensure id.Shard is never nil.
|
|
if idData.Shard == nil {
|
|
idData.Shard = gateway.DefaultShard
|
|
}
|
|
|
|
id := gateway.NewIdentifier(idData)
|
|
|
|
url, err := updateIdentifier(context.Background(), &id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get gateway info: %w", err)
|
|
}
|
|
|
|
id.Shard = idData.Shard
|
|
|
|
return NewIdentifiedManagerWithURL(url, id, fn)
|
|
}
|
|
|
|
// NewIdentifiedManagerWithURL creates a new Manager with the given Identifier
|
|
// and gateway URL. It behaves similarly to NewIdentifiedManager.
|
|
func NewIdentifiedManagerWithURL(
|
|
url string, id gateway.Identifier, fn NewShardFunc) (*Manager, error) {
|
|
|
|
m := Manager{
|
|
gatewayURL: gateway.AddGatewayParams(url),
|
|
shards: make([]ShardState, id.Shard.NumShards()),
|
|
new: fn,
|
|
}
|
|
|
|
var err error
|
|
|
|
for i := range m.shards {
|
|
data := id.IdentifyCommand
|
|
data.Shard = &gateway.Shard{i, len(m.shards)}
|
|
|
|
m.shards[i] = ShardState{
|
|
ID: gateway.Identifier{
|
|
IdentifyCommand: data,
|
|
IdentifyShortLimit: id.IdentifyShortLimit,
|
|
IdentifyGlobalLimit: id.IdentifyGlobalLimit,
|
|
},
|
|
}
|
|
|
|
m.shards[i].Shard, err = fn(&m, &m.shards[i].ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create shard %d/%d: %w", i, len(m.shards)-1, err)
|
|
}
|
|
}
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
// GatewayURL returns the URL to the gateway. The URL will always have the
|
|
// needed gateway parameters appended.
|
|
func (m *Manager) GatewayURL() string {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
return m.gatewayURL
|
|
}
|
|
|
|
// NumShards returns the total number of shards. It is OK for the caller to rely
|
|
// on NumShards while they're inside ForEach.
|
|
func (m *Manager) NumShards() int {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
return len(m.shards)
|
|
}
|
|
|
|
// Shard gets the shard with the given ID.
|
|
func (m *Manager) Shard(ix int) Shard {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
if ix >= len(m.shards) {
|
|
return nil
|
|
}
|
|
|
|
return m.shards[ix].Shard
|
|
}
|
|
|
|
// FromGuildID returns the Shard and the shard ID for the guild with the given
|
|
// ID.
|
|
func (m *Manager) FromGuildID(guildID discord.GuildID) (shard Shard, ix int) {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
ix = int(uint64(guildID>>22) % uint64(len(m.shards)))
|
|
return m.shards[ix].Shard, ix
|
|
}
|
|
|
|
// ForEach calls the given function on each shard from first to last. The caller
|
|
// can safely access the number of shards by either asserting Shard to get the
|
|
// IdentifyData or call m.NumShards.
|
|
func (m *Manager) ForEach(f func(shard Shard)) {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
for _, g := range m.shards {
|
|
f(g.Shard)
|
|
}
|
|
}
|
|
|
|
// Open opens all gateways handled by this Manager. If an error occurs, Open
|
|
// will attempt to close all previously opened gateways before returning.
|
|
func (m *Manager) Open(ctx context.Context) error {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
return OpenShards(ctx, m.shards)
|
|
}
|
|
|
|
// Close closes all gateways handled by this Manager; it will stop rescaling if
|
|
// the manager is currently being rescaled. If an error occurs, Close will
|
|
// attempt to close all remaining gateways first, before returning.
|
|
func (m *Manager) Close() error {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
if m.rescaling != nil {
|
|
m.rescaling.haltRescale()
|
|
m.rescaling.rescaleDone.Wait()
|
|
|
|
m.rescaling = nil
|
|
}
|
|
|
|
return CloseShards(m.shards)
|
|
}
|
|
|
|
// Rescale rescales the manager asynchronously. The caller MUST NOT call Rescale
|
|
// in the constructor function; doing so WILL cause the state to be inconsistent
|
|
// and eventually crash and burn and destroy us all.
|
|
func (m *Manager) Rescale() {
|
|
go m.rescale()
|
|
}
|
|
|
|
func (m *Manager) rescale() {
|
|
m.mutex.Lock()
|
|
|
|
// Exit if we're already rescaling.
|
|
if m.rescaling != nil {
|
|
m.mutex.Unlock()
|
|
return
|
|
}
|
|
|
|
// Create a new context to allow the caller to cancel rescaling.
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
m.rescaling = &rescalingState{haltRescale: cancel}
|
|
m.rescaling.rescaleDone.Add(1)
|
|
defer m.rescaling.rescaleDone.Done()
|
|
|
|
// Take the old list of shards for ourselves.
|
|
oldShards := m.shards
|
|
m.shards = nil
|
|
|
|
m.mutex.Unlock()
|
|
|
|
// Close the shards outside the lock. This should be fairly quickly, but it
|
|
// allows the caller to halt rescaling while we're closing or opening the
|
|
// shards.
|
|
CloseShards(oldShards)
|
|
|
|
backoffT := backoff.NewTimer(time.Second, 15*time.Minute)
|
|
defer backoffT.Stop()
|
|
|
|
for {
|
|
if m.tryRescale(ctx) {
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-backoffT.Next():
|
|
continue
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// tryRescale attempts once to rescale. It assumes the mutex is unlocked and
|
|
// will unlock the mutex itself.
|
|
func (m *Manager) tryRescale(ctx context.Context) bool {
|
|
m.mutex.Lock()
|
|
|
|
data := m.shards[0].ID.IdentifyCommand
|
|
newID := gateway.NewIdentifier(data)
|
|
|
|
url, err := updateIdentifier(ctx, &newID)
|
|
if err != nil {
|
|
m.mutex.Unlock()
|
|
return false
|
|
}
|
|
|
|
numShards := newID.Shard.NumShards()
|
|
m.gatewayURL = url
|
|
|
|
// Release the mutex early.
|
|
m.mutex.Unlock()
|
|
|
|
// Create the shards slice to set after we reacquire the mutex.
|
|
newShards := make([]ShardState, numShards)
|
|
|
|
for i := 0; i < numShards; i++ {
|
|
data := newID.IdentifyCommand
|
|
data.Shard = &gateway.Shard{i, len(m.shards)}
|
|
|
|
newShards[i] = ShardState{
|
|
ID: gateway.Identifier{
|
|
IdentifyCommand: data,
|
|
IdentifyShortLimit: newID.IdentifyShortLimit,
|
|
IdentifyGlobalLimit: newID.IdentifyGlobalLimit,
|
|
},
|
|
}
|
|
|
|
newShards[i].Shard, err = m.new(m, &newShards[i].ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
}
|
|
|
|
if err := OpenShards(ctx, newShards); err != nil {
|
|
return false
|
|
}
|
|
|
|
m.mutex.Lock()
|
|
m.shards = newShards
|
|
m.rescaling = nil
|
|
m.mutex.Unlock()
|
|
|
|
return true
|
|
}
|