2021-06-10 23:48:32 +00:00
|
|
|
package shard
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
|
|
|
|
"github.com/diamondburned/arikawa/v3/gateway"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
)
|
|
|
|
|
|
|
|
// Shard defines a shard gateway interface that the shard manager can use.
|
|
|
|
type Shard interface {
|
|
|
|
Open(context.Context) error
|
|
|
|
Close() error
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewShardFunc is the constructor to create a new gateway. For examples, see
|
|
|
|
// package session and state's. The constructor must manually connect the
|
|
|
|
// Manager's Rescale method appropriately.
|
|
|
|
//
|
|
|
|
// A new Gateway must not open any background resources until OpenCtx is called;
|
|
|
|
// if the gateway has never been opened, its Close method will never be called.
|
|
|
|
// During callback, the Manager is not locked, so the callback can use Manager's
|
|
|
|
// methods without deadlocking.
|
|
|
|
type NewShardFunc func(m *Manager, id *gateway.Identifier) (Shard, error)
|
|
|
|
|
|
|
|
// NewGatewayShardFunc wraps around NewGatewayShard to be compatible with
|
|
|
|
// NewShardFunc.
|
|
|
|
var NewGatewayShardFunc NewShardFunc = func(m *Manager, id *gateway.Identifier) (Shard, error) {
|
|
|
|
return NewGatewayShard(m, id), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewGatewayShard creates a new gateway that's plugged into the shard manager.
|
|
|
|
func NewGatewayShard(m *Manager, id *gateway.Identifier) *gateway.Gateway {
|
|
|
|
gw := gateway.NewCustomIdentifiedGateway(m.GatewayURL(), id)
|
|
|
|
gw.OnShardingRequired(m.Rescale)
|
|
|
|
return gw
|
|
|
|
}
|
|
|
|
|
|
|
|
// ShardState wraps around the Gateway interface to provide additional state.
|
|
|
|
type ShardState struct {
|
2021-10-31 20:10:34 +00:00
|
|
|
Shard Shard
|
2021-06-10 23:48:32 +00:00
|
|
|
// This is a bit wasteful: 2 constant pointers are stored here, and they
|
|
|
|
// waste GC cycles. This is unavoidable, however, since the API has to take
|
|
|
|
// in a pointer to Identifier, not IdentifyData. This is to ensure rescales
|
|
|
|
// are consistent.
|
|
|
|
ID gateway.Identifier
|
|
|
|
Opened bool
|
|
|
|
}
|
|
|
|
|
|
|
|
// ShardID returns the shard state's shard ID.
|
|
|
|
func (state ShardState) ShardID() int {
|
|
|
|
return state.ID.Shard.ShardID()
|
|
|
|
}
|
|
|
|
|
|
|
|
// OpenShards opens the gateways of the given list of shard states.
|
|
|
|
func OpenShards(ctx context.Context, shards []ShardState) error {
|
|
|
|
for i, shard := range shards {
|
2021-10-31 20:10:34 +00:00
|
|
|
if err := shard.Shard.Open(ctx); err != nil {
|
2021-06-10 23:48:32 +00:00
|
|
|
CloseShards(shards)
|
|
|
|
return errors.Wrapf(err, "failed to open shard %d/%d", i, len(shards)-1)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Mark as opened so we can close them.
|
|
|
|
shards[i].Opened = true
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// CloseShards closes the gateways of the given list of shard states.
|
|
|
|
func CloseShards(shards []ShardState) error {
|
|
|
|
var lastError error
|
|
|
|
|
|
|
|
for i, gw := range shards {
|
|
|
|
if gw.Opened {
|
2021-10-31 20:10:34 +00:00
|
|
|
if err := gw.Shard.Close(); err != nil {
|
2021-06-10 23:48:32 +00:00
|
|
|
lastError = err
|
|
|
|
}
|
|
|
|
|
|
|
|
shards[i].Opened = false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return lastError
|
|
|
|
}
|