mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-02 18:26:41 +00:00
Gateway: Added intent helpers and more context API support
This commit is contained in:
parent
f33b4ff7d8
commit
edb8a46ef2
|
@ -140,7 +140,8 @@ type Context struct {
|
|||
|
||||
// Start quickly starts a bot with the given command. It will prepend "Bot"
|
||||
// into the token automatically. Refer to example/ for usage.
|
||||
func Start(token string, cmd interface{},
|
||||
func Start(
|
||||
token string, cmd interface{},
|
||||
opts func(*Context) error) (wait func() error, err error) {
|
||||
|
||||
s, err := state.New("Bot " + token)
|
||||
|
@ -227,6 +228,12 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
|
|||
return ctx, nil
|
||||
}
|
||||
|
||||
// AddIntent adds the given Gateway Intent into the Gateway. This is a
|
||||
// convenient function that calls Gateway's AddIntent.
|
||||
func (ctx *Context) AddIntent(i gateway.Intents) {
|
||||
ctx.Gateway.AddIntent(i)
|
||||
}
|
||||
|
||||
// Subcommands returns the slice of subcommands. To add subcommands, use
|
||||
// RegisterSubcommand().
|
||||
func (ctx *Context) Subcommands() []*Subcommand {
|
||||
|
|
|
@ -15,11 +15,18 @@ func (g *Gateway) Identify() error {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.IdentifyCtx(ctx)
|
||||
}
|
||||
|
||||
func (g *Gateway) IdentifyCtx(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := g.Identifier.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "can't wait for identify()")
|
||||
}
|
||||
|
||||
return g.Send(IdentifyOP, g.Identifier)
|
||||
return g.SendCtx(ctx, IdentifyOP, g.Identifier)
|
||||
}
|
||||
|
||||
type ResumeData struct {
|
||||
|
@ -31,6 +38,15 @@ type ResumeData struct {
|
|||
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
|
||||
// from a dead connection. Start() resumes from a dead connection.
|
||||
func (g *Gateway) Resume() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.ResumeCtx(ctx)
|
||||
}
|
||||
|
||||
// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume
|
||||
// from a dead connection. Start() resumes from a dead connection.
|
||||
func (g *Gateway) ResumeCtx(ctx context.Context) error {
|
||||
var (
|
||||
ses = g.SessionID
|
||||
seq = g.Sequence.Get()
|
||||
|
@ -40,7 +56,7 @@ func (g *Gateway) Resume() error {
|
|||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
return g.Send(ResumeOP, ResumeData{
|
||||
return g.SendCtx(ctx, ResumeOP, ResumeData{
|
||||
Token: g.Identifier.Token,
|
||||
SessionID: ses,
|
||||
Sequence: seq,
|
||||
|
@ -51,7 +67,14 @@ func (g *Gateway) Resume() error {
|
|||
type HeartbeatData int
|
||||
|
||||
func (g *Gateway) Heartbeat() error {
|
||||
return g.Send(HeartbeatOP, g.Sequence.Get())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.HeartbeatCtx(ctx)
|
||||
}
|
||||
|
||||
func (g *Gateway) HeartbeatCtx(ctx context.Context) error {
|
||||
return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get())
|
||||
}
|
||||
|
||||
type RequestGuildMembersData struct {
|
||||
|
@ -61,10 +84,20 @@ type RequestGuildMembersData struct {
|
|||
Query string `json:"query,omitempty"`
|
||||
Limit uint `json:"limit"`
|
||||
Presences bool `json:"presences,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error {
|
||||
return g.Send(RequestGuildMembersOP, data)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.RequestGuildMembersCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) RequestGuildMembersCtx(
|
||||
ctx context.Context, data RequestGuildMembersData) error {
|
||||
|
||||
return g.SendCtx(ctx, RequestGuildMembersOP, data)
|
||||
}
|
||||
|
||||
type UpdateVoiceStateData struct {
|
||||
|
@ -75,7 +108,16 @@ type UpdateVoiceStateData struct {
|
|||
}
|
||||
|
||||
func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error {
|
||||
return g.Send(VoiceStateUpdateOP, data)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.UpdateVoiceStateCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) UpdateVoiceStateCtx(
|
||||
ctx context.Context, data UpdateVoiceStateData) error {
|
||||
|
||||
return g.SendCtx(ctx, VoiceStateUpdateOP, data)
|
||||
}
|
||||
|
||||
type UpdateStatusData struct {
|
||||
|
@ -90,7 +132,14 @@ type UpdateStatusData struct {
|
|||
}
|
||||
|
||||
func (g *Gateway) UpdateStatus(data UpdateStatusData) error {
|
||||
return g.Send(StatusUpdateOP, data)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.UpdateStatusCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error {
|
||||
return g.SendCtx(ctx, StatusUpdateOP, data)
|
||||
}
|
||||
|
||||
// Undocumented
|
||||
|
@ -104,5 +153,12 @@ type GuildSubscribeData struct {
|
|||
}
|
||||
|
||||
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {
|
||||
return g.Send(GuildSubscriptionsOP, data)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.GuildSubscribeCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error {
|
||||
return g.SendCtx(ctx, GuildSubscriptionsOP, data)
|
||||
}
|
||||
|
|
|
@ -99,11 +99,15 @@ type (
|
|||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
Members []discord.Member `json:"members"`
|
||||
|
||||
ChunkIndex int `json:"chunk_index"`
|
||||
ChunkCount int `json:"chunk_count"`
|
||||
|
||||
// Whatever's not found goes here
|
||||
NotFound []string `json:"not_found,omitempty"`
|
||||
|
||||
// Only filled if requested
|
||||
Presences []discord.Presence `json:"presences,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
// GuildMemberListUpdate is an undocumented event. It's received when the
|
||||
|
|
|
@ -107,8 +107,23 @@ type Gateway struct {
|
|||
waitGroup *sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
||||
// information, refer to NewGatewayWithDriver.
|
||||
// NewGatewayWithIntents creates a new Gateway with the given intents and the
|
||||
// default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
|
||||
func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
|
||||
g, err := NewGateway(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, intent := range intents {
|
||||
g.AddIntent(intent)
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// NewGateway creates a new Gateway with the default stdlib JSON driver. For
|
||||
// more information, refer to NewGatewayWithDriver.
|
||||
func NewGateway(token string) (*Gateway, error) {
|
||||
URL, err := URL()
|
||||
if err != nil {
|
||||
|
@ -141,6 +156,12 @@ func NewCustomGateway(gatewayURL, token string) *Gateway {
|
|||
}
|
||||
}
|
||||
|
||||
// AddIntent adds a Gateway Intent before connecting to the Gateway. As
|
||||
// such, this function will only work before Open() is called.
|
||||
func (g *Gateway) AddIntent(i Intents) {
|
||||
g.Identifier.Intents |= i
|
||||
}
|
||||
|
||||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
wsutil.WSDebug("Trying to close.")
|
||||
|
@ -182,10 +203,13 @@ func (g *Gateway) Close() error {
|
|||
// Reconnect tries to reconnect forever. It will resume the connection if
|
||||
// possible. If an Invalid Session is received, it will start a fresh one.
|
||||
func (g *Gateway) Reconnect() error {
|
||||
return g.ReconnectContext(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.ReconnectCtx(ctx)
|
||||
}
|
||||
|
||||
func (g *Gateway) ReconnectContext(ctx context.Context) error {
|
||||
func (g *Gateway) ReconnectCtx(ctx context.Context) error {
|
||||
wsutil.WSDebug("Reconnecting...")
|
||||
|
||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||
|
@ -212,9 +236,15 @@ func (g *Gateway) ReconnectContext(ctx context.Context) error {
|
|||
// Open connects to the Websocket and authenticate it. You should usually use
|
||||
// this function over Start().
|
||||
func (g *Gateway) Open() error {
|
||||
return g.OpenContext(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.OpenContext(ctx)
|
||||
}
|
||||
|
||||
// OpenContext connects to the Websocket and authenticates it. Yuo should
|
||||
// usually use this function over Start(). The given context provides
|
||||
// cancellation and timeout.
|
||||
func (g *Gateway) OpenContext(ctx context.Context) error {
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Dial(ctx); err != nil {
|
||||
|
@ -224,7 +254,7 @@ func (g *Gateway) OpenContext(ctx context.Context) error {
|
|||
wsutil.WSDebug("Trying to start...")
|
||||
|
||||
// Try to resume the connection
|
||||
if err := g.Start(); err != nil {
|
||||
if err := g.StartCtx(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -232,14 +262,19 @@ func (g *Gateway) OpenContext(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Start authenticates with the websocket, or resume from a dead Websocket
|
||||
// connection. This function doesn't block. You wouldn't usually use this
|
||||
// Start calls StartCtx with a background context. You wouldn't usually use this
|
||||
// function, but Open() instead.
|
||||
func (g *Gateway) Start() error {
|
||||
// g.available.Lock()
|
||||
// defer g.available.Unlock()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := g.start(); err != nil {
|
||||
return g.StartCtx(ctx)
|
||||
}
|
||||
|
||||
// StartCtx authenticates with the websocket, or resume from a dead Websocket
|
||||
// connection. You wouldn't usually use this function, but OpenCtx() instead.
|
||||
func (g *Gateway) StartCtx(ctx context.Context) error {
|
||||
if err := g.start(ctx); err != nil {
|
||||
wsutil.WSDebug("Start failed:", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
|
@ -249,31 +284,41 @@ func (g *Gateway) Start() error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) start() error {
|
||||
func (g *Gateway) start(ctx context.Context) error {
|
||||
// This is where we'll get our events
|
||||
ch := g.WS.Listen()
|
||||
|
||||
// Make a new WaitGroup for use in background loops:
|
||||
g.waitGroup = new(sync.WaitGroup)
|
||||
|
||||
// Wait for an OP 10 Hello
|
||||
// Create a new Hello event and wait for it.
|
||||
var hello HelloEvent
|
||||
if _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "error at Hello")
|
||||
// Wait for an OP 10 Hello.
|
||||
select {
|
||||
case e, ok := <-ch:
|
||||
if !ok {
|
||||
return errors.New("unexpected ws close while waiting for Hello")
|
||||
}
|
||||
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "error at Hello")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
|
||||
}
|
||||
|
||||
// Send Discord either the Identify packet (if it's a fresh connection), or
|
||||
// a Resume packet (if it's a dead connection).
|
||||
if g.SessionID == "" {
|
||||
// SessionID is empty, so this is a completely new session.
|
||||
if err := g.Identify(); err != nil {
|
||||
if err := g.IdentifyCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to identify")
|
||||
}
|
||||
} else {
|
||||
if err := g.Resume(); err != nil {
|
||||
if err := g.ResumeCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to resume")
|
||||
}
|
||||
}
|
||||
|
@ -282,7 +327,7 @@ func (g *Gateway) start() error {
|
|||
wsutil.WSDebug("Waiting for either READY or RESUMED.")
|
||||
|
||||
// WaitForEvent should
|
||||
err := wsutil.WaitForEvent(g, ch, func(op *wsutil.OP) bool {
|
||||
err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
|
||||
switch op.EventName {
|
||||
case "READY":
|
||||
wsutil.WSDebug("Found READY event.")
|
||||
|
@ -319,7 +364,9 @@ func (g *Gateway) start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) Send(code OPCode, v interface{}) error {
|
||||
// SendCtx is a low-level function to send an OP payload to the Gateway. Most
|
||||
// users shouldn't touch this, unless they know what they're doing.
|
||||
func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
|
||||
var op = wsutil.OP{
|
||||
Code: code,
|
||||
}
|
||||
|
@ -339,5 +386,5 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
|||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return g.WS.Send(b)
|
||||
return g.WS.SendCtx(ctx, b)
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ func (i *IdentifyData) SetShard(id, num int) {
|
|||
i.Shard[0], i.Shard[1] = id, num
|
||||
}
|
||||
|
||||
// Intents is a new Discord API feature that's documented at
|
||||
// Intents for the new Discord API feature, documented at
|
||||
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
|
||||
type Intents uint32
|
||||
|
||||
|
|
|
@ -107,13 +107,15 @@ func wait(t *testing.T, evCh chan interface{}) interface{} {
|
|||
select {
|
||||
case ev := <-evCh:
|
||||
return ev
|
||||
case <-time.After(10 * time.Second):
|
||||
case <-time.After(20 * time.Second):
|
||||
t.Fatal("Timed out waiting for event")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func gotimeout(t *testing.T, fn func()) {
|
||||
t.Helper()
|
||||
|
||||
var done = make(chan struct{})
|
||||
go func() {
|
||||
fn()
|
||||
|
@ -121,7 +123,7 @@ func gotimeout(t *testing.T, fn func()) {
|
|||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
case <-time.After(20 * time.Second):
|
||||
t.Fatal("Timed out waiting for function.")
|
||||
case <-done:
|
||||
return
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
@ -36,15 +37,21 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
g.PacerLoop.Echo()
|
||||
|
||||
case HeartbeatOP:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Server requesting a heartbeat.
|
||||
return g.PacerLoop.Pace()
|
||||
return g.PacerLoop.Pace(ctx)
|
||||
|
||||
case ReconnectOP:
|
||||
// Server requests to reconnect, die and retry.
|
||||
wsutil.WSDebug("ReconnectOP received.")
|
||||
|
||||
// We must reconnect in another goroutine, as running Reconnect
|
||||
// synchronously would prevent the main event loop from exiting.
|
||||
go g.Reconnect()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
go func() { g.ReconnectCtx(ctx); cancel() }()
|
||||
|
||||
// Gracefully exit with a nil let the event handler take the signal from
|
||||
// the pacemaker.
|
||||
return nil
|
||||
|
@ -53,11 +60,16 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
// Discord expects us to sleep for no reason
|
||||
time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Invalid session, try and Identify.
|
||||
if err := g.Identify(); err != nil {
|
||||
if err := g.IdentifyCtx(ctx); err != nil {
|
||||
// Can't identify, reconnect.
|
||||
go g.Reconnect()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
go func() { g.ReconnectCtx(ctx); cancel() }()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
case HelloOP:
|
||||
|
|
|
@ -41,6 +41,17 @@ type Session struct {
|
|||
hstop chan struct{}
|
||||
}
|
||||
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
|
||||
g, err := gateway.NewGatewayWithIntents(token, intents...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to connect to Gateway")
|
||||
}
|
||||
|
||||
return NewWithGateway(g), nil
|
||||
}
|
||||
|
||||
// New creates a new session from a given token. Most bots should be using
|
||||
// NewWithIntents instead.
|
||||
func New(token string) (*Session, error) {
|
||||
// Create a gateway
|
||||
g, err := gateway.NewGateway(token)
|
||||
|
@ -48,7 +59,7 @@ func New(token string) (*Session, error) {
|
|||
return nil, errors.Wrap(err, "failed to connect to Gateway")
|
||||
}
|
||||
|
||||
return NewWithGateway(g), err
|
||||
return NewWithGateway(g), nil
|
||||
}
|
||||
|
||||
// Login tries to log in as a normal user account; MFA is optional.
|
||||
|
|
|
@ -97,10 +97,22 @@ type State struct {
|
|||
unreadyGuilds *moreatomic.SnowflakeSet
|
||||
}
|
||||
|
||||
// New creates a new state.
|
||||
func New(token string) (*State, error) {
|
||||
return NewWithStore(token, NewDefaultStore(nil))
|
||||
}
|
||||
|
||||
// NewWithIntents creates a new state with the given gateway intents. For more
|
||||
// information, refer to gateway.Intents.
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
|
||||
s, err := session.NewWithIntents(token, intents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewFromSession(s, NewDefaultStore(nil))
|
||||
}
|
||||
|
||||
func NewWithStore(token string, store Store) (*State, error) {
|
||||
s, err := session.New(token)
|
||||
if err != nil {
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/state"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/voice/voicegateway"
|
||||
)
|
||||
|
@ -94,24 +94,23 @@ func TestIntegration(t *testing.T) {
|
|||
log.Println(append([]interface{}{caller}, v...)...)
|
||||
}
|
||||
|
||||
// heart.Debug = func(v ...interface{}) {
|
||||
// log.Println(append([]interface{}{"Pacemaker:"}, v...)...)
|
||||
// }
|
||||
|
||||
s, err := state.New("Bot " + config.BotToken)
|
||||
v, err := NewVoiceFromToken("Bot " + config.BotToken)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create a new session:", err)
|
||||
t.Fatal("Failed to create a new voice session:", err)
|
||||
}
|
||||
v.Gateway.AddIntent(gateway.IntentGuildVoiceStates)
|
||||
|
||||
v.ErrorLog = func(err error) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v := NewVoice(s)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
if err := v.Open(); err != nil {
|
||||
t.Fatal("Failed to connect:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
defer v.Close()
|
||||
|
||||
// Validate the given voice channel.
|
||||
c, err := s.Channel(config.VoiceChID)
|
||||
c, err := v.Channel(config.VoiceChID)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to get channel:", err)
|
||||
}
|
||||
|
@ -119,6 +118,8 @@ func TestIntegration(t *testing.T) {
|
|||
t.Fatal("Channel isn't a guild voice channel.")
|
||||
}
|
||||
|
||||
log.Println("The voice channel's name is", c.Name)
|
||||
|
||||
// Grab a timer to benchmark things.
|
||||
finish := timer()
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
|
@ -17,6 +19,11 @@ const Protocol = "xsalsa20_poly1305"
|
|||
|
||||
var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE}
|
||||
|
||||
// WSTimeout is the duration to wait for a gateway operation including Session
|
||||
// to complete before erroring out. This only applies to functions that don't
|
||||
// take in a context already.
|
||||
var WSTimeout = 10 * time.Second
|
||||
|
||||
type Session struct {
|
||||
session *session.Session
|
||||
state voicegateway.State
|
||||
|
@ -52,11 +59,16 @@ func NewSession(ses *session.Session, userID discord.Snowflake) *Session {
|
|||
UserID: userID,
|
||||
},
|
||||
ErrorLog: func(err error) {},
|
||||
incoming: make(chan struct{}),
|
||||
incoming: make(chan struct{}, 2),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||
if s.state.GuildID != ev.GuildID {
|
||||
// Not our state.
|
||||
return
|
||||
}
|
||||
|
||||
// If this is true, then mutex is acquired already.
|
||||
if s.joining.Get() {
|
||||
s.state.Endpoint = ev.Endpoint
|
||||
|
@ -73,7 +85,10 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
|||
s.state.Endpoint = ev.Endpoint
|
||||
s.state.Token = ev.Token
|
||||
|
||||
if err := s.reconnect(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.reconnectCtx(ctx); err != nil {
|
||||
s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update"))
|
||||
}
|
||||
}
|
||||
|
@ -95,6 +110,16 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
|||
}
|
||||
|
||||
func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
|
||||
}
|
||||
|
||||
func (s *Session) JoinChannelCtx(
|
||||
ctx context.Context,
|
||||
gID, cID discord.Snowflake, muted, deafened bool) error {
|
||||
|
||||
// Acquire the mutex during join, locking during IO as well.
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
@ -103,7 +128,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
|
|||
s.joining.Set(true)
|
||||
defer s.joining.Set(false) // reset when done
|
||||
|
||||
// ensure gateeway and voiceUDP is already closed.
|
||||
// Ensure gateway and voiceUDP are already closed.
|
||||
s.ensureClosed()
|
||||
|
||||
// Set the state.
|
||||
|
@ -122,7 +147,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
|
|||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
|
||||
// Send a Voice State Update event to the gateway.
|
||||
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{
|
||||
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
|
||||
GuildID: gID,
|
||||
ChannelID: channelID,
|
||||
SelfMute: muted,
|
||||
|
@ -132,23 +157,37 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
|
|||
return errors.Wrap(err, "failed to send Voice State Update event")
|
||||
}
|
||||
|
||||
// Wait for replies. The above command should reply with these 2 events.
|
||||
<-s.incoming
|
||||
<-s.incoming
|
||||
// Wait for 2 replies. The above command should reply with these 2 events.
|
||||
if err := s.waitForIncoming(ctx, 2); err != nil {
|
||||
return errors.Wrap(err, "failed to wait for needed gateway events")
|
||||
}
|
||||
|
||||
// These 2 methods should've updated s.state before sending into these
|
||||
// channels. Since s.state is already filled, we can go ahead and connect.
|
||||
|
||||
return s.reconnect()
|
||||
return s.reconnectCtx(ctx)
|
||||
}
|
||||
|
||||
func (s *Session) waitForIncoming(ctx context.Context, n int) error {
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case <-s.incoming:
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconnect uses the current state to reconnect to a new gateway and UDP
|
||||
// connection.
|
||||
func (s *Session) reconnect() (err error) {
|
||||
func (s *Session) reconnectCtx(ctx context.Context) (err error) {
|
||||
s.gateway = voicegateway.New(s.state)
|
||||
|
||||
// Open the voice gateway. The function will block until Ready is received.
|
||||
if err := s.gateway.Open(); err != nil {
|
||||
if err := s.gateway.OpenCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to open voice gateway")
|
||||
}
|
||||
|
||||
|
@ -156,13 +195,13 @@ func (s *Session) reconnect() (err error) {
|
|||
voiceReady := s.gateway.Ready()
|
||||
|
||||
// Prepare the UDP voice connection.
|
||||
s.voiceUDP, err = udp.DialConnection(voiceReady.Addr(), voiceReady.SSRC)
|
||||
s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to open voice UDP connection")
|
||||
}
|
||||
|
||||
// Get the session description from the voice gateway.
|
||||
d, err := s.gateway.SessionDescription(voicegateway.SelectProtocol{
|
||||
d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{
|
||||
Protocol: "udp",
|
||||
Data: voicegateway.SelectProtocolData{
|
||||
Address: s.voiceUDP.GatewayIP,
|
||||
|
@ -200,17 +239,31 @@ func (s *Session) StopSpeaking() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Write writes into the UDP voice connection WITHOUT a timeout.
|
||||
func (s *Session) Write(b []byte) (int, error) {
|
||||
return s.WriteCtx(context.Background(), b)
|
||||
}
|
||||
|
||||
// WriteCtx writes into the UDP voice connection with a context for timeout.
|
||||
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
if s.voiceUDP == nil {
|
||||
return 0, ErrCannotSend
|
||||
}
|
||||
return s.voiceUDP.Write(b)
|
||||
|
||||
return s.voiceUDP.WriteCtx(ctx, b)
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.DisconnectCtx(ctx)
|
||||
}
|
||||
|
||||
func (s *Session) DisconnectCtx(ctx context.Context) error {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
|
@ -223,7 +276,7 @@ func (s *Session) Disconnect() error {
|
|||
// VoiceStateUpdateEvent, in which our handler will promptly remove the
|
||||
// session from the map.
|
||||
|
||||
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{
|
||||
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
|
||||
GuildID: s.state.GuildID,
|
||||
ChannelID: discord.NullSnowflake,
|
||||
SelfMute: true,
|
||||
|
|
|
@ -2,6 +2,7 @@ package udp
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -11,6 +12,11 @@ import (
|
|||
"golang.org/x/crypto/nacl/secretbox"
|
||||
)
|
||||
|
||||
// Dialer is the default dialer that this package uses for all its dialing.
|
||||
var Dialer = net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
GatewayIP string
|
||||
GatewayPort uint16
|
||||
|
@ -21,7 +27,7 @@ type Connection struct {
|
|||
timestamp uint32
|
||||
nonce [24]byte
|
||||
|
||||
conn *net.UDPConn
|
||||
conn net.Conn
|
||||
close chan struct{}
|
||||
closed chan struct{}
|
||||
|
||||
|
@ -29,15 +35,15 @@ type Connection struct {
|
|||
reply chan error
|
||||
}
|
||||
|
||||
func DialConnection(addr string, ssrc uint32) (*Connection, error) {
|
||||
// Resolve the host.
|
||||
a, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to resolve host")
|
||||
}
|
||||
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
|
||||
// // Resolve the host.
|
||||
// a, err := net.ResolveUDPAddr("udp", addr)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "failed to resolve host")
|
||||
// }
|
||||
|
||||
// Create a new UDP connection.
|
||||
conn, err := net.DialUDP("udp", nil, a)
|
||||
conn, err := Dialer.DialContext(ctx, "udp", addr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to dial host")
|
||||
}
|
||||
|
@ -154,9 +160,22 @@ func (c *Connection) Close() error {
|
|||
|
||||
// Write sends bytes into the voice UDP connection.
|
||||
func (c *Connection) Write(b []byte) (int, error) {
|
||||
c.send <- b
|
||||
if err := <-c.reply; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
return c.WriteCtx(context.Background(), b)
|
||||
}
|
||||
|
||||
// WriteCtx sends bytes into the voice UDP connection with a timeout.
|
||||
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
select {
|
||||
case c.send <- b:
|
||||
break
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-c.reply:
|
||||
return len(b), err
|
||||
case <-ctx.Done():
|
||||
return len(b), ctx.Err()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,11 +31,25 @@ type Voice struct {
|
|||
mapmutex sync.Mutex
|
||||
sessions map[discord.Snowflake]*Session // guildID:Session
|
||||
|
||||
// Callbacks to remove the handlers.
|
||||
closers []func()
|
||||
|
||||
// ErrorLog will be called when an error occurs (defaults to log.Println)
|
||||
ErrorLog func(err error)
|
||||
}
|
||||
|
||||
// NewVoice creates a new Voice repository wrapped around a state.
|
||||
// NewVoiceFromToken creates a new voice session from the given token.
|
||||
func NewVoiceFromToken(token string) (*Voice, error) {
|
||||
s, err := state.New(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create a new session")
|
||||
}
|
||||
|
||||
return NewVoice(s), nil
|
||||
}
|
||||
|
||||
// NewVoice creates a new Voice repository wrapped around a state. The function
|
||||
// will also automatically add the GuildVoiceStates intent, as that is required.
|
||||
func NewVoice(s *state.State) *Voice {
|
||||
v := &Voice{
|
||||
State: s,
|
||||
|
@ -44,8 +58,10 @@ func NewVoice(s *state.State) *Voice {
|
|||
}
|
||||
|
||||
// Add the required event handlers to the session.
|
||||
s.AddHandler(v.onVoiceStateUpdate)
|
||||
s.AddHandler(v.onVoiceServerUpdate)
|
||||
v.closers = []func(){
|
||||
s.AddHandler(v.onVoiceStateUpdate),
|
||||
s.AddHandler(v.onVoiceServerUpdate),
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
@ -129,6 +145,7 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (*
|
|||
}
|
||||
|
||||
conn = NewSession(v.Session, u.ID)
|
||||
conn.ErrorLog = v.ErrorLog
|
||||
|
||||
v.mapmutex.Lock()
|
||||
v.sessions[gID] = conn
|
||||
|
@ -139,6 +156,33 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (*
|
|||
return conn, conn.JoinChannel(gID, cID, muted, deafened)
|
||||
}
|
||||
|
||||
func (v *Voice) Close() error {
|
||||
err := &CloseError{
|
||||
SessionErrors: make(map[discord.Snowflake]error),
|
||||
}
|
||||
|
||||
v.mapmutex.Lock()
|
||||
defer v.mapmutex.Unlock()
|
||||
|
||||
// Remove all callback handlers.
|
||||
for _, fn := range v.closers {
|
||||
fn()
|
||||
}
|
||||
|
||||
for gID, s := range v.sessions {
|
||||
if dErr := s.Disconnect(); dErr != nil {
|
||||
err.SessionErrors[gID] = dErr
|
||||
}
|
||||
}
|
||||
|
||||
err.StateErr = v.State.Close()
|
||||
if err.HasError() {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type CloseError struct {
|
||||
SessionErrors map[discord.Snowflake]error
|
||||
StateErr error
|
||||
|
@ -163,25 +207,3 @@ func (e *CloseError) Error() string {
|
|||
|
||||
return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect"
|
||||
}
|
||||
|
||||
func (v *Voice) Close() error {
|
||||
err := &CloseError{
|
||||
SessionErrors: make(map[discord.Snowflake]error),
|
||||
}
|
||||
|
||||
v.mapmutex.Lock()
|
||||
defer v.mapmutex.Unlock()
|
||||
|
||||
for gID, s := range v.sessions {
|
||||
if dErr := s.Disconnect(); dErr != nil {
|
||||
err.SessionErrors[gID] = dErr
|
||||
}
|
||||
}
|
||||
|
||||
err.StateErr = v.State.Close()
|
||||
if err.HasError() {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package voicegateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
|
@ -26,6 +27,14 @@ type IdentifyData struct {
|
|||
|
||||
// Identify sends an Identify operation (opcode 0) to the Gateway Gateway.
|
||||
func (c *Gateway) Identify() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.IdentifyCtx(ctx)
|
||||
}
|
||||
|
||||
// IdentifyCtx sends an Identify operation (opcode 0) to the Gateway Gateway.
|
||||
func (c *Gateway) IdentifyCtx(ctx context.Context) error {
|
||||
guildID := c.state.GuildID
|
||||
userID := c.state.UserID
|
||||
sessionID := c.state.SessionID
|
||||
|
@ -35,7 +44,7 @@ func (c *Gateway) Identify() error {
|
|||
return ErrMissingForIdentify
|
||||
}
|
||||
|
||||
return c.Send(IdentifyOP, IdentifyData{
|
||||
return c.SendCtx(ctx, IdentifyOP, IdentifyData{
|
||||
GuildID: guildID,
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
|
@ -58,16 +67,32 @@ type SelectProtocolData struct {
|
|||
|
||||
// SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
|
||||
func (c *Gateway) SelectProtocol(data SelectProtocol) error {
|
||||
return c.Send(SelectProtocolOP, data)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.SelectProtocolCtx(ctx, data)
|
||||
}
|
||||
|
||||
// SelectProtocolCtx sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
|
||||
func (c *Gateway) SelectProtocolCtx(ctx context.Context, data SelectProtocol) error {
|
||||
return c.SendCtx(ctx, SelectProtocolOP, data)
|
||||
}
|
||||
|
||||
// OPCode 3
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
|
||||
type Heartbeat uint64
|
||||
// type Heartbeat uint64
|
||||
|
||||
// Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
|
||||
func (c *Gateway) Heartbeat() error {
|
||||
return c.Send(HeartbeatOP, time.Now().UnixNano())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.HeartbeatCtx(ctx)
|
||||
}
|
||||
|
||||
// HeartbeatCtx sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
|
||||
func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
|
||||
return c.SendCtx(ctx, HeartbeatOP, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#speaking
|
||||
|
@ -89,10 +114,18 @@ type SpeakingData struct {
|
|||
|
||||
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
|
||||
func (c *Gateway) Speaking(flag SpeakingFlag) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.SpeakingCtx(ctx, flag)
|
||||
}
|
||||
|
||||
// SpeakingCtx sends a Speaking operation (opcode 5) to the Gateway Gateway.
|
||||
func (c *Gateway) SpeakingCtx(ctx context.Context, flag SpeakingFlag) error {
|
||||
// How do we allow a user to stop speaking?
|
||||
// Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation
|
||||
|
||||
return c.Send(SpeakingOP, SpeakingData{
|
||||
return c.SendCtx(ctx, SpeakingOP, SpeakingData{
|
||||
Speaking: flag,
|
||||
Delay: 0,
|
||||
SSRC: c.ready.SSRC,
|
||||
|
@ -109,6 +142,13 @@ type ResumeData struct {
|
|||
|
||||
// Resume sends a Resume operation (opcode 7) to the Gateway Gateway.
|
||||
func (c *Gateway) Resume() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
return c.ResumeCtx(ctx)
|
||||
}
|
||||
|
||||
// ResumeCtx sends a Resume operation (opcode 7) to the Gateway Gateway.
|
||||
func (c *Gateway) ResumeCtx(ctx context.Context) error {
|
||||
guildID := c.state.GuildID
|
||||
sessionID := c.state.SessionID
|
||||
token := c.state.Token
|
||||
|
@ -117,7 +157,7 @@ func (c *Gateway) Resume() error {
|
|||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
return c.Send(ResumeOP, ResumeData{
|
||||
return c.SendCtx(ctx, ResumeOP, ResumeData{
|
||||
GuildID: guildID,
|
||||
SessionID: sessionID,
|
||||
Token: token,
|
||||
|
|
|
@ -85,8 +85,12 @@ func (c *Gateway) Ready() ReadyEvent {
|
|||
return c.ready
|
||||
}
|
||||
|
||||
// Open shouldn't be used, but JoinServer instead.
|
||||
func (c *Gateway) Open() error {
|
||||
// OpenCtx shouldn't be used, but JoinServer instead.
|
||||
func (c *Gateway) OpenCtx(ctx context.Context) error {
|
||||
if c.state.Endpoint == "" {
|
||||
return errors.New("missing endpoint in state")
|
||||
}
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
|
||||
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
|
||||
|
||||
|
@ -94,7 +98,7 @@ func (c *Gateway) Open() error {
|
|||
c.ws = wsutil.New(endpoint)
|
||||
|
||||
// Create a new context with a timeout for the connection.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Connect to the Gateway Gateway.
|
||||
|
@ -105,7 +109,7 @@ func (c *Gateway) Open() error {
|
|||
wsutil.WSDebug("Trying to start...")
|
||||
|
||||
// Try to start or resume the connection.
|
||||
if err := c.start(); err != nil {
|
||||
if err := c.start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -113,8 +117,8 @@ func (c *Gateway) Open() error {
|
|||
}
|
||||
|
||||
// Start .
|
||||
func (c *Gateway) start() error {
|
||||
if err := c.__start(); err != nil {
|
||||
func (c *Gateway) start(ctx context.Context) error {
|
||||
if err := c.__start(ctx); err != nil {
|
||||
wsutil.WSDebug("Start failed: ", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
|
@ -129,7 +133,7 @@ func (c *Gateway) start() error {
|
|||
}
|
||||
|
||||
// this function blocks until READY.
|
||||
func (c *Gateway) __start() error {
|
||||
func (c *Gateway) __start(ctx context.Context) error {
|
||||
// Make a new WaitGroup for use in background loops:
|
||||
c.waitGroup = new(sync.WaitGroup)
|
||||
|
||||
|
@ -139,9 +143,17 @@ func (c *Gateway) __start() error {
|
|||
wsutil.WSDebug("Waiting for Hello..")
|
||||
|
||||
var hello *HelloEvent
|
||||
_, err := wsutil.AssertEvent(<-ch, HelloOP, &hello)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error at Hello")
|
||||
// Wait for the Hello event; return if it times out.
|
||||
select {
|
||||
case e, ok := <-ch:
|
||||
if !ok {
|
||||
return errors.New("unexpected ws close while waiting for Hello")
|
||||
}
|
||||
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "error at Hello")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Received Hello")
|
||||
|
@ -149,11 +161,11 @@ func (c *Gateway) __start() error {
|
|||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
|
||||
// Turns out Hello is sent right away on connection start.
|
||||
if !c.reconnect.Get() {
|
||||
if err := c.Identify(); err != nil {
|
||||
if err := c.IdentifyCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to identify")
|
||||
}
|
||||
} else {
|
||||
if err := c.Resume(); err != nil {
|
||||
if err := c.ResumeCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to resume")
|
||||
}
|
||||
}
|
||||
|
@ -161,7 +173,7 @@ func (c *Gateway) __start() error {
|
|||
c.reconnect.Set(false)
|
||||
|
||||
// Wait for either Ready or Resumed.
|
||||
err = wsutil.WaitForEvent(c, ch, func(op *wsutil.OP) bool {
|
||||
err := wsutil.WaitForEvent(ctx, c, ch, func(op *wsutil.OP) bool {
|
||||
return op.Code == ReadyOP || op.Code == ResumedOP
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -180,7 +192,7 @@ func (c *Gateway) __start() error {
|
|||
|
||||
if err != nil {
|
||||
c.ErrorLog(err)
|
||||
c.Reconnect()
|
||||
c.ReconnectCtx(ctx)
|
||||
// Reconnect should spawn another eventLoop in its Start function.
|
||||
}
|
||||
})
|
||||
|
@ -226,7 +238,7 @@ func (c *Gateway) Close() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *Gateway) Reconnect() error {
|
||||
func (c *Gateway) ReconnectCtx(ctx context.Context) error {
|
||||
wsutil.WSDebug("Reconnecting...")
|
||||
|
||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||
|
@ -239,7 +251,7 @@ func (c *Gateway) Reconnect() error {
|
|||
// If the connection is rate limited (documented behavior):
|
||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||
|
||||
if err := c.Open(); err != nil {
|
||||
if err := c.OpenCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to reopen gateway")
|
||||
}
|
||||
|
||||
|
@ -248,34 +260,46 @@ func (c *Gateway) Reconnect() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Gateway) SessionDescription(sp SelectProtocol) (*SessionDescriptionEvent, error) {
|
||||
func (c *Gateway) SessionDescriptionCtx(
|
||||
ctx context.Context, sp SelectProtocol) (*SessionDescriptionEvent, error) {
|
||||
|
||||
// Add the handler first.
|
||||
ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool {
|
||||
return op.Code == SessionDescriptionOP
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
if err := c.SelectProtocol(sp); err != nil {
|
||||
if err := c.SelectProtocolCtx(ctx, sp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sesdesc *SessionDescriptionEvent
|
||||
|
||||
// Wait for SessionDescriptionOP packet.
|
||||
if err := (<-ch).UnmarshalData(&sesdesc); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal session description")
|
||||
select {
|
||||
case e, ok := <-ch:
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected close waiting for session description")
|
||||
}
|
||||
if err := e.UnmarshalData(&sesdesc); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal session description")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, errors.Wrap(ctx.Err(), "failed to wait for session description")
|
||||
}
|
||||
|
||||
return sesdesc, nil
|
||||
}
|
||||
|
||||
// Send .
|
||||
// Send sends a payload to the Gateway with the default timeout.
|
||||
func (c *Gateway) Send(code OPCode, v interface{}) error {
|
||||
return c.send(code, v)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.SendCtx(ctx, code, v)
|
||||
}
|
||||
|
||||
// send .
|
||||
func (c *Gateway) send(code OPCode, v interface{}) error {
|
||||
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
|
||||
if c.ws == nil {
|
||||
return errors.New("tried to send data to a connection without a Websocket")
|
||||
}
|
||||
|
@ -303,5 +327,5 @@ func (c *Gateway) send(code OPCode, v interface{}) error {
|
|||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return c.ws.Send(b)
|
||||
return c.ws.SendCtx(ctx, b)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue