diff --git a/bot/ctx.go b/bot/ctx.go index 67d81b3..f82199f 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -140,7 +140,8 @@ type Context struct { // Start quickly starts a bot with the given command. It will prepend "Bot" // into the token automatically. Refer to example/ for usage. -func Start(token string, cmd interface{}, +func Start( + token string, cmd interface{}, opts func(*Context) error) (wait func() error, err error) { s, err := state.New("Bot " + token) @@ -227,6 +228,12 @@ func New(s *state.State, cmd interface{}) (*Context, error) { return ctx, nil } +// AddIntent adds the given Gateway Intent into the Gateway. This is a +// convenient function that calls Gateway's AddIntent. +func (ctx *Context) AddIntent(i gateway.Intents) { + ctx.Gateway.AddIntent(i) +} + // Subcommands returns the slice of subcommands. To add subcommands, use // RegisterSubcommand(). func (ctx *Context) Subcommands() []*Subcommand { diff --git a/gateway/commands.go b/gateway/commands.go index a6a2283..e95697a 100644 --- a/gateway/commands.go +++ b/gateway/commands.go @@ -15,11 +15,18 @@ func (g *Gateway) Identify() error { ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) defer cancel() + return g.IdentifyCtx(ctx) +} + +func (g *Gateway) IdentifyCtx(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, g.WSTimeout) + defer cancel() + if err := g.Identifier.Wait(ctx); err != nil { return errors.Wrap(err, "can't wait for identify()") } - return g.Send(IdentifyOP, g.Identifier) + return g.SendCtx(ctx, IdentifyOP, g.Identifier) } type ResumeData struct { @@ -31,6 +38,15 @@ type ResumeData struct { // Resume sends to the Websocket a Resume OP, but it doesn't actually resume // from a dead connection. Start() resumes from a dead connection. func (g *Gateway) Resume() error { + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.ResumeCtx(ctx) +} + +// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume +// from a dead connection. Start() resumes from a dead connection. +func (g *Gateway) ResumeCtx(ctx context.Context) error { var ( ses = g.SessionID seq = g.Sequence.Get() @@ -40,7 +56,7 @@ func (g *Gateway) Resume() error { return ErrMissingForResume } - return g.Send(ResumeOP, ResumeData{ + return g.SendCtx(ctx, ResumeOP, ResumeData{ Token: g.Identifier.Token, SessionID: ses, Sequence: seq, @@ -51,7 +67,14 @@ func (g *Gateway) Resume() error { type HeartbeatData int func (g *Gateway) Heartbeat() error { - return g.Send(HeartbeatOP, g.Sequence.Get()) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.HeartbeatCtx(ctx) +} + +func (g *Gateway) HeartbeatCtx(ctx context.Context) error { + return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get()) } type RequestGuildMembersData struct { @@ -61,10 +84,20 @@ type RequestGuildMembersData struct { Query string `json:"query,omitempty"` Limit uint `json:"limit"` Presences bool `json:"presences,omitempty"` + Nonce string `json:"nonce,omitempty"` } func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error { - return g.Send(RequestGuildMembersOP, data) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.RequestGuildMembersCtx(ctx, data) +} + +func (g *Gateway) RequestGuildMembersCtx( + ctx context.Context, data RequestGuildMembersData) error { + + return g.SendCtx(ctx, RequestGuildMembersOP, data) } type UpdateVoiceStateData struct { @@ -75,7 +108,16 @@ type UpdateVoiceStateData struct { } func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error { - return g.Send(VoiceStateUpdateOP, data) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.UpdateVoiceStateCtx(ctx, data) +} + +func (g *Gateway) UpdateVoiceStateCtx( + ctx context.Context, data UpdateVoiceStateData) error { + + return g.SendCtx(ctx, VoiceStateUpdateOP, data) } type UpdateStatusData struct { @@ -90,7 +132,14 @@ type UpdateStatusData struct { } func (g *Gateway) UpdateStatus(data UpdateStatusData) error { - return g.Send(StatusUpdateOP, data) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.UpdateStatusCtx(ctx, data) +} + +func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error { + return g.SendCtx(ctx, StatusUpdateOP, data) } // Undocumented @@ -104,5 +153,12 @@ type GuildSubscribeData struct { } func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error { - return g.Send(GuildSubscriptionsOP, data) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.GuildSubscribeCtx(ctx, data) +} + +func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error { + return g.SendCtx(ctx, GuildSubscriptionsOP, data) } diff --git a/gateway/events.go b/gateway/events.go index f824983..76c223a 100644 --- a/gateway/events.go +++ b/gateway/events.go @@ -99,11 +99,15 @@ type ( GuildID discord.Snowflake `json:"guild_id"` Members []discord.Member `json:"members"` + ChunkIndex int `json:"chunk_index"` + ChunkCount int `json:"chunk_count"` + // Whatever's not found goes here NotFound []string `json:"not_found,omitempty"` // Only filled if requested Presences []discord.Presence `json:"presences,omitempty"` + Nonce string `json:"nonce,omitempty"` } // GuildMemberListUpdate is an undocumented event. It's received when the diff --git a/gateway/gateway.go b/gateway/gateway.go index 7efba3d..876b4b6 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -107,8 +107,23 @@ type Gateway struct { waitGroup *sync.WaitGroup } -// NewGateway starts a new Gateway with the default stdlib JSON driver. For more -// information, refer to NewGatewayWithDriver. +// NewGatewayWithIntents creates a new Gateway with the given intents and the +// default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents. +func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) { + g, err := NewGateway(token) + if err != nil { + return nil, err + } + + for _, intent := range intents { + g.AddIntent(intent) + } + + return g, nil +} + +// NewGateway creates a new Gateway with the default stdlib JSON driver. For +// more information, refer to NewGatewayWithDriver. func NewGateway(token string) (*Gateway, error) { URL, err := URL() if err != nil { @@ -141,6 +156,12 @@ func NewCustomGateway(gatewayURL, token string) *Gateway { } } +// AddIntent adds a Gateway Intent before connecting to the Gateway. As +// such, this function will only work before Open() is called. +func (g *Gateway) AddIntent(i Intents) { + g.Identifier.Intents |= i +} + // Close closes the underlying Websocket connection. func (g *Gateway) Close() error { wsutil.WSDebug("Trying to close.") @@ -182,10 +203,13 @@ func (g *Gateway) Close() error { // Reconnect tries to reconnect forever. It will resume the connection if // possible. If an Invalid Session is received, it will start a fresh one. func (g *Gateway) Reconnect() error { - return g.ReconnectContext(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.ReconnectCtx(ctx) } -func (g *Gateway) ReconnectContext(ctx context.Context) error { +func (g *Gateway) ReconnectCtx(ctx context.Context) error { wsutil.WSDebug("Reconnecting...") // Guarantee the gateway is already closed. Ignore its error, as we're @@ -212,9 +236,15 @@ func (g *Gateway) ReconnectContext(ctx context.Context) error { // Open connects to the Websocket and authenticate it. You should usually use // this function over Start(). func (g *Gateway) Open() error { - return g.OpenContext(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.OpenContext(ctx) } +// OpenContext connects to the Websocket and authenticates it. Yuo should +// usually use this function over Start(). The given context provides +// cancellation and timeout. func (g *Gateway) OpenContext(ctx context.Context) error { // Reconnect to the Gateway if err := g.WS.Dial(ctx); err != nil { @@ -224,7 +254,7 @@ func (g *Gateway) OpenContext(ctx context.Context) error { wsutil.WSDebug("Trying to start...") // Try to resume the connection - if err := g.Start(); err != nil { + if err := g.StartCtx(ctx); err != nil { return err } @@ -232,14 +262,19 @@ func (g *Gateway) OpenContext(ctx context.Context) error { return nil } -// Start authenticates with the websocket, or resume from a dead Websocket -// connection. This function doesn't block. You wouldn't usually use this +// Start calls StartCtx with a background context. You wouldn't usually use this // function, but Open() instead. func (g *Gateway) Start() error { - // g.available.Lock() - // defer g.available.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() - if err := g.start(); err != nil { + return g.StartCtx(ctx) +} + +// StartCtx authenticates with the websocket, or resume from a dead Websocket +// connection. You wouldn't usually use this function, but OpenCtx() instead. +func (g *Gateway) StartCtx(ctx context.Context) error { + if err := g.start(ctx); err != nil { wsutil.WSDebug("Start failed:", err) // Close can be called with the mutex still acquired here, as the @@ -249,31 +284,41 @@ func (g *Gateway) Start() error { } return err } + return nil } -func (g *Gateway) start() error { +func (g *Gateway) start(ctx context.Context) error { // This is where we'll get our events ch := g.WS.Listen() // Make a new WaitGroup for use in background loops: g.waitGroup = new(sync.WaitGroup) - // Wait for an OP 10 Hello + // Create a new Hello event and wait for it. var hello HelloEvent - if _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello); err != nil { - return errors.Wrap(err, "error at Hello") + // Wait for an OP 10 Hello. + select { + case e, ok := <-ch: + if !ok { + return errors.New("unexpected ws close while waiting for Hello") + } + if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil { + return errors.Wrap(err, "error at Hello") + } + case <-ctx.Done(): + return errors.Wrap(ctx.Err(), "failed to wait for Hello event") } // Send Discord either the Identify packet (if it's a fresh connection), or // a Resume packet (if it's a dead connection). if g.SessionID == "" { // SessionID is empty, so this is a completely new session. - if err := g.Identify(); err != nil { + if err := g.IdentifyCtx(ctx); err != nil { return errors.Wrap(err, "failed to identify") } } else { - if err := g.Resume(); err != nil { + if err := g.ResumeCtx(ctx); err != nil { return errors.Wrap(err, "failed to resume") } } @@ -282,7 +327,7 @@ func (g *Gateway) start() error { wsutil.WSDebug("Waiting for either READY or RESUMED.") // WaitForEvent should - err := wsutil.WaitForEvent(g, ch, func(op *wsutil.OP) bool { + err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool { switch op.EventName { case "READY": wsutil.WSDebug("Found READY event.") @@ -319,7 +364,9 @@ func (g *Gateway) start() error { return nil } -func (g *Gateway) Send(code OPCode, v interface{}) error { +// SendCtx is a low-level function to send an OP payload to the Gateway. Most +// users shouldn't touch this, unless they know what they're doing. +func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error { var op = wsutil.OP{ Code: code, } @@ -339,5 +386,5 @@ func (g *Gateway) Send(code OPCode, v interface{}) error { } // WS should already be thread-safe. - return g.WS.Send(b) + return g.WS.SendCtx(ctx, b) } diff --git a/gateway/identify.go b/gateway/identify.go index 4031b26..e5d7ed7 100644 --- a/gateway/identify.go +++ b/gateway/identify.go @@ -55,7 +55,7 @@ func (i *IdentifyData) SetShard(id, num int) { i.Shard[0], i.Shard[1] = id, num } -// Intents is a new Discord API feature that's documented at +// Intents for the new Discord API feature, documented at // https://discordapp.com/developers/docs/topics/gateway#gateway-intents. type Intents uint32 diff --git a/gateway/integration_test.go b/gateway/integration_test.go index b845faa..5e05aee 100644 --- a/gateway/integration_test.go +++ b/gateway/integration_test.go @@ -107,13 +107,15 @@ func wait(t *testing.T, evCh chan interface{}) interface{} { select { case ev := <-evCh: return ev - case <-time.After(10 * time.Second): + case <-time.After(20 * time.Second): t.Fatal("Timed out waiting for event") return nil } } func gotimeout(t *testing.T, fn func()) { + t.Helper() + var done = make(chan struct{}) go func() { fn() @@ -121,7 +123,7 @@ func gotimeout(t *testing.T, fn func()) { }() select { - case <-time.After(10 * time.Second): + case <-time.After(20 * time.Second): t.Fatal("Timed out waiting for function.") case <-done: return diff --git a/gateway/op.go b/gateway/op.go index d862307..e591b01 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -1,6 +1,7 @@ package gateway import ( + "context" "fmt" "math/rand" "time" @@ -36,15 +37,21 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error { g.PacerLoop.Echo() case HeartbeatOP: + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + // Server requesting a heartbeat. - return g.PacerLoop.Pace() + return g.PacerLoop.Pace(ctx) case ReconnectOP: // Server requests to reconnect, die and retry. wsutil.WSDebug("ReconnectOP received.") + // We must reconnect in another goroutine, as running Reconnect // synchronously would prevent the main event loop from exiting. - go g.Reconnect() + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + go func() { g.ReconnectCtx(ctx); cancel() }() + // Gracefully exit with a nil let the event handler take the signal from // the pacemaker. return nil @@ -53,11 +60,16 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error { // Discord expects us to sleep for no reason time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + // Invalid session, try and Identify. - if err := g.Identify(); err != nil { + if err := g.IdentifyCtx(ctx); err != nil { // Can't identify, reconnect. - go g.Reconnect() + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + go func() { g.ReconnectCtx(ctx); cancel() }() } + return nil case HelloOP: diff --git a/session/session.go b/session/session.go index 3ecb1dd..89e9d3c 100644 --- a/session/session.go +++ b/session/session.go @@ -41,6 +41,17 @@ type Session struct { hstop chan struct{} } +func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) { + g, err := gateway.NewGatewayWithIntents(token, intents...) + if err != nil { + return nil, errors.Wrap(err, "failed to connect to Gateway") + } + + return NewWithGateway(g), nil +} + +// New creates a new session from a given token. Most bots should be using +// NewWithIntents instead. func New(token string) (*Session, error) { // Create a gateway g, err := gateway.NewGateway(token) @@ -48,7 +59,7 @@ func New(token string) (*Session, error) { return nil, errors.Wrap(err, "failed to connect to Gateway") } - return NewWithGateway(g), err + return NewWithGateway(g), nil } // Login tries to log in as a normal user account; MFA is optional. diff --git a/state/state.go b/state/state.go index 12e3b26..6e8181b 100644 --- a/state/state.go +++ b/state/state.go @@ -97,10 +97,22 @@ type State struct { unreadyGuilds *moreatomic.SnowflakeSet } +// New creates a new state. func New(token string) (*State, error) { return NewWithStore(token, NewDefaultStore(nil)) } +// NewWithIntents creates a new state with the given gateway intents. For more +// information, refer to gateway.Intents. +func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) { + s, err := session.NewWithIntents(token, intents...) + if err != nil { + return nil, err + } + + return NewFromSession(s, NewDefaultStore(nil)) +} + func NewWithStore(token string, store Store) (*State, error) { s, err := session.New(token) if err != nil { diff --git a/voice/integration_test.go b/voice/integration_test.go index 2805df0..beda26c 100644 --- a/voice/integration_test.go +++ b/voice/integration_test.go @@ -13,7 +13,7 @@ import ( "time" "github.com/diamondburned/arikawa/discord" - "github.com/diamondburned/arikawa/state" + "github.com/diamondburned/arikawa/gateway" "github.com/diamondburned/arikawa/utils/wsutil" "github.com/diamondburned/arikawa/voice/voicegateway" ) @@ -94,24 +94,23 @@ func TestIntegration(t *testing.T) { log.Println(append([]interface{}{caller}, v...)...) } - // heart.Debug = func(v ...interface{}) { - // log.Println(append([]interface{}{"Pacemaker:"}, v...)...) - // } - - s, err := state.New("Bot " + config.BotToken) + v, err := NewVoiceFromToken("Bot " + config.BotToken) if err != nil { - t.Fatal("Failed to create a new session:", err) + t.Fatal("Failed to create a new voice session:", err) + } + v.Gateway.AddIntent(gateway.IntentGuildVoiceStates) + + v.ErrorLog = func(err error) { + t.Error(err) } - v := NewVoice(s) - - if err := s.Open(); err != nil { + if err := v.Open(); err != nil { t.Fatal("Failed to connect:", err) } - defer s.Close() + defer v.Close() // Validate the given voice channel. - c, err := s.Channel(config.VoiceChID) + c, err := v.Channel(config.VoiceChID) if err != nil { t.Fatal("Failed to get channel:", err) } @@ -119,6 +118,8 @@ func TestIntegration(t *testing.T) { t.Fatal("Channel isn't a guild voice channel.") } + log.Println("The voice channel's name is", c.Name) + // Grab a timer to benchmark things. finish := timer() diff --git a/voice/session.go b/voice/session.go index 7b29348..3e6b06f 100644 --- a/voice/session.go +++ b/voice/session.go @@ -1,7 +1,9 @@ package voice import ( + "context" "sync" + "time" "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/gateway" @@ -17,6 +19,11 @@ const Protocol = "xsalsa20_poly1305" var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE} +// WSTimeout is the duration to wait for a gateway operation including Session +// to complete before erroring out. This only applies to functions that don't +// take in a context already. +var WSTimeout = 10 * time.Second + type Session struct { session *session.Session state voicegateway.State @@ -52,11 +59,16 @@ func NewSession(ses *session.Session, userID discord.Snowflake) *Session { UserID: userID, }, ErrorLog: func(err error) {}, - incoming: make(chan struct{}), + incoming: make(chan struct{}, 2), } } func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) { + if s.state.GuildID != ev.GuildID { + // Not our state. + return + } + // If this is true, then mutex is acquired already. if s.joining.Get() { s.state.Endpoint = ev.Endpoint @@ -73,7 +85,10 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) { s.state.Endpoint = ev.Endpoint s.state.Token = ev.Token - if err := s.reconnect(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) + defer cancel() + + if err := s.reconnectCtx(ctx); err != nil { s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update")) } } @@ -95,6 +110,16 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) { } func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) error { + ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) + defer cancel() + + return s.JoinChannelCtx(ctx, gID, cID, muted, deafened) +} + +func (s *Session) JoinChannelCtx( + ctx context.Context, + gID, cID discord.Snowflake, muted, deafened bool) error { + // Acquire the mutex during join, locking during IO as well. s.mut.Lock() defer s.mut.Unlock() @@ -103,7 +128,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) s.joining.Set(true) defer s.joining.Set(false) // reset when done - // ensure gateeway and voiceUDP is already closed. + // Ensure gateway and voiceUDP are already closed. s.ensureClosed() // Set the state. @@ -122,7 +147,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) // https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information // Send a Voice State Update event to the gateway. - err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{ + err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{ GuildID: gID, ChannelID: channelID, SelfMute: muted, @@ -132,23 +157,37 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) return errors.Wrap(err, "failed to send Voice State Update event") } - // Wait for replies. The above command should reply with these 2 events. - <-s.incoming - <-s.incoming + // Wait for 2 replies. The above command should reply with these 2 events. + if err := s.waitForIncoming(ctx, 2); err != nil { + return errors.Wrap(err, "failed to wait for needed gateway events") + } // These 2 methods should've updated s.state before sending into these // channels. Since s.state is already filled, we can go ahead and connect. - return s.reconnect() + return s.reconnectCtx(ctx) +} + +func (s *Session) waitForIncoming(ctx context.Context, n int) error { + for i := 0; i < n; i++ { + select { + case <-s.incoming: + continue + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil } // reconnect uses the current state to reconnect to a new gateway and UDP // connection. -func (s *Session) reconnect() (err error) { +func (s *Session) reconnectCtx(ctx context.Context) (err error) { s.gateway = voicegateway.New(s.state) // Open the voice gateway. The function will block until Ready is received. - if err := s.gateway.Open(); err != nil { + if err := s.gateway.OpenCtx(ctx); err != nil { return errors.Wrap(err, "failed to open voice gateway") } @@ -156,13 +195,13 @@ func (s *Session) reconnect() (err error) { voiceReady := s.gateway.Ready() // Prepare the UDP voice connection. - s.voiceUDP, err = udp.DialConnection(voiceReady.Addr(), voiceReady.SSRC) + s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC) if err != nil { return errors.Wrap(err, "failed to open voice UDP connection") } // Get the session description from the voice gateway. - d, err := s.gateway.SessionDescription(voicegateway.SelectProtocol{ + d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{ Protocol: "udp", Data: voicegateway.SelectProtocolData{ Address: s.voiceUDP.GatewayIP, @@ -200,17 +239,31 @@ func (s *Session) StopSpeaking() error { return nil } +// Write writes into the UDP voice connection WITHOUT a timeout. func (s *Session) Write(b []byte) (int, error) { + return s.WriteCtx(context.Background(), b) +} + +// WriteCtx writes into the UDP voice connection with a context for timeout. +func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) { s.mut.RLock() defer s.mut.RUnlock() if s.voiceUDP == nil { return 0, ErrCannotSend } - return s.voiceUDP.Write(b) + + return s.voiceUDP.WriteCtx(ctx, b) } func (s *Session) Disconnect() error { + ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) + defer cancel() + + return s.DisconnectCtx(ctx) +} + +func (s *Session) DisconnectCtx(ctx context.Context) error { s.mut.Lock() defer s.mut.Unlock() @@ -223,7 +276,7 @@ func (s *Session) Disconnect() error { // VoiceStateUpdateEvent, in which our handler will promptly remove the // session from the map. - err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{ + err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{ GuildID: s.state.GuildID, ChannelID: discord.NullSnowflake, SelfMute: true, diff --git a/voice/udp/udp.go b/voice/udp/udp.go index f4ba5e1..7772870 100644 --- a/voice/udp/udp.go +++ b/voice/udp/udp.go @@ -2,6 +2,7 @@ package udp import ( "bytes" + "context" "encoding/binary" "io" "net" @@ -11,6 +12,11 @@ import ( "golang.org/x/crypto/nacl/secretbox" ) +// Dialer is the default dialer that this package uses for all its dialing. +var Dialer = net.Dialer{ + Timeout: 10 * time.Second, +} + type Connection struct { GatewayIP string GatewayPort uint16 @@ -21,7 +27,7 @@ type Connection struct { timestamp uint32 nonce [24]byte - conn *net.UDPConn + conn net.Conn close chan struct{} closed chan struct{} @@ -29,15 +35,15 @@ type Connection struct { reply chan error } -func DialConnection(addr string, ssrc uint32) (*Connection, error) { - // Resolve the host. - a, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, errors.Wrap(err, "failed to resolve host") - } +func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { + // // Resolve the host. + // a, err := net.ResolveUDPAddr("udp", addr) + // if err != nil { + // return nil, errors.Wrap(err, "failed to resolve host") + // } // Create a new UDP connection. - conn, err := net.DialUDP("udp", nil, a) + conn, err := Dialer.DialContext(ctx, "udp", addr) if err != nil { return nil, errors.Wrap(err, "failed to dial host") } @@ -154,9 +160,22 @@ func (c *Connection) Close() error { // Write sends bytes into the voice UDP connection. func (c *Connection) Write(b []byte) (int, error) { - c.send <- b - if err := <-c.reply; err != nil { - return 0, err - } - return len(b), nil + return c.WriteCtx(context.Background(), b) +} + +// WriteCtx sends bytes into the voice UDP connection with a timeout. +func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) { + select { + case c.send <- b: + break + case <-ctx.Done(): + return 0, ctx.Err() + } + + select { + case err := <-c.reply: + return len(b), err + case <-ctx.Done(): + return len(b), ctx.Err() + } } diff --git a/voice/voice.go b/voice/voice.go index c59207e..b52b422 100644 --- a/voice/voice.go +++ b/voice/voice.go @@ -31,11 +31,25 @@ type Voice struct { mapmutex sync.Mutex sessions map[discord.Snowflake]*Session // guildID:Session + // Callbacks to remove the handlers. + closers []func() + // ErrorLog will be called when an error occurs (defaults to log.Println) ErrorLog func(err error) } -// NewVoice creates a new Voice repository wrapped around a state. +// NewVoiceFromToken creates a new voice session from the given token. +func NewVoiceFromToken(token string) (*Voice, error) { + s, err := state.New(token) + if err != nil { + return nil, errors.Wrap(err, "failed to create a new session") + } + + return NewVoice(s), nil +} + +// NewVoice creates a new Voice repository wrapped around a state. The function +// will also automatically add the GuildVoiceStates intent, as that is required. func NewVoice(s *state.State) *Voice { v := &Voice{ State: s, @@ -44,8 +58,10 @@ func NewVoice(s *state.State) *Voice { } // Add the required event handlers to the session. - s.AddHandler(v.onVoiceStateUpdate) - s.AddHandler(v.onVoiceServerUpdate) + v.closers = []func(){ + s.AddHandler(v.onVoiceStateUpdate), + s.AddHandler(v.onVoiceServerUpdate), + } return v } @@ -129,6 +145,7 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (* } conn = NewSession(v.Session, u.ID) + conn.ErrorLog = v.ErrorLog v.mapmutex.Lock() v.sessions[gID] = conn @@ -139,6 +156,33 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (* return conn, conn.JoinChannel(gID, cID, muted, deafened) } +func (v *Voice) Close() error { + err := &CloseError{ + SessionErrors: make(map[discord.Snowflake]error), + } + + v.mapmutex.Lock() + defer v.mapmutex.Unlock() + + // Remove all callback handlers. + for _, fn := range v.closers { + fn() + } + + for gID, s := range v.sessions { + if dErr := s.Disconnect(); dErr != nil { + err.SessionErrors[gID] = dErr + } + } + + err.StateErr = v.State.Close() + if err.HasError() { + return err + } + + return nil +} + type CloseError struct { SessionErrors map[discord.Snowflake]error StateErr error @@ -163,25 +207,3 @@ func (e *CloseError) Error() string { return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect" } - -func (v *Voice) Close() error { - err := &CloseError{ - SessionErrors: make(map[discord.Snowflake]error), - } - - v.mapmutex.Lock() - defer v.mapmutex.Unlock() - - for gID, s := range v.sessions { - if dErr := s.Disconnect(); dErr != nil { - err.SessionErrors[gID] = dErr - } - } - - err.StateErr = v.State.Close() - if err.HasError() { - return err - } - - return nil -} diff --git a/voice/voicegateway/commands.go b/voice/voicegateway/commands.go index 8adb97e..ea0038b 100644 --- a/voice/voicegateway/commands.go +++ b/voice/voicegateway/commands.go @@ -1,6 +1,7 @@ package voicegateway import ( + "context" "time" "github.com/diamondburned/arikawa/discord" @@ -26,6 +27,14 @@ type IdentifyData struct { // Identify sends an Identify operation (opcode 0) to the Gateway Gateway. func (c *Gateway) Identify() error { + ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) + defer cancel() + + return c.IdentifyCtx(ctx) +} + +// IdentifyCtx sends an Identify operation (opcode 0) to the Gateway Gateway. +func (c *Gateway) IdentifyCtx(ctx context.Context) error { guildID := c.state.GuildID userID := c.state.UserID sessionID := c.state.SessionID @@ -35,7 +44,7 @@ func (c *Gateway) Identify() error { return ErrMissingForIdentify } - return c.Send(IdentifyOP, IdentifyData{ + return c.SendCtx(ctx, IdentifyOP, IdentifyData{ GuildID: guildID, UserID: userID, SessionID: sessionID, @@ -58,16 +67,32 @@ type SelectProtocolData struct { // SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway. func (c *Gateway) SelectProtocol(data SelectProtocol) error { - return c.Send(SelectProtocolOP, data) + ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) + defer cancel() + + return c.SelectProtocolCtx(ctx, data) +} + +// SelectProtocolCtx sends a Select Protocol operation (opcode 1) to the Gateway Gateway. +func (c *Gateway) SelectProtocolCtx(ctx context.Context, data SelectProtocol) error { + return c.SendCtx(ctx, SelectProtocolOP, data) } // OPCode 3 // https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload -type Heartbeat uint64 +// type Heartbeat uint64 // Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway. func (c *Gateway) Heartbeat() error { - return c.Send(HeartbeatOP, time.Now().UnixNano()) + ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) + defer cancel() + + return c.HeartbeatCtx(ctx) +} + +// HeartbeatCtx sends a Heartbeat operation (opcode 3) to the Gateway Gateway. +func (c *Gateway) HeartbeatCtx(ctx context.Context) error { + return c.SendCtx(ctx, HeartbeatOP, time.Now().UnixNano()) } // https://discordapp.com/developers/docs/topics/voice-connections#speaking @@ -89,10 +114,18 @@ type SpeakingData struct { // Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway. func (c *Gateway) Speaking(flag SpeakingFlag) error { + ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) + defer cancel() + + return c.SpeakingCtx(ctx, flag) +} + +// SpeakingCtx sends a Speaking operation (opcode 5) to the Gateway Gateway. +func (c *Gateway) SpeakingCtx(ctx context.Context, flag SpeakingFlag) error { // How do we allow a user to stop speaking? // Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation - return c.Send(SpeakingOP, SpeakingData{ + return c.SendCtx(ctx, SpeakingOP, SpeakingData{ Speaking: flag, Delay: 0, SSRC: c.ready.SSRC, @@ -109,6 +142,13 @@ type ResumeData struct { // Resume sends a Resume operation (opcode 7) to the Gateway Gateway. func (c *Gateway) Resume() error { + ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) + defer cancel() + return c.ResumeCtx(ctx) +} + +// ResumeCtx sends a Resume operation (opcode 7) to the Gateway Gateway. +func (c *Gateway) ResumeCtx(ctx context.Context) error { guildID := c.state.GuildID sessionID := c.state.SessionID token := c.state.Token @@ -117,7 +157,7 @@ func (c *Gateway) Resume() error { return ErrMissingForResume } - return c.Send(ResumeOP, ResumeData{ + return c.SendCtx(ctx, ResumeOP, ResumeData{ GuildID: guildID, SessionID: sessionID, Token: token, diff --git a/voice/voicegateway/gateway.go b/voice/voicegateway/gateway.go index 563839d..7147a44 100644 --- a/voice/voicegateway/gateway.go +++ b/voice/voicegateway/gateway.go @@ -85,8 +85,12 @@ func (c *Gateway) Ready() ReadyEvent { return c.ready } -// Open shouldn't be used, but JoinServer instead. -func (c *Gateway) Open() error { +// OpenCtx shouldn't be used, but JoinServer instead. +func (c *Gateway) OpenCtx(ctx context.Context) error { + if c.state.Endpoint == "" { + return errors.New("missing endpoint in state") + } + // https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version @@ -94,7 +98,7 @@ func (c *Gateway) Open() error { c.ws = wsutil.New(endpoint) // Create a new context with a timeout for the connection. - ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) + ctx, cancel := context.WithTimeout(ctx, c.Timeout) defer cancel() // Connect to the Gateway Gateway. @@ -105,7 +109,7 @@ func (c *Gateway) Open() error { wsutil.WSDebug("Trying to start...") // Try to start or resume the connection. - if err := c.start(); err != nil { + if err := c.start(ctx); err != nil { return err } @@ -113,8 +117,8 @@ func (c *Gateway) Open() error { } // Start . -func (c *Gateway) start() error { - if err := c.__start(); err != nil { +func (c *Gateway) start(ctx context.Context) error { + if err := c.__start(ctx); err != nil { wsutil.WSDebug("Start failed: ", err) // Close can be called with the mutex still acquired here, as the @@ -129,7 +133,7 @@ func (c *Gateway) start() error { } // this function blocks until READY. -func (c *Gateway) __start() error { +func (c *Gateway) __start(ctx context.Context) error { // Make a new WaitGroup for use in background loops: c.waitGroup = new(sync.WaitGroup) @@ -139,9 +143,17 @@ func (c *Gateway) __start() error { wsutil.WSDebug("Waiting for Hello..") var hello *HelloEvent - _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello) - if err != nil { - return errors.Wrap(err, "error at Hello") + // Wait for the Hello event; return if it times out. + select { + case e, ok := <-ch: + if !ok { + return errors.New("unexpected ws close while waiting for Hello") + } + if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil { + return errors.Wrap(err, "error at Hello") + } + case <-ctx.Done(): + return errors.Wrap(ctx.Err(), "failed to wait for Hello event") } wsutil.WSDebug("Received Hello") @@ -149,11 +161,11 @@ func (c *Gateway) __start() error { // https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection // Turns out Hello is sent right away on connection start. if !c.reconnect.Get() { - if err := c.Identify(); err != nil { + if err := c.IdentifyCtx(ctx); err != nil { return errors.Wrap(err, "failed to identify") } } else { - if err := c.Resume(); err != nil { + if err := c.ResumeCtx(ctx); err != nil { return errors.Wrap(err, "failed to resume") } } @@ -161,7 +173,7 @@ func (c *Gateway) __start() error { c.reconnect.Set(false) // Wait for either Ready or Resumed. - err = wsutil.WaitForEvent(c, ch, func(op *wsutil.OP) bool { + err := wsutil.WaitForEvent(ctx, c, ch, func(op *wsutil.OP) bool { return op.Code == ReadyOP || op.Code == ResumedOP }) if err != nil { @@ -180,7 +192,7 @@ func (c *Gateway) __start() error { if err != nil { c.ErrorLog(err) - c.Reconnect() + c.ReconnectCtx(ctx) // Reconnect should spawn another eventLoop in its Start function. } }) @@ -226,7 +238,7 @@ func (c *Gateway) Close() error { return err } -func (c *Gateway) Reconnect() error { +func (c *Gateway) ReconnectCtx(ctx context.Context) error { wsutil.WSDebug("Reconnecting...") // Guarantee the gateway is already closed. Ignore its error, as we're @@ -239,7 +251,7 @@ func (c *Gateway) Reconnect() error { // If the connection is rate limited (documented behavior): // https://discordapp.com/developers/docs/topics/gateway#rate-limiting - if err := c.Open(); err != nil { + if err := c.OpenCtx(ctx); err != nil { return errors.Wrap(err, "failed to reopen gateway") } @@ -248,34 +260,46 @@ func (c *Gateway) Reconnect() error { return nil } -func (c *Gateway) SessionDescription(sp SelectProtocol) (*SessionDescriptionEvent, error) { +func (c *Gateway) SessionDescriptionCtx( + ctx context.Context, sp SelectProtocol) (*SessionDescriptionEvent, error) { + // Add the handler first. ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool { return op.Code == SessionDescriptionOP }) defer cancel() - if err := c.SelectProtocol(sp); err != nil { + if err := c.SelectProtocolCtx(ctx, sp); err != nil { return nil, err } var sesdesc *SessionDescriptionEvent // Wait for SessionDescriptionOP packet. - if err := (<-ch).UnmarshalData(&sesdesc); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal session description") + select { + case e, ok := <-ch: + if !ok { + return nil, errors.New("unexpected close waiting for session description") + } + if err := e.UnmarshalData(&sesdesc); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal session description") + } + case <-ctx.Done(): + return nil, errors.Wrap(ctx.Err(), "failed to wait for session description") } return sesdesc, nil } -// Send . +// Send sends a payload to the Gateway with the default timeout. func (c *Gateway) Send(code OPCode, v interface{}) error { - return c.send(code, v) + ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) + defer cancel() + + return c.SendCtx(ctx, code, v) } -// send . -func (c *Gateway) send(code OPCode, v interface{}) error { +func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error { if c.ws == nil { return errors.New("tried to send data to a connection without a Websocket") } @@ -303,5 +327,5 @@ func (c *Gateway) send(code OPCode, v interface{}) error { } // WS should already be thread-safe. - return c.ws.Send(b) + return c.ws.SendCtx(ctx, b) }