From f5ae68c7814db8e12ad169cd269ce1a6b86a2c9d Mon Sep 17 00:00:00 2001 From: diamondburned Date: Fri, 2 Jul 2021 02:42:00 -0700 Subject: [PATCH] voice: Refactor and fix up This commit refactors a lot of voice's internals to be more stable and handle more edge cases from Discord's voice servers. It should result in an overall more stable voice connection. A few helper functions have been added into voice.Session. Some fields will have been broken and changed to accomodate for the refactor, as well. Below are some commits that have been squashed in: voice: Fix Speaking() panic on closed voice: StopSpeaking should not error out The rationale is added as a comment into the Speaking() method. voice: Add TestKickedOut voice: Fix region change disconnecting --- utils/ws/conn.go | 4 + utils/ws/gateway.go | 2 + utils/ws/ws.go | 5 +- voice/session.go | 166 +++++++++----------- voice/session_example_test.go | 10 +- voice/session_test.go | 227 +++++++++++++++++++++------- voice/testdata/testdata.go | 48 ++++++ voice/udp/{udp.go => connection.go} | 126 +++++++++------ voice/udp/manager.go | 221 +++++++++++++++++++++++++++ voice/voice.go | 6 +- voice/voicegateway/gateway.go | 18 +-- 11 files changed, 615 insertions(+), 218 deletions(-) create mode 100644 voice/testdata/testdata.go rename voice/udp/{udp.go => connection.go} (71%) create mode 100644 voice/udp/manager.go diff --git a/utils/ws/conn.go b/utils/ws/conn.go index 0027bca..822c627 100644 --- a/utils/ws/conn.go +++ b/utils/ws/conn.go @@ -180,6 +180,10 @@ func (c *Conn) Send(ctx context.Context, b []byte) error { conn := c.conn c.mut.Unlock() + if conn == nil || conn.Conn == nil { + return ErrWebsocketClosed + } + select { case conn.wrmut <- struct{}{}: defer func() { <-conn.wrmut }() diff --git a/utils/ws/gateway.go b/utils/ws/gateway.go index d361c62..69c06a8 100644 --- a/utils/ws/gateway.go +++ b/utils/ws/gateway.go @@ -150,6 +150,8 @@ func (g *Gateway) Send(ctx context.Context, data Event) error { Data: data, } + WSDebug("sending command Op", op.Code, "type", op.Type) + b, err := json.Marshal(op) if err != nil { return errors.Wrap(err, "failed to encode payload") diff --git a/utils/ws/ws.go b/utils/ws/ws.go index 774dd0e..dd5965d 100644 --- a/utils/ws/ws.go +++ b/utils/ws/ws.go @@ -70,10 +70,9 @@ func (ws *Websocket) Dial(ctx context.Context) (<-chan Op, error) { // Send sends b over the Websocket with a deadline. It closes the internal // Websocket if the Send method errors out. func (ws *Websocket) Send(ctx context.Context, b []byte) error { - WSDebug("Acquiring the websoccket mutex for sending.") + WSDebug("Acquiring the websocket mutex for sending.") ws.mutex.Lock() - WSDebug("Mutex lock acquired.") sendLimiter := ws.sendLimiter conn := ws.conn ws.mutex.Unlock() @@ -85,7 +84,7 @@ func (ws *Websocket) Send(ctx context.Context, b []byte) error { return errors.Wrap(err, "SendLimiter failed") } - WSDebug("Send has passed the rate limiting. Waiting on mutex.") + WSDebug("Send has passed the rate limiting.") return conn.Send(ctx, b) } diff --git a/voice/session.go b/voice/session.go index 6a0e7db..8003c38 100644 --- a/voice/session.go +++ b/voice/session.go @@ -14,7 +14,6 @@ import ( "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/internal/lazytime" "github.com/diamondburned/arikawa/v3/internal/moreatomic" "github.com/diamondburned/arikawa/v3/session" "github.com/diamondburned/arikawa/v3/utils/ws" @@ -86,14 +85,14 @@ type Session struct { detachReconnect []func() - voiceUDP *udp.Connection + // udpManager is the manager for a UDP connection. The user can use this to + // plug in a custom UDP dialer. + udpManager *udp.Manager + gateway *voicegateway.Gateway gwCancel context.CancelFunc gwDone <-chan struct{} - // DialUDP is the custom function for dialing up a UDP connection. - DialUDP UDPDialer - WSTimeout time.Duration // global WSTimeout WSMaxRetry int // 2 WSRetryDelay time.Duration // 2s @@ -130,7 +129,7 @@ func NewSessionCustom(ses MainSession, userID discord.UserID) *Session { state: voicegateway.State{ UserID: userID, }, - DialUDP: udp.DialConnection, + udpManager: udp.NewManager(), WSTimeout: WSTimeout, WSMaxRetry: 2, WSRetryDelay: 2 * time.Second, @@ -145,6 +144,12 @@ func NewSessionCustom(ses MainSession, userID discord.UserID) *Session { return session } +// SetUDPDialer sets the given dialer to be used for dialing UDP voice +// connections. +func (s *Session) SetUDPDialer(d *net.Dialer) { + s.udpManager.SetDialer(d) +} + func (s *Session) acquireUpdate(f func()) bool { if s.joining.Get() { return false @@ -154,11 +159,8 @@ func (s *Session) acquireUpdate(f func()) bool { defer s.mut.Unlock() // Ignore if we haven't connected yet or we're still joining. - select { - case <-s.disconnected: + if s.udpManager.IsClosed() { return false - default: - // ok } f() @@ -200,6 +202,15 @@ func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) { }) } +// JoinChannelAndSpeak is a convenient function that calls JoinChannel then +// Speaking. +func (s *Session) JoinChannelAndSpeak(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error { + if err := s.JoinChannel(ctx, chID, mute, deaf); err != nil { + return errors.Wrap(err, "cannot join channel") + } + return s.Speaking(ctx, voicegateway.Microphone) +} + type waitEventChs struct { serverUpdate chan *gateway.VoiceServerUpdateEvent stateUpdate chan *gateway.VoiceStateUpdateEvent @@ -222,11 +233,10 @@ func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, // Error out if we're already joining. JoinChannel shouldn't be called // concurrently. - if s.joining.Get() { + if !s.joining.Acquire() { return errors.New("JoinChannel working elsewhere") } - s.joining.Set(true) defer s.joining.Set(false) // Set the state. @@ -303,7 +313,7 @@ func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, case <-timer.C: continue case <-ctx.Done(): - return err + return errors.Wrap(err, "cannot ask Discord for events") } } @@ -375,6 +385,15 @@ func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error { func (s *Session) reconnectCtx(ctx context.Context) error { ws.WSDebug("Sending stop handle.") + if err := s.udpManager.Pause(ctx); err != nil { + return errors.Wrap(err, "cannot pause UDP manager") + } + defer func() { + if !s.udpManager.Continue() { + panic("UDP manager continued but invalid lock ownership") + } + }() + s.ensureClosed() ws.WSDebug("Start gateway.") @@ -385,8 +404,10 @@ func (s *Session) reconnectCtx(ctx context.Context) error { s.gwCancel = gwcancel gwch := s.gateway.Connect(gwctx) + ws.WSDebug("Voice Gateway connected") if err := s.spinGateway(ctx, gwch); err != nil { + ws.WSDebug("Voice spinGateway error:", err) // Early cancel the gateway. gwcancel() // Nil this so future reconnects don't use the invalid gwDone. @@ -394,17 +415,20 @@ func (s *Session) reconnectCtx(ctx context.Context) error { // Emit the error. It's fine to do this here since this is the only // place that can error out. s.Handler.Call(&ReconnectError{err}) - return err + return errors.Wrap(err, "cannot wait for event sequence from voice gateway") } // Start dispatching. s.gwDone = ophandler.Loop(gwch, s.Handler) + ws.WSDebug("Voice reconnectCtx finished with no error") + return nil } func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { var err error + var conn *udp.Connection for { select { @@ -412,7 +436,7 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { return ctx.Err() case ev, ok := <-gwch: if !ok { - return s.gateway.LastError() + return errors.Wrap(s.gateway.LastError(), "voice gateway error") } switch data := ev.Data.(type) { @@ -420,8 +444,10 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { return errors.Wrap(err, "voice gateway error") case *voicegateway.ReadyEvent: + ws.WSDebug("Got ready from voice gateway, SSRC:", data.SSRC) + // Prepare the UDP voice connection. - s.voiceUDP, err = s.DialUDP(ctx, data.Addr(), data.SSRC) + conn, err = s.udpManager.Dial(ctx, data.Addr(), data.SSRC) if err != nil { return errors.Wrap(err, "failed to open voice UDP connection") } @@ -429,8 +455,8 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { if err := s.gateway.Send(ctx, &voicegateway.SelectProtocolCommand{ Protocol: "udp", Data: voicegateway.SelectProtocolData{ - Address: s.voiceUDP.GatewayIP, - Port: s.voiceUDP.GatewayPort, + Address: conn.GatewayIP, + Port: conn.GatewayPort, Mode: Protocol, }, }); err != nil { @@ -438,8 +464,14 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { } case *voicegateway.SessionDescriptionEvent: + if conn == nil { + return errors.New("server bug: SessionDescription before Ready") + } + + ws.WSDebug("Received secret key from voice gateway") + // We're done. - s.voiceUDP.UseSecret(data.SecretKey) + conn.UseSecret(data.SecretKey) return nil } @@ -451,88 +483,35 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { // Speaking tells Discord we're speaking. This method should not be called // concurrently. +// +// If only NotSpeaking (0) is given, then even if the gateway cannot be reached, +// a nil error will be returned. This is because sending Discord a not-speaking +// event is a destruction command that doesn't affect the outcome of anything +// done after whatsoever. func (s *Session) Speaking(ctx context.Context, flag voicegateway.SpeakingFlag) error { s.mut.Lock() gateway := s.gateway s.mut.Unlock() - return gateway.Speaking(ctx, flag) -} - -func (s *Session) useUDP(f func(c *udp.Connection) error) (err error) { - const maxAttempts = 5 - const retryDelay = 250 * time.Millisecond // adds up to about 1.25s - - var lazyWait lazytime.Timer - - // Hack: loop until we no longer get an error closed or until the connection - // is dead. This is a workaround for when the session is trying to reconnect - // itself in the background, which would drop the UDP connection. - for i := 0; i < maxAttempts; i++ { - s.mut.RLock() - voiceUDP := s.voiceUDP - disconnected := s.disconnected - s.mut.RUnlock() - - select { - case <-disconnected: - return net.ErrClosed - default: - if voiceUDP == nil { - // Session is still connected, but our voice UDP connection is - // nil, so we're probably in the process of reconnecting - // already. - goto retry - } - } - - if err = f(voiceUDP); err != nil && errors.Is(err, net.ErrClosed) { - // Session is still connected, but our UDP connection is somehow - // closed, so we're probably waiting for the server to ask us to - // reconnect with a new session. - goto retry - } - - // Unknown error or none at all; exit. + if err := gateway.Speaking(ctx, flag); err != nil && flag != 0 { return err - - retry: - // Wait a slight bit. We can probably make the caller wait a couple - // milliseconds without a wait. - lazyWait.Reset(retryDelay) - select { - case <-lazyWait.C: - continue - case <-disconnected: - return net.ErrClosed - } } - return + return nil } // Write writes into the UDP voice connection. This method is thread safe as far // as calling other methods of Session goes; HOWEVER it is not thread safe to // call Write itself concurrently. func (s *Session) Write(b []byte) (int, error) { - var n int - err := s.useUDP(func(c *udp.Connection) (err error) { - n, err = c.Write(b) - return - }) - return n, err + return s.udpManager.Write(b) } // ReadPacket reads a single packet from the UDP connection. This is NOT at all // thread safe, and must be used very carefully. The backing buffer is always // reused. func (s *Session) ReadPacket() (*udp.Packet, error) { - var p *udp.Packet - err := s.useUDP(func(c *udp.Connection) (err error) { - p, err = c.ReadPacket() - return - }) - return p, err + return s.udpManager.ReadPacket() } // Leave disconnects the current voice session from the currently connected @@ -552,12 +531,12 @@ func (s *Session) Leave(ctx context.Context) error { } // If we're already closed. - if s.gateway == nil && s.voiceUDP == nil { + if s.gateway == nil && s.udpManager.IsClosed() { return nil } // Notify Discord that we're leaving. - err := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{ + sendErr := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{ GuildID: s.state.GuildID, ChannelID: discord.ChannelID(discord.NullSnowflake), SelfMute: true, @@ -570,8 +549,8 @@ func (s *Session) Leave(ctx context.Context) error { return err } - if err != nil { - return errors.Wrap(err, "failed to update voice state") + if sendErr != nil { + return errors.Wrap(sendErr, "failed to update voice state") } return nil @@ -592,18 +571,15 @@ func (s *Session) cancelGateway(ctx context.Context) error { return nil } +const ( + permanentClose = true + temporaryClose = false +) + // close ensures everything is closed. It does not acquire the mutex. func (s *Session) ensureClosed() { - // Disconnect the UDP connection. - if s.voiceUDP != nil { - s.voiceUDP.Close() - s.voiceUDP = nil - } - - if !s.disconnectClosed { - close(s.disconnected) - s.disconnectClosed = true - } + // Disconnect the UDP connection. If not permanent, then pause. + s.udpManager.Close() if s.gwCancel != nil { s.gwCancel() diff --git a/voice/session_example_test.go b/voice/session_example_test.go index e96bbd6..120bb20 100644 --- a/voice/session_example_test.go +++ b/voice/session_example_test.go @@ -2,7 +2,6 @@ package voice_test import ( "context" - "io" "log" "testing" @@ -10,6 +9,7 @@ import ( "github.com/diamondburned/arikawa/v3/internal/testenv" "github.com/diamondburned/arikawa/v3/state" "github.com/diamondburned/arikawa/v3/voice" + "github.com/diamondburned/arikawa/v3/voice/testdata" ) var ( @@ -25,9 +25,6 @@ func init() { } } -// pseudo function for example -func writeOpusInto(w io.Writer) {} - // make godoc not show the full file func TestNoop(t *testing.T) { t.Skip("noop") @@ -54,8 +51,7 @@ func ExampleSession() { } defer v.Leave(context.TODO()) - // Start writing Opus frames. - for { - writeOpusInto(v) + if err := testdata.WriteOpus(v, "testdata/nico.dca"); err != nil { + log.Fatalln("failed to write opus:", err) } } diff --git a/voice/session_test.go b/voice/session_test.go index 7ea8a64..f4f86d8 100644 --- a/voice/session_test.go +++ b/voice/session_test.go @@ -2,9 +2,8 @@ package voice import ( "context" - "encoding/binary" - "io" "log" + "math/rand" "os" "runtime" "strconv" @@ -12,23 +11,37 @@ import ( "testing" "time" + "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/internal/testenv" "github.com/diamondburned/arikawa/v3/state" + "github.com/diamondburned/arikawa/v3/utils/json/option" "github.com/diamondburned/arikawa/v3/utils/ws" + "github.com/diamondburned/arikawa/v3/voice/testdata" + "github.com/diamondburned/arikawa/v3/voice/udp" "github.com/diamondburned/arikawa/v3/voice/voicegateway" "github.com/pkg/errors" ) -func TestIntegration(t *testing.T) { - config := testenv.Must(t) - +func TestMain(m *testing.M) { ws.WSDebug = func(v ...interface{}) { _, file, line, _ := runtime.Caller(1) caller := file + ":" + strconv.Itoa(line) log.Println(append([]interface{}{caller}, v...)...) } + code := m.Run() + os.Exit(code) +} + +type testState struct { + *state.State + channel *discord.Channel +} + +func testOpen(t *testing.T) *testState { + config := testenv.Must(t) + s := state.New("Bot " + config.BotToken) AddIntents(s) @@ -52,17 +65,22 @@ func TestIntegration(t *testing.T) { t.Fatal("channel isn't a guild voice channel.") } - log.Println("The voice channel's name is", c.Name) + t.Log("The voice channel's name is", c.Name) - testVoice(t, s, c) - - // BUG: Discord doesn't want to send the second VoiceServerUpdateEvent. I - // have no idea why. - - // testVoice(t, s, c) + return &testState{ + State: s, + channel: c, + } } -func testVoice(t *testing.T, s *state.State, c *discord.Channel) { +func TestIntegration(t *testing.T) { + state := testOpen(t) + + t.Run("1st", func(t *testing.T) { testIntegrationOnce(t, state) }) + t.Run("2nd", func(t *testing.T) { testIntegrationOnce(t, state) }) +} + +func testIntegrationOnce(t *testing.T, s *testState) { v, err := NewSession(s) if err != nil { t.Fatal("failed to create a new voice session:", err) @@ -79,59 +97,43 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) { ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) t.Cleanup(cancel) - if err := v.JoinChannel(ctx, c.ID, false, false); err != nil { + if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil { t.Fatal("failed to join voice:", err) } t.Cleanup(func() { - log.Println("Leaving the voice channel concurrently.") + t.Log("Leaving the voice channel concurrently.") - raceMe(t, "failed to leave voice channel", func() (interface{}, error) { - return nil, v.Leave(ctx) + raceMe(t, "failed to leave voice channel", func() error { + return v.Leave(ctx) }) }) finish("joining the voice channel") - // Trigger speaking. - if err := v.Speaking(ctx, voicegateway.Microphone); err != nil { - t.Fatal("failed to start speaking:", err) - } + t.Cleanup(func() {}) finish("sending the speaking command") - f, err := os.Open("testdata/nico.dca") - if err != nil { - t.Fatal("failed to open nico.dca:", err) - } - defer f.Close() - - var lenbuf [4]byte - - // Copy the audio? - for { - if _, err := io.ReadFull(f, lenbuf[:]); err != nil { - if err == io.EOF { - break - } - t.Fatal("failed to read:", err) + doneCh := make(chan struct{}) + go func() { + if err := testdata.WriteOpus(v, testdata.Nico); err != nil { + t.Error(err) } + doneCh <- struct{}{} + }() - // Read the integer - framelen := int64(binary.LittleEndian.Uint32(lenbuf[:])) - - // Copy the frame. - if _, err := io.CopyN(v, f, framelen); err != nil && err != io.EOF { - t.Fatal("failed to write:", err) - } + select { + case <-ctx.Done(): + t.Error("timed out waiting for voice to be done") + case <-doneCh: + finish("copying the audio") } - - finish("copying the audio") } // raceMe intentionally calls fn multiple times in goroutines to ensure it's not // racy. -func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interface{} { +func raceMe(t *testing.T, wrapErr string, fn func() error) { const n = 3 // run 3 times t.Helper() @@ -139,21 +141,21 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf var wgr sync.WaitGroup var mut sync.Mutex - var val interface{} var err error for i := 0; i < n; i++ { wgr.Add(1) go func() { - v, e := fn() + e := fn() mut.Lock() - val = v - err = e + if e != nil { + err = e + } mut.Unlock() if e != nil { - log.Println("Potential race test error:", e) + t.Log("Potential race test error:", e) } wgr.Done() @@ -163,10 +165,8 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf wgr.Wait() if err != nil { - t.Fatal("Race test failed:", errors.Wrap(err, wrapErr)) + t.Fatal("race test failed:", errors.Wrap(err, wrapErr)) } - - return val } // simple shitty benchmark thing @@ -179,3 +179,122 @@ func timer() func(finished string) { then = now } } + +func TestKickedOut(t *testing.T) { + err := testReconnect(t, func(s *testState) { + me, err := s.Me() + if err != nil { + t.Fatal("cannot get me") + } + + if err := s.ModifyMember(s.channel.GuildID, me.ID, api.ModifyMemberData{ + // Kick the bot out. + VoiceChannel: discord.NullChannelID, + }); err != nil { + t.Error("cannot kick the bot out:", err) + } + }) + + if !errors.Is(err, udp.ErrManagerClosed) { + t.Error("unexpected error while sending nico.dca:", err) + } +} + +func TestRegionChange(t *testing.T) { + var state *testState + err := testReconnect(t, func(s *testState) { + state = s + t.Log("got voice region", s.channel.RTCRegionID) + + regions, err := s.VoiceRegionsGuild(s.channel.GuildID) + if err != nil { + t.Error("cannot get voice region:", err) + return + } + + rand.Shuffle(len(regions), func(i, j int) { + regions[i], regions[j] = regions[j], regions[i] + }) + + var anyRegion string + for _, region := range regions { + if region.ID != s.channel.RTCRegionID { + anyRegion = region.ID + break + } + } + + t.Log("changing voice region to", anyRegion) + + if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{ + RTCRegionID: option.NewNullableString(anyRegion), + }); err != nil { + t.Error("cannot change voice region:", err) + } + }) + + if err != nil { + t.Error("unexpected error while sending nico.dca:", err) + } + + s := state + + // Change voice region back. + if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{ + RTCRegionID: option.NewNullableString(s.channel.RTCRegionID), + }); err != nil { + t.Error("cannot change voice region back:", err) + } + + t.Log("changed voice region back to", s.channel.RTCRegionID) +} + +func testReconnect(t *testing.T, interrupt func(*testState)) error { + s := testOpen(t) + + v, err := NewSession(s) + if err != nil { + t.Fatal("cannot") + } + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + t.Cleanup(cancel) + + if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil { + t.Fatal("failed to join voice:", err) + } + + t.Cleanup(func() { + if err := v.Speaking(ctx, voicegateway.NotSpeaking); err != nil { + t.Error("cannot stop speaking:", err) + } + if err := v.Leave(ctx); err != nil { + t.Error("cannot leave voice:", err) + } + }) + + // Ensure the channel is buffered so we can send into it. Write may not be + // called often enough to immediately receive a tick from the unbuffered + // timer. + oneSec := make(chan struct{}, 1) + go func() { + <-time.After(450 * time.Millisecond) + oneSec <- struct{}{} + }() + + // Use a WriterFunc so we can interrupt the writing. + // Give 1s for the function to write before interrupting it; we already know + // that the saved dca file is longer than 1s, so we're fine doing this. + interruptWriter := testdata.WriterFunc(func(b []byte) (int, error) { + select { + case <-oneSec: + interrupt(s) + default: + // ok + } + + return v.Write(b) + }) + + return testdata.WriteOpus(interruptWriter, testdata.Nico) +} diff --git a/voice/testdata/testdata.go b/voice/testdata/testdata.go new file mode 100644 index 0000000..7f13a49 --- /dev/null +++ b/voice/testdata/testdata.go @@ -0,0 +1,48 @@ +package testdata + +import ( + "encoding/binary" + "io" + "os" + + "github.com/pkg/errors" +) + +const Nico = "testdata/nico.dca" + +// WriteOpus reads the given file containing the Opus frames into the give +// io.Writer. +func WriteOpus(w io.Writer, file string) error { + f, err := os.Open(file) + if err != nil { + return errors.Wrap(err, "failed to open "+file) + } + defer f.Close() + + var lenbuf [4]byte + for { + _, err := io.ReadFull(f, lenbuf[:]) + if err != nil { + if err == io.EOF { + return nil + } + return errors.Wrap(err, "failed to read "+file) + } + + // Read the integer + framelen := int64(binary.LittleEndian.Uint32(lenbuf[:])) + + // Copy the frame. + _, err = io.CopyN(w, f, framelen) + if err != nil && err != io.EOF { + return errors.Wrap(err, "failed to write") + } + } +} + +// WriterFunc wraps f to be an io.Writer. +type WriterFunc func([]byte) (int, error) + +func (w WriterFunc) Write(b []byte) (int, error) { + return w(b) +} diff --git a/voice/udp/udp.go b/voice/udp/connection.go similarity index 71% rename from voice/udp/udp.go rename to voice/udp/connection.go index 00f0034..ca4a3e7 100644 --- a/voice/udp/udp.go +++ b/voice/udp/connection.go @@ -6,33 +6,20 @@ import ( "encoding/binary" "io" "net" + "sync" "time" - "github.com/diamondburned/arikawa/v3/internal/moreatomic" "github.com/pkg/errors" "golang.org/x/crypto/nacl/secretbox" ) -const ( - packetHeaderSize = 12 -) +// ErrDecryptionFailed is returned from ReadPacket if the received packet fails +// to decrypt. +var ErrDecryptionFailed = errors.New("decryption failed") // Dialer is the default dialer that this package uses for all its dialing. -var ( - ErrDecryptionFailed = errors.New("decryption failed") - Dialer = net.Dialer{ - Timeout: 10 * time.Second, - } -) - -// Packet represents a voice packet. It is not thread-safe. -type Packet struct { - VersionFlags byte - Type byte - SSRC uint32 - Sequence uint16 - Timestamp uint32 - Opus []byte +var Dialer = net.Dialer{ + Timeout: 10 * time.Second, } // Connection represents a voice connection. It is not thread-safe. @@ -61,13 +48,21 @@ type Connection struct { recvOpus []byte // len 1400 recvPacket *Packet // uses recvOpus' backing array - closed moreatomic.Bool + closed sync.Once } -// DialConnection dials a UDP connection. +// DialConnection dials the UDP connection using the given address and SSRC +// number. func DialConnection(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { + return DialConnectionCustom(ctx, &Dialer, addr, ssrc) +} + +// DialConnectionCustom dials the UDP connection with a custom dialer. +func DialConnectionCustom( + ctx context.Context, dialer *net.Dialer, addr string, ssrc uint32) (*Connection, error) { + // Create a new UDP connection. - conn, err := Dialer.DialContext(ctx, "udp", addr) + conn, err := dialer.DialContext(ctx, "udp", addr) if err != nil { return nil, errors.Wrap(err, "failed to dial host") } @@ -173,18 +168,20 @@ func (c *Connection) SetReadDeadline(deadline time.Time) { c.conn.SetReadDeadline(deadline) } +// Close closes the connection. func (c *Connection) Close() error { - if c.closed.Acquire() { + c.closed.Do(func() { // Be sure to only run this ONCE. c.frequency.Stop() close(c.stopFreq) - } + }) return c.conn.Close() } -// Write sends a packet of audio into the voice UDP connection using the preset -// context. +// Write sends a packet of audio into the voice UDP connection. It is made to be +// stream-compatible: the internal frequency clock will slow Write down to match +// the real playback time. func (c *Connection) Write(b []byte) (int, error) { // Write a new sequence. binary.BigEndian.PutUint16(c.packet[2:4], c.sequence) @@ -193,53 +190,86 @@ func (c *Connection) Write(b []byte) (int, error) { binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp) c.timestamp += c.timeIncr - copy(c.nonce[:], c.packet[:]) + // Copy the first 12 bytes from the packet into the nonce. + copy(c.nonce[:12], c.packet[:]) - toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret) + // Seal the message, but reuse the packet buffer. We pass in the first 12 + // bytes of the packet, but allow it to reuse the whole packet buffer + toSend := secretbox.Seal(c.packet[:12], b, &c.nonce, &c.secret) select { case <-c.frequency.C: // ok case <-c.stopFreq: - return 0, net.ErrClosed + return 0, errors.Wrap(net.ErrClosed, "frequency ticker stopped") } - n, err := c.conn.Write(toSend) + _, err := c.conn.Write(toSend) if err != nil { - return n, errors.Wrap(err, "failed to write to UDP connection") + return 0, err } - // We're not really returning everything, since we're "sealing" the bytes. return len(b), nil } -// ReadPacket reads the UDP connection and returns a packet if successful. This -// packet is not thread-safe to use, as it shares recvBuf's buffer. Byte slices -// inside it must be copied or used before the next call to ReadPacket happens. -func (c *Connection) ReadPacket() (*Packet, error) { - for { - rlen, err := c.conn.Read(c.recvBuf) +// Packet represents a voice packet. +type Packet struct { + header []byte + Opus []byte +} +// VersionFlags returns the version flags of the current packet. +func (p *Packet) VersionFlags() byte { return p.header[0] } + +// Type returns the packet type. +func (p *Packet) Type() byte { return p.header[1] } + +// Sequence returns the packet sequence. +func (p *Packet) Sequence() uint16 { return binary.BigEndian.Uint16(p.header[2:4]) } + +// Timestamp returns the packet's timestamp. +func (p *Packet) Timestamp() uint32 { return binary.BigEndian.Uint32(p.header[4:8]) } + +// SSRC returns the packet's SSRC number. +func (p *Packet) SSRC() uint32 { return binary.BigEndian.Uint32(p.header[8:12]) } + +// Copy copies the current packet into the given packet. +func (p *Packet) Copy(dst *Packet) { + dst.header = append(dst.header[:0], p.header...) + dst.Opus = append(dst.Opus[:0], p.Opus...) +} + +const packetHeaderSize = 12 + +// ReadPacket reads the UDP connection and returns a packet if successful. The +// returned packet is invalidated once ReadPacket is called again. To avoid +// this, manually Copy the packet. +func (c *Connection) ReadPacket() (*Packet, error) { + if c.recvPacket.header == nil { + // Initialize the recvPacket's header. + c.recvPacket.header = c.recvBuf[:12] + } + + for { + i, err := c.conn.Read(c.recvBuf) if err != nil { return nil, err } - if rlen < packetHeaderSize || (c.recvBuf[0] != 0x80 && c.recvBuf[0] != 0x90) { + if i < packetHeaderSize || (c.recvBuf[0] != 0x80 && c.recvBuf[0] != 0x90) { continue } - c.recvPacket.VersionFlags = c.recvBuf[0] - c.recvPacket.Type = c.recvBuf[1] - c.recvPacket.Sequence = binary.BigEndian.Uint16(c.recvBuf[2:4]) - c.recvPacket.Timestamp = binary.BigEndian.Uint32(c.recvBuf[4:8]) - c.recvPacket.SSRC = binary.BigEndian.Uint32(c.recvBuf[8:12]) - + // Copy the nonce to be read. + // TODO: once Go 1.17 is released, we can remove recvNonce and directly + // cast it as (*[packetHeaderSize]byte)(c.recvBuf). copy(c.recvNonce[:], c.recvBuf[0:packetHeaderSize]) var ok bool + // Open (decrypt) the rest of the received bytes. c.recvPacket.Opus, ok = secretbox.Open( - c.recvOpus[:0], c.recvBuf[packetHeaderSize:rlen], &c.recvNonce, &c.secret) + c.recvOpus[:0], c.recvBuf[packetHeaderSize:i], &c.recvNonce, &c.secret) if !ok { return nil, ErrDecryptionFailed } @@ -267,7 +297,7 @@ func (c *Connection) ReadPacket() (*Packet, error) { // exactly one header extension, with a format defined in Section // 5.3.1. // - isExtension := c.recvPacket.VersionFlags&0x10 == 0x10 + isExtension := c.recvPacket.VersionFlags()&0x10 == 0x10 // We then check for whether or not the marker bit (9th bit) is set. The // 9th bit is carried over to the second byte (Type), so we check its @@ -291,7 +321,7 @@ func (c *Connection) ReadPacket() (*Packet, error) { // This implies that, when the marker bit is 1, the received packet is // an RTCP packet and NOT an RTP packet; therefore, we must ignore the // unknown sections, so we do a (NOT isMarker) check below. - isMarker := c.recvPacket.Type&0x80 != 0x0 + isMarker := c.recvPacket.Type()&0x80 != 0x0 if isExtension && !isMarker { extLen := binary.BigEndian.Uint16(c.recvPacket.Opus[2:4]) diff --git a/voice/udp/manager.go b/voice/udp/manager.go new file mode 100644 index 0000000..fa14243 --- /dev/null +++ b/voice/udp/manager.go @@ -0,0 +1,221 @@ +package udp + +import ( + "context" + "net" + "sync" + "time" + + "github.com/diamondburned/arikawa/v3/utils/ws" + "github.com/pkg/errors" +) + +// ErrManagerClosed is returned when a Manager that is already closed is dialed, +// written to or read from. +var ErrManagerClosed = errors.New("UDP connection manager is closed") + +// ErrDialWhileUnpaused is returned if Dial is called on the Manager without +// pausing it first +var ErrDialWhileUnpaused = errors.New("dial is called while manager is not paused") + +// Manager manages a UDP connection. It allows reconnecting. A Manager instance +// is thread-safe, meaning it can be used concurrently. +type Manager struct { + dialer *net.Dialer + + stopMu sync.Mutex + stopConn chan struct{} + stopDial context.CancelFunc + + // conn state + conn *Connection + connLock chan struct{} + + frequency time.Duration + timeIncr uint32 +} + +// NewManager creates a new UDP connection manager with the defalt dialer. +func NewManager() *Manager { + return &Manager{ + dialer: &Dialer, + stopConn: make(chan struct{}), + connLock: make(chan struct{}, 1), + } +} + +// SetDialer sets the manager's dialer. Calling this function while the Manager +// is working will cause a panic. Only call this method directly after +// construction. +func (m *Manager) SetDialer(d *net.Dialer) { + select { + case m.connLock <- struct{}{}: + m.dialer = d + <-m.connLock + default: + panic("SetDialer called while Manager is working") + } +} + +// Pause explicitly pauses the manager. It blocks until the Manager is paused or +// the context expires. +func (m *Manager) Pause(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case m.connLock <- struct{}{}: + return nil + } +} + +// Close closes the current connection. If the connection is already closed, +// then nothing is done and ErrManagerClosed is returned. Close does not pause +// the connection; calls to Close while the user is using the connection will +// result in the user getting ErrManagerClosed. +func (m *Manager) Close() (err error) { + // Acquire the mutex first. + m.stopMu.Lock() + defer m.stopMu.Unlock() + + // Cancel the dialing. + if m.stopDial != nil { + m.stopDial() + m.stopDial = nil + } + + // Stop existing Manager users. + select { + case <-m.stopConn: + // m.stopConn already closed + ws.WSDebug("UDP manager already closed") + return ErrManagerClosed + default: + close(m.stopConn) + ws.WSDebug("UDP manager closed") + } + + return nil +} + +// IsClosed returns true if the connection is closed. +func (m *Manager) IsClosed() bool { + return m.acquireConn() == nil +} + +// Continue unpauses and resumes the active user. If the manager has been +// successfully resumed, then true is returned, otherwise if it's already +// continued, then false is returned. +func (m *Manager) Continue() bool { + ws.WSDebug("UDP continued") + + select { + case <-m.connLock: + return true + default: + return false + } +} + +// Dial dials the internal connection to the given address and SSRC number. If +// the Manager is not Paused, then an error is returned. The caller must call +// Dial after Pause and before Unpause. If the Manager is already being dialed +// elsewhere, then ErrAlreadyDialing is returned. +func (m *Manager) Dial(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { + select { + case m.connLock <- struct{}{}: + return nil, errors.New("Dial called on unpaused Manager") + default: + // ok + } + + m.stopMu.Lock() + // Allow cancelling from another goroutine with this context. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + m.stopDial = cancel + m.stopMu.Unlock() + + conn, err := DialConnectionCustom(ctx, m.dialer, addr, ssrc) + if err != nil { + // Unlock if we failed. + <-m.connLock + return nil, errors.Wrap(err, "failed to dial") + } + + if m.frequency > 0 && m.timeIncr > 0 { + conn.ResetFrequency(m.frequency, m.timeIncr) + } + + m.stopMu.Lock() + ws.WSDebug("setting UDP conn to one w/ gateway address", conn.GatewayIP) + m.conn = conn + m.stopDial = nil + m.stopConn = make(chan struct{}) + m.stopMu.Unlock() + + return conn, nil +} + +// ResetFrequency sets the current connection and future connections' write +// frequency. Note that calling this method while Connection is being used in a +// different goroutine is not thread-safe. +func (m *Manager) ResetFrequency(frameDuration time.Duration, timeIncr uint32) { + m.connLock <- struct{}{} + defer func() { <-m.connLock }() + + m.frequency = frameDuration + m.timeIncr = timeIncr + + if m.conn != nil { + m.conn.ResetFrequency(frameDuration, timeIncr) + } +} + +// ReadPacket reads the current packet. It blocks until a packet arrives or +// the Manager is closed. +func (m *Manager) ReadPacket() (p *Packet, err error) { + conn := m.acquireConn() + if conn == nil { + return nil, ErrManagerClosed + } + + return conn.ReadPacket() +} + +// Write writes to the current connection in the manager. It blocks if the +// connection is being re-established. +func (m *Manager) Write(b []byte) (n int, err error) { + conn := m.acquireConn() + if conn == nil { + return 0, ErrManagerClosed + } + + return conn.Write(b) +} + +// acquireConn acquires the current connection and releases the lock, returning +// the connection at that point in time. Nil is returned if Manager is closed. +func (m *Manager) acquireConn() *Connection { + // Acquire the pause lock first. We must only rely on the stopConn being + // closed once we have this. + m.connLock <- struct{}{} + defer func() { <-m.connLock }() + + m.stopMu.Lock() + defer m.stopMu.Unlock() + + select { + case <-m.stopConn: + ws.WSDebug("UDP acquisition got stopped conn") + return nil + default: + // ok + } + + if m.conn == nil { + ws.WSDebug("UDP acquisition got nil conn") + } + + return m.conn +} diff --git a/voice/voice.go b/voice/voice.go index 4141213..955f8c7 100644 --- a/voice/voice.go +++ b/voice/voice.go @@ -6,9 +6,11 @@ package voice import "github.com/diamondburned/arikawa/v3/gateway" +// Intents are the intents needed for voice to work properly. +const Intents = gateway.IntentGuilds | gateway.IntentGuildVoiceStates + // AddIntents adds the needed voice intents into gw. Bots should always call // this before Open if voice is required. func AddIntents(gw interface{ AddIntents(gateway.Intents) }) { - gw.AddIntents(gateway.IntentGuilds) - gw.AddIntents(gateway.IntentGuildVoiceStates) + gw.AddIntents(Intents) } diff --git a/voice/voicegateway/gateway.go b/voice/voicegateway/gateway.go index cba21ee..f8d7399 100644 --- a/voice/voicegateway/gateway.go +++ b/voice/voicegateway/gateway.go @@ -21,18 +21,14 @@ import ( "github.com/diamondburned/arikawa/v3/utils/ws" ) -const ( - // Version represents the current version of the Discord Gateway Gateway this package uses. - Version = "4" -) +// Version represents the current version of the Discord Gateway Gateway this package uses. +const Version = "4" var ( ErrNoSessionID = errors.New("no sessionID was received") ErrNoEndpoint = errors.New("no endpoint was received") ) -type Event = interface{} - // State contains state information of a voice gateway. type State struct { UserID discord.UserID // constant @@ -81,7 +77,7 @@ var DefaultGatewayOpts = ws.GatewayOpts{ // New creates a new voice gateway. func New(state State) *Gateway { // https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection - var endpoint = "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version + endpoint := "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version gw := ws.NewGateway( ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), endpoint), @@ -117,13 +113,17 @@ func (g *Gateway) Send(ctx context.Context, data ws.Event) error { // Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway. func (g *Gateway) Speaking(ctx context.Context, flag SpeakingFlag) error { g.mutex.RLock() - ssrc := g.ready.SSRC + ready := g.ready g.mutex.RUnlock() + if ready == nil { + return errors.New("Speaking called before gateway was ready") + } + return g.gateway.Send(ctx, &SpeakingEvent{ Speaking: flag, Delay: 0, - SSRC: ssrc, + SSRC: ready.SSRC, }) }