Gateway: Added intent helpers and more context API support

This commit is contained in:
diamondburned 2020-07-11 12:50:32 -07:00
parent f33b4ff7d8
commit edb8a46ef2
15 changed files with 441 additions and 131 deletions

View File

@ -140,7 +140,8 @@ type Context struct {
// Start quickly starts a bot with the given command. It will prepend "Bot" // Start quickly starts a bot with the given command. It will prepend "Bot"
// into the token automatically. Refer to example/ for usage. // into the token automatically. Refer to example/ for usage.
func Start(token string, cmd interface{}, func Start(
token string, cmd interface{},
opts func(*Context) error) (wait func() error, err error) { opts func(*Context) error) (wait func() error, err error) {
s, err := state.New("Bot " + token) s, err := state.New("Bot " + token)
@ -227,6 +228,12 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
return ctx, nil return ctx, nil
} }
// AddIntent adds the given Gateway Intent into the Gateway. This is a
// convenient function that calls Gateway's AddIntent.
func (ctx *Context) AddIntent(i gateway.Intents) {
ctx.Gateway.AddIntent(i)
}
// Subcommands returns the slice of subcommands. To add subcommands, use // Subcommands returns the slice of subcommands. To add subcommands, use
// RegisterSubcommand(). // RegisterSubcommand().
func (ctx *Context) Subcommands() []*Subcommand { func (ctx *Context) Subcommands() []*Subcommand {

View File

@ -15,11 +15,18 @@ func (g *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel() defer cancel()
return g.IdentifyCtx(ctx)
}
func (g *Gateway) IdentifyCtx(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, g.WSTimeout)
defer cancel()
if err := g.Identifier.Wait(ctx); err != nil { if err := g.Identifier.Wait(ctx); err != nil {
return errors.Wrap(err, "can't wait for identify()") return errors.Wrap(err, "can't wait for identify()")
} }
return g.Send(IdentifyOP, g.Identifier) return g.SendCtx(ctx, IdentifyOP, g.Identifier)
} }
type ResumeData struct { type ResumeData struct {
@ -31,6 +38,15 @@ type ResumeData struct {
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume // Resume sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection. // from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) Resume() error { func (g *Gateway) Resume() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.ResumeCtx(ctx)
}
// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) ResumeCtx(ctx context.Context) error {
var ( var (
ses = g.SessionID ses = g.SessionID
seq = g.Sequence.Get() seq = g.Sequence.Get()
@ -40,7 +56,7 @@ func (g *Gateway) Resume() error {
return ErrMissingForResume return ErrMissingForResume
} }
return g.Send(ResumeOP, ResumeData{ return g.SendCtx(ctx, ResumeOP, ResumeData{
Token: g.Identifier.Token, Token: g.Identifier.Token,
SessionID: ses, SessionID: ses,
Sequence: seq, Sequence: seq,
@ -51,7 +67,14 @@ func (g *Gateway) Resume() error {
type HeartbeatData int type HeartbeatData int
func (g *Gateway) Heartbeat() error { func (g *Gateway) Heartbeat() error {
return g.Send(HeartbeatOP, g.Sequence.Get()) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.HeartbeatCtx(ctx)
}
func (g *Gateway) HeartbeatCtx(ctx context.Context) error {
return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get())
} }
type RequestGuildMembersData struct { type RequestGuildMembersData struct {
@ -61,10 +84,20 @@ type RequestGuildMembersData struct {
Query string `json:"query,omitempty"` Query string `json:"query,omitempty"`
Limit uint `json:"limit"` Limit uint `json:"limit"`
Presences bool `json:"presences,omitempty"` Presences bool `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
} }
func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error { func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error {
return g.Send(RequestGuildMembersOP, data) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.RequestGuildMembersCtx(ctx, data)
}
func (g *Gateway) RequestGuildMembersCtx(
ctx context.Context, data RequestGuildMembersData) error {
return g.SendCtx(ctx, RequestGuildMembersOP, data)
} }
type UpdateVoiceStateData struct { type UpdateVoiceStateData struct {
@ -75,7 +108,16 @@ type UpdateVoiceStateData struct {
} }
func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error { func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error {
return g.Send(VoiceStateUpdateOP, data) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateVoiceStateCtx(ctx, data)
}
func (g *Gateway) UpdateVoiceStateCtx(
ctx context.Context, data UpdateVoiceStateData) error {
return g.SendCtx(ctx, VoiceStateUpdateOP, data)
} }
type UpdateStatusData struct { type UpdateStatusData struct {
@ -90,7 +132,14 @@ type UpdateStatusData struct {
} }
func (g *Gateway) UpdateStatus(data UpdateStatusData) error { func (g *Gateway) UpdateStatus(data UpdateStatusData) error {
return g.Send(StatusUpdateOP, data) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateStatusCtx(ctx, data)
}
func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error {
return g.SendCtx(ctx, StatusUpdateOP, data)
} }
// Undocumented // Undocumented
@ -104,5 +153,12 @@ type GuildSubscribeData struct {
} }
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error { func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {
return g.Send(GuildSubscriptionsOP, data) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.GuildSubscribeCtx(ctx, data)
}
func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error {
return g.SendCtx(ctx, GuildSubscriptionsOP, data)
} }

View File

@ -99,11 +99,15 @@ type (
GuildID discord.Snowflake `json:"guild_id"` GuildID discord.Snowflake `json:"guild_id"`
Members []discord.Member `json:"members"` Members []discord.Member `json:"members"`
ChunkIndex int `json:"chunk_index"`
ChunkCount int `json:"chunk_count"`
// Whatever's not found goes here // Whatever's not found goes here
NotFound []string `json:"not_found,omitempty"` NotFound []string `json:"not_found,omitempty"`
// Only filled if requested // Only filled if requested
Presences []discord.Presence `json:"presences,omitempty"` Presences []discord.Presence `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
} }
// GuildMemberListUpdate is an undocumented event. It's received when the // GuildMemberListUpdate is an undocumented event. It's received when the

View File

@ -107,8 +107,23 @@ type Gateway struct {
waitGroup *sync.WaitGroup waitGroup *sync.WaitGroup
} }
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more // NewGatewayWithIntents creates a new Gateway with the given intents and the
// information, refer to NewGatewayWithDriver. // default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
g, err := NewGateway(token)
if err != nil {
return nil, err
}
for _, intent := range intents {
g.AddIntent(intent)
}
return g, nil
}
// NewGateway creates a new Gateway with the default stdlib JSON driver. For
// more information, refer to NewGatewayWithDriver.
func NewGateway(token string) (*Gateway, error) { func NewGateway(token string) (*Gateway, error) {
URL, err := URL() URL, err := URL()
if err != nil { if err != nil {
@ -141,6 +156,12 @@ func NewCustomGateway(gatewayURL, token string) *Gateway {
} }
} }
// AddIntent adds a Gateway Intent before connecting to the Gateway. As
// such, this function will only work before Open() is called.
func (g *Gateway) AddIntent(i Intents) {
g.Identifier.Intents |= i
}
// Close closes the underlying Websocket connection. // Close closes the underlying Websocket connection.
func (g *Gateway) Close() error { func (g *Gateway) Close() error {
wsutil.WSDebug("Trying to close.") wsutil.WSDebug("Trying to close.")
@ -182,10 +203,13 @@ func (g *Gateway) Close() error {
// Reconnect tries to reconnect forever. It will resume the connection if // Reconnect tries to reconnect forever. It will resume the connection if
// possible. If an Invalid Session is received, it will start a fresh one. // possible. If an Invalid Session is received, it will start a fresh one.
func (g *Gateway) Reconnect() error { func (g *Gateway) Reconnect() error {
return g.ReconnectContext(context.Background()) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.ReconnectCtx(ctx)
} }
func (g *Gateway) ReconnectContext(ctx context.Context) error { func (g *Gateway) ReconnectCtx(ctx context.Context) error {
wsutil.WSDebug("Reconnecting...") wsutil.WSDebug("Reconnecting...")
// Guarantee the gateway is already closed. Ignore its error, as we're // Guarantee the gateway is already closed. Ignore its error, as we're
@ -212,9 +236,15 @@ func (g *Gateway) ReconnectContext(ctx context.Context) error {
// Open connects to the Websocket and authenticate it. You should usually use // Open connects to the Websocket and authenticate it. You should usually use
// this function over Start(). // this function over Start().
func (g *Gateway) Open() error { func (g *Gateway) Open() error {
return g.OpenContext(context.Background()) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.OpenContext(ctx)
} }
// OpenContext connects to the Websocket and authenticates it. Yuo should
// usually use this function over Start(). The given context provides
// cancellation and timeout.
func (g *Gateway) OpenContext(ctx context.Context) error { func (g *Gateway) OpenContext(ctx context.Context) error {
// Reconnect to the Gateway // Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil { if err := g.WS.Dial(ctx); err != nil {
@ -224,7 +254,7 @@ func (g *Gateway) OpenContext(ctx context.Context) error {
wsutil.WSDebug("Trying to start...") wsutil.WSDebug("Trying to start...")
// Try to resume the connection // Try to resume the connection
if err := g.Start(); err != nil { if err := g.StartCtx(ctx); err != nil {
return err return err
} }
@ -232,14 +262,19 @@ func (g *Gateway) OpenContext(ctx context.Context) error {
return nil return nil
} }
// Start authenticates with the websocket, or resume from a dead Websocket // Start calls StartCtx with a background context. You wouldn't usually use this
// connection. This function doesn't block. You wouldn't usually use this
// function, but Open() instead. // function, but Open() instead.
func (g *Gateway) Start() error { func (g *Gateway) Start() error {
// g.available.Lock() ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
// defer g.available.Unlock() defer cancel()
if err := g.start(); err != nil { return g.StartCtx(ctx)
}
// StartCtx authenticates with the websocket, or resume from a dead Websocket
// connection. You wouldn't usually use this function, but OpenCtx() instead.
func (g *Gateway) StartCtx(ctx context.Context) error {
if err := g.start(ctx); err != nil {
wsutil.WSDebug("Start failed:", err) wsutil.WSDebug("Start failed:", err)
// Close can be called with the mutex still acquired here, as the // Close can be called with the mutex still acquired here, as the
@ -249,31 +284,41 @@ func (g *Gateway) Start() error {
} }
return err return err
} }
return nil return nil
} }
func (g *Gateway) start() error { func (g *Gateway) start(ctx context.Context) error {
// This is where we'll get our events // This is where we'll get our events
ch := g.WS.Listen() ch := g.WS.Listen()
// Make a new WaitGroup for use in background loops: // Make a new WaitGroup for use in background loops:
g.waitGroup = new(sync.WaitGroup) g.waitGroup = new(sync.WaitGroup)
// Wait for an OP 10 Hello // Create a new Hello event and wait for it.
var hello HelloEvent var hello HelloEvent
if _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello); err != nil { // Wait for an OP 10 Hello.
return errors.Wrap(err, "error at Hello") select {
case e, ok := <-ch:
if !ok {
return errors.New("unexpected ws close while waiting for Hello")
}
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
}
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
} }
// Send Discord either the Identify packet (if it's a fresh connection), or // Send Discord either the Identify packet (if it's a fresh connection), or
// a Resume packet (if it's a dead connection). // a Resume packet (if it's a dead connection).
if g.SessionID == "" { if g.SessionID == "" {
// SessionID is empty, so this is a completely new session. // SessionID is empty, so this is a completely new session.
if err := g.Identify(); err != nil { if err := g.IdentifyCtx(ctx); err != nil {
return errors.Wrap(err, "failed to identify") return errors.Wrap(err, "failed to identify")
} }
} else { } else {
if err := g.Resume(); err != nil { if err := g.ResumeCtx(ctx); err != nil {
return errors.Wrap(err, "failed to resume") return errors.Wrap(err, "failed to resume")
} }
} }
@ -282,7 +327,7 @@ func (g *Gateway) start() error {
wsutil.WSDebug("Waiting for either READY or RESUMED.") wsutil.WSDebug("Waiting for either READY or RESUMED.")
// WaitForEvent should // WaitForEvent should
err := wsutil.WaitForEvent(g, ch, func(op *wsutil.OP) bool { err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
switch op.EventName { switch op.EventName {
case "READY": case "READY":
wsutil.WSDebug("Found READY event.") wsutil.WSDebug("Found READY event.")
@ -319,7 +364,9 @@ func (g *Gateway) start() error {
return nil return nil
} }
func (g *Gateway) Send(code OPCode, v interface{}) error { // SendCtx is a low-level function to send an OP payload to the Gateway. Most
// users shouldn't touch this, unless they know what they're doing.
func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
var op = wsutil.OP{ var op = wsutil.OP{
Code: code, Code: code,
} }
@ -339,5 +386,5 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
} }
// WS should already be thread-safe. // WS should already be thread-safe.
return g.WS.Send(b) return g.WS.SendCtx(ctx, b)
} }

View File

@ -55,7 +55,7 @@ func (i *IdentifyData) SetShard(id, num int) {
i.Shard[0], i.Shard[1] = id, num i.Shard[0], i.Shard[1] = id, num
} }
// Intents is a new Discord API feature that's documented at // Intents for the new Discord API feature, documented at
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents. // https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
type Intents uint32 type Intents uint32

View File

@ -107,13 +107,15 @@ func wait(t *testing.T, evCh chan interface{}) interface{} {
select { select {
case ev := <-evCh: case ev := <-evCh:
return ev return ev
case <-time.After(10 * time.Second): case <-time.After(20 * time.Second):
t.Fatal("Timed out waiting for event") t.Fatal("Timed out waiting for event")
return nil return nil
} }
} }
func gotimeout(t *testing.T, fn func()) { func gotimeout(t *testing.T, fn func()) {
t.Helper()
var done = make(chan struct{}) var done = make(chan struct{})
go func() { go func() {
fn() fn()
@ -121,7 +123,7 @@ func gotimeout(t *testing.T, fn func()) {
}() }()
select { select {
case <-time.After(10 * time.Second): case <-time.After(20 * time.Second):
t.Fatal("Timed out waiting for function.") t.Fatal("Timed out waiting for function.")
case <-done: case <-done:
return return

View File

@ -1,6 +1,7 @@
package gateway package gateway
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"time" "time"
@ -36,15 +37,21 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
g.PacerLoop.Echo() g.PacerLoop.Echo()
case HeartbeatOP: case HeartbeatOP:
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Server requesting a heartbeat. // Server requesting a heartbeat.
return g.PacerLoop.Pace() return g.PacerLoop.Pace(ctx)
case ReconnectOP: case ReconnectOP:
// Server requests to reconnect, die and retry. // Server requests to reconnect, die and retry.
wsutil.WSDebug("ReconnectOP received.") wsutil.WSDebug("ReconnectOP received.")
// We must reconnect in another goroutine, as running Reconnect // We must reconnect in another goroutine, as running Reconnect
// synchronously would prevent the main event loop from exiting. // synchronously would prevent the main event loop from exiting.
go g.Reconnect() ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
go func() { g.ReconnectCtx(ctx); cancel() }()
// Gracefully exit with a nil let the event handler take the signal from // Gracefully exit with a nil let the event handler take the signal from
// the pacemaker. // the pacemaker.
return nil return nil
@ -53,11 +60,16 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
// Discord expects us to sleep for no reason // Discord expects us to sleep for no reason
time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second) time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Invalid session, try and Identify. // Invalid session, try and Identify.
if err := g.Identify(); err != nil { if err := g.IdentifyCtx(ctx); err != nil {
// Can't identify, reconnect. // Can't identify, reconnect.
go g.Reconnect() ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
go func() { g.ReconnectCtx(ctx); cancel() }()
} }
return nil return nil
case HelloOP: case HelloOP:

View File

@ -41,6 +41,17 @@ type Session struct {
hstop chan struct{} hstop chan struct{}
} }
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
g, err := gateway.NewGatewayWithIntents(token, intents...)
if err != nil {
return nil, errors.Wrap(err, "failed to connect to Gateway")
}
return NewWithGateway(g), nil
}
// New creates a new session from a given token. Most bots should be using
// NewWithIntents instead.
func New(token string) (*Session, error) { func New(token string) (*Session, error) {
// Create a gateway // Create a gateway
g, err := gateway.NewGateway(token) g, err := gateway.NewGateway(token)
@ -48,7 +59,7 @@ func New(token string) (*Session, error) {
return nil, errors.Wrap(err, "failed to connect to Gateway") return nil, errors.Wrap(err, "failed to connect to Gateway")
} }
return NewWithGateway(g), err return NewWithGateway(g), nil
} }
// Login tries to log in as a normal user account; MFA is optional. // Login tries to log in as a normal user account; MFA is optional.

View File

@ -97,10 +97,22 @@ type State struct {
unreadyGuilds *moreatomic.SnowflakeSet unreadyGuilds *moreatomic.SnowflakeSet
} }
// New creates a new state.
func New(token string) (*State, error) { func New(token string) (*State, error) {
return NewWithStore(token, NewDefaultStore(nil)) return NewWithStore(token, NewDefaultStore(nil))
} }
// NewWithIntents creates a new state with the given gateway intents. For more
// information, refer to gateway.Intents.
func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
s, err := session.NewWithIntents(token, intents...)
if err != nil {
return nil, err
}
return NewFromSession(s, NewDefaultStore(nil))
}
func NewWithStore(token string, store Store) (*State, error) { func NewWithStore(token string, store Store) (*State, error) {
s, err := session.New(token) s, err := session.New(token)
if err != nil { if err != nil {

View File

@ -13,7 +13,7 @@ import (
"time" "time"
"github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/state" "github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/utils/wsutil" "github.com/diamondburned/arikawa/utils/wsutil"
"github.com/diamondburned/arikawa/voice/voicegateway" "github.com/diamondburned/arikawa/voice/voicegateway"
) )
@ -94,24 +94,23 @@ func TestIntegration(t *testing.T) {
log.Println(append([]interface{}{caller}, v...)...) log.Println(append([]interface{}{caller}, v...)...)
} }
// heart.Debug = func(v ...interface{}) { v, err := NewVoiceFromToken("Bot " + config.BotToken)
// log.Println(append([]interface{}{"Pacemaker:"}, v...)...)
// }
s, err := state.New("Bot " + config.BotToken)
if err != nil { if err != nil {
t.Fatal("Failed to create a new session:", err) t.Fatal("Failed to create a new voice session:", err)
}
v.Gateway.AddIntent(gateway.IntentGuildVoiceStates)
v.ErrorLog = func(err error) {
t.Error(err)
} }
v := NewVoice(s) if err := v.Open(); err != nil {
if err := s.Open(); err != nil {
t.Fatal("Failed to connect:", err) t.Fatal("Failed to connect:", err)
} }
defer s.Close() defer v.Close()
// Validate the given voice channel. // Validate the given voice channel.
c, err := s.Channel(config.VoiceChID) c, err := v.Channel(config.VoiceChID)
if err != nil { if err != nil {
t.Fatal("Failed to get channel:", err) t.Fatal("Failed to get channel:", err)
} }
@ -119,6 +118,8 @@ func TestIntegration(t *testing.T) {
t.Fatal("Channel isn't a guild voice channel.") t.Fatal("Channel isn't a guild voice channel.")
} }
log.Println("The voice channel's name is", c.Name)
// Grab a timer to benchmark things. // Grab a timer to benchmark things.
finish := timer() finish := timer()

View File

@ -1,7 +1,9 @@
package voice package voice
import ( import (
"context"
"sync" "sync"
"time"
"github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway" "github.com/diamondburned/arikawa/gateway"
@ -17,6 +19,11 @@ const Protocol = "xsalsa20_poly1305"
var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE} var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE}
// WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't
// take in a context already.
var WSTimeout = 10 * time.Second
type Session struct { type Session struct {
session *session.Session session *session.Session
state voicegateway.State state voicegateway.State
@ -52,11 +59,16 @@ func NewSession(ses *session.Session, userID discord.Snowflake) *Session {
UserID: userID, UserID: userID,
}, },
ErrorLog: func(err error) {}, ErrorLog: func(err error) {},
incoming: make(chan struct{}), incoming: make(chan struct{}, 2),
} }
} }
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) { func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
if s.state.GuildID != ev.GuildID {
// Not our state.
return
}
// If this is true, then mutex is acquired already. // If this is true, then mutex is acquired already.
if s.joining.Get() { if s.joining.Get() {
s.state.Endpoint = ev.Endpoint s.state.Endpoint = ev.Endpoint
@ -73,7 +85,10 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
s.state.Endpoint = ev.Endpoint s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token s.state.Token = ev.Token
if err := s.reconnect(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
if err := s.reconnectCtx(ctx); err != nil {
s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update")) s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update"))
} }
} }
@ -95,6 +110,16 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
} }
func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) error { func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
}
func (s *Session) JoinChannelCtx(
ctx context.Context,
gID, cID discord.Snowflake, muted, deafened bool) error {
// Acquire the mutex during join, locking during IO as well. // Acquire the mutex during join, locking during IO as well.
s.mut.Lock() s.mut.Lock()
defer s.mut.Unlock() defer s.mut.Unlock()
@ -103,7 +128,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
s.joining.Set(true) s.joining.Set(true)
defer s.joining.Set(false) // reset when done defer s.joining.Set(false) // reset when done
// ensure gateeway and voiceUDP is already closed. // Ensure gateway and voiceUDP are already closed.
s.ensureClosed() s.ensureClosed()
// Set the state. // Set the state.
@ -122,7 +147,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
// https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information // https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
// Send a Voice State Update event to the gateway. // Send a Voice State Update event to the gateway.
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{ err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
GuildID: gID, GuildID: gID,
ChannelID: channelID, ChannelID: channelID,
SelfMute: muted, SelfMute: muted,
@ -132,23 +157,37 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
return errors.Wrap(err, "failed to send Voice State Update event") return errors.Wrap(err, "failed to send Voice State Update event")
} }
// Wait for replies. The above command should reply with these 2 events. // Wait for 2 replies. The above command should reply with these 2 events.
<-s.incoming if err := s.waitForIncoming(ctx, 2); err != nil {
<-s.incoming return errors.Wrap(err, "failed to wait for needed gateway events")
}
// These 2 methods should've updated s.state before sending into these // These 2 methods should've updated s.state before sending into these
// channels. Since s.state is already filled, we can go ahead and connect. // channels. Since s.state is already filled, we can go ahead and connect.
return s.reconnect() return s.reconnectCtx(ctx)
}
func (s *Session) waitForIncoming(ctx context.Context, n int) error {
for i := 0; i < n; i++ {
select {
case <-s.incoming:
continue
case <-ctx.Done():
return ctx.Err()
}
}
return nil
} }
// reconnect uses the current state to reconnect to a new gateway and UDP // reconnect uses the current state to reconnect to a new gateway and UDP
// connection. // connection.
func (s *Session) reconnect() (err error) { func (s *Session) reconnectCtx(ctx context.Context) (err error) {
s.gateway = voicegateway.New(s.state) s.gateway = voicegateway.New(s.state)
// Open the voice gateway. The function will block until Ready is received. // Open the voice gateway. The function will block until Ready is received.
if err := s.gateway.Open(); err != nil { if err := s.gateway.OpenCtx(ctx); err != nil {
return errors.Wrap(err, "failed to open voice gateway") return errors.Wrap(err, "failed to open voice gateway")
} }
@ -156,13 +195,13 @@ func (s *Session) reconnect() (err error) {
voiceReady := s.gateway.Ready() voiceReady := s.gateway.Ready()
// Prepare the UDP voice connection. // Prepare the UDP voice connection.
s.voiceUDP, err = udp.DialConnection(voiceReady.Addr(), voiceReady.SSRC) s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to open voice UDP connection") return errors.Wrap(err, "failed to open voice UDP connection")
} }
// Get the session description from the voice gateway. // Get the session description from the voice gateway.
d, err := s.gateway.SessionDescription(voicegateway.SelectProtocol{ d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{
Protocol: "udp", Protocol: "udp",
Data: voicegateway.SelectProtocolData{ Data: voicegateway.SelectProtocolData{
Address: s.voiceUDP.GatewayIP, Address: s.voiceUDP.GatewayIP,
@ -200,17 +239,31 @@ func (s *Session) StopSpeaking() error {
return nil return nil
} }
// Write writes into the UDP voice connection WITHOUT a timeout.
func (s *Session) Write(b []byte) (int, error) { func (s *Session) Write(b []byte) (int, error) {
return s.WriteCtx(context.Background(), b)
}
// WriteCtx writes into the UDP voice connection with a context for timeout.
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
s.mut.RLock() s.mut.RLock()
defer s.mut.RUnlock() defer s.mut.RUnlock()
if s.voiceUDP == nil { if s.voiceUDP == nil {
return 0, ErrCannotSend return 0, ErrCannotSend
} }
return s.voiceUDP.Write(b)
return s.voiceUDP.WriteCtx(ctx, b)
} }
func (s *Session) Disconnect() error { func (s *Session) Disconnect() error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.DisconnectCtx(ctx)
}
func (s *Session) DisconnectCtx(ctx context.Context) error {
s.mut.Lock() s.mut.Lock()
defer s.mut.Unlock() defer s.mut.Unlock()
@ -223,7 +276,7 @@ func (s *Session) Disconnect() error {
// VoiceStateUpdateEvent, in which our handler will promptly remove the // VoiceStateUpdateEvent, in which our handler will promptly remove the
// session from the map. // session from the map.
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{ err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
GuildID: s.state.GuildID, GuildID: s.state.GuildID,
ChannelID: discord.NullSnowflake, ChannelID: discord.NullSnowflake,
SelfMute: true, SelfMute: true,

View File

@ -2,6 +2,7 @@ package udp
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
@ -11,6 +12,11 @@ import (
"golang.org/x/crypto/nacl/secretbox" "golang.org/x/crypto/nacl/secretbox"
) )
// Dialer is the default dialer that this package uses for all its dialing.
var Dialer = net.Dialer{
Timeout: 10 * time.Second,
}
type Connection struct { type Connection struct {
GatewayIP string GatewayIP string
GatewayPort uint16 GatewayPort uint16
@ -21,7 +27,7 @@ type Connection struct {
timestamp uint32 timestamp uint32
nonce [24]byte nonce [24]byte
conn *net.UDPConn conn net.Conn
close chan struct{} close chan struct{}
closed chan struct{} closed chan struct{}
@ -29,15 +35,15 @@ type Connection struct {
reply chan error reply chan error
} }
func DialConnection(addr string, ssrc uint32) (*Connection, error) { func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
// Resolve the host. // // Resolve the host.
a, err := net.ResolveUDPAddr("udp", addr) // a, err := net.ResolveUDPAddr("udp", addr)
if err != nil { // if err != nil {
return nil, errors.Wrap(err, "failed to resolve host") // return nil, errors.Wrap(err, "failed to resolve host")
} // }
// Create a new UDP connection. // Create a new UDP connection.
conn, err := net.DialUDP("udp", nil, a) conn, err := Dialer.DialContext(ctx, "udp", addr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to dial host") return nil, errors.Wrap(err, "failed to dial host")
} }
@ -154,9 +160,22 @@ func (c *Connection) Close() error {
// Write sends bytes into the voice UDP connection. // Write sends bytes into the voice UDP connection.
func (c *Connection) Write(b []byte) (int, error) { func (c *Connection) Write(b []byte) (int, error) {
c.send <- b return c.WriteCtx(context.Background(), b)
if err := <-c.reply; err != nil { }
return 0, err
} // WriteCtx sends bytes into the voice UDP connection with a timeout.
return len(b), nil func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
select {
case c.send <- b:
break
case <-ctx.Done():
return 0, ctx.Err()
}
select {
case err := <-c.reply:
return len(b), err
case <-ctx.Done():
return len(b), ctx.Err()
}
} }

View File

@ -31,11 +31,25 @@ type Voice struct {
mapmutex sync.Mutex mapmutex sync.Mutex
sessions map[discord.Snowflake]*Session // guildID:Session sessions map[discord.Snowflake]*Session // guildID:Session
// Callbacks to remove the handlers.
closers []func()
// ErrorLog will be called when an error occurs (defaults to log.Println) // ErrorLog will be called when an error occurs (defaults to log.Println)
ErrorLog func(err error) ErrorLog func(err error)
} }
// NewVoice creates a new Voice repository wrapped around a state. // NewVoiceFromToken creates a new voice session from the given token.
func NewVoiceFromToken(token string) (*Voice, error) {
s, err := state.New(token)
if err != nil {
return nil, errors.Wrap(err, "failed to create a new session")
}
return NewVoice(s), nil
}
// NewVoice creates a new Voice repository wrapped around a state. The function
// will also automatically add the GuildVoiceStates intent, as that is required.
func NewVoice(s *state.State) *Voice { func NewVoice(s *state.State) *Voice {
v := &Voice{ v := &Voice{
State: s, State: s,
@ -44,8 +58,10 @@ func NewVoice(s *state.State) *Voice {
} }
// Add the required event handlers to the session. // Add the required event handlers to the session.
s.AddHandler(v.onVoiceStateUpdate) v.closers = []func(){
s.AddHandler(v.onVoiceServerUpdate) s.AddHandler(v.onVoiceStateUpdate),
s.AddHandler(v.onVoiceServerUpdate),
}
return v return v
} }
@ -129,6 +145,7 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (*
} }
conn = NewSession(v.Session, u.ID) conn = NewSession(v.Session, u.ID)
conn.ErrorLog = v.ErrorLog
v.mapmutex.Lock() v.mapmutex.Lock()
v.sessions[gID] = conn v.sessions[gID] = conn
@ -139,6 +156,33 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (*
return conn, conn.JoinChannel(gID, cID, muted, deafened) return conn, conn.JoinChannel(gID, cID, muted, deafened)
} }
func (v *Voice) Close() error {
err := &CloseError{
SessionErrors: make(map[discord.Snowflake]error),
}
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
// Remove all callback handlers.
for _, fn := range v.closers {
fn()
}
for gID, s := range v.sessions {
if dErr := s.Disconnect(); dErr != nil {
err.SessionErrors[gID] = dErr
}
}
err.StateErr = v.State.Close()
if err.HasError() {
return err
}
return nil
}
type CloseError struct { type CloseError struct {
SessionErrors map[discord.Snowflake]error SessionErrors map[discord.Snowflake]error
StateErr error StateErr error
@ -163,25 +207,3 @@ func (e *CloseError) Error() string {
return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect" return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect"
} }
func (v *Voice) Close() error {
err := &CloseError{
SessionErrors: make(map[discord.Snowflake]error),
}
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
for gID, s := range v.sessions {
if dErr := s.Disconnect(); dErr != nil {
err.SessionErrors[gID] = dErr
}
}
err.StateErr = v.State.Close()
if err.HasError() {
return err
}
return nil
}

View File

@ -1,6 +1,7 @@
package voicegateway package voicegateway
import ( import (
"context"
"time" "time"
"github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/discord"
@ -26,6 +27,14 @@ type IdentifyData struct {
// Identify sends an Identify operation (opcode 0) to the Gateway Gateway. // Identify sends an Identify operation (opcode 0) to the Gateway Gateway.
func (c *Gateway) Identify() error { func (c *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.IdentifyCtx(ctx)
}
// IdentifyCtx sends an Identify operation (opcode 0) to the Gateway Gateway.
func (c *Gateway) IdentifyCtx(ctx context.Context) error {
guildID := c.state.GuildID guildID := c.state.GuildID
userID := c.state.UserID userID := c.state.UserID
sessionID := c.state.SessionID sessionID := c.state.SessionID
@ -35,7 +44,7 @@ func (c *Gateway) Identify() error {
return ErrMissingForIdentify return ErrMissingForIdentify
} }
return c.Send(IdentifyOP, IdentifyData{ return c.SendCtx(ctx, IdentifyOP, IdentifyData{
GuildID: guildID, GuildID: guildID,
UserID: userID, UserID: userID,
SessionID: sessionID, SessionID: sessionID,
@ -58,16 +67,32 @@ type SelectProtocolData struct {
// SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway. // SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
func (c *Gateway) SelectProtocol(data SelectProtocol) error { func (c *Gateway) SelectProtocol(data SelectProtocol) error {
return c.Send(SelectProtocolOP, data) ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SelectProtocolCtx(ctx, data)
}
// SelectProtocolCtx sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
func (c *Gateway) SelectProtocolCtx(ctx context.Context, data SelectProtocol) error {
return c.SendCtx(ctx, SelectProtocolOP, data)
} }
// OPCode 3 // OPCode 3
// https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload // https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
type Heartbeat uint64 // type Heartbeat uint64
// Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway. // Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
func (c *Gateway) Heartbeat() error { func (c *Gateway) Heartbeat() error {
return c.Send(HeartbeatOP, time.Now().UnixNano()) ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.HeartbeatCtx(ctx)
}
// HeartbeatCtx sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
return c.SendCtx(ctx, HeartbeatOP, time.Now().UnixNano())
} }
// https://discordapp.com/developers/docs/topics/voice-connections#speaking // https://discordapp.com/developers/docs/topics/voice-connections#speaking
@ -89,10 +114,18 @@ type SpeakingData struct {
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway. // Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (c *Gateway) Speaking(flag SpeakingFlag) error { func (c *Gateway) Speaking(flag SpeakingFlag) error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SpeakingCtx(ctx, flag)
}
// SpeakingCtx sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (c *Gateway) SpeakingCtx(ctx context.Context, flag SpeakingFlag) error {
// How do we allow a user to stop speaking? // How do we allow a user to stop speaking?
// Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation // Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation
return c.Send(SpeakingOP, SpeakingData{ return c.SendCtx(ctx, SpeakingOP, SpeakingData{
Speaking: flag, Speaking: flag,
Delay: 0, Delay: 0,
SSRC: c.ready.SSRC, SSRC: c.ready.SSRC,
@ -109,6 +142,13 @@ type ResumeData struct {
// Resume sends a Resume operation (opcode 7) to the Gateway Gateway. // Resume sends a Resume operation (opcode 7) to the Gateway Gateway.
func (c *Gateway) Resume() error { func (c *Gateway) Resume() error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.ResumeCtx(ctx)
}
// ResumeCtx sends a Resume operation (opcode 7) to the Gateway Gateway.
func (c *Gateway) ResumeCtx(ctx context.Context) error {
guildID := c.state.GuildID guildID := c.state.GuildID
sessionID := c.state.SessionID sessionID := c.state.SessionID
token := c.state.Token token := c.state.Token
@ -117,7 +157,7 @@ func (c *Gateway) Resume() error {
return ErrMissingForResume return ErrMissingForResume
} }
return c.Send(ResumeOP, ResumeData{ return c.SendCtx(ctx, ResumeOP, ResumeData{
GuildID: guildID, GuildID: guildID,
SessionID: sessionID, SessionID: sessionID,
Token: token, Token: token,

View File

@ -85,8 +85,12 @@ func (c *Gateway) Ready() ReadyEvent {
return c.ready return c.ready
} }
// Open shouldn't be used, but JoinServer instead. // OpenCtx shouldn't be used, but JoinServer instead.
func (c *Gateway) Open() error { func (c *Gateway) OpenCtx(ctx context.Context) error {
if c.state.Endpoint == "" {
return errors.New("missing endpoint in state")
}
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection // https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
@ -94,7 +98,7 @@ func (c *Gateway) Open() error {
c.ws = wsutil.New(endpoint) c.ws = wsutil.New(endpoint)
// Create a new context with a timeout for the connection. // Create a new context with a timeout for the connection.
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) ctx, cancel := context.WithTimeout(ctx, c.Timeout)
defer cancel() defer cancel()
// Connect to the Gateway Gateway. // Connect to the Gateway Gateway.
@ -105,7 +109,7 @@ func (c *Gateway) Open() error {
wsutil.WSDebug("Trying to start...") wsutil.WSDebug("Trying to start...")
// Try to start or resume the connection. // Try to start or resume the connection.
if err := c.start(); err != nil { if err := c.start(ctx); err != nil {
return err return err
} }
@ -113,8 +117,8 @@ func (c *Gateway) Open() error {
} }
// Start . // Start .
func (c *Gateway) start() error { func (c *Gateway) start(ctx context.Context) error {
if err := c.__start(); err != nil { if err := c.__start(ctx); err != nil {
wsutil.WSDebug("Start failed: ", err) wsutil.WSDebug("Start failed: ", err)
// Close can be called with the mutex still acquired here, as the // Close can be called with the mutex still acquired here, as the
@ -129,7 +133,7 @@ func (c *Gateway) start() error {
} }
// this function blocks until READY. // this function blocks until READY.
func (c *Gateway) __start() error { func (c *Gateway) __start(ctx context.Context) error {
// Make a new WaitGroup for use in background loops: // Make a new WaitGroup for use in background loops:
c.waitGroup = new(sync.WaitGroup) c.waitGroup = new(sync.WaitGroup)
@ -139,9 +143,17 @@ func (c *Gateway) __start() error {
wsutil.WSDebug("Waiting for Hello..") wsutil.WSDebug("Waiting for Hello..")
var hello *HelloEvent var hello *HelloEvent
_, err := wsutil.AssertEvent(<-ch, HelloOP, &hello) // Wait for the Hello event; return if it times out.
if err != nil { select {
return errors.Wrap(err, "error at Hello") case e, ok := <-ch:
if !ok {
return errors.New("unexpected ws close while waiting for Hello")
}
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
}
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
} }
wsutil.WSDebug("Received Hello") wsutil.WSDebug("Received Hello")
@ -149,11 +161,11 @@ func (c *Gateway) __start() error {
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection // https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
// Turns out Hello is sent right away on connection start. // Turns out Hello is sent right away on connection start.
if !c.reconnect.Get() { if !c.reconnect.Get() {
if err := c.Identify(); err != nil { if err := c.IdentifyCtx(ctx); err != nil {
return errors.Wrap(err, "failed to identify") return errors.Wrap(err, "failed to identify")
} }
} else { } else {
if err := c.Resume(); err != nil { if err := c.ResumeCtx(ctx); err != nil {
return errors.Wrap(err, "failed to resume") return errors.Wrap(err, "failed to resume")
} }
} }
@ -161,7 +173,7 @@ func (c *Gateway) __start() error {
c.reconnect.Set(false) c.reconnect.Set(false)
// Wait for either Ready or Resumed. // Wait for either Ready or Resumed.
err = wsutil.WaitForEvent(c, ch, func(op *wsutil.OP) bool { err := wsutil.WaitForEvent(ctx, c, ch, func(op *wsutil.OP) bool {
return op.Code == ReadyOP || op.Code == ResumedOP return op.Code == ReadyOP || op.Code == ResumedOP
}) })
if err != nil { if err != nil {
@ -180,7 +192,7 @@ func (c *Gateway) __start() error {
if err != nil { if err != nil {
c.ErrorLog(err) c.ErrorLog(err)
c.Reconnect() c.ReconnectCtx(ctx)
// Reconnect should spawn another eventLoop in its Start function. // Reconnect should spawn another eventLoop in its Start function.
} }
}) })
@ -226,7 +238,7 @@ func (c *Gateway) Close() error {
return err return err
} }
func (c *Gateway) Reconnect() error { func (c *Gateway) ReconnectCtx(ctx context.Context) error {
wsutil.WSDebug("Reconnecting...") wsutil.WSDebug("Reconnecting...")
// Guarantee the gateway is already closed. Ignore its error, as we're // Guarantee the gateway is already closed. Ignore its error, as we're
@ -239,7 +251,7 @@ func (c *Gateway) Reconnect() error {
// If the connection is rate limited (documented behavior): // If the connection is rate limited (documented behavior):
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting // https://discordapp.com/developers/docs/topics/gateway#rate-limiting
if err := c.Open(); err != nil { if err := c.OpenCtx(ctx); err != nil {
return errors.Wrap(err, "failed to reopen gateway") return errors.Wrap(err, "failed to reopen gateway")
} }
@ -248,34 +260,46 @@ func (c *Gateway) Reconnect() error {
return nil return nil
} }
func (c *Gateway) SessionDescription(sp SelectProtocol) (*SessionDescriptionEvent, error) { func (c *Gateway) SessionDescriptionCtx(
ctx context.Context, sp SelectProtocol) (*SessionDescriptionEvent, error) {
// Add the handler first. // Add the handler first.
ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool { ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool {
return op.Code == SessionDescriptionOP return op.Code == SessionDescriptionOP
}) })
defer cancel() defer cancel()
if err := c.SelectProtocol(sp); err != nil { if err := c.SelectProtocolCtx(ctx, sp); err != nil {
return nil, err return nil, err
} }
var sesdesc *SessionDescriptionEvent var sesdesc *SessionDescriptionEvent
// Wait for SessionDescriptionOP packet. // Wait for SessionDescriptionOP packet.
if err := (<-ch).UnmarshalData(&sesdesc); err != nil { select {
return nil, errors.Wrap(err, "failed to unmarshal session description") case e, ok := <-ch:
if !ok {
return nil, errors.New("unexpected close waiting for session description")
}
if err := e.UnmarshalData(&sesdesc); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal session description")
}
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "failed to wait for session description")
} }
return sesdesc, nil return sesdesc, nil
} }
// Send . // Send sends a payload to the Gateway with the default timeout.
func (c *Gateway) Send(code OPCode, v interface{}) error { func (c *Gateway) Send(code OPCode, v interface{}) error {
return c.send(code, v) ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SendCtx(ctx, code, v)
} }
// send . func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
func (c *Gateway) send(code OPCode, v interface{}) error {
if c.ws == nil { if c.ws == nil {
return errors.New("tried to send data to a connection without a Websocket") return errors.New("tried to send data to a connection without a Websocket")
} }
@ -303,5 +327,5 @@ func (c *Gateway) send(code OPCode, v interface{}) error {
} }
// WS should already be thread-safe. // WS should already be thread-safe.
return c.ws.Send(b) return c.ws.SendCtx(ctx, b)
} }