diff --git a/utils/ws/conn.go b/utils/ws/conn.go index 0027bca..822c627 100644 --- a/utils/ws/conn.go +++ b/utils/ws/conn.go @@ -180,6 +180,10 @@ func (c *Conn) Send(ctx context.Context, b []byte) error { conn := c.conn c.mut.Unlock() + if conn == nil || conn.Conn == nil { + return ErrWebsocketClosed + } + select { case conn.wrmut <- struct{}{}: defer func() { <-conn.wrmut }() diff --git a/utils/ws/gateway.go b/utils/ws/gateway.go index d361c62..69c06a8 100644 --- a/utils/ws/gateway.go +++ b/utils/ws/gateway.go @@ -150,6 +150,8 @@ func (g *Gateway) Send(ctx context.Context, data Event) error { Data: data, } + WSDebug("sending command Op", op.Code, "type", op.Type) + b, err := json.Marshal(op) if err != nil { return errors.Wrap(err, "failed to encode payload") diff --git a/utils/ws/ws.go b/utils/ws/ws.go index 774dd0e..dd5965d 100644 --- a/utils/ws/ws.go +++ b/utils/ws/ws.go @@ -70,10 +70,9 @@ func (ws *Websocket) Dial(ctx context.Context) (<-chan Op, error) { // Send sends b over the Websocket with a deadline. It closes the internal // Websocket if the Send method errors out. func (ws *Websocket) Send(ctx context.Context, b []byte) error { - WSDebug("Acquiring the websoccket mutex for sending.") + WSDebug("Acquiring the websocket mutex for sending.") ws.mutex.Lock() - WSDebug("Mutex lock acquired.") sendLimiter := ws.sendLimiter conn := ws.conn ws.mutex.Unlock() @@ -85,7 +84,7 @@ func (ws *Websocket) Send(ctx context.Context, b []byte) error { return errors.Wrap(err, "SendLimiter failed") } - WSDebug("Send has passed the rate limiting. Waiting on mutex.") + WSDebug("Send has passed the rate limiting.") return conn.Send(ctx, b) } diff --git a/voice/session.go b/voice/session.go index 6a0e7db..8003c38 100644 --- a/voice/session.go +++ b/voice/session.go @@ -14,7 +14,6 @@ import ( "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/internal/lazytime" "github.com/diamondburned/arikawa/v3/internal/moreatomic" "github.com/diamondburned/arikawa/v3/session" "github.com/diamondburned/arikawa/v3/utils/ws" @@ -86,14 +85,14 @@ type Session struct { detachReconnect []func() - voiceUDP *udp.Connection + // udpManager is the manager for a UDP connection. The user can use this to + // plug in a custom UDP dialer. + udpManager *udp.Manager + gateway *voicegateway.Gateway gwCancel context.CancelFunc gwDone <-chan struct{} - // DialUDP is the custom function for dialing up a UDP connection. - DialUDP UDPDialer - WSTimeout time.Duration // global WSTimeout WSMaxRetry int // 2 WSRetryDelay time.Duration // 2s @@ -130,7 +129,7 @@ func NewSessionCustom(ses MainSession, userID discord.UserID) *Session { state: voicegateway.State{ UserID: userID, }, - DialUDP: udp.DialConnection, + udpManager: udp.NewManager(), WSTimeout: WSTimeout, WSMaxRetry: 2, WSRetryDelay: 2 * time.Second, @@ -145,6 +144,12 @@ func NewSessionCustom(ses MainSession, userID discord.UserID) *Session { return session } +// SetUDPDialer sets the given dialer to be used for dialing UDP voice +// connections. +func (s *Session) SetUDPDialer(d *net.Dialer) { + s.udpManager.SetDialer(d) +} + func (s *Session) acquireUpdate(f func()) bool { if s.joining.Get() { return false @@ -154,11 +159,8 @@ func (s *Session) acquireUpdate(f func()) bool { defer s.mut.Unlock() // Ignore if we haven't connected yet or we're still joining. - select { - case <-s.disconnected: + if s.udpManager.IsClosed() { return false - default: - // ok } f() @@ -200,6 +202,15 @@ func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) { }) } +// JoinChannelAndSpeak is a convenient function that calls JoinChannel then +// Speaking. +func (s *Session) JoinChannelAndSpeak(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error { + if err := s.JoinChannel(ctx, chID, mute, deaf); err != nil { + return errors.Wrap(err, "cannot join channel") + } + return s.Speaking(ctx, voicegateway.Microphone) +} + type waitEventChs struct { serverUpdate chan *gateway.VoiceServerUpdateEvent stateUpdate chan *gateway.VoiceStateUpdateEvent @@ -222,11 +233,10 @@ func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, // Error out if we're already joining. JoinChannel shouldn't be called // concurrently. - if s.joining.Get() { + if !s.joining.Acquire() { return errors.New("JoinChannel working elsewhere") } - s.joining.Set(true) defer s.joining.Set(false) // Set the state. @@ -303,7 +313,7 @@ func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, case <-timer.C: continue case <-ctx.Done(): - return err + return errors.Wrap(err, "cannot ask Discord for events") } } @@ -375,6 +385,15 @@ func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error { func (s *Session) reconnectCtx(ctx context.Context) error { ws.WSDebug("Sending stop handle.") + if err := s.udpManager.Pause(ctx); err != nil { + return errors.Wrap(err, "cannot pause UDP manager") + } + defer func() { + if !s.udpManager.Continue() { + panic("UDP manager continued but invalid lock ownership") + } + }() + s.ensureClosed() ws.WSDebug("Start gateway.") @@ -385,8 +404,10 @@ func (s *Session) reconnectCtx(ctx context.Context) error { s.gwCancel = gwcancel gwch := s.gateway.Connect(gwctx) + ws.WSDebug("Voice Gateway connected") if err := s.spinGateway(ctx, gwch); err != nil { + ws.WSDebug("Voice spinGateway error:", err) // Early cancel the gateway. gwcancel() // Nil this so future reconnects don't use the invalid gwDone. @@ -394,17 +415,20 @@ func (s *Session) reconnectCtx(ctx context.Context) error { // Emit the error. It's fine to do this here since this is the only // place that can error out. s.Handler.Call(&ReconnectError{err}) - return err + return errors.Wrap(err, "cannot wait for event sequence from voice gateway") } // Start dispatching. s.gwDone = ophandler.Loop(gwch, s.Handler) + ws.WSDebug("Voice reconnectCtx finished with no error") + return nil } func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { var err error + var conn *udp.Connection for { select { @@ -412,7 +436,7 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { return ctx.Err() case ev, ok := <-gwch: if !ok { - return s.gateway.LastError() + return errors.Wrap(s.gateway.LastError(), "voice gateway error") } switch data := ev.Data.(type) { @@ -420,8 +444,10 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { return errors.Wrap(err, "voice gateway error") case *voicegateway.ReadyEvent: + ws.WSDebug("Got ready from voice gateway, SSRC:", data.SSRC) + // Prepare the UDP voice connection. - s.voiceUDP, err = s.DialUDP(ctx, data.Addr(), data.SSRC) + conn, err = s.udpManager.Dial(ctx, data.Addr(), data.SSRC) if err != nil { return errors.Wrap(err, "failed to open voice UDP connection") } @@ -429,8 +455,8 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { if err := s.gateway.Send(ctx, &voicegateway.SelectProtocolCommand{ Protocol: "udp", Data: voicegateway.SelectProtocolData{ - Address: s.voiceUDP.GatewayIP, - Port: s.voiceUDP.GatewayPort, + Address: conn.GatewayIP, + Port: conn.GatewayPort, Mode: Protocol, }, }); err != nil { @@ -438,8 +464,14 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { } case *voicegateway.SessionDescriptionEvent: + if conn == nil { + return errors.New("server bug: SessionDescription before Ready") + } + + ws.WSDebug("Received secret key from voice gateway") + // We're done. - s.voiceUDP.UseSecret(data.SecretKey) + conn.UseSecret(data.SecretKey) return nil } @@ -451,88 +483,35 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { // Speaking tells Discord we're speaking. This method should not be called // concurrently. +// +// If only NotSpeaking (0) is given, then even if the gateway cannot be reached, +// a nil error will be returned. This is because sending Discord a not-speaking +// event is a destruction command that doesn't affect the outcome of anything +// done after whatsoever. func (s *Session) Speaking(ctx context.Context, flag voicegateway.SpeakingFlag) error { s.mut.Lock() gateway := s.gateway s.mut.Unlock() - return gateway.Speaking(ctx, flag) -} - -func (s *Session) useUDP(f func(c *udp.Connection) error) (err error) { - const maxAttempts = 5 - const retryDelay = 250 * time.Millisecond // adds up to about 1.25s - - var lazyWait lazytime.Timer - - // Hack: loop until we no longer get an error closed or until the connection - // is dead. This is a workaround for when the session is trying to reconnect - // itself in the background, which would drop the UDP connection. - for i := 0; i < maxAttempts; i++ { - s.mut.RLock() - voiceUDP := s.voiceUDP - disconnected := s.disconnected - s.mut.RUnlock() - - select { - case <-disconnected: - return net.ErrClosed - default: - if voiceUDP == nil { - // Session is still connected, but our voice UDP connection is - // nil, so we're probably in the process of reconnecting - // already. - goto retry - } - } - - if err = f(voiceUDP); err != nil && errors.Is(err, net.ErrClosed) { - // Session is still connected, but our UDP connection is somehow - // closed, so we're probably waiting for the server to ask us to - // reconnect with a new session. - goto retry - } - - // Unknown error or none at all; exit. + if err := gateway.Speaking(ctx, flag); err != nil && flag != 0 { return err - - retry: - // Wait a slight bit. We can probably make the caller wait a couple - // milliseconds without a wait. - lazyWait.Reset(retryDelay) - select { - case <-lazyWait.C: - continue - case <-disconnected: - return net.ErrClosed - } } - return + return nil } // Write writes into the UDP voice connection. This method is thread safe as far // as calling other methods of Session goes; HOWEVER it is not thread safe to // call Write itself concurrently. func (s *Session) Write(b []byte) (int, error) { - var n int - err := s.useUDP(func(c *udp.Connection) (err error) { - n, err = c.Write(b) - return - }) - return n, err + return s.udpManager.Write(b) } // ReadPacket reads a single packet from the UDP connection. This is NOT at all // thread safe, and must be used very carefully. The backing buffer is always // reused. func (s *Session) ReadPacket() (*udp.Packet, error) { - var p *udp.Packet - err := s.useUDP(func(c *udp.Connection) (err error) { - p, err = c.ReadPacket() - return - }) - return p, err + return s.udpManager.ReadPacket() } // Leave disconnects the current voice session from the currently connected @@ -552,12 +531,12 @@ func (s *Session) Leave(ctx context.Context) error { } // If we're already closed. - if s.gateway == nil && s.voiceUDP == nil { + if s.gateway == nil && s.udpManager.IsClosed() { return nil } // Notify Discord that we're leaving. - err := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{ + sendErr := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{ GuildID: s.state.GuildID, ChannelID: discord.ChannelID(discord.NullSnowflake), SelfMute: true, @@ -570,8 +549,8 @@ func (s *Session) Leave(ctx context.Context) error { return err } - if err != nil { - return errors.Wrap(err, "failed to update voice state") + if sendErr != nil { + return errors.Wrap(sendErr, "failed to update voice state") } return nil @@ -592,18 +571,15 @@ func (s *Session) cancelGateway(ctx context.Context) error { return nil } +const ( + permanentClose = true + temporaryClose = false +) + // close ensures everything is closed. It does not acquire the mutex. func (s *Session) ensureClosed() { - // Disconnect the UDP connection. - if s.voiceUDP != nil { - s.voiceUDP.Close() - s.voiceUDP = nil - } - - if !s.disconnectClosed { - close(s.disconnected) - s.disconnectClosed = true - } + // Disconnect the UDP connection. If not permanent, then pause. + s.udpManager.Close() if s.gwCancel != nil { s.gwCancel() diff --git a/voice/session_example_test.go b/voice/session_example_test.go index e96bbd6..120bb20 100644 --- a/voice/session_example_test.go +++ b/voice/session_example_test.go @@ -2,7 +2,6 @@ package voice_test import ( "context" - "io" "log" "testing" @@ -10,6 +9,7 @@ import ( "github.com/diamondburned/arikawa/v3/internal/testenv" "github.com/diamondburned/arikawa/v3/state" "github.com/diamondburned/arikawa/v3/voice" + "github.com/diamondburned/arikawa/v3/voice/testdata" ) var ( @@ -25,9 +25,6 @@ func init() { } } -// pseudo function for example -func writeOpusInto(w io.Writer) {} - // make godoc not show the full file func TestNoop(t *testing.T) { t.Skip("noop") @@ -54,8 +51,7 @@ func ExampleSession() { } defer v.Leave(context.TODO()) - // Start writing Opus frames. - for { - writeOpusInto(v) + if err := testdata.WriteOpus(v, "testdata/nico.dca"); err != nil { + log.Fatalln("failed to write opus:", err) } } diff --git a/voice/session_test.go b/voice/session_test.go index 7ea8a64..f4f86d8 100644 --- a/voice/session_test.go +++ b/voice/session_test.go @@ -2,9 +2,8 @@ package voice import ( "context" - "encoding/binary" - "io" "log" + "math/rand" "os" "runtime" "strconv" @@ -12,23 +11,37 @@ import ( "testing" "time" + "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/internal/testenv" "github.com/diamondburned/arikawa/v3/state" + "github.com/diamondburned/arikawa/v3/utils/json/option" "github.com/diamondburned/arikawa/v3/utils/ws" + "github.com/diamondburned/arikawa/v3/voice/testdata" + "github.com/diamondburned/arikawa/v3/voice/udp" "github.com/diamondburned/arikawa/v3/voice/voicegateway" "github.com/pkg/errors" ) -func TestIntegration(t *testing.T) { - config := testenv.Must(t) - +func TestMain(m *testing.M) { ws.WSDebug = func(v ...interface{}) { _, file, line, _ := runtime.Caller(1) caller := file + ":" + strconv.Itoa(line) log.Println(append([]interface{}{caller}, v...)...) } + code := m.Run() + os.Exit(code) +} + +type testState struct { + *state.State + channel *discord.Channel +} + +func testOpen(t *testing.T) *testState { + config := testenv.Must(t) + s := state.New("Bot " + config.BotToken) AddIntents(s) @@ -52,17 +65,22 @@ func TestIntegration(t *testing.T) { t.Fatal("channel isn't a guild voice channel.") } - log.Println("The voice channel's name is", c.Name) + t.Log("The voice channel's name is", c.Name) - testVoice(t, s, c) - - // BUG: Discord doesn't want to send the second VoiceServerUpdateEvent. I - // have no idea why. - - // testVoice(t, s, c) + return &testState{ + State: s, + channel: c, + } } -func testVoice(t *testing.T, s *state.State, c *discord.Channel) { +func TestIntegration(t *testing.T) { + state := testOpen(t) + + t.Run("1st", func(t *testing.T) { testIntegrationOnce(t, state) }) + t.Run("2nd", func(t *testing.T) { testIntegrationOnce(t, state) }) +} + +func testIntegrationOnce(t *testing.T, s *testState) { v, err := NewSession(s) if err != nil { t.Fatal("failed to create a new voice session:", err) @@ -79,59 +97,43 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) { ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) t.Cleanup(cancel) - if err := v.JoinChannel(ctx, c.ID, false, false); err != nil { + if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil { t.Fatal("failed to join voice:", err) } t.Cleanup(func() { - log.Println("Leaving the voice channel concurrently.") + t.Log("Leaving the voice channel concurrently.") - raceMe(t, "failed to leave voice channel", func() (interface{}, error) { - return nil, v.Leave(ctx) + raceMe(t, "failed to leave voice channel", func() error { + return v.Leave(ctx) }) }) finish("joining the voice channel") - // Trigger speaking. - if err := v.Speaking(ctx, voicegateway.Microphone); err != nil { - t.Fatal("failed to start speaking:", err) - } + t.Cleanup(func() {}) finish("sending the speaking command") - f, err := os.Open("testdata/nico.dca") - if err != nil { - t.Fatal("failed to open nico.dca:", err) - } - defer f.Close() - - var lenbuf [4]byte - - // Copy the audio? - for { - if _, err := io.ReadFull(f, lenbuf[:]); err != nil { - if err == io.EOF { - break - } - t.Fatal("failed to read:", err) + doneCh := make(chan struct{}) + go func() { + if err := testdata.WriteOpus(v, testdata.Nico); err != nil { + t.Error(err) } + doneCh <- struct{}{} + }() - // Read the integer - framelen := int64(binary.LittleEndian.Uint32(lenbuf[:])) - - // Copy the frame. - if _, err := io.CopyN(v, f, framelen); err != nil && err != io.EOF { - t.Fatal("failed to write:", err) - } + select { + case <-ctx.Done(): + t.Error("timed out waiting for voice to be done") + case <-doneCh: + finish("copying the audio") } - - finish("copying the audio") } // raceMe intentionally calls fn multiple times in goroutines to ensure it's not // racy. -func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interface{} { +func raceMe(t *testing.T, wrapErr string, fn func() error) { const n = 3 // run 3 times t.Helper() @@ -139,21 +141,21 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf var wgr sync.WaitGroup var mut sync.Mutex - var val interface{} var err error for i := 0; i < n; i++ { wgr.Add(1) go func() { - v, e := fn() + e := fn() mut.Lock() - val = v - err = e + if e != nil { + err = e + } mut.Unlock() if e != nil { - log.Println("Potential race test error:", e) + t.Log("Potential race test error:", e) } wgr.Done() @@ -163,10 +165,8 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf wgr.Wait() if err != nil { - t.Fatal("Race test failed:", errors.Wrap(err, wrapErr)) + t.Fatal("race test failed:", errors.Wrap(err, wrapErr)) } - - return val } // simple shitty benchmark thing @@ -179,3 +179,122 @@ func timer() func(finished string) { then = now } } + +func TestKickedOut(t *testing.T) { + err := testReconnect(t, func(s *testState) { + me, err := s.Me() + if err != nil { + t.Fatal("cannot get me") + } + + if err := s.ModifyMember(s.channel.GuildID, me.ID, api.ModifyMemberData{ + // Kick the bot out. + VoiceChannel: discord.NullChannelID, + }); err != nil { + t.Error("cannot kick the bot out:", err) + } + }) + + if !errors.Is(err, udp.ErrManagerClosed) { + t.Error("unexpected error while sending nico.dca:", err) + } +} + +func TestRegionChange(t *testing.T) { + var state *testState + err := testReconnect(t, func(s *testState) { + state = s + t.Log("got voice region", s.channel.RTCRegionID) + + regions, err := s.VoiceRegionsGuild(s.channel.GuildID) + if err != nil { + t.Error("cannot get voice region:", err) + return + } + + rand.Shuffle(len(regions), func(i, j int) { + regions[i], regions[j] = regions[j], regions[i] + }) + + var anyRegion string + for _, region := range regions { + if region.ID != s.channel.RTCRegionID { + anyRegion = region.ID + break + } + } + + t.Log("changing voice region to", anyRegion) + + if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{ + RTCRegionID: option.NewNullableString(anyRegion), + }); err != nil { + t.Error("cannot change voice region:", err) + } + }) + + if err != nil { + t.Error("unexpected error while sending nico.dca:", err) + } + + s := state + + // Change voice region back. + if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{ + RTCRegionID: option.NewNullableString(s.channel.RTCRegionID), + }); err != nil { + t.Error("cannot change voice region back:", err) + } + + t.Log("changed voice region back to", s.channel.RTCRegionID) +} + +func testReconnect(t *testing.T, interrupt func(*testState)) error { + s := testOpen(t) + + v, err := NewSession(s) + if err != nil { + t.Fatal("cannot") + } + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + t.Cleanup(cancel) + + if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil { + t.Fatal("failed to join voice:", err) + } + + t.Cleanup(func() { + if err := v.Speaking(ctx, voicegateway.NotSpeaking); err != nil { + t.Error("cannot stop speaking:", err) + } + if err := v.Leave(ctx); err != nil { + t.Error("cannot leave voice:", err) + } + }) + + // Ensure the channel is buffered so we can send into it. Write may not be + // called often enough to immediately receive a tick from the unbuffered + // timer. + oneSec := make(chan struct{}, 1) + go func() { + <-time.After(450 * time.Millisecond) + oneSec <- struct{}{} + }() + + // Use a WriterFunc so we can interrupt the writing. + // Give 1s for the function to write before interrupting it; we already know + // that the saved dca file is longer than 1s, so we're fine doing this. + interruptWriter := testdata.WriterFunc(func(b []byte) (int, error) { + select { + case <-oneSec: + interrupt(s) + default: + // ok + } + + return v.Write(b) + }) + + return testdata.WriteOpus(interruptWriter, testdata.Nico) +} diff --git a/voice/testdata/testdata.go b/voice/testdata/testdata.go new file mode 100644 index 0000000..7f13a49 --- /dev/null +++ b/voice/testdata/testdata.go @@ -0,0 +1,48 @@ +package testdata + +import ( + "encoding/binary" + "io" + "os" + + "github.com/pkg/errors" +) + +const Nico = "testdata/nico.dca" + +// WriteOpus reads the given file containing the Opus frames into the give +// io.Writer. +func WriteOpus(w io.Writer, file string) error { + f, err := os.Open(file) + if err != nil { + return errors.Wrap(err, "failed to open "+file) + } + defer f.Close() + + var lenbuf [4]byte + for { + _, err := io.ReadFull(f, lenbuf[:]) + if err != nil { + if err == io.EOF { + return nil + } + return errors.Wrap(err, "failed to read "+file) + } + + // Read the integer + framelen := int64(binary.LittleEndian.Uint32(lenbuf[:])) + + // Copy the frame. + _, err = io.CopyN(w, f, framelen) + if err != nil && err != io.EOF { + return errors.Wrap(err, "failed to write") + } + } +} + +// WriterFunc wraps f to be an io.Writer. +type WriterFunc func([]byte) (int, error) + +func (w WriterFunc) Write(b []byte) (int, error) { + return w(b) +} diff --git a/voice/udp/udp.go b/voice/udp/connection.go similarity index 71% rename from voice/udp/udp.go rename to voice/udp/connection.go index 00f0034..ca4a3e7 100644 --- a/voice/udp/udp.go +++ b/voice/udp/connection.go @@ -6,33 +6,20 @@ import ( "encoding/binary" "io" "net" + "sync" "time" - "github.com/diamondburned/arikawa/v3/internal/moreatomic" "github.com/pkg/errors" "golang.org/x/crypto/nacl/secretbox" ) -const ( - packetHeaderSize = 12 -) +// ErrDecryptionFailed is returned from ReadPacket if the received packet fails +// to decrypt. +var ErrDecryptionFailed = errors.New("decryption failed") // Dialer is the default dialer that this package uses for all its dialing. -var ( - ErrDecryptionFailed = errors.New("decryption failed") - Dialer = net.Dialer{ - Timeout: 10 * time.Second, - } -) - -// Packet represents a voice packet. It is not thread-safe. -type Packet struct { - VersionFlags byte - Type byte - SSRC uint32 - Sequence uint16 - Timestamp uint32 - Opus []byte +var Dialer = net.Dialer{ + Timeout: 10 * time.Second, } // Connection represents a voice connection. It is not thread-safe. @@ -61,13 +48,21 @@ type Connection struct { recvOpus []byte // len 1400 recvPacket *Packet // uses recvOpus' backing array - closed moreatomic.Bool + closed sync.Once } -// DialConnection dials a UDP connection. +// DialConnection dials the UDP connection using the given address and SSRC +// number. func DialConnection(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { + return DialConnectionCustom(ctx, &Dialer, addr, ssrc) +} + +// DialConnectionCustom dials the UDP connection with a custom dialer. +func DialConnectionCustom( + ctx context.Context, dialer *net.Dialer, addr string, ssrc uint32) (*Connection, error) { + // Create a new UDP connection. - conn, err := Dialer.DialContext(ctx, "udp", addr) + conn, err := dialer.DialContext(ctx, "udp", addr) if err != nil { return nil, errors.Wrap(err, "failed to dial host") } @@ -173,18 +168,20 @@ func (c *Connection) SetReadDeadline(deadline time.Time) { c.conn.SetReadDeadline(deadline) } +// Close closes the connection. func (c *Connection) Close() error { - if c.closed.Acquire() { + c.closed.Do(func() { // Be sure to only run this ONCE. c.frequency.Stop() close(c.stopFreq) - } + }) return c.conn.Close() } -// Write sends a packet of audio into the voice UDP connection using the preset -// context. +// Write sends a packet of audio into the voice UDP connection. It is made to be +// stream-compatible: the internal frequency clock will slow Write down to match +// the real playback time. func (c *Connection) Write(b []byte) (int, error) { // Write a new sequence. binary.BigEndian.PutUint16(c.packet[2:4], c.sequence) @@ -193,53 +190,86 @@ func (c *Connection) Write(b []byte) (int, error) { binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp) c.timestamp += c.timeIncr - copy(c.nonce[:], c.packet[:]) + // Copy the first 12 bytes from the packet into the nonce. + copy(c.nonce[:12], c.packet[:]) - toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret) + // Seal the message, but reuse the packet buffer. We pass in the first 12 + // bytes of the packet, but allow it to reuse the whole packet buffer + toSend := secretbox.Seal(c.packet[:12], b, &c.nonce, &c.secret) select { case <-c.frequency.C: // ok case <-c.stopFreq: - return 0, net.ErrClosed + return 0, errors.Wrap(net.ErrClosed, "frequency ticker stopped") } - n, err := c.conn.Write(toSend) + _, err := c.conn.Write(toSend) if err != nil { - return n, errors.Wrap(err, "failed to write to UDP connection") + return 0, err } - // We're not really returning everything, since we're "sealing" the bytes. return len(b), nil } -// ReadPacket reads the UDP connection and returns a packet if successful. This -// packet is not thread-safe to use, as it shares recvBuf's buffer. Byte slices -// inside it must be copied or used before the next call to ReadPacket happens. -func (c *Connection) ReadPacket() (*Packet, error) { - for { - rlen, err := c.conn.Read(c.recvBuf) +// Packet represents a voice packet. +type Packet struct { + header []byte + Opus []byte +} +// VersionFlags returns the version flags of the current packet. +func (p *Packet) VersionFlags() byte { return p.header[0] } + +// Type returns the packet type. +func (p *Packet) Type() byte { return p.header[1] } + +// Sequence returns the packet sequence. +func (p *Packet) Sequence() uint16 { return binary.BigEndian.Uint16(p.header[2:4]) } + +// Timestamp returns the packet's timestamp. +func (p *Packet) Timestamp() uint32 { return binary.BigEndian.Uint32(p.header[4:8]) } + +// SSRC returns the packet's SSRC number. +func (p *Packet) SSRC() uint32 { return binary.BigEndian.Uint32(p.header[8:12]) } + +// Copy copies the current packet into the given packet. +func (p *Packet) Copy(dst *Packet) { + dst.header = append(dst.header[:0], p.header...) + dst.Opus = append(dst.Opus[:0], p.Opus...) +} + +const packetHeaderSize = 12 + +// ReadPacket reads the UDP connection and returns a packet if successful. The +// returned packet is invalidated once ReadPacket is called again. To avoid +// this, manually Copy the packet. +func (c *Connection) ReadPacket() (*Packet, error) { + if c.recvPacket.header == nil { + // Initialize the recvPacket's header. + c.recvPacket.header = c.recvBuf[:12] + } + + for { + i, err := c.conn.Read(c.recvBuf) if err != nil { return nil, err } - if rlen < packetHeaderSize || (c.recvBuf[0] != 0x80 && c.recvBuf[0] != 0x90) { + if i < packetHeaderSize || (c.recvBuf[0] != 0x80 && c.recvBuf[0] != 0x90) { continue } - c.recvPacket.VersionFlags = c.recvBuf[0] - c.recvPacket.Type = c.recvBuf[1] - c.recvPacket.Sequence = binary.BigEndian.Uint16(c.recvBuf[2:4]) - c.recvPacket.Timestamp = binary.BigEndian.Uint32(c.recvBuf[4:8]) - c.recvPacket.SSRC = binary.BigEndian.Uint32(c.recvBuf[8:12]) - + // Copy the nonce to be read. + // TODO: once Go 1.17 is released, we can remove recvNonce and directly + // cast it as (*[packetHeaderSize]byte)(c.recvBuf). copy(c.recvNonce[:], c.recvBuf[0:packetHeaderSize]) var ok bool + // Open (decrypt) the rest of the received bytes. c.recvPacket.Opus, ok = secretbox.Open( - c.recvOpus[:0], c.recvBuf[packetHeaderSize:rlen], &c.recvNonce, &c.secret) + c.recvOpus[:0], c.recvBuf[packetHeaderSize:i], &c.recvNonce, &c.secret) if !ok { return nil, ErrDecryptionFailed } @@ -267,7 +297,7 @@ func (c *Connection) ReadPacket() (*Packet, error) { // exactly one header extension, with a format defined in Section // 5.3.1. // - isExtension := c.recvPacket.VersionFlags&0x10 == 0x10 + isExtension := c.recvPacket.VersionFlags()&0x10 == 0x10 // We then check for whether or not the marker bit (9th bit) is set. The // 9th bit is carried over to the second byte (Type), so we check its @@ -291,7 +321,7 @@ func (c *Connection) ReadPacket() (*Packet, error) { // This implies that, when the marker bit is 1, the received packet is // an RTCP packet and NOT an RTP packet; therefore, we must ignore the // unknown sections, so we do a (NOT isMarker) check below. - isMarker := c.recvPacket.Type&0x80 != 0x0 + isMarker := c.recvPacket.Type()&0x80 != 0x0 if isExtension && !isMarker { extLen := binary.BigEndian.Uint16(c.recvPacket.Opus[2:4]) diff --git a/voice/udp/manager.go b/voice/udp/manager.go new file mode 100644 index 0000000..fa14243 --- /dev/null +++ b/voice/udp/manager.go @@ -0,0 +1,221 @@ +package udp + +import ( + "context" + "net" + "sync" + "time" + + "github.com/diamondburned/arikawa/v3/utils/ws" + "github.com/pkg/errors" +) + +// ErrManagerClosed is returned when a Manager that is already closed is dialed, +// written to or read from. +var ErrManagerClosed = errors.New("UDP connection manager is closed") + +// ErrDialWhileUnpaused is returned if Dial is called on the Manager without +// pausing it first +var ErrDialWhileUnpaused = errors.New("dial is called while manager is not paused") + +// Manager manages a UDP connection. It allows reconnecting. A Manager instance +// is thread-safe, meaning it can be used concurrently. +type Manager struct { + dialer *net.Dialer + + stopMu sync.Mutex + stopConn chan struct{} + stopDial context.CancelFunc + + // conn state + conn *Connection + connLock chan struct{} + + frequency time.Duration + timeIncr uint32 +} + +// NewManager creates a new UDP connection manager with the defalt dialer. +func NewManager() *Manager { + return &Manager{ + dialer: &Dialer, + stopConn: make(chan struct{}), + connLock: make(chan struct{}, 1), + } +} + +// SetDialer sets the manager's dialer. Calling this function while the Manager +// is working will cause a panic. Only call this method directly after +// construction. +func (m *Manager) SetDialer(d *net.Dialer) { + select { + case m.connLock <- struct{}{}: + m.dialer = d + <-m.connLock + default: + panic("SetDialer called while Manager is working") + } +} + +// Pause explicitly pauses the manager. It blocks until the Manager is paused or +// the context expires. +func (m *Manager) Pause(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case m.connLock <- struct{}{}: + return nil + } +} + +// Close closes the current connection. If the connection is already closed, +// then nothing is done and ErrManagerClosed is returned. Close does not pause +// the connection; calls to Close while the user is using the connection will +// result in the user getting ErrManagerClosed. +func (m *Manager) Close() (err error) { + // Acquire the mutex first. + m.stopMu.Lock() + defer m.stopMu.Unlock() + + // Cancel the dialing. + if m.stopDial != nil { + m.stopDial() + m.stopDial = nil + } + + // Stop existing Manager users. + select { + case <-m.stopConn: + // m.stopConn already closed + ws.WSDebug("UDP manager already closed") + return ErrManagerClosed + default: + close(m.stopConn) + ws.WSDebug("UDP manager closed") + } + + return nil +} + +// IsClosed returns true if the connection is closed. +func (m *Manager) IsClosed() bool { + return m.acquireConn() == nil +} + +// Continue unpauses and resumes the active user. If the manager has been +// successfully resumed, then true is returned, otherwise if it's already +// continued, then false is returned. +func (m *Manager) Continue() bool { + ws.WSDebug("UDP continued") + + select { + case <-m.connLock: + return true + default: + return false + } +} + +// Dial dials the internal connection to the given address and SSRC number. If +// the Manager is not Paused, then an error is returned. The caller must call +// Dial after Pause and before Unpause. If the Manager is already being dialed +// elsewhere, then ErrAlreadyDialing is returned. +func (m *Manager) Dial(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { + select { + case m.connLock <- struct{}{}: + return nil, errors.New("Dial called on unpaused Manager") + default: + // ok + } + + m.stopMu.Lock() + // Allow cancelling from another goroutine with this context. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + m.stopDial = cancel + m.stopMu.Unlock() + + conn, err := DialConnectionCustom(ctx, m.dialer, addr, ssrc) + if err != nil { + // Unlock if we failed. + <-m.connLock + return nil, errors.Wrap(err, "failed to dial") + } + + if m.frequency > 0 && m.timeIncr > 0 { + conn.ResetFrequency(m.frequency, m.timeIncr) + } + + m.stopMu.Lock() + ws.WSDebug("setting UDP conn to one w/ gateway address", conn.GatewayIP) + m.conn = conn + m.stopDial = nil + m.stopConn = make(chan struct{}) + m.stopMu.Unlock() + + return conn, nil +} + +// ResetFrequency sets the current connection and future connections' write +// frequency. Note that calling this method while Connection is being used in a +// different goroutine is not thread-safe. +func (m *Manager) ResetFrequency(frameDuration time.Duration, timeIncr uint32) { + m.connLock <- struct{}{} + defer func() { <-m.connLock }() + + m.frequency = frameDuration + m.timeIncr = timeIncr + + if m.conn != nil { + m.conn.ResetFrequency(frameDuration, timeIncr) + } +} + +// ReadPacket reads the current packet. It blocks until a packet arrives or +// the Manager is closed. +func (m *Manager) ReadPacket() (p *Packet, err error) { + conn := m.acquireConn() + if conn == nil { + return nil, ErrManagerClosed + } + + return conn.ReadPacket() +} + +// Write writes to the current connection in the manager. It blocks if the +// connection is being re-established. +func (m *Manager) Write(b []byte) (n int, err error) { + conn := m.acquireConn() + if conn == nil { + return 0, ErrManagerClosed + } + + return conn.Write(b) +} + +// acquireConn acquires the current connection and releases the lock, returning +// the connection at that point in time. Nil is returned if Manager is closed. +func (m *Manager) acquireConn() *Connection { + // Acquire the pause lock first. We must only rely on the stopConn being + // closed once we have this. + m.connLock <- struct{}{} + defer func() { <-m.connLock }() + + m.stopMu.Lock() + defer m.stopMu.Unlock() + + select { + case <-m.stopConn: + ws.WSDebug("UDP acquisition got stopped conn") + return nil + default: + // ok + } + + if m.conn == nil { + ws.WSDebug("UDP acquisition got nil conn") + } + + return m.conn +} diff --git a/voice/voice.go b/voice/voice.go index 4141213..955f8c7 100644 --- a/voice/voice.go +++ b/voice/voice.go @@ -6,9 +6,11 @@ package voice import "github.com/diamondburned/arikawa/v3/gateway" +// Intents are the intents needed for voice to work properly. +const Intents = gateway.IntentGuilds | gateway.IntentGuildVoiceStates + // AddIntents adds the needed voice intents into gw. Bots should always call // this before Open if voice is required. func AddIntents(gw interface{ AddIntents(gateway.Intents) }) { - gw.AddIntents(gateway.IntentGuilds) - gw.AddIntents(gateway.IntentGuildVoiceStates) + gw.AddIntents(Intents) } diff --git a/voice/voicegateway/gateway.go b/voice/voicegateway/gateway.go index cba21ee..f8d7399 100644 --- a/voice/voicegateway/gateway.go +++ b/voice/voicegateway/gateway.go @@ -21,18 +21,14 @@ import ( "github.com/diamondburned/arikawa/v3/utils/ws" ) -const ( - // Version represents the current version of the Discord Gateway Gateway this package uses. - Version = "4" -) +// Version represents the current version of the Discord Gateway Gateway this package uses. +const Version = "4" var ( ErrNoSessionID = errors.New("no sessionID was received") ErrNoEndpoint = errors.New("no endpoint was received") ) -type Event = interface{} - // State contains state information of a voice gateway. type State struct { UserID discord.UserID // constant @@ -81,7 +77,7 @@ var DefaultGatewayOpts = ws.GatewayOpts{ // New creates a new voice gateway. func New(state State) *Gateway { // https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection - var endpoint = "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version + endpoint := "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version gw := ws.NewGateway( ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), endpoint), @@ -117,13 +113,17 @@ func (g *Gateway) Send(ctx context.Context, data ws.Event) error { // Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway. func (g *Gateway) Speaking(ctx context.Context, flag SpeakingFlag) error { g.mutex.RLock() - ssrc := g.ready.SSRC + ready := g.ready g.mutex.RUnlock() + if ready == nil { + return errors.New("Speaking called before gateway was ready") + } + return g.gateway.Send(ctx, &SpeakingEvent{ Speaking: flag, Delay: 0, - SSRC: ssrc, + SSRC: ready.SSRC, }) }