diff --git a/gateway/gateway.go b/gateway/gateway.go index b044584..4f70faf 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -161,40 +161,38 @@ func (g *Gateway) AddIntent(i Intents) { } // Close closes the underlying Websocket connection. -func (g *Gateway) Close() error { +func (g *Gateway) Close() (err error) { wsutil.WSDebug("Trying to close.") // Check if the WS is already closed: - if g.waitGroup == nil && g.PacerLoop.Stopped() { + if g.PacerLoop.Stopped() { wsutil.WSDebug("Gateway is already closed.") - - g.AfterClose(nil) - return nil + return err } + // Trigger the close callback on exit. + defer func() { g.AfterClose(err) }() + // If the pacemaker is running: if !g.PacerLoop.Stopped() { wsutil.WSDebug("Stopping pacemaker...") - // Stop the pacemaker and the event handler + // Stop the pacemaker and the event handler. g.PacerLoop.Stop() wsutil.WSDebug("Stopped pacemaker.") } + wsutil.WSDebug("Closing the websocket...") + err = g.WS.Close() + wsutil.WSDebug("Waiting for WaitGroup to be done.") // This should work, since Pacemaker should signal its loop to stop, which // would also exit our event loop. Both would be 2. g.waitGroup.Wait() - // Mark g.waitGroup as empty: - g.waitGroup = nil - wsutil.WSDebug("WaitGroup is done. Closing the websocket.") - - err := g.WS.Close() - g.AfterClose(err) return err } diff --git a/gateway/op.go b/gateway/op.go index 954e2b2..dc0258b 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -46,7 +46,9 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error { defer cancel() // Server requesting a heartbeat. - return g.PacerLoop.Pace(ctx) + if err := g.PacerLoop.Pace(ctx); err != nil { + return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace")) + } case ReconnectOP: // Server requests to reconnect, die and retry. @@ -54,7 +56,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error { // Exit with the ReconnectOP error to force the heartbeat event loop to // reconnect synchronously. Not really a fatal error. - return ErrReconnectRequest + return wsutil.ErrBrokenConnection(ErrReconnectRequest) case InvalidSessionOP: // Discord expects us to sleep for no reason @@ -66,7 +68,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error { // Invalid session, try and Identify. if err := g.IdentifyCtx(ctx); err != nil { // Can't identify, reconnect. - go g.Reconnect() + return wsutil.ErrBrokenConnection(ErrReconnectRequest) } return nil diff --git a/go.sum b/go.sum index 4e87908..4a17a66 100644 --- a/go.sum +++ b/go.sum @@ -16,3 +16,4 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI= golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= diff --git a/internal/heart/heart.go b/internal/heart/heart.go index 50b2be3..ec43eb0 100644 --- a/internal/heart/heart.go +++ b/internal/heart/heart.go @@ -3,7 +3,6 @@ package heart import ( "context" - "sync" "sync/atomic" "time" @@ -36,23 +35,30 @@ type Pacemaker struct { // Heartrate is the received duration between heartbeats. Heartrate time.Duration + ticker time.Ticker + Ticks <-chan time.Time + // Time in nanoseconds, guarded by atomic read/writes. SentBeat AtomicTime EchoBeat AtomicTime // Any callback that returns an error will stop the pacer. - Pace func(context.Context) error - - stop chan struct{} - once sync.Once - death chan error + Pacer func(context.Context) error } -func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) *Pacemaker { - return &Pacemaker{ +func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker { + p := Pacemaker{ Heartrate: heartrate, - Pace: pacer, + Pacer: pacer, + ticker: *time.NewTicker(heartrate), } + p.Ticks = p.ticker.C + // Reset states to its old position. + now := time.Now() + p.EchoBeat.Set(now) + p.SentBeat.Set(now) + + return p } func (p *Pacemaker) Echo() { @@ -62,14 +68,6 @@ func (p *Pacemaker) Echo() { // Dead, if true, will have Pace return an ErrDead. func (p *Pacemaker) Dead() bool { - /* Deprecated - if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() { - return false - } - - return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2 - */ - var ( echo = p.EchoBeat.Get() sent = p.SentBeat.Get() @@ -84,75 +82,84 @@ func (p *Pacemaker) Dead() bool { // Stop stops the pacemaker, or it does nothing if the pacemaker is not started. func (p *Pacemaker) Stop() { - Debug("(*Pacemaker).stop is trying sync.Once.") - - p.once.Do(func() { - Debug("(*Pacemaker).stop closed the channel.") - close(p.stop) - }) + p.ticker.Stop() } // pace sends a heartbeat with the appropriate timeout for the context. -func (p *Pacemaker) pace() error { +func (p *Pacemaker) Pace() error { ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate) defer cancel() - return p.Pace(ctx) + return p.PaceCtx(ctx) } -func (p *Pacemaker) start() error { - // Reset states to its old position. - p.EchoBeat.Set(time.Time{}) - p.SentBeat.Set(time.Time{}) - - // Create a new ticker. - tick := time.NewTicker(p.Heartrate) - defer tick.Stop() - - // Echo at least once - p.Echo() - - for { - if err := p.pace(); err != nil { - return errors.Wrap(err, "failed to pace") - } - - // Paced, save: - p.SentBeat.Set(time.Now()) - - if p.Dead() { - return ErrDead - } - - select { - case <-p.stop: - return nil - - case <-tick.C: - } - } -} - -// StartAsync starts the pacemaker asynchronously. The WaitGroup is optional. -func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) { - p.death = make(chan error) - p.stop = make(chan struct{}) - p.once = sync.Once{} - - if wg != nil { - wg.Add(1) +func (p *Pacemaker) PaceCtx(ctx context.Context) error { + if err := p.Pacer(ctx); err != nil { + return err } - go func() { - p.death <- p.start() - // Debug. - Debug("Pacemaker returned.") + p.SentBeat.Set(time.Now()) - // Mark the pacemaker loop as done. - if wg != nil { - wg.Done() - } - }() + if p.Dead() { + return ErrDead + } - return p.death + return nil } + +// func (p *Pacemaker) start() error { +// // Reset states to its old position. +// p.EchoBeat.Set(time.Time{}) +// p.SentBeat.Set(time.Time{}) + +// // Create a new ticker. +// tick := time.NewTicker(p.Heartrate) +// defer tick.Stop() + +// // Echo at least once +// p.Echo() + +// for { +// if err := p.pace(); err != nil { +// return errors.Wrap(err, "failed to pace") +// } + +// // Paced, save: +// p.SentBeat.Set(time.Now()) + +// if p.Dead() { +// return ErrDead +// } + +// select { +// case <-p.stop: +// return nil + +// case <-tick.C: +// } +// } +// } + +// // StartAsync starts the pacemaker asynchronously. The WaitGroup is optional. +// func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) { +// p.death = make(chan error) +// p.stop = make(chan struct{}) +// p.once = sync.Once{} + +// if wg != nil { +// wg.Add(1) +// } + +// go func() { +// p.death <- p.start() +// // Debug. +// Debug("Pacemaker returned.") + +// // Mark the pacemaker loop as done. +// if wg != nil { +// wg.Done() +// } +// }() + +// return p.death +// } diff --git a/utils/wsutil/heart.go b/utils/wsutil/heart.go index 6fc1ec0..f27212e 100644 --- a/utils/wsutil/heart.go +++ b/utils/wsutil/heart.go @@ -2,6 +2,7 @@ package wsutil import ( "context" + "runtime/debug" "time" "github.com/pkg/errors" @@ -10,6 +11,34 @@ import ( "github.com/diamondburned/arikawa/internal/moreatomic" ) +type errBrokenConnection struct { + underneath error +} + +// Error formats the broken connection error with the message "explicit +// connection break." +func (err errBrokenConnection) Error() string { + return "explicit connection break: " + err.underneath.Error() +} + +// Unwrap returns the underlying error. +func (err errBrokenConnection) Unwrap() error { + return err.underneath +} + +// ErrBrokenConnection marks the given error as a broken connection error. This +// error will cause the pacemaker loop to break and return the error. The error, +// when stringified, will say "explicit connection break." +func ErrBrokenConnection(err error) error { + return errBrokenConnection{underneath: err} +} + +// IsBrokenConnection returns true if the error is a broken connection error. +func IsBrokenConnection(err error) bool { + var broken *errBrokenConnection + return errors.As(err, &broken) +} + // TODO API type EventLoopHandler interface { EventHandler @@ -19,14 +48,15 @@ type EventLoopHandler interface { // PacemakerLoop provides an event loop with a pacemaker. A zero-value instance // is a valid instance only when RunAsync is called first. type PacemakerLoop struct { - pacemaker *heart.Pacemaker // let's not copy this - pacedeath chan error - + heart.Pacemaker running moreatomic.Bool + stop chan struct{} events <-chan Event handler func(*OP) error + stack []byte + Extras ExtraHandlers ErrorLog func(error) @@ -43,17 +73,19 @@ func (p *PacemakerLoop) errorLog(err error) { // Pace calls the pacemaker's Pace function. func (p *PacemakerLoop) Pace(ctx context.Context) error { - return p.pacemaker.Pace(ctx) + return p.Pacemaker.PaceCtx(ctx) } -// Echo calls the pacemaker's Echo function. -func (p *PacemakerLoop) Echo() { - p.pacemaker.Echo() -} - -// Stop calls the pacemaker's Stop function. +// Stop stops the pacer loop. It does nothing if the loop is already stopped. func (p *PacemakerLoop) Stop() { - p.pacemaker.Stop() + if p.Stopped() { + return + } + + // Despite p.running and p.stop being thread-safe on their own, this entire + // block is actually not thread-safe. + p.Pacemaker.Stop() + close(p.stop) } func (p *PacemakerLoop) Stopped() bool { @@ -65,12 +97,12 @@ func (p *PacemakerLoop) RunAsync( WSDebug("Starting the pacemaker loop.") - p.pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx) - p.events = evs + p.Pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx) p.handler = evl.HandleOP + p.events = evs + p.stack = debug.Stack() + p.stop = make(chan struct{}) - // callers should explicitly handle waitgroups. - p.pacedeath = p.pacemaker.StartAsync(nil) p.running.Set(true) go func() { @@ -81,21 +113,27 @@ func (p *PacemakerLoop) RunAsync( func (p *PacemakerLoop) startLoop() error { defer WSDebug("Pacemaker loop has exited.") defer p.running.Set(false) + defer p.Pacemaker.Stop() for { select { - case err := <-p.pacedeath: - WSDebug("Pacedeath returned with error:", err) - return errors.Wrap(err, "pacemaker died, reconnecting") + case <-p.stop: + WSDebug("Stop requested; exiting.") + return nil + + case <-p.Pacemaker.Ticks: + if err := p.Pacemaker.Pace(); err != nil { + return errors.Wrap(err, "pace failed, reconnecting") + } case ev, ok := <-p.events: if !ok { WSDebug("Events channel closed, stopping pacemaker.") - defer WSDebug("Pacemaker stopped automatically.") - // Events channel is closed. Kill the pacemaker manually and - // die. - p.pacemaker.Stop() - return <-p.pacedeath + return nil + } + + if ev.Error != nil { + return errors.Wrap(ev.Error, "event returned error") } o, err := DecodeOP(ev) @@ -108,7 +146,11 @@ func (p *PacemakerLoop) startLoop() error { // Handle the event if err := p.handler(o); err != nil { - return errors.Wrap(err, "handler failed") + if IsBrokenConnection(err) { + return errors.Wrap(err, "handler failed") + } + + p.errorLog(err) } } } diff --git a/utils/wsutil/ws.go b/utils/wsutil/ws.go index d54a625..dcbb54d 100644 --- a/utils/wsutil/ws.go +++ b/utils/wsutil/ws.go @@ -79,9 +79,6 @@ func (ws *Websocket) Dial(ctx context.Context) error { return errors.Wrap(err, "failed to dial") } - // Reset the SendLimiter: - ws.SendLimiter = NewSendLimiter() - return nil } diff --git a/voice/integration_test.go b/voice/integration_test.go index aa60a43..5d929f3 100644 --- a/voice/integration_test.go +++ b/voice/integration_test.go @@ -3,6 +3,7 @@ package voice import ( + "context" "encoding/binary" "io" "log" @@ -18,73 +19,6 @@ import ( "github.com/diamondburned/arikawa/voice/voicegateway" ) -type testConfig struct { - BotToken string - VoiceChID discord.ChannelID -} - -func mustConfig(t *testing.T) testConfig { - var token = os.Getenv("BOT_TOKEN") - if token == "" { - t.Fatal("Missing $BOT_TOKEN") - } - - var sid = os.Getenv("VOICE_ID") - if sid == "" { - t.Fatal("Missing $VOICE_ID") - } - - id, err := discord.ParseSnowflake(sid) - if err != nil { - t.Fatal("Invalid $VOICE_ID:", err) - } - - return testConfig{ - BotToken: token, - VoiceChID: discord.ChannelID(id), - } -} - -// file is only a few bytes lolmao -func nicoReadTo(t *testing.T, dst io.Writer) { - f, err := os.Open("testdata/nico.dca") - if err != nil { - t.Fatal("Failed to open nico.dca:", err) - } - - t.Cleanup(func() { - f.Close() - }) - - var lenbuf [4]byte - - for { - if _, err := io.ReadFull(f, lenbuf[:]); !catchRead(t, err) { - return - } - - // Read the integer - framelen := int64(binary.LittleEndian.Uint32(lenbuf[:])) - - // Copy the frame. - if _, err := io.CopyN(dst, f, framelen); !catchRead(t, err) { - return - } - } -} - -func catchRead(t *testing.T, err error) bool { - t.Helper() - - if err == io.EOF { - return false - } - if err != nil { - t.Fatal("Failed to read:", err) - } - return true -} - func TestIntegration(t *testing.T) { config := mustConfig(t) @@ -150,12 +84,76 @@ func TestIntegration(t *testing.T) { finish("sending the speaking command") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := vs.UseContext(ctx); err != nil { + t.Fatal("failed to set ctx into vs:", err) + } + // Copy the audio? nicoReadTo(t, vs) finish("copying the audio") } +type testConfig struct { + BotToken string + VoiceChID discord.ChannelID +} + +func mustConfig(t *testing.T) testConfig { + var token = os.Getenv("BOT_TOKEN") + if token == "" { + t.Fatal("Missing $BOT_TOKEN") + } + + var sid = os.Getenv("VOICE_ID") + if sid == "" { + t.Fatal("Missing $VOICE_ID") + } + + id, err := discord.ParseSnowflake(sid) + if err != nil { + t.Fatal("Invalid $VOICE_ID:", err) + } + + return testConfig{ + BotToken: token, + VoiceChID: discord.ChannelID(id), + } +} + +// file is only a few bytes lolmao +func nicoReadTo(t *testing.T, dst io.Writer) { + t.Helper() + + f, err := os.Open("testdata/nico.dca") + if err != nil { + t.Fatal("Failed to open nico.dca:", err) + } + defer f.Close() + + var lenbuf [4]byte + + for { + if _, err := io.ReadFull(f, lenbuf[:]); err != nil { + if err == io.EOF { + break + } + t.Fatal("failed to read:", err) + } + + // Read the integer + framelen := int64(binary.LittleEndian.Uint32(lenbuf[:])) + + // Copy the frame. + if _, err := io.CopyN(dst, f, framelen); err != nil && err != io.EOF { + t.Fatal("failed to write:", err) + } + } +} + // simple shitty benchmark thing func timer() func(finished string) { var then = time.Now() diff --git a/voice/session.go b/voice/session.go index 291aadc..3d513f0 100644 --- a/voice/session.go +++ b/voice/session.go @@ -110,14 +110,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) { } } -func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error { +func (s *Session) JoinChannel( + gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error { + ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) defer cancel() return s.JoinChannelCtx(ctx, gID, cID, muted, deafened) } -func (s *Session) JoinChannelCtx(ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error { +func (s *Session) JoinChannelCtx( + ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error { + // Acquire the mutex during join, locking during IO as well. s.mut.Lock() defer s.mut.Unlock() @@ -211,8 +215,7 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) { return errors.Wrap(err, "failed to select protocol") } - // Start the UDP loop. - go s.voiceUDP.Start(&d.SecretKey) + s.voiceUDP.UseSecret(d.SecretKey) return nil } @@ -237,6 +240,18 @@ func (s *Session) StopSpeaking() error { return nil } +// UseContext tells the UDP voice connection to write with the given mutex. +func (s *Session) UseContext(ctx context.Context) error { + s.mut.RLock() + defer s.mut.RUnlock() + + if s.voiceUDP == nil { + return ErrCannotSend + } + + return s.voiceUDP.UseContext(ctx) +} + // Write writes into the UDP voice connection WITHOUT a timeout. func (s *Session) Write(b []byte) (int, error) { return s.WriteCtx(context.Background(), b) diff --git a/voice/udp/udp.go b/voice/udp/udp.go index 7772870..ee2c018 100644 --- a/voice/udp/udp.go +++ b/voice/udp/udp.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "golang.org/x/crypto/nacl/secretbox" + "golang.org/x/time/rate" ) // Dialer is the default dialer that this package uses for all its dialing. @@ -17,31 +18,29 @@ var Dialer = net.Dialer{ Timeout: 10 * time.Second, } +// ErrClosed is returned if a Write was called on a closed connection. +var ErrClosed = errors.New("UDP connection closed") + type Connection struct { GatewayIP string GatewayPort uint16 - ssrc uint32 + mutex chan struct{} // for ctx + + context context.Context + conn net.Conn + ssrc uint32 + + frequency rate.Limiter + packet [12]byte + secret [32]byte sequence uint16 timestamp uint32 nonce [24]byte - - conn net.Conn - close chan struct{} - closed chan struct{} - - send chan []byte - reply chan error } func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { - // // Resolve the host. - // a, err := net.ResolveUDPAddr("udp", addr) - // if err != nil { - // return nil, errors.Wrap(err, "failed to resolve host") - // } - // Create a new UDP connection. conn, err := Dialer.DialContext(ctx, "udp", addr) if err != nil { @@ -78,20 +77,6 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti ip := ipbody[:nullPos] port := binary.LittleEndian.Uint16(ipBuffer[68:70]) - return &Connection{ - GatewayIP: string(ip), - GatewayPort: port, - - ssrc: ssrc, - conn: conn, - send: make(chan []byte), - reply: make(chan error), - close: make(chan struct{}), - closed: make(chan struct{}), - }, nil -} - -func (c *Connection) Start(secret *[32]byte) { // https://discordapp.com/developers/docs/topics/voice-connections#encrypting-and-sending-voice packet := [12]byte{ 0: 0x80, // Version + Flags @@ -101,81 +86,118 @@ func (c *Connection) Start(secret *[32]byte) { } // Write SSRC to the header. - binary.BigEndian.PutUint32(packet[8:12], c.ssrc) // SSRC + binary.BigEndian.PutUint32(packet[8:12], ssrc) // SSRC - // 50 sends per second, 960 samples each at 48kHz - frequency := time.NewTicker(time.Millisecond * 20) - defer frequency.Stop() + return &Connection{ + GatewayIP: string(ip), + GatewayPort: port, + // 50 sends per second, 960 samples each at 48kHz + frequency: *rate.NewLimiter(rate.Every(20*time.Millisecond), 1), + context: context.Background(), + mutex: make(chan struct{}, 1), + packet: packet, + ssrc: ssrc, + conn: conn, + }, nil +} - var b []byte - var ok bool +// UseSecret uses the given secret. This method is not thread-safe, so it should +// only be used right after initialization. +func (c *Connection) UseSecret(secret [32]byte) { + c.secret = secret +} - // Close these channels at the end so Write() doesn't block. - defer func() { - close(c.send) - close(c.closed) - }() +// UseContext lets the connection use the given context for its Write method. +// WriteCtx will override this context. +func (c *Connection) UseContext(ctx context.Context) error { + c.mutex <- struct{}{} + defer func() { <-c.mutex }() - for { - select { - case b, ok = <-c.send: - if !ok { - return - } - case <-c.close: - return - } + return c.useContext(ctx) +} - // Write a new sequence. - binary.BigEndian.PutUint16(packet[2:4], c.sequence) - c.sequence++ +func (c *Connection) useContext(ctx context.Context) error { + if c.conn == nil { + return ErrClosed + } - binary.BigEndian.PutUint32(packet[4:8], c.timestamp) - c.timestamp += 960 // Samples + if c.context == ctx { + return nil + } - copy(c.nonce[:], packet[:]) + c.context = ctx - toSend := secretbox.Seal(packet[:], b, &c.nonce, secret) - - select { - case <-frequency.C: - case <-c.close: - // Prevent Write() from stalling before exiting. - c.reply <- nil - - return - } - - _, err := c.conn.Write(toSend) - c.reply <- err + if deadline, ok := c.context.Deadline(); ok { + return c.conn.SetWriteDeadline(deadline) + } else { + return c.conn.SetWriteDeadline(time.Time{}) } } func (c *Connection) Close() error { - close(c.close) - <-c.closed - - return c.conn.Close() + c.mutex <- struct{}{} + err := c.conn.Close() + c.conn = nil + <-c.mutex + return err } // Write sends bytes into the voice UDP connection. func (c *Connection) Write(b []byte) (int, error) { - return c.WriteCtx(context.Background(), b) + select { + case c.mutex <- struct{}{}: + defer func() { <-c.mutex }() + case <-c.context.Done(): + return 0, c.context.Err() + } + + if c.conn == nil { + return 0, ErrClosed + } + + return c.write(b) } // WriteCtx sends bytes into the voice UDP connection with a timeout. func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) { select { - case c.send <- b: - break + case c.mutex <- struct{}{}: + defer func() { <-c.mutex }() + case <-c.context.Done(): + return 0, c.context.Err() case <-ctx.Done(): return 0, ctx.Err() } - select { - case err := <-c.reply: - return len(b), err - case <-ctx.Done(): - return len(b), ctx.Err() + if err := c.useContext(ctx); err != nil { + return 0, errors.Wrap(err, "failed to use context") } + + return c.write(b) +} + +// write is thread-unsafe. +func (c *Connection) write(b []byte) (int, error) { + // Write a new sequence. + binary.BigEndian.PutUint16(c.packet[2:4], c.sequence) + c.sequence++ + + binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp) + c.timestamp += 960 // Samples + + copy(c.nonce[:], c.packet[:]) + + if err := c.frequency.Wait(c.context); err != nil { + return 0, errors.Wrap(err, "failed to wait for frequency tick") + } + + toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret) + + n, err := c.conn.Write(toSend) + if err != nil { + return n, errors.Wrap(err, "failed to write to UDP connection") + } + + // We're not really returning everything, since we're "sealing" the bytes. + return len(b), nil } diff --git a/voice/voice.go b/voice/voice.go index a804f3a..33038de 100644 --- a/voice/voice.go +++ b/voice/voice.go @@ -5,9 +5,11 @@ package voice import ( + "context" "log" "strconv" "sync" + "time" "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/gateway" @@ -170,9 +172,13 @@ func (v *Voice) Close() error { } for gID, s := range v.sessions { - if dErr := s.Disconnect(); dErr != nil { + log.Println("closing", gID) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + if dErr := s.DisconnectCtx(ctx); dErr != nil { err.SessionErrors[gID] = dErr } + cancel() + log.Println("closed", gID) } err.StateErr = v.State.Close() diff --git a/voice/voicegateway/gateway.go b/voice/voicegateway/gateway.go index d01186a..0e871c9 100644 --- a/voice/voicegateway/gateway.go +++ b/voice/voicegateway/gateway.go @@ -51,7 +51,7 @@ type Gateway struct { mutex sync.RWMutex ready ReadyEvent - ws *wsutil.Websocket + WS *wsutil.Websocket Timeout time.Duration reconnect moreatomic.Bool @@ -96,14 +96,14 @@ func (c *Gateway) OpenCtx(ctx context.Context) error { var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version wsutil.WSDebug("Connecting to voice endpoint (endpoint=" + endpoint + ")") - c.ws = wsutil.New(endpoint) + c.WS = wsutil.New(endpoint) // Create a new context with a timeout for the connection. ctx, cancel := context.WithTimeout(ctx, c.Timeout) defer cancel() // Connect to the Gateway Gateway. - if err := c.ws.Dial(ctx); err != nil { + if err := c.WS.Dial(ctx); err != nil { return errors.Wrap(err, "failed to connect to voice gateway") } @@ -138,7 +138,7 @@ func (c *Gateway) __start(ctx context.Context) error { // Make a new WaitGroup for use in background loops: c.waitGroup = new(sync.WaitGroup) - ch := c.ws.Listen() + ch := c.WS.Listen() // Wait for hello. wsutil.WSDebug("Waiting for Hello..") @@ -205,38 +205,38 @@ func (c *Gateway) __start(ctx context.Context) error { } // Close . -func (c *Gateway) Close() error { - // Check if the WS is already closed: - if c.waitGroup == nil && c.EventLoop.Stopped() { - wsutil.WSDebug("Gateway is already closed.") +func (c *Gateway) Close() (err error) { + wsutil.WSDebug("Trying to close.") - c.AfterClose(nil) - return nil + // Check if the WS is already closed: + if c.EventLoop.Stopped() { + wsutil.WSDebug("Gateway is already closed.") + return err } + // Trigger the close callback on exit. + defer func() { c.AfterClose(err) }() + // If the pacemaker is running: if !c.EventLoop.Stopped() { wsutil.WSDebug("Stopping pacemaker...") - // Stop the pacemaker and the event handler + // Stop the pacemaker and the event handler. c.EventLoop.Stop() wsutil.WSDebug("Stopped pacemaker.") } + wsutil.WSDebug("Closing the websocket...") + err = c.WS.Close() + wsutil.WSDebug("Waiting for WaitGroup to be done.") // This should work, since Pacemaker should signal its loop to stop, which // would also exit our event loop. Both would be 2. c.waitGroup.Wait() - // Mark g.waitGroup as empty: - c.waitGroup = nil - wsutil.WSDebug("WaitGroup is done. Closing the websocket.") - - err := c.ws.Close() - c.AfterClose(err) return err } @@ -308,11 +308,11 @@ func (c *Gateway) Send(code OPCode, v interface{}) error { } func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error { - if c.ws == nil { + if c.WS == nil { return errors.New("tried to send data to a connection without a Websocket") } - if c.ws.Conn == nil { + if c.WS.Conn == nil { return errors.New("tried to send data to a connection with a closed Websocket") } @@ -335,5 +335,5 @@ func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error } // WS should already be thread-safe. - return c.ws.SendCtx(ctx, b) + return c.WS.SendCtx(ctx, b) }