mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-12-21 12:47:16 +00:00
{Voice,}Gateway: Fixed various race conditions
This commit fixes race conditions in both package voice, package voicegateway and package gateway. Originally, several race conditions exist when both the user's and the pacemaker's goroutines both want to do several things to the websocket connection. For example, the user's goroutine could be writing, and the pacemaker's goroutine could trigger a reconnection. This is racey. This issue is partially fixed by removing the pacer loop from package heart and combining the ticker into the event (pacemaker) loop itself. Technically, a race condition could still be triggered with care, but the API itself never guaranteed any of those. As events are handled using an internal loop into a channel, a race condition will not be triggered just by handling events and writing to the websocket.
This commit is contained in:
parent
91ee92e9d5
commit
6c332ac145
|
@ -161,40 +161,38 @@ func (g *Gateway) AddIntent(i Intents) {
|
|||
}
|
||||
|
||||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
func (g *Gateway) Close() (err error) {
|
||||
wsutil.WSDebug("Trying to close.")
|
||||
|
||||
// Check if the WS is already closed:
|
||||
if g.waitGroup == nil && g.PacerLoop.Stopped() {
|
||||
if g.PacerLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
|
||||
g.AfterClose(nil)
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Trigger the close callback on exit.
|
||||
defer func() { g.AfterClose(err) }()
|
||||
|
||||
// If the pacemaker is running:
|
||||
if !g.PacerLoop.Stopped() {
|
||||
wsutil.WSDebug("Stopping pacemaker...")
|
||||
|
||||
// Stop the pacemaker and the event handler
|
||||
// Stop the pacemaker and the event handler.
|
||||
g.PacerLoop.Stop()
|
||||
|
||||
wsutil.WSDebug("Stopped pacemaker.")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Closing the websocket...")
|
||||
err = g.WS.Close()
|
||||
|
||||
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||
|
||||
// This should work, since Pacemaker should signal its loop to stop, which
|
||||
// would also exit our event loop. Both would be 2.
|
||||
g.waitGroup.Wait()
|
||||
|
||||
// Mark g.waitGroup as empty:
|
||||
g.waitGroup = nil
|
||||
|
||||
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
|
||||
err := g.WS.Close()
|
||||
g.AfterClose(err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -46,7 +46,9 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
defer cancel()
|
||||
|
||||
// Server requesting a heartbeat.
|
||||
return g.PacerLoop.Pace(ctx)
|
||||
if err := g.PacerLoop.Pace(ctx); err != nil {
|
||||
return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace"))
|
||||
}
|
||||
|
||||
case ReconnectOP:
|
||||
// Server requests to reconnect, die and retry.
|
||||
|
@ -54,7 +56,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
|
||||
// Exit with the ReconnectOP error to force the heartbeat event loop to
|
||||
// reconnect synchronously. Not really a fatal error.
|
||||
return ErrReconnectRequest
|
||||
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
|
||||
|
||||
case InvalidSessionOP:
|
||||
// Discord expects us to sleep for no reason
|
||||
|
@ -66,7 +68,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
// Invalid session, try and Identify.
|
||||
if err := g.IdentifyCtx(ctx); err != nil {
|
||||
// Can't identify, reconnect.
|
||||
go g.Reconnect()
|
||||
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
1
go.sum
1
go.sum
|
@ -16,3 +16,4 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI=
|
||||
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s=
|
||||
|
|
|
@ -3,7 +3,6 @@ package heart
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -36,23 +35,30 @@ type Pacemaker struct {
|
|||
// Heartrate is the received duration between heartbeats.
|
||||
Heartrate time.Duration
|
||||
|
||||
ticker time.Ticker
|
||||
Ticks <-chan time.Time
|
||||
|
||||
// Time in nanoseconds, guarded by atomic read/writes.
|
||||
SentBeat AtomicTime
|
||||
EchoBeat AtomicTime
|
||||
|
||||
// Any callback that returns an error will stop the pacer.
|
||||
Pace func(context.Context) error
|
||||
|
||||
stop chan struct{}
|
||||
once sync.Once
|
||||
death chan error
|
||||
Pacer func(context.Context) error
|
||||
}
|
||||
|
||||
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) *Pacemaker {
|
||||
return &Pacemaker{
|
||||
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker {
|
||||
p := Pacemaker{
|
||||
Heartrate: heartrate,
|
||||
Pace: pacer,
|
||||
Pacer: pacer,
|
||||
ticker: *time.NewTicker(heartrate),
|
||||
}
|
||||
p.Ticks = p.ticker.C
|
||||
// Reset states to its old position.
|
||||
now := time.Now()
|
||||
p.EchoBeat.Set(now)
|
||||
p.SentBeat.Set(now)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pacemaker) Echo() {
|
||||
|
@ -62,14 +68,6 @@ func (p *Pacemaker) Echo() {
|
|||
|
||||
// Dead, if true, will have Pace return an ErrDead.
|
||||
func (p *Pacemaker) Dead() bool {
|
||||
/* Deprecated
|
||||
if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
|
||||
*/
|
||||
|
||||
var (
|
||||
echo = p.EchoBeat.Get()
|
||||
sent = p.SentBeat.Get()
|
||||
|
@ -84,75 +82,84 @@ func (p *Pacemaker) Dead() bool {
|
|||
|
||||
// Stop stops the pacemaker, or it does nothing if the pacemaker is not started.
|
||||
func (p *Pacemaker) Stop() {
|
||||
Debug("(*Pacemaker).stop is trying sync.Once.")
|
||||
|
||||
p.once.Do(func() {
|
||||
Debug("(*Pacemaker).stop closed the channel.")
|
||||
close(p.stop)
|
||||
})
|
||||
p.ticker.Stop()
|
||||
}
|
||||
|
||||
// pace sends a heartbeat with the appropriate timeout for the context.
|
||||
func (p *Pacemaker) pace() error {
|
||||
func (p *Pacemaker) Pace() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate)
|
||||
defer cancel()
|
||||
|
||||
return p.Pace(ctx)
|
||||
return p.PaceCtx(ctx)
|
||||
}
|
||||
|
||||
func (p *Pacemaker) start() error {
|
||||
// Reset states to its old position.
|
||||
p.EchoBeat.Set(time.Time{})
|
||||
p.SentBeat.Set(time.Time{})
|
||||
|
||||
// Create a new ticker.
|
||||
tick := time.NewTicker(p.Heartrate)
|
||||
defer tick.Stop()
|
||||
|
||||
// Echo at least once
|
||||
p.Echo()
|
||||
|
||||
for {
|
||||
if err := p.pace(); err != nil {
|
||||
return errors.Wrap(err, "failed to pace")
|
||||
}
|
||||
|
||||
// Paced, save:
|
||||
p.SentBeat.Set(time.Now())
|
||||
|
||||
if p.Dead() {
|
||||
return ErrDead
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.stop:
|
||||
return nil
|
||||
|
||||
case <-tick.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
|
||||
func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
||||
p.death = make(chan error)
|
||||
p.stop = make(chan struct{})
|
||||
p.once = sync.Once{}
|
||||
|
||||
if wg != nil {
|
||||
wg.Add(1)
|
||||
func (p *Pacemaker) PaceCtx(ctx context.Context) error {
|
||||
if err := p.Pacer(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.death <- p.start()
|
||||
// Debug.
|
||||
Debug("Pacemaker returned.")
|
||||
p.SentBeat.Set(time.Now())
|
||||
|
||||
// Mark the pacemaker loop as done.
|
||||
if wg != nil {
|
||||
wg.Done()
|
||||
}
|
||||
}()
|
||||
if p.Dead() {
|
||||
return ErrDead
|
||||
}
|
||||
|
||||
return p.death
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (p *Pacemaker) start() error {
|
||||
// // Reset states to its old position.
|
||||
// p.EchoBeat.Set(time.Time{})
|
||||
// p.SentBeat.Set(time.Time{})
|
||||
|
||||
// // Create a new ticker.
|
||||
// tick := time.NewTicker(p.Heartrate)
|
||||
// defer tick.Stop()
|
||||
|
||||
// // Echo at least once
|
||||
// p.Echo()
|
||||
|
||||
// for {
|
||||
// if err := p.pace(); err != nil {
|
||||
// return errors.Wrap(err, "failed to pace")
|
||||
// }
|
||||
|
||||
// // Paced, save:
|
||||
// p.SentBeat.Set(time.Now())
|
||||
|
||||
// if p.Dead() {
|
||||
// return ErrDead
|
||||
// }
|
||||
|
||||
// select {
|
||||
// case <-p.stop:
|
||||
// return nil
|
||||
|
||||
// case <-tick.C:
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
|
||||
// func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
||||
// p.death = make(chan error)
|
||||
// p.stop = make(chan struct{})
|
||||
// p.once = sync.Once{}
|
||||
|
||||
// if wg != nil {
|
||||
// wg.Add(1)
|
||||
// }
|
||||
|
||||
// go func() {
|
||||
// p.death <- p.start()
|
||||
// // Debug.
|
||||
// Debug("Pacemaker returned.")
|
||||
|
||||
// // Mark the pacemaker loop as done.
|
||||
// if wg != nil {
|
||||
// wg.Done()
|
||||
// }
|
||||
// }()
|
||||
|
||||
// return p.death
|
||||
// }
|
||||
|
|
|
@ -2,6 +2,7 @@ package wsutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -10,6 +11,34 @@ import (
|
|||
"github.com/diamondburned/arikawa/internal/moreatomic"
|
||||
)
|
||||
|
||||
type errBrokenConnection struct {
|
||||
underneath error
|
||||
}
|
||||
|
||||
// Error formats the broken connection error with the message "explicit
|
||||
// connection break."
|
||||
func (err errBrokenConnection) Error() string {
|
||||
return "explicit connection break: " + err.underneath.Error()
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (err errBrokenConnection) Unwrap() error {
|
||||
return err.underneath
|
||||
}
|
||||
|
||||
// ErrBrokenConnection marks the given error as a broken connection error. This
|
||||
// error will cause the pacemaker loop to break and return the error. The error,
|
||||
// when stringified, will say "explicit connection break."
|
||||
func ErrBrokenConnection(err error) error {
|
||||
return errBrokenConnection{underneath: err}
|
||||
}
|
||||
|
||||
// IsBrokenConnection returns true if the error is a broken connection error.
|
||||
func IsBrokenConnection(err error) bool {
|
||||
var broken *errBrokenConnection
|
||||
return errors.As(err, &broken)
|
||||
}
|
||||
|
||||
// TODO API
|
||||
type EventLoopHandler interface {
|
||||
EventHandler
|
||||
|
@ -19,14 +48,15 @@ type EventLoopHandler interface {
|
|||
// PacemakerLoop provides an event loop with a pacemaker. A zero-value instance
|
||||
// is a valid instance only when RunAsync is called first.
|
||||
type PacemakerLoop struct {
|
||||
pacemaker *heart.Pacemaker // let's not copy this
|
||||
pacedeath chan error
|
||||
|
||||
heart.Pacemaker
|
||||
running moreatomic.Bool
|
||||
|
||||
stop chan struct{}
|
||||
events <-chan Event
|
||||
handler func(*OP) error
|
||||
|
||||
stack []byte
|
||||
|
||||
Extras ExtraHandlers
|
||||
|
||||
ErrorLog func(error)
|
||||
|
@ -43,17 +73,19 @@ func (p *PacemakerLoop) errorLog(err error) {
|
|||
|
||||
// Pace calls the pacemaker's Pace function.
|
||||
func (p *PacemakerLoop) Pace(ctx context.Context) error {
|
||||
return p.pacemaker.Pace(ctx)
|
||||
return p.Pacemaker.PaceCtx(ctx)
|
||||
}
|
||||
|
||||
// Echo calls the pacemaker's Echo function.
|
||||
func (p *PacemakerLoop) Echo() {
|
||||
p.pacemaker.Echo()
|
||||
}
|
||||
|
||||
// Stop calls the pacemaker's Stop function.
|
||||
// Stop stops the pacer loop. It does nothing if the loop is already stopped.
|
||||
func (p *PacemakerLoop) Stop() {
|
||||
p.pacemaker.Stop()
|
||||
if p.Stopped() {
|
||||
return
|
||||
}
|
||||
|
||||
// Despite p.running and p.stop being thread-safe on their own, this entire
|
||||
// block is actually not thread-safe.
|
||||
p.Pacemaker.Stop()
|
||||
close(p.stop)
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) Stopped() bool {
|
||||
|
@ -65,12 +97,12 @@ func (p *PacemakerLoop) RunAsync(
|
|||
|
||||
WSDebug("Starting the pacemaker loop.")
|
||||
|
||||
p.pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx)
|
||||
p.events = evs
|
||||
p.Pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx)
|
||||
p.handler = evl.HandleOP
|
||||
p.events = evs
|
||||
p.stack = debug.Stack()
|
||||
p.stop = make(chan struct{})
|
||||
|
||||
// callers should explicitly handle waitgroups.
|
||||
p.pacedeath = p.pacemaker.StartAsync(nil)
|
||||
p.running.Set(true)
|
||||
|
||||
go func() {
|
||||
|
@ -81,21 +113,27 @@ func (p *PacemakerLoop) RunAsync(
|
|||
func (p *PacemakerLoop) startLoop() error {
|
||||
defer WSDebug("Pacemaker loop has exited.")
|
||||
defer p.running.Set(false)
|
||||
defer p.Pacemaker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-p.pacedeath:
|
||||
WSDebug("Pacedeath returned with error:", err)
|
||||
return errors.Wrap(err, "pacemaker died, reconnecting")
|
||||
case <-p.stop:
|
||||
WSDebug("Stop requested; exiting.")
|
||||
return nil
|
||||
|
||||
case <-p.Pacemaker.Ticks:
|
||||
if err := p.Pacemaker.Pace(); err != nil {
|
||||
return errors.Wrap(err, "pace failed, reconnecting")
|
||||
}
|
||||
|
||||
case ev, ok := <-p.events:
|
||||
if !ok {
|
||||
WSDebug("Events channel closed, stopping pacemaker.")
|
||||
defer WSDebug("Pacemaker stopped automatically.")
|
||||
// Events channel is closed. Kill the pacemaker manually and
|
||||
// die.
|
||||
p.pacemaker.Stop()
|
||||
return <-p.pacedeath
|
||||
return nil
|
||||
}
|
||||
|
||||
if ev.Error != nil {
|
||||
return errors.Wrap(ev.Error, "event returned error")
|
||||
}
|
||||
|
||||
o, err := DecodeOP(ev)
|
||||
|
@ -108,7 +146,11 @@ func (p *PacemakerLoop) startLoop() error {
|
|||
|
||||
// Handle the event
|
||||
if err := p.handler(o); err != nil {
|
||||
return errors.Wrap(err, "handler failed")
|
||||
if IsBrokenConnection(err) {
|
||||
return errors.Wrap(err, "handler failed")
|
||||
}
|
||||
|
||||
p.errorLog(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,9 +79,6 @@ func (ws *Websocket) Dial(ctx context.Context) error {
|
|||
return errors.Wrap(err, "failed to dial")
|
||||
}
|
||||
|
||||
// Reset the SendLimiter:
|
||||
ws.SendLimiter = NewSendLimiter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
|
@ -18,73 +19,6 @@ import (
|
|||
"github.com/diamondburned/arikawa/voice/voicegateway"
|
||||
)
|
||||
|
||||
type testConfig struct {
|
||||
BotToken string
|
||||
VoiceChID discord.ChannelID
|
||||
}
|
||||
|
||||
func mustConfig(t *testing.T) testConfig {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
t.Fatal("Missing $BOT_TOKEN")
|
||||
}
|
||||
|
||||
var sid = os.Getenv("VOICE_ID")
|
||||
if sid == "" {
|
||||
t.Fatal("Missing $VOICE_ID")
|
||||
}
|
||||
|
||||
id, err := discord.ParseSnowflake(sid)
|
||||
if err != nil {
|
||||
t.Fatal("Invalid $VOICE_ID:", err)
|
||||
}
|
||||
|
||||
return testConfig{
|
||||
BotToken: token,
|
||||
VoiceChID: discord.ChannelID(id),
|
||||
}
|
||||
}
|
||||
|
||||
// file is only a few bytes lolmao
|
||||
func nicoReadTo(t *testing.T, dst io.Writer) {
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open nico.dca:", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
f.Close()
|
||||
})
|
||||
|
||||
var lenbuf [4]byte
|
||||
|
||||
for {
|
||||
if _, err := io.ReadFull(f, lenbuf[:]); !catchRead(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Read the integer
|
||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||
|
||||
// Copy the frame.
|
||||
if _, err := io.CopyN(dst, f, framelen); !catchRead(t, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func catchRead(t *testing.T, err error) bool {
|
||||
t.Helper()
|
||||
|
||||
if err == io.EOF {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("Failed to read:", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
config := mustConfig(t)
|
||||
|
||||
|
@ -150,12 +84,76 @@ func TestIntegration(t *testing.T) {
|
|||
|
||||
finish("sending the speaking command")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := vs.UseContext(ctx); err != nil {
|
||||
t.Fatal("failed to set ctx into vs:", err)
|
||||
}
|
||||
|
||||
// Copy the audio?
|
||||
nicoReadTo(t, vs)
|
||||
|
||||
finish("copying the audio")
|
||||
}
|
||||
|
||||
type testConfig struct {
|
||||
BotToken string
|
||||
VoiceChID discord.ChannelID
|
||||
}
|
||||
|
||||
func mustConfig(t *testing.T) testConfig {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
t.Fatal("Missing $BOT_TOKEN")
|
||||
}
|
||||
|
||||
var sid = os.Getenv("VOICE_ID")
|
||||
if sid == "" {
|
||||
t.Fatal("Missing $VOICE_ID")
|
||||
}
|
||||
|
||||
id, err := discord.ParseSnowflake(sid)
|
||||
if err != nil {
|
||||
t.Fatal("Invalid $VOICE_ID:", err)
|
||||
}
|
||||
|
||||
return testConfig{
|
||||
BotToken: token,
|
||||
VoiceChID: discord.ChannelID(id),
|
||||
}
|
||||
}
|
||||
|
||||
// file is only a few bytes lolmao
|
||||
func nicoReadTo(t *testing.T, dst io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open nico.dca:", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var lenbuf [4]byte
|
||||
|
||||
for {
|
||||
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
t.Fatal("failed to read:", err)
|
||||
}
|
||||
|
||||
// Read the integer
|
||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||
|
||||
// Copy the frame.
|
||||
if _, err := io.CopyN(dst, f, framelen); err != nil && err != io.EOF {
|
||||
t.Fatal("failed to write:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// simple shitty benchmark thing
|
||||
func timer() func(finished string) {
|
||||
var then = time.Now()
|
||||
|
|
|
@ -110,14 +110,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
func (s *Session) JoinChannel(
|
||||
gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
|
||||
}
|
||||
|
||||
func (s *Session) JoinChannelCtx(ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
func (s *Session) JoinChannelCtx(
|
||||
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
|
||||
// Acquire the mutex during join, locking during IO as well.
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
@ -211,8 +215,7 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
|
|||
return errors.Wrap(err, "failed to select protocol")
|
||||
}
|
||||
|
||||
// Start the UDP loop.
|
||||
go s.voiceUDP.Start(&d.SecretKey)
|
||||
s.voiceUDP.UseSecret(d.SecretKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -237,6 +240,18 @@ func (s *Session) StopSpeaking() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UseContext tells the UDP voice connection to write with the given mutex.
|
||||
func (s *Session) UseContext(ctx context.Context) error {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
if s.voiceUDP == nil {
|
||||
return ErrCannotSend
|
||||
}
|
||||
|
||||
return s.voiceUDP.UseContext(ctx)
|
||||
}
|
||||
|
||||
// Write writes into the UDP voice connection WITHOUT a timeout.
|
||||
func (s *Session) Write(b []byte) (int, error) {
|
||||
return s.WriteCtx(context.Background(), b)
|
||||
|
|
180
voice/udp/udp.go
180
voice/udp/udp.go
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/nacl/secretbox"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Dialer is the default dialer that this package uses for all its dialing.
|
||||
|
@ -17,31 +18,29 @@ var Dialer = net.Dialer{
|
|||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// ErrClosed is returned if a Write was called on a closed connection.
|
||||
var ErrClosed = errors.New("UDP connection closed")
|
||||
|
||||
type Connection struct {
|
||||
GatewayIP string
|
||||
GatewayPort uint16
|
||||
|
||||
ssrc uint32
|
||||
mutex chan struct{} // for ctx
|
||||
|
||||
context context.Context
|
||||
conn net.Conn
|
||||
ssrc uint32
|
||||
|
||||
frequency rate.Limiter
|
||||
packet [12]byte
|
||||
secret [32]byte
|
||||
|
||||
sequence uint16
|
||||
timestamp uint32
|
||||
nonce [24]byte
|
||||
|
||||
conn net.Conn
|
||||
close chan struct{}
|
||||
closed chan struct{}
|
||||
|
||||
send chan []byte
|
||||
reply chan error
|
||||
}
|
||||
|
||||
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
|
||||
// // Resolve the host.
|
||||
// a, err := net.ResolveUDPAddr("udp", addr)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "failed to resolve host")
|
||||
// }
|
||||
|
||||
// Create a new UDP connection.
|
||||
conn, err := Dialer.DialContext(ctx, "udp", addr)
|
||||
if err != nil {
|
||||
|
@ -78,20 +77,6 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti
|
|||
ip := ipbody[:nullPos]
|
||||
port := binary.LittleEndian.Uint16(ipBuffer[68:70])
|
||||
|
||||
return &Connection{
|
||||
GatewayIP: string(ip),
|
||||
GatewayPort: port,
|
||||
|
||||
ssrc: ssrc,
|
||||
conn: conn,
|
||||
send: make(chan []byte),
|
||||
reply: make(chan error),
|
||||
close: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Start(secret *[32]byte) {
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#encrypting-and-sending-voice
|
||||
packet := [12]byte{
|
||||
0: 0x80, // Version + Flags
|
||||
|
@ -101,81 +86,118 @@ func (c *Connection) Start(secret *[32]byte) {
|
|||
}
|
||||
|
||||
// Write SSRC to the header.
|
||||
binary.BigEndian.PutUint32(packet[8:12], c.ssrc) // SSRC
|
||||
binary.BigEndian.PutUint32(packet[8:12], ssrc) // SSRC
|
||||
|
||||
// 50 sends per second, 960 samples each at 48kHz
|
||||
frequency := time.NewTicker(time.Millisecond * 20)
|
||||
defer frequency.Stop()
|
||||
return &Connection{
|
||||
GatewayIP: string(ip),
|
||||
GatewayPort: port,
|
||||
// 50 sends per second, 960 samples each at 48kHz
|
||||
frequency: *rate.NewLimiter(rate.Every(20*time.Millisecond), 1),
|
||||
context: context.Background(),
|
||||
mutex: make(chan struct{}, 1),
|
||||
packet: packet,
|
||||
ssrc: ssrc,
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var b []byte
|
||||
var ok bool
|
||||
// UseSecret uses the given secret. This method is not thread-safe, so it should
|
||||
// only be used right after initialization.
|
||||
func (c *Connection) UseSecret(secret [32]byte) {
|
||||
c.secret = secret
|
||||
}
|
||||
|
||||
// Close these channels at the end so Write() doesn't block.
|
||||
defer func() {
|
||||
close(c.send)
|
||||
close(c.closed)
|
||||
}()
|
||||
// UseContext lets the connection use the given context for its Write method.
|
||||
// WriteCtx will override this context.
|
||||
func (c *Connection) UseContext(ctx context.Context) error {
|
||||
c.mutex <- struct{}{}
|
||||
defer func() { <-c.mutex }()
|
||||
|
||||
for {
|
||||
select {
|
||||
case b, ok = <-c.send:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-c.close:
|
||||
return
|
||||
}
|
||||
return c.useContext(ctx)
|
||||
}
|
||||
|
||||
// Write a new sequence.
|
||||
binary.BigEndian.PutUint16(packet[2:4], c.sequence)
|
||||
c.sequence++
|
||||
func (c *Connection) useContext(ctx context.Context) error {
|
||||
if c.conn == nil {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(packet[4:8], c.timestamp)
|
||||
c.timestamp += 960 // Samples
|
||||
if c.context == ctx {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy(c.nonce[:], packet[:])
|
||||
c.context = ctx
|
||||
|
||||
toSend := secretbox.Seal(packet[:], b, &c.nonce, secret)
|
||||
|
||||
select {
|
||||
case <-frequency.C:
|
||||
case <-c.close:
|
||||
// Prevent Write() from stalling before exiting.
|
||||
c.reply <- nil
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
_, err := c.conn.Write(toSend)
|
||||
c.reply <- err
|
||||
if deadline, ok := c.context.Deadline(); ok {
|
||||
return c.conn.SetWriteDeadline(deadline)
|
||||
} else {
|
||||
return c.conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
close(c.close)
|
||||
<-c.closed
|
||||
|
||||
return c.conn.Close()
|
||||
c.mutex <- struct{}{}
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
<-c.mutex
|
||||
return err
|
||||
}
|
||||
|
||||
// Write sends bytes into the voice UDP connection.
|
||||
func (c *Connection) Write(b []byte) (int, error) {
|
||||
return c.WriteCtx(context.Background(), b)
|
||||
select {
|
||||
case c.mutex <- struct{}{}:
|
||||
defer func() { <-c.mutex }()
|
||||
case <-c.context.Done():
|
||||
return 0, c.context.Err()
|
||||
}
|
||||
|
||||
if c.conn == nil {
|
||||
return 0, ErrClosed
|
||||
}
|
||||
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
// WriteCtx sends bytes into the voice UDP connection with a timeout.
|
||||
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
select {
|
||||
case c.send <- b:
|
||||
break
|
||||
case c.mutex <- struct{}{}:
|
||||
defer func() { <-c.mutex }()
|
||||
case <-c.context.Done():
|
||||
return 0, c.context.Err()
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-c.reply:
|
||||
return len(b), err
|
||||
case <-ctx.Done():
|
||||
return len(b), ctx.Err()
|
||||
if err := c.useContext(ctx); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to use context")
|
||||
}
|
||||
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
// write is thread-unsafe.
|
||||
func (c *Connection) write(b []byte) (int, error) {
|
||||
// Write a new sequence.
|
||||
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
|
||||
c.sequence++
|
||||
|
||||
binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp)
|
||||
c.timestamp += 960 // Samples
|
||||
|
||||
copy(c.nonce[:], c.packet[:])
|
||||
|
||||
if err := c.frequency.Wait(c.context); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to wait for frequency tick")
|
||||
}
|
||||
|
||||
toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret)
|
||||
|
||||
n, err := c.conn.Write(toSend)
|
||||
if err != nil {
|
||||
return n, errors.Wrap(err, "failed to write to UDP connection")
|
||||
}
|
||||
|
||||
// We're not really returning everything, since we're "sealing" the bytes.
|
||||
return len(b), nil
|
||||
}
|
||||
|
|
|
@ -5,9 +5,11 @@
|
|||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
|
@ -170,9 +172,13 @@ func (v *Voice) Close() error {
|
|||
}
|
||||
|
||||
for gID, s := range v.sessions {
|
||||
if dErr := s.Disconnect(); dErr != nil {
|
||||
log.Println("closing", gID)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
if dErr := s.DisconnectCtx(ctx); dErr != nil {
|
||||
err.SessionErrors[gID] = dErr
|
||||
}
|
||||
cancel()
|
||||
log.Println("closed", gID)
|
||||
}
|
||||
|
||||
err.StateErr = v.State.Close()
|
||||
|
|
|
@ -51,7 +51,7 @@ type Gateway struct {
|
|||
mutex sync.RWMutex
|
||||
ready ReadyEvent
|
||||
|
||||
ws *wsutil.Websocket
|
||||
WS *wsutil.Websocket
|
||||
|
||||
Timeout time.Duration
|
||||
reconnect moreatomic.Bool
|
||||
|
@ -96,14 +96,14 @@ func (c *Gateway) OpenCtx(ctx context.Context) error {
|
|||
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
|
||||
|
||||
wsutil.WSDebug("Connecting to voice endpoint (endpoint=" + endpoint + ")")
|
||||
c.ws = wsutil.New(endpoint)
|
||||
c.WS = wsutil.New(endpoint)
|
||||
|
||||
// Create a new context with a timeout for the connection.
|
||||
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Connect to the Gateway Gateway.
|
||||
if err := c.ws.Dial(ctx); err != nil {
|
||||
if err := c.WS.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to connect to voice gateway")
|
||||
}
|
||||
|
||||
|
@ -138,7 +138,7 @@ func (c *Gateway) __start(ctx context.Context) error {
|
|||
// Make a new WaitGroup for use in background loops:
|
||||
c.waitGroup = new(sync.WaitGroup)
|
||||
|
||||
ch := c.ws.Listen()
|
||||
ch := c.WS.Listen()
|
||||
|
||||
// Wait for hello.
|
||||
wsutil.WSDebug("Waiting for Hello..")
|
||||
|
@ -205,38 +205,38 @@ func (c *Gateway) __start(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Close .
|
||||
func (c *Gateway) Close() error {
|
||||
// Check if the WS is already closed:
|
||||
if c.waitGroup == nil && c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
func (c *Gateway) Close() (err error) {
|
||||
wsutil.WSDebug("Trying to close.")
|
||||
|
||||
c.AfterClose(nil)
|
||||
return nil
|
||||
// Check if the WS is already closed:
|
||||
if c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
return err
|
||||
}
|
||||
|
||||
// Trigger the close callback on exit.
|
||||
defer func() { c.AfterClose(err) }()
|
||||
|
||||
// If the pacemaker is running:
|
||||
if !c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Stopping pacemaker...")
|
||||
|
||||
// Stop the pacemaker and the event handler
|
||||
// Stop the pacemaker and the event handler.
|
||||
c.EventLoop.Stop()
|
||||
|
||||
wsutil.WSDebug("Stopped pacemaker.")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Closing the websocket...")
|
||||
err = c.WS.Close()
|
||||
|
||||
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||
|
||||
// This should work, since Pacemaker should signal its loop to stop, which
|
||||
// would also exit our event loop. Both would be 2.
|
||||
c.waitGroup.Wait()
|
||||
|
||||
// Mark g.waitGroup as empty:
|
||||
c.waitGroup = nil
|
||||
|
||||
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
|
||||
err := c.ws.Close()
|
||||
c.AfterClose(err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -308,11 +308,11 @@ func (c *Gateway) Send(code OPCode, v interface{}) error {
|
|||
}
|
||||
|
||||
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
|
||||
if c.ws == nil {
|
||||
if c.WS == nil {
|
||||
return errors.New("tried to send data to a connection without a Websocket")
|
||||
}
|
||||
|
||||
if c.ws.Conn == nil {
|
||||
if c.WS.Conn == nil {
|
||||
return errors.New("tried to send data to a connection with a closed Websocket")
|
||||
}
|
||||
|
||||
|
@ -335,5 +335,5 @@ func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error
|
|||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return c.ws.SendCtx(ctx, b)
|
||||
return c.WS.SendCtx(ctx, b)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue