{Voice,}Gateway: Fixed various race conditions

This commit fixes race conditions in both package voice, package
voicegateway and package gateway.

Originally, several race conditions exist when both the user's and the
pacemaker's goroutines both want to do several things to the websocket
connection. For example, the user's goroutine could be writing, and the
pacemaker's goroutine could trigger a reconnection. This is racey.

This issue is partially fixed by removing the pacer loop from package
heart and combining the ticker into the event (pacemaker) loop itself.

Technically, a race condition could still be triggered with care, but
the API itself never guaranteed any of those. As events are handled
using an internal loop into a channel, a race condition will not be
triggered just by handling events and writing to the websocket.
This commit is contained in:
diamondburned 2020-10-22 10:47:27 -07:00
parent 91ee92e9d5
commit 6c332ac145
11 changed files with 377 additions and 289 deletions

View File

@ -161,40 +161,38 @@ func (g *Gateway) AddIntent(i Intents) {
}
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
func (g *Gateway) Close() (err error) {
wsutil.WSDebug("Trying to close.")
// Check if the WS is already closed:
if g.waitGroup == nil && g.PacerLoop.Stopped() {
if g.PacerLoop.Stopped() {
wsutil.WSDebug("Gateway is already closed.")
g.AfterClose(nil)
return nil
return err
}
// Trigger the close callback on exit.
defer func() { g.AfterClose(err) }()
// If the pacemaker is running:
if !g.PacerLoop.Stopped() {
wsutil.WSDebug("Stopping pacemaker...")
// Stop the pacemaker and the event handler
// Stop the pacemaker and the event handler.
g.PacerLoop.Stop()
wsutil.WSDebug("Stopped pacemaker.")
}
wsutil.WSDebug("Closing the websocket...")
err = g.WS.Close()
wsutil.WSDebug("Waiting for WaitGroup to be done.")
// This should work, since Pacemaker should signal its loop to stop, which
// would also exit our event loop. Both would be 2.
g.waitGroup.Wait()
// Mark g.waitGroup as empty:
g.waitGroup = nil
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
err := g.WS.Close()
g.AfterClose(err)
return err
}

View File

@ -46,7 +46,9 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
defer cancel()
// Server requesting a heartbeat.
return g.PacerLoop.Pace(ctx)
if err := g.PacerLoop.Pace(ctx); err != nil {
return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace"))
}
case ReconnectOP:
// Server requests to reconnect, die and retry.
@ -54,7 +56,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
// Exit with the ReconnectOP error to force the heartbeat event loop to
// reconnect synchronously. Not really a fatal error.
return ErrReconnectRequest
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
case InvalidSessionOP:
// Discord expects us to sleep for no reason
@ -66,7 +68,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
// Invalid session, try and Identify.
if err := g.IdentifyCtx(ctx); err != nil {
// Can't identify, reconnect.
go g.Reconnect()
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
}
return nil

1
go.sum
View File

@ -16,3 +16,4 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s=

View File

@ -3,7 +3,6 @@ package heart
import (
"context"
"sync"
"sync/atomic"
"time"
@ -36,23 +35,30 @@ type Pacemaker struct {
// Heartrate is the received duration between heartbeats.
Heartrate time.Duration
ticker time.Ticker
Ticks <-chan time.Time
// Time in nanoseconds, guarded by atomic read/writes.
SentBeat AtomicTime
EchoBeat AtomicTime
// Any callback that returns an error will stop the pacer.
Pace func(context.Context) error
stop chan struct{}
once sync.Once
death chan error
Pacer func(context.Context) error
}
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) *Pacemaker {
return &Pacemaker{
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker {
p := Pacemaker{
Heartrate: heartrate,
Pace: pacer,
Pacer: pacer,
ticker: *time.NewTicker(heartrate),
}
p.Ticks = p.ticker.C
// Reset states to its old position.
now := time.Now()
p.EchoBeat.Set(now)
p.SentBeat.Set(now)
return p
}
func (p *Pacemaker) Echo() {
@ -62,14 +68,6 @@ func (p *Pacemaker) Echo() {
// Dead, if true, will have Pace return an ErrDead.
func (p *Pacemaker) Dead() bool {
/* Deprecated
if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() {
return false
}
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
*/
var (
echo = p.EchoBeat.Get()
sent = p.SentBeat.Get()
@ -84,75 +82,84 @@ func (p *Pacemaker) Dead() bool {
// Stop stops the pacemaker, or it does nothing if the pacemaker is not started.
func (p *Pacemaker) Stop() {
Debug("(*Pacemaker).stop is trying sync.Once.")
p.once.Do(func() {
Debug("(*Pacemaker).stop closed the channel.")
close(p.stop)
})
p.ticker.Stop()
}
// pace sends a heartbeat with the appropriate timeout for the context.
func (p *Pacemaker) pace() error {
func (p *Pacemaker) Pace() error {
ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate)
defer cancel()
return p.Pace(ctx)
return p.PaceCtx(ctx)
}
func (p *Pacemaker) start() error {
// Reset states to its old position.
p.EchoBeat.Set(time.Time{})
p.SentBeat.Set(time.Time{})
// Create a new ticker.
tick := time.NewTicker(p.Heartrate)
defer tick.Stop()
// Echo at least once
p.Echo()
for {
if err := p.pace(); err != nil {
return errors.Wrap(err, "failed to pace")
}
// Paced, save:
p.SentBeat.Set(time.Now())
if p.Dead() {
return ErrDead
}
select {
case <-p.stop:
return nil
case <-tick.C:
}
}
}
// StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
p.death = make(chan error)
p.stop = make(chan struct{})
p.once = sync.Once{}
if wg != nil {
wg.Add(1)
func (p *Pacemaker) PaceCtx(ctx context.Context) error {
if err := p.Pacer(ctx); err != nil {
return err
}
go func() {
p.death <- p.start()
// Debug.
Debug("Pacemaker returned.")
p.SentBeat.Set(time.Now())
// Mark the pacemaker loop as done.
if wg != nil {
wg.Done()
}
}()
if p.Dead() {
return ErrDead
}
return p.death
return nil
}
// func (p *Pacemaker) start() error {
// // Reset states to its old position.
// p.EchoBeat.Set(time.Time{})
// p.SentBeat.Set(time.Time{})
// // Create a new ticker.
// tick := time.NewTicker(p.Heartrate)
// defer tick.Stop()
// // Echo at least once
// p.Echo()
// for {
// if err := p.pace(); err != nil {
// return errors.Wrap(err, "failed to pace")
// }
// // Paced, save:
// p.SentBeat.Set(time.Now())
// if p.Dead() {
// return ErrDead
// }
// select {
// case <-p.stop:
// return nil
// case <-tick.C:
// }
// }
// }
// // StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
// func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
// p.death = make(chan error)
// p.stop = make(chan struct{})
// p.once = sync.Once{}
// if wg != nil {
// wg.Add(1)
// }
// go func() {
// p.death <- p.start()
// // Debug.
// Debug("Pacemaker returned.")
// // Mark the pacemaker loop as done.
// if wg != nil {
// wg.Done()
// }
// }()
// return p.death
// }

View File

@ -2,6 +2,7 @@ package wsutil
import (
"context"
"runtime/debug"
"time"
"github.com/pkg/errors"
@ -10,6 +11,34 @@ import (
"github.com/diamondburned/arikawa/internal/moreatomic"
)
type errBrokenConnection struct {
underneath error
}
// Error formats the broken connection error with the message "explicit
// connection break."
func (err errBrokenConnection) Error() string {
return "explicit connection break: " + err.underneath.Error()
}
// Unwrap returns the underlying error.
func (err errBrokenConnection) Unwrap() error {
return err.underneath
}
// ErrBrokenConnection marks the given error as a broken connection error. This
// error will cause the pacemaker loop to break and return the error. The error,
// when stringified, will say "explicit connection break."
func ErrBrokenConnection(err error) error {
return errBrokenConnection{underneath: err}
}
// IsBrokenConnection returns true if the error is a broken connection error.
func IsBrokenConnection(err error) bool {
var broken *errBrokenConnection
return errors.As(err, &broken)
}
// TODO API
type EventLoopHandler interface {
EventHandler
@ -19,14 +48,15 @@ type EventLoopHandler interface {
// PacemakerLoop provides an event loop with a pacemaker. A zero-value instance
// is a valid instance only when RunAsync is called first.
type PacemakerLoop struct {
pacemaker *heart.Pacemaker // let's not copy this
pacedeath chan error
heart.Pacemaker
running moreatomic.Bool
stop chan struct{}
events <-chan Event
handler func(*OP) error
stack []byte
Extras ExtraHandlers
ErrorLog func(error)
@ -43,17 +73,19 @@ func (p *PacemakerLoop) errorLog(err error) {
// Pace calls the pacemaker's Pace function.
func (p *PacemakerLoop) Pace(ctx context.Context) error {
return p.pacemaker.Pace(ctx)
return p.Pacemaker.PaceCtx(ctx)
}
// Echo calls the pacemaker's Echo function.
func (p *PacemakerLoop) Echo() {
p.pacemaker.Echo()
}
// Stop calls the pacemaker's Stop function.
// Stop stops the pacer loop. It does nothing if the loop is already stopped.
func (p *PacemakerLoop) Stop() {
p.pacemaker.Stop()
if p.Stopped() {
return
}
// Despite p.running and p.stop being thread-safe on their own, this entire
// block is actually not thread-safe.
p.Pacemaker.Stop()
close(p.stop)
}
func (p *PacemakerLoop) Stopped() bool {
@ -65,12 +97,12 @@ func (p *PacemakerLoop) RunAsync(
WSDebug("Starting the pacemaker loop.")
p.pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx)
p.events = evs
p.Pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx)
p.handler = evl.HandleOP
p.events = evs
p.stack = debug.Stack()
p.stop = make(chan struct{})
// callers should explicitly handle waitgroups.
p.pacedeath = p.pacemaker.StartAsync(nil)
p.running.Set(true)
go func() {
@ -81,21 +113,27 @@ func (p *PacemakerLoop) RunAsync(
func (p *PacemakerLoop) startLoop() error {
defer WSDebug("Pacemaker loop has exited.")
defer p.running.Set(false)
defer p.Pacemaker.Stop()
for {
select {
case err := <-p.pacedeath:
WSDebug("Pacedeath returned with error:", err)
return errors.Wrap(err, "pacemaker died, reconnecting")
case <-p.stop:
WSDebug("Stop requested; exiting.")
return nil
case <-p.Pacemaker.Ticks:
if err := p.Pacemaker.Pace(); err != nil {
return errors.Wrap(err, "pace failed, reconnecting")
}
case ev, ok := <-p.events:
if !ok {
WSDebug("Events channel closed, stopping pacemaker.")
defer WSDebug("Pacemaker stopped automatically.")
// Events channel is closed. Kill the pacemaker manually and
// die.
p.pacemaker.Stop()
return <-p.pacedeath
return nil
}
if ev.Error != nil {
return errors.Wrap(ev.Error, "event returned error")
}
o, err := DecodeOP(ev)
@ -108,7 +146,11 @@ func (p *PacemakerLoop) startLoop() error {
// Handle the event
if err := p.handler(o); err != nil {
return errors.Wrap(err, "handler failed")
if IsBrokenConnection(err) {
return errors.Wrap(err, "handler failed")
}
p.errorLog(err)
}
}
}

View File

@ -79,9 +79,6 @@ func (ws *Websocket) Dial(ctx context.Context) error {
return errors.Wrap(err, "failed to dial")
}
// Reset the SendLimiter:
ws.SendLimiter = NewSendLimiter()
return nil
}

View File

@ -3,6 +3,7 @@
package voice
import (
"context"
"encoding/binary"
"io"
"log"
@ -18,73 +19,6 @@ import (
"github.com/diamondburned/arikawa/voice/voicegateway"
)
type testConfig struct {
BotToken string
VoiceChID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
t.Fatal("Missing $VOICE_ID")
}
id, err := discord.ParseSnowflake(sid)
if err != nil {
t.Fatal("Invalid $VOICE_ID:", err)
}
return testConfig{
BotToken: token,
VoiceChID: discord.ChannelID(id),
}
}
// file is only a few bytes lolmao
func nicoReadTo(t *testing.T, dst io.Writer) {
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("Failed to open nico.dca:", err)
}
t.Cleanup(func() {
f.Close()
})
var lenbuf [4]byte
for {
if _, err := io.ReadFull(f, lenbuf[:]); !catchRead(t, err) {
return
}
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(dst, f, framelen); !catchRead(t, err) {
return
}
}
}
func catchRead(t *testing.T, err error) bool {
t.Helper()
if err == io.EOF {
return false
}
if err != nil {
t.Fatal("Failed to read:", err)
}
return true
}
func TestIntegration(t *testing.T) {
config := mustConfig(t)
@ -150,12 +84,76 @@ func TestIntegration(t *testing.T) {
finish("sending the speaking command")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := vs.UseContext(ctx); err != nil {
t.Fatal("failed to set ctx into vs:", err)
}
// Copy the audio?
nicoReadTo(t, vs)
finish("copying the audio")
}
type testConfig struct {
BotToken string
VoiceChID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
t.Fatal("Missing $VOICE_ID")
}
id, err := discord.ParseSnowflake(sid)
if err != nil {
t.Fatal("Invalid $VOICE_ID:", err)
}
return testConfig{
BotToken: token,
VoiceChID: discord.ChannelID(id),
}
}
// file is only a few bytes lolmao
func nicoReadTo(t *testing.T, dst io.Writer) {
t.Helper()
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("Failed to open nico.dca:", err)
}
defer f.Close()
var lenbuf [4]byte
for {
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
if err == io.EOF {
break
}
t.Fatal("failed to read:", err)
}
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(dst, f, framelen); err != nil && err != io.EOF {
t.Fatal("failed to write:", err)
}
}
}
// simple shitty benchmark thing
func timer() func(finished string) {
var then = time.Now()

View File

@ -110,14 +110,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
}
}
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
func (s *Session) JoinChannel(
gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
}
func (s *Session) JoinChannelCtx(ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
func (s *Session) JoinChannelCtx(
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
// Acquire the mutex during join, locking during IO as well.
s.mut.Lock()
defer s.mut.Unlock()
@ -211,8 +215,7 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
return errors.Wrap(err, "failed to select protocol")
}
// Start the UDP loop.
go s.voiceUDP.Start(&d.SecretKey)
s.voiceUDP.UseSecret(d.SecretKey)
return nil
}
@ -237,6 +240,18 @@ func (s *Session) StopSpeaking() error {
return nil
}
// UseContext tells the UDP voice connection to write with the given mutex.
func (s *Session) UseContext(ctx context.Context) error {
s.mut.RLock()
defer s.mut.RUnlock()
if s.voiceUDP == nil {
return ErrCannotSend
}
return s.voiceUDP.UseContext(ctx)
}
// Write writes into the UDP voice connection WITHOUT a timeout.
func (s *Session) Write(b []byte) (int, error) {
return s.WriteCtx(context.Background(), b)

View File

@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"
"golang.org/x/crypto/nacl/secretbox"
"golang.org/x/time/rate"
)
// Dialer is the default dialer that this package uses for all its dialing.
@ -17,31 +18,29 @@ var Dialer = net.Dialer{
Timeout: 10 * time.Second,
}
// ErrClosed is returned if a Write was called on a closed connection.
var ErrClosed = errors.New("UDP connection closed")
type Connection struct {
GatewayIP string
GatewayPort uint16
ssrc uint32
mutex chan struct{} // for ctx
context context.Context
conn net.Conn
ssrc uint32
frequency rate.Limiter
packet [12]byte
secret [32]byte
sequence uint16
timestamp uint32
nonce [24]byte
conn net.Conn
close chan struct{}
closed chan struct{}
send chan []byte
reply chan error
}
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
// // Resolve the host.
// a, err := net.ResolveUDPAddr("udp", addr)
// if err != nil {
// return nil, errors.Wrap(err, "failed to resolve host")
// }
// Create a new UDP connection.
conn, err := Dialer.DialContext(ctx, "udp", addr)
if err != nil {
@ -78,20 +77,6 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti
ip := ipbody[:nullPos]
port := binary.LittleEndian.Uint16(ipBuffer[68:70])
return &Connection{
GatewayIP: string(ip),
GatewayPort: port,
ssrc: ssrc,
conn: conn,
send: make(chan []byte),
reply: make(chan error),
close: make(chan struct{}),
closed: make(chan struct{}),
}, nil
}
func (c *Connection) Start(secret *[32]byte) {
// https://discordapp.com/developers/docs/topics/voice-connections#encrypting-and-sending-voice
packet := [12]byte{
0: 0x80, // Version + Flags
@ -101,81 +86,118 @@ func (c *Connection) Start(secret *[32]byte) {
}
// Write SSRC to the header.
binary.BigEndian.PutUint32(packet[8:12], c.ssrc) // SSRC
binary.BigEndian.PutUint32(packet[8:12], ssrc) // SSRC
// 50 sends per second, 960 samples each at 48kHz
frequency := time.NewTicker(time.Millisecond * 20)
defer frequency.Stop()
return &Connection{
GatewayIP: string(ip),
GatewayPort: port,
// 50 sends per second, 960 samples each at 48kHz
frequency: *rate.NewLimiter(rate.Every(20*time.Millisecond), 1),
context: context.Background(),
mutex: make(chan struct{}, 1),
packet: packet,
ssrc: ssrc,
conn: conn,
}, nil
}
var b []byte
var ok bool
// UseSecret uses the given secret. This method is not thread-safe, so it should
// only be used right after initialization.
func (c *Connection) UseSecret(secret [32]byte) {
c.secret = secret
}
// Close these channels at the end so Write() doesn't block.
defer func() {
close(c.send)
close(c.closed)
}()
// UseContext lets the connection use the given context for its Write method.
// WriteCtx will override this context.
func (c *Connection) UseContext(ctx context.Context) error {
c.mutex <- struct{}{}
defer func() { <-c.mutex }()
for {
select {
case b, ok = <-c.send:
if !ok {
return
}
case <-c.close:
return
}
return c.useContext(ctx)
}
// Write a new sequence.
binary.BigEndian.PutUint16(packet[2:4], c.sequence)
c.sequence++
func (c *Connection) useContext(ctx context.Context) error {
if c.conn == nil {
return ErrClosed
}
binary.BigEndian.PutUint32(packet[4:8], c.timestamp)
c.timestamp += 960 // Samples
if c.context == ctx {
return nil
}
copy(c.nonce[:], packet[:])
c.context = ctx
toSend := secretbox.Seal(packet[:], b, &c.nonce, secret)
select {
case <-frequency.C:
case <-c.close:
// Prevent Write() from stalling before exiting.
c.reply <- nil
return
}
_, err := c.conn.Write(toSend)
c.reply <- err
if deadline, ok := c.context.Deadline(); ok {
return c.conn.SetWriteDeadline(deadline)
} else {
return c.conn.SetWriteDeadline(time.Time{})
}
}
func (c *Connection) Close() error {
close(c.close)
<-c.closed
return c.conn.Close()
c.mutex <- struct{}{}
err := c.conn.Close()
c.conn = nil
<-c.mutex
return err
}
// Write sends bytes into the voice UDP connection.
func (c *Connection) Write(b []byte) (int, error) {
return c.WriteCtx(context.Background(), b)
select {
case c.mutex <- struct{}{}:
defer func() { <-c.mutex }()
case <-c.context.Done():
return 0, c.context.Err()
}
if c.conn == nil {
return 0, ErrClosed
}
return c.write(b)
}
// WriteCtx sends bytes into the voice UDP connection with a timeout.
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
select {
case c.send <- b:
break
case c.mutex <- struct{}{}:
defer func() { <-c.mutex }()
case <-c.context.Done():
return 0, c.context.Err()
case <-ctx.Done():
return 0, ctx.Err()
}
select {
case err := <-c.reply:
return len(b), err
case <-ctx.Done():
return len(b), ctx.Err()
if err := c.useContext(ctx); err != nil {
return 0, errors.Wrap(err, "failed to use context")
}
return c.write(b)
}
// write is thread-unsafe.
func (c *Connection) write(b []byte) (int, error) {
// Write a new sequence.
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
c.sequence++
binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp)
c.timestamp += 960 // Samples
copy(c.nonce[:], c.packet[:])
if err := c.frequency.Wait(c.context); err != nil {
return 0, errors.Wrap(err, "failed to wait for frequency tick")
}
toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret)
n, err := c.conn.Write(toSend)
if err != nil {
return n, errors.Wrap(err, "failed to write to UDP connection")
}
// We're not really returning everything, since we're "sealing" the bytes.
return len(b), nil
}

View File

@ -5,9 +5,11 @@
package voice
import (
"context"
"log"
"strconv"
"sync"
"time"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
@ -170,9 +172,13 @@ func (v *Voice) Close() error {
}
for gID, s := range v.sessions {
if dErr := s.Disconnect(); dErr != nil {
log.Println("closing", gID)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
if dErr := s.DisconnectCtx(ctx); dErr != nil {
err.SessionErrors[gID] = dErr
}
cancel()
log.Println("closed", gID)
}
err.StateErr = v.State.Close()

View File

@ -51,7 +51,7 @@ type Gateway struct {
mutex sync.RWMutex
ready ReadyEvent
ws *wsutil.Websocket
WS *wsutil.Websocket
Timeout time.Duration
reconnect moreatomic.Bool
@ -96,14 +96,14 @@ func (c *Gateway) OpenCtx(ctx context.Context) error {
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
wsutil.WSDebug("Connecting to voice endpoint (endpoint=" + endpoint + ")")
c.ws = wsutil.New(endpoint)
c.WS = wsutil.New(endpoint)
// Create a new context with a timeout for the connection.
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
defer cancel()
// Connect to the Gateway Gateway.
if err := c.ws.Dial(ctx); err != nil {
if err := c.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "failed to connect to voice gateway")
}
@ -138,7 +138,7 @@ func (c *Gateway) __start(ctx context.Context) error {
// Make a new WaitGroup for use in background loops:
c.waitGroup = new(sync.WaitGroup)
ch := c.ws.Listen()
ch := c.WS.Listen()
// Wait for hello.
wsutil.WSDebug("Waiting for Hello..")
@ -205,38 +205,38 @@ func (c *Gateway) __start(ctx context.Context) error {
}
// Close .
func (c *Gateway) Close() error {
// Check if the WS is already closed:
if c.waitGroup == nil && c.EventLoop.Stopped() {
wsutil.WSDebug("Gateway is already closed.")
func (c *Gateway) Close() (err error) {
wsutil.WSDebug("Trying to close.")
c.AfterClose(nil)
return nil
// Check if the WS is already closed:
if c.EventLoop.Stopped() {
wsutil.WSDebug("Gateway is already closed.")
return err
}
// Trigger the close callback on exit.
defer func() { c.AfterClose(err) }()
// If the pacemaker is running:
if !c.EventLoop.Stopped() {
wsutil.WSDebug("Stopping pacemaker...")
// Stop the pacemaker and the event handler
// Stop the pacemaker and the event handler.
c.EventLoop.Stop()
wsutil.WSDebug("Stopped pacemaker.")
}
wsutil.WSDebug("Closing the websocket...")
err = c.WS.Close()
wsutil.WSDebug("Waiting for WaitGroup to be done.")
// This should work, since Pacemaker should signal its loop to stop, which
// would also exit our event loop. Both would be 2.
c.waitGroup.Wait()
// Mark g.waitGroup as empty:
c.waitGroup = nil
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
err := c.ws.Close()
c.AfterClose(err)
return err
}
@ -308,11 +308,11 @@ func (c *Gateway) Send(code OPCode, v interface{}) error {
}
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
if c.ws == nil {
if c.WS == nil {
return errors.New("tried to send data to a connection without a Websocket")
}
if c.ws.Conn == nil {
if c.WS.Conn == nil {
return errors.New("tried to send data to a connection with a closed Websocket")
}
@ -335,5 +335,5 @@ func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error
}
// WS should already be thread-safe.
return c.ws.SendCtx(ctx, b)
return c.WS.SendCtx(ctx, b)
}