Gateway: Added intent helpers and more context API support

This commit is contained in:
diamondburned 2020-07-11 12:50:32 -07:00
parent f33b4ff7d8
commit edb8a46ef2
15 changed files with 441 additions and 131 deletions

View File

@ -140,7 +140,8 @@ type Context struct {
// Start quickly starts a bot with the given command. It will prepend "Bot"
// into the token automatically. Refer to example/ for usage.
func Start(token string, cmd interface{},
func Start(
token string, cmd interface{},
opts func(*Context) error) (wait func() error, err error) {
s, err := state.New("Bot " + token)
@ -227,6 +228,12 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
return ctx, nil
}
// AddIntent adds the given Gateway Intent into the Gateway. This is a
// convenient function that calls Gateway's AddIntent.
func (ctx *Context) AddIntent(i gateway.Intents) {
ctx.Gateway.AddIntent(i)
}
// Subcommands returns the slice of subcommands. To add subcommands, use
// RegisterSubcommand().
func (ctx *Context) Subcommands() []*Subcommand {

View File

@ -15,11 +15,18 @@ func (g *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.IdentifyCtx(ctx)
}
func (g *Gateway) IdentifyCtx(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, g.WSTimeout)
defer cancel()
if err := g.Identifier.Wait(ctx); err != nil {
return errors.Wrap(err, "can't wait for identify()")
}
return g.Send(IdentifyOP, g.Identifier)
return g.SendCtx(ctx, IdentifyOP, g.Identifier)
}
type ResumeData struct {
@ -31,6 +38,15 @@ type ResumeData struct {
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) Resume() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.ResumeCtx(ctx)
}
// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) ResumeCtx(ctx context.Context) error {
var (
ses = g.SessionID
seq = g.Sequence.Get()
@ -40,7 +56,7 @@ func (g *Gateway) Resume() error {
return ErrMissingForResume
}
return g.Send(ResumeOP, ResumeData{
return g.SendCtx(ctx, ResumeOP, ResumeData{
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,
@ -51,7 +67,14 @@ func (g *Gateway) Resume() error {
type HeartbeatData int
func (g *Gateway) Heartbeat() error {
return g.Send(HeartbeatOP, g.Sequence.Get())
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.HeartbeatCtx(ctx)
}
func (g *Gateway) HeartbeatCtx(ctx context.Context) error {
return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get())
}
type RequestGuildMembersData struct {
@ -61,10 +84,20 @@ type RequestGuildMembersData struct {
Query string `json:"query,omitempty"`
Limit uint `json:"limit"`
Presences bool `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error {
return g.Send(RequestGuildMembersOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.RequestGuildMembersCtx(ctx, data)
}
func (g *Gateway) RequestGuildMembersCtx(
ctx context.Context, data RequestGuildMembersData) error {
return g.SendCtx(ctx, RequestGuildMembersOP, data)
}
type UpdateVoiceStateData struct {
@ -75,7 +108,16 @@ type UpdateVoiceStateData struct {
}
func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error {
return g.Send(VoiceStateUpdateOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateVoiceStateCtx(ctx, data)
}
func (g *Gateway) UpdateVoiceStateCtx(
ctx context.Context, data UpdateVoiceStateData) error {
return g.SendCtx(ctx, VoiceStateUpdateOP, data)
}
type UpdateStatusData struct {
@ -90,7 +132,14 @@ type UpdateStatusData struct {
}
func (g *Gateway) UpdateStatus(data UpdateStatusData) error {
return g.Send(StatusUpdateOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateStatusCtx(ctx, data)
}
func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error {
return g.SendCtx(ctx, StatusUpdateOP, data)
}
// Undocumented
@ -104,5 +153,12 @@ type GuildSubscribeData struct {
}
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {
return g.Send(GuildSubscriptionsOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.GuildSubscribeCtx(ctx, data)
}
func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error {
return g.SendCtx(ctx, GuildSubscriptionsOP, data)
}

View File

@ -99,11 +99,15 @@ type (
GuildID discord.Snowflake `json:"guild_id"`
Members []discord.Member `json:"members"`
ChunkIndex int `json:"chunk_index"`
ChunkCount int `json:"chunk_count"`
// Whatever's not found goes here
NotFound []string `json:"not_found,omitempty"`
// Only filled if requested
Presences []discord.Presence `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
// GuildMemberListUpdate is an undocumented event. It's received when the

View File

@ -107,8 +107,23 @@ type Gateway struct {
waitGroup *sync.WaitGroup
}
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
// information, refer to NewGatewayWithDriver.
// NewGatewayWithIntents creates a new Gateway with the given intents and the
// default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
g, err := NewGateway(token)
if err != nil {
return nil, err
}
for _, intent := range intents {
g.AddIntent(intent)
}
return g, nil
}
// NewGateway creates a new Gateway with the default stdlib JSON driver. For
// more information, refer to NewGatewayWithDriver.
func NewGateway(token string) (*Gateway, error) {
URL, err := URL()
if err != nil {
@ -141,6 +156,12 @@ func NewCustomGateway(gatewayURL, token string) *Gateway {
}
}
// AddIntent adds a Gateway Intent before connecting to the Gateway. As
// such, this function will only work before Open() is called.
func (g *Gateway) AddIntent(i Intents) {
g.Identifier.Intents |= i
}
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
wsutil.WSDebug("Trying to close.")
@ -182,10 +203,13 @@ func (g *Gateway) Close() error {
// Reconnect tries to reconnect forever. It will resume the connection if
// possible. If an Invalid Session is received, it will start a fresh one.
func (g *Gateway) Reconnect() error {
return g.ReconnectContext(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.ReconnectCtx(ctx)
}
func (g *Gateway) ReconnectContext(ctx context.Context) error {
func (g *Gateway) ReconnectCtx(ctx context.Context) error {
wsutil.WSDebug("Reconnecting...")
// Guarantee the gateway is already closed. Ignore its error, as we're
@ -212,9 +236,15 @@ func (g *Gateway) ReconnectContext(ctx context.Context) error {
// Open connects to the Websocket and authenticate it. You should usually use
// this function over Start().
func (g *Gateway) Open() error {
return g.OpenContext(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.OpenContext(ctx)
}
// OpenContext connects to the Websocket and authenticates it. Yuo should
// usually use this function over Start(). The given context provides
// cancellation and timeout.
func (g *Gateway) OpenContext(ctx context.Context) error {
// Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil {
@ -224,7 +254,7 @@ func (g *Gateway) OpenContext(ctx context.Context) error {
wsutil.WSDebug("Trying to start...")
// Try to resume the connection
if err := g.Start(); err != nil {
if err := g.StartCtx(ctx); err != nil {
return err
}
@ -232,14 +262,19 @@ func (g *Gateway) OpenContext(ctx context.Context) error {
return nil
}
// Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block. You wouldn't usually use this
// Start calls StartCtx with a background context. You wouldn't usually use this
// function, but Open() instead.
func (g *Gateway) Start() error {
// g.available.Lock()
// defer g.available.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
if err := g.start(); err != nil {
return g.StartCtx(ctx)
}
// StartCtx authenticates with the websocket, or resume from a dead Websocket
// connection. You wouldn't usually use this function, but OpenCtx() instead.
func (g *Gateway) StartCtx(ctx context.Context) error {
if err := g.start(ctx); err != nil {
wsutil.WSDebug("Start failed:", err)
// Close can be called with the mutex still acquired here, as the
@ -249,31 +284,41 @@ func (g *Gateway) Start() error {
}
return err
}
return nil
}
func (g *Gateway) start() error {
func (g *Gateway) start(ctx context.Context) error {
// This is where we'll get our events
ch := g.WS.Listen()
// Make a new WaitGroup for use in background loops:
g.waitGroup = new(sync.WaitGroup)
// Wait for an OP 10 Hello
// Create a new Hello event and wait for it.
var hello HelloEvent
if _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
// Wait for an OP 10 Hello.
select {
case e, ok := <-ch:
if !ok {
return errors.New("unexpected ws close while waiting for Hello")
}
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
}
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
}
// Send Discord either the Identify packet (if it's a fresh connection), or
// a Resume packet (if it's a dead connection).
if g.SessionID == "" {
// SessionID is empty, so this is a completely new session.
if err := g.Identify(); err != nil {
if err := g.IdentifyCtx(ctx); err != nil {
return errors.Wrap(err, "failed to identify")
}
} else {
if err := g.Resume(); err != nil {
if err := g.ResumeCtx(ctx); err != nil {
return errors.Wrap(err, "failed to resume")
}
}
@ -282,7 +327,7 @@ func (g *Gateway) start() error {
wsutil.WSDebug("Waiting for either READY or RESUMED.")
// WaitForEvent should
err := wsutil.WaitForEvent(g, ch, func(op *wsutil.OP) bool {
err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
switch op.EventName {
case "READY":
wsutil.WSDebug("Found READY event.")
@ -319,7 +364,9 @@ func (g *Gateway) start() error {
return nil
}
func (g *Gateway) Send(code OPCode, v interface{}) error {
// SendCtx is a low-level function to send an OP payload to the Gateway. Most
// users shouldn't touch this, unless they know what they're doing.
func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
var op = wsutil.OP{
Code: code,
}
@ -339,5 +386,5 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
}
// WS should already be thread-safe.
return g.WS.Send(b)
return g.WS.SendCtx(ctx, b)
}

View File

@ -55,7 +55,7 @@ func (i *IdentifyData) SetShard(id, num int) {
i.Shard[0], i.Shard[1] = id, num
}
// Intents is a new Discord API feature that's documented at
// Intents for the new Discord API feature, documented at
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
type Intents uint32

View File

@ -107,13 +107,15 @@ func wait(t *testing.T, evCh chan interface{}) interface{} {
select {
case ev := <-evCh:
return ev
case <-time.After(10 * time.Second):
case <-time.After(20 * time.Second):
t.Fatal("Timed out waiting for event")
return nil
}
}
func gotimeout(t *testing.T, fn func()) {
t.Helper()
var done = make(chan struct{})
go func() {
fn()
@ -121,7 +123,7 @@ func gotimeout(t *testing.T, fn func()) {
}()
select {
case <-time.After(10 * time.Second):
case <-time.After(20 * time.Second):
t.Fatal("Timed out waiting for function.")
case <-done:
return

View File

@ -1,6 +1,7 @@
package gateway
import (
"context"
"fmt"
"math/rand"
"time"
@ -36,15 +37,21 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
g.PacerLoop.Echo()
case HeartbeatOP:
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Server requesting a heartbeat.
return g.PacerLoop.Pace()
return g.PacerLoop.Pace(ctx)
case ReconnectOP:
// Server requests to reconnect, die and retry.
wsutil.WSDebug("ReconnectOP received.")
// We must reconnect in another goroutine, as running Reconnect
// synchronously would prevent the main event loop from exiting.
go g.Reconnect()
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
go func() { g.ReconnectCtx(ctx); cancel() }()
// Gracefully exit with a nil let the event handler take the signal from
// the pacemaker.
return nil
@ -53,11 +60,16 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
// Discord expects us to sleep for no reason
time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Invalid session, try and Identify.
if err := g.Identify(); err != nil {
if err := g.IdentifyCtx(ctx); err != nil {
// Can't identify, reconnect.
go g.Reconnect()
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
go func() { g.ReconnectCtx(ctx); cancel() }()
}
return nil
case HelloOP:

View File

@ -41,6 +41,17 @@ type Session struct {
hstop chan struct{}
}
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
g, err := gateway.NewGatewayWithIntents(token, intents...)
if err != nil {
return nil, errors.Wrap(err, "failed to connect to Gateway")
}
return NewWithGateway(g), nil
}
// New creates a new session from a given token. Most bots should be using
// NewWithIntents instead.
func New(token string) (*Session, error) {
// Create a gateway
g, err := gateway.NewGateway(token)
@ -48,7 +59,7 @@ func New(token string) (*Session, error) {
return nil, errors.Wrap(err, "failed to connect to Gateway")
}
return NewWithGateway(g), err
return NewWithGateway(g), nil
}
// Login tries to log in as a normal user account; MFA is optional.

View File

@ -97,10 +97,22 @@ type State struct {
unreadyGuilds *moreatomic.SnowflakeSet
}
// New creates a new state.
func New(token string) (*State, error) {
return NewWithStore(token, NewDefaultStore(nil))
}
// NewWithIntents creates a new state with the given gateway intents. For more
// information, refer to gateway.Intents.
func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
s, err := session.NewWithIntents(token, intents...)
if err != nil {
return nil, err
}
return NewFromSession(s, NewDefaultStore(nil))
}
func NewWithStore(token string, store Store) (*State, error) {
s, err := session.New(token)
if err != nil {

View File

@ -13,7 +13,7 @@ import (
"time"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/state"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/utils/wsutil"
"github.com/diamondburned/arikawa/voice/voicegateway"
)
@ -94,24 +94,23 @@ func TestIntegration(t *testing.T) {
log.Println(append([]interface{}{caller}, v...)...)
}
// heart.Debug = func(v ...interface{}) {
// log.Println(append([]interface{}{"Pacemaker:"}, v...)...)
// }
s, err := state.New("Bot " + config.BotToken)
v, err := NewVoiceFromToken("Bot " + config.BotToken)
if err != nil {
t.Fatal("Failed to create a new session:", err)
t.Fatal("Failed to create a new voice session:", err)
}
v.Gateway.AddIntent(gateway.IntentGuildVoiceStates)
v.ErrorLog = func(err error) {
t.Error(err)
}
v := NewVoice(s)
if err := s.Open(); err != nil {
if err := v.Open(); err != nil {
t.Fatal("Failed to connect:", err)
}
defer s.Close()
defer v.Close()
// Validate the given voice channel.
c, err := s.Channel(config.VoiceChID)
c, err := v.Channel(config.VoiceChID)
if err != nil {
t.Fatal("Failed to get channel:", err)
}
@ -119,6 +118,8 @@ func TestIntegration(t *testing.T) {
t.Fatal("Channel isn't a guild voice channel.")
}
log.Println("The voice channel's name is", c.Name)
// Grab a timer to benchmark things.
finish := timer()

View File

@ -1,7 +1,9 @@
package voice
import (
"context"
"sync"
"time"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
@ -17,6 +19,11 @@ const Protocol = "xsalsa20_poly1305"
var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE}
// WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't
// take in a context already.
var WSTimeout = 10 * time.Second
type Session struct {
session *session.Session
state voicegateway.State
@ -52,11 +59,16 @@ func NewSession(ses *session.Session, userID discord.Snowflake) *Session {
UserID: userID,
},
ErrorLog: func(err error) {},
incoming: make(chan struct{}),
incoming: make(chan struct{}, 2),
}
}
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
if s.state.GuildID != ev.GuildID {
// Not our state.
return
}
// If this is true, then mutex is acquired already.
if s.joining.Get() {
s.state.Endpoint = ev.Endpoint
@ -73,7 +85,10 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
if err := s.reconnect(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
if err := s.reconnectCtx(ctx); err != nil {
s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update"))
}
}
@ -95,6 +110,16 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
}
func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
}
func (s *Session) JoinChannelCtx(
ctx context.Context,
gID, cID discord.Snowflake, muted, deafened bool) error {
// Acquire the mutex during join, locking during IO as well.
s.mut.Lock()
defer s.mut.Unlock()
@ -103,7 +128,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
s.joining.Set(true)
defer s.joining.Set(false) // reset when done
// ensure gateeway and voiceUDP is already closed.
// Ensure gateway and voiceUDP are already closed.
s.ensureClosed()
// Set the state.
@ -122,7 +147,7 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
// https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
// Send a Voice State Update event to the gateway.
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
GuildID: gID,
ChannelID: channelID,
SelfMute: muted,
@ -132,23 +157,37 @@ func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool)
return errors.Wrap(err, "failed to send Voice State Update event")
}
// Wait for replies. The above command should reply with these 2 events.
<-s.incoming
<-s.incoming
// Wait for 2 replies. The above command should reply with these 2 events.
if err := s.waitForIncoming(ctx, 2); err != nil {
return errors.Wrap(err, "failed to wait for needed gateway events")
}
// These 2 methods should've updated s.state before sending into these
// channels. Since s.state is already filled, we can go ahead and connect.
return s.reconnect()
return s.reconnectCtx(ctx)
}
func (s *Session) waitForIncoming(ctx context.Context, n int) error {
for i := 0; i < n; i++ {
select {
case <-s.incoming:
continue
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// reconnect uses the current state to reconnect to a new gateway and UDP
// connection.
func (s *Session) reconnect() (err error) {
func (s *Session) reconnectCtx(ctx context.Context) (err error) {
s.gateway = voicegateway.New(s.state)
// Open the voice gateway. The function will block until Ready is received.
if err := s.gateway.Open(); err != nil {
if err := s.gateway.OpenCtx(ctx); err != nil {
return errors.Wrap(err, "failed to open voice gateway")
}
@ -156,13 +195,13 @@ func (s *Session) reconnect() (err error) {
voiceReady := s.gateway.Ready()
// Prepare the UDP voice connection.
s.voiceUDP, err = udp.DialConnection(voiceReady.Addr(), voiceReady.SSRC)
s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC)
if err != nil {
return errors.Wrap(err, "failed to open voice UDP connection")
}
// Get the session description from the voice gateway.
d, err := s.gateway.SessionDescription(voicegateway.SelectProtocol{
d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{
Protocol: "udp",
Data: voicegateway.SelectProtocolData{
Address: s.voiceUDP.GatewayIP,
@ -200,17 +239,31 @@ func (s *Session) StopSpeaking() error {
return nil
}
// Write writes into the UDP voice connection WITHOUT a timeout.
func (s *Session) Write(b []byte) (int, error) {
return s.WriteCtx(context.Background(), b)
}
// WriteCtx writes into the UDP voice connection with a context for timeout.
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
s.mut.RLock()
defer s.mut.RUnlock()
if s.voiceUDP == nil {
return 0, ErrCannotSend
}
return s.voiceUDP.Write(b)
return s.voiceUDP.WriteCtx(ctx, b)
}
func (s *Session) Disconnect() error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.DisconnectCtx(ctx)
}
func (s *Session) DisconnectCtx(ctx context.Context) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -223,7 +276,7 @@ func (s *Session) Disconnect() error {
// VoiceStateUpdateEvent, in which our handler will promptly remove the
// session from the map.
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
GuildID: s.state.GuildID,
ChannelID: discord.NullSnowflake,
SelfMute: true,

View File

@ -2,6 +2,7 @@ package udp
import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
@ -11,6 +12,11 @@ import (
"golang.org/x/crypto/nacl/secretbox"
)
// Dialer is the default dialer that this package uses for all its dialing.
var Dialer = net.Dialer{
Timeout: 10 * time.Second,
}
type Connection struct {
GatewayIP string
GatewayPort uint16
@ -21,7 +27,7 @@ type Connection struct {
timestamp uint32
nonce [24]byte
conn *net.UDPConn
conn net.Conn
close chan struct{}
closed chan struct{}
@ -29,15 +35,15 @@ type Connection struct {
reply chan error
}
func DialConnection(addr string, ssrc uint32) (*Connection, error) {
// Resolve the host.
a, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, errors.Wrap(err, "failed to resolve host")
}
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
// // Resolve the host.
// a, err := net.ResolveUDPAddr("udp", addr)
// if err != nil {
// return nil, errors.Wrap(err, "failed to resolve host")
// }
// Create a new UDP connection.
conn, err := net.DialUDP("udp", nil, a)
conn, err := Dialer.DialContext(ctx, "udp", addr)
if err != nil {
return nil, errors.Wrap(err, "failed to dial host")
}
@ -154,9 +160,22 @@ func (c *Connection) Close() error {
// Write sends bytes into the voice UDP connection.
func (c *Connection) Write(b []byte) (int, error) {
c.send <- b
if err := <-c.reply; err != nil {
return 0, err
}
return len(b), nil
return c.WriteCtx(context.Background(), b)
}
// WriteCtx sends bytes into the voice UDP connection with a timeout.
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
select {
case c.send <- b:
break
case <-ctx.Done():
return 0, ctx.Err()
}
select {
case err := <-c.reply:
return len(b), err
case <-ctx.Done():
return len(b), ctx.Err()
}
}

View File

@ -31,11 +31,25 @@ type Voice struct {
mapmutex sync.Mutex
sessions map[discord.Snowflake]*Session // guildID:Session
// Callbacks to remove the handlers.
closers []func()
// ErrorLog will be called when an error occurs (defaults to log.Println)
ErrorLog func(err error)
}
// NewVoice creates a new Voice repository wrapped around a state.
// NewVoiceFromToken creates a new voice session from the given token.
func NewVoiceFromToken(token string) (*Voice, error) {
s, err := state.New(token)
if err != nil {
return nil, errors.Wrap(err, "failed to create a new session")
}
return NewVoice(s), nil
}
// NewVoice creates a new Voice repository wrapped around a state. The function
// will also automatically add the GuildVoiceStates intent, as that is required.
func NewVoice(s *state.State) *Voice {
v := &Voice{
State: s,
@ -44,8 +58,10 @@ func NewVoice(s *state.State) *Voice {
}
// Add the required event handlers to the session.
s.AddHandler(v.onVoiceStateUpdate)
s.AddHandler(v.onVoiceServerUpdate)
v.closers = []func(){
s.AddHandler(v.onVoiceStateUpdate),
s.AddHandler(v.onVoiceServerUpdate),
}
return v
}
@ -129,6 +145,7 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (*
}
conn = NewSession(v.Session, u.ID)
conn.ErrorLog = v.ErrorLog
v.mapmutex.Lock()
v.sessions[gID] = conn
@ -139,6 +156,33 @@ func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (*
return conn, conn.JoinChannel(gID, cID, muted, deafened)
}
func (v *Voice) Close() error {
err := &CloseError{
SessionErrors: make(map[discord.Snowflake]error),
}
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
// Remove all callback handlers.
for _, fn := range v.closers {
fn()
}
for gID, s := range v.sessions {
if dErr := s.Disconnect(); dErr != nil {
err.SessionErrors[gID] = dErr
}
}
err.StateErr = v.State.Close()
if err.HasError() {
return err
}
return nil
}
type CloseError struct {
SessionErrors map[discord.Snowflake]error
StateErr error
@ -163,25 +207,3 @@ func (e *CloseError) Error() string {
return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect"
}
func (v *Voice) Close() error {
err := &CloseError{
SessionErrors: make(map[discord.Snowflake]error),
}
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
for gID, s := range v.sessions {
if dErr := s.Disconnect(); dErr != nil {
err.SessionErrors[gID] = dErr
}
}
err.StateErr = v.State.Close()
if err.HasError() {
return err
}
return nil
}

View File

@ -1,6 +1,7 @@
package voicegateway
import (
"context"
"time"
"github.com/diamondburned/arikawa/discord"
@ -26,6 +27,14 @@ type IdentifyData struct {
// Identify sends an Identify operation (opcode 0) to the Gateway Gateway.
func (c *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.IdentifyCtx(ctx)
}
// IdentifyCtx sends an Identify operation (opcode 0) to the Gateway Gateway.
func (c *Gateway) IdentifyCtx(ctx context.Context) error {
guildID := c.state.GuildID
userID := c.state.UserID
sessionID := c.state.SessionID
@ -35,7 +44,7 @@ func (c *Gateway) Identify() error {
return ErrMissingForIdentify
}
return c.Send(IdentifyOP, IdentifyData{
return c.SendCtx(ctx, IdentifyOP, IdentifyData{
GuildID: guildID,
UserID: userID,
SessionID: sessionID,
@ -58,16 +67,32 @@ type SelectProtocolData struct {
// SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
func (c *Gateway) SelectProtocol(data SelectProtocol) error {
return c.Send(SelectProtocolOP, data)
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SelectProtocolCtx(ctx, data)
}
// SelectProtocolCtx sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
func (c *Gateway) SelectProtocolCtx(ctx context.Context, data SelectProtocol) error {
return c.SendCtx(ctx, SelectProtocolOP, data)
}
// OPCode 3
// https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
type Heartbeat uint64
// type Heartbeat uint64
// Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
func (c *Gateway) Heartbeat() error {
return c.Send(HeartbeatOP, time.Now().UnixNano())
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.HeartbeatCtx(ctx)
}
// HeartbeatCtx sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
return c.SendCtx(ctx, HeartbeatOP, time.Now().UnixNano())
}
// https://discordapp.com/developers/docs/topics/voice-connections#speaking
@ -89,10 +114,18 @@ type SpeakingData struct {
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (c *Gateway) Speaking(flag SpeakingFlag) error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SpeakingCtx(ctx, flag)
}
// SpeakingCtx sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (c *Gateway) SpeakingCtx(ctx context.Context, flag SpeakingFlag) error {
// How do we allow a user to stop speaking?
// Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation
return c.Send(SpeakingOP, SpeakingData{
return c.SendCtx(ctx, SpeakingOP, SpeakingData{
Speaking: flag,
Delay: 0,
SSRC: c.ready.SSRC,
@ -109,6 +142,13 @@ type ResumeData struct {
// Resume sends a Resume operation (opcode 7) to the Gateway Gateway.
func (c *Gateway) Resume() error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.ResumeCtx(ctx)
}
// ResumeCtx sends a Resume operation (opcode 7) to the Gateway Gateway.
func (c *Gateway) ResumeCtx(ctx context.Context) error {
guildID := c.state.GuildID
sessionID := c.state.SessionID
token := c.state.Token
@ -117,7 +157,7 @@ func (c *Gateway) Resume() error {
return ErrMissingForResume
}
return c.Send(ResumeOP, ResumeData{
return c.SendCtx(ctx, ResumeOP, ResumeData{
GuildID: guildID,
SessionID: sessionID,
Token: token,

View File

@ -85,8 +85,12 @@ func (c *Gateway) Ready() ReadyEvent {
return c.ready
}
// Open shouldn't be used, but JoinServer instead.
func (c *Gateway) Open() error {
// OpenCtx shouldn't be used, but JoinServer instead.
func (c *Gateway) OpenCtx(ctx context.Context) error {
if c.state.Endpoint == "" {
return errors.New("missing endpoint in state")
}
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
@ -94,7 +98,7 @@ func (c *Gateway) Open() error {
c.ws = wsutil.New(endpoint)
// Create a new context with a timeout for the connection.
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
defer cancel()
// Connect to the Gateway Gateway.
@ -105,7 +109,7 @@ func (c *Gateway) Open() error {
wsutil.WSDebug("Trying to start...")
// Try to start or resume the connection.
if err := c.start(); err != nil {
if err := c.start(ctx); err != nil {
return err
}
@ -113,8 +117,8 @@ func (c *Gateway) Open() error {
}
// Start .
func (c *Gateway) start() error {
if err := c.__start(); err != nil {
func (c *Gateway) start(ctx context.Context) error {
if err := c.__start(ctx); err != nil {
wsutil.WSDebug("Start failed: ", err)
// Close can be called with the mutex still acquired here, as the
@ -129,7 +133,7 @@ func (c *Gateway) start() error {
}
// this function blocks until READY.
func (c *Gateway) __start() error {
func (c *Gateway) __start(ctx context.Context) error {
// Make a new WaitGroup for use in background loops:
c.waitGroup = new(sync.WaitGroup)
@ -139,9 +143,17 @@ func (c *Gateway) __start() error {
wsutil.WSDebug("Waiting for Hello..")
var hello *HelloEvent
_, err := wsutil.AssertEvent(<-ch, HelloOP, &hello)
if err != nil {
return errors.Wrap(err, "error at Hello")
// Wait for the Hello event; return if it times out.
select {
case e, ok := <-ch:
if !ok {
return errors.New("unexpected ws close while waiting for Hello")
}
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
}
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
}
wsutil.WSDebug("Received Hello")
@ -149,11 +161,11 @@ func (c *Gateway) __start() error {
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
// Turns out Hello is sent right away on connection start.
if !c.reconnect.Get() {
if err := c.Identify(); err != nil {
if err := c.IdentifyCtx(ctx); err != nil {
return errors.Wrap(err, "failed to identify")
}
} else {
if err := c.Resume(); err != nil {
if err := c.ResumeCtx(ctx); err != nil {
return errors.Wrap(err, "failed to resume")
}
}
@ -161,7 +173,7 @@ func (c *Gateway) __start() error {
c.reconnect.Set(false)
// Wait for either Ready or Resumed.
err = wsutil.WaitForEvent(c, ch, func(op *wsutil.OP) bool {
err := wsutil.WaitForEvent(ctx, c, ch, func(op *wsutil.OP) bool {
return op.Code == ReadyOP || op.Code == ResumedOP
})
if err != nil {
@ -180,7 +192,7 @@ func (c *Gateway) __start() error {
if err != nil {
c.ErrorLog(err)
c.Reconnect()
c.ReconnectCtx(ctx)
// Reconnect should spawn another eventLoop in its Start function.
}
})
@ -226,7 +238,7 @@ func (c *Gateway) Close() error {
return err
}
func (c *Gateway) Reconnect() error {
func (c *Gateway) ReconnectCtx(ctx context.Context) error {
wsutil.WSDebug("Reconnecting...")
// Guarantee the gateway is already closed. Ignore its error, as we're
@ -239,7 +251,7 @@ func (c *Gateway) Reconnect() error {
// If the connection is rate limited (documented behavior):
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
if err := c.Open(); err != nil {
if err := c.OpenCtx(ctx); err != nil {
return errors.Wrap(err, "failed to reopen gateway")
}
@ -248,34 +260,46 @@ func (c *Gateway) Reconnect() error {
return nil
}
func (c *Gateway) SessionDescription(sp SelectProtocol) (*SessionDescriptionEvent, error) {
func (c *Gateway) SessionDescriptionCtx(
ctx context.Context, sp SelectProtocol) (*SessionDescriptionEvent, error) {
// Add the handler first.
ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool {
return op.Code == SessionDescriptionOP
})
defer cancel()
if err := c.SelectProtocol(sp); err != nil {
if err := c.SelectProtocolCtx(ctx, sp); err != nil {
return nil, err
}
var sesdesc *SessionDescriptionEvent
// Wait for SessionDescriptionOP packet.
if err := (<-ch).UnmarshalData(&sesdesc); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal session description")
select {
case e, ok := <-ch:
if !ok {
return nil, errors.New("unexpected close waiting for session description")
}
if err := e.UnmarshalData(&sesdesc); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal session description")
}
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "failed to wait for session description")
}
return sesdesc, nil
}
// Send .
// Send sends a payload to the Gateway with the default timeout.
func (c *Gateway) Send(code OPCode, v interface{}) error {
return c.send(code, v)
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SendCtx(ctx, code, v)
}
// send .
func (c *Gateway) send(code OPCode, v interface{}) error {
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
if c.ws == nil {
return errors.New("tried to send data to a connection without a Websocket")
}
@ -303,5 +327,5 @@ func (c *Gateway) send(code OPCode, v interface{}) error {
}
// WS should already be thread-safe.
return c.ws.Send(b)
return c.ws.SendCtx(ctx, b)
}