voice: Refactor and fix up

This commit refactors a lot of voice's internals to be more stable and
handle more edge cases from Discord's voice servers. It should result in
an overall more stable voice connection.

A few helper functions have been added into voice.Session. Some fields
will have been broken and changed to accomodate for the refactor, as
well.

Below are some commits that have been squashed in:

    voice: Fix Speaking() panic on closed
    voice: StopSpeaking should not error out
        The rationale is added as a comment into the Speaking() method.
    voice: Add TestKickedOut
    voice: Fix region change disconnecting
This commit is contained in:
diamondburned 2021-07-02 02:42:00 -07:00 committed by diamondburned
parent f5e713dee5
commit f5ae68c781
No known key found for this signature in database
GPG Key ID: D78C4471CE776659
11 changed files with 615 additions and 218 deletions

View File

@ -180,6 +180,10 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
conn := c.conn
c.mut.Unlock()
if conn == nil || conn.Conn == nil {
return ErrWebsocketClosed
}
select {
case conn.wrmut <- struct{}{}:
defer func() { <-conn.wrmut }()

View File

@ -150,6 +150,8 @@ func (g *Gateway) Send(ctx context.Context, data Event) error {
Data: data,
}
WSDebug("sending command Op", op.Code, "type", op.Type)
b, err := json.Marshal(op)
if err != nil {
return errors.Wrap(err, "failed to encode payload")

View File

@ -70,10 +70,9 @@ func (ws *Websocket) Dial(ctx context.Context) (<-chan Op, error) {
// Send sends b over the Websocket with a deadline. It closes the internal
// Websocket if the Send method errors out.
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
WSDebug("Acquiring the websoccket mutex for sending.")
WSDebug("Acquiring the websocket mutex for sending.")
ws.mutex.Lock()
WSDebug("Mutex lock acquired.")
sendLimiter := ws.sendLimiter
conn := ws.conn
ws.mutex.Unlock()
@ -85,7 +84,7 @@ func (ws *Websocket) Send(ctx context.Context, b []byte) error {
return errors.Wrap(err, "SendLimiter failed")
}
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
WSDebug("Send has passed the rate limiting.")
return conn.Send(ctx, b)
}

View File

@ -14,7 +14,6 @@ import (
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/internal/lazytime"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/utils/ws"
@ -86,14 +85,14 @@ type Session struct {
detachReconnect []func()
voiceUDP *udp.Connection
// udpManager is the manager for a UDP connection. The user can use this to
// plug in a custom UDP dialer.
udpManager *udp.Manager
gateway *voicegateway.Gateway
gwCancel context.CancelFunc
gwDone <-chan struct{}
// DialUDP is the custom function for dialing up a UDP connection.
DialUDP UDPDialer
WSTimeout time.Duration // global WSTimeout
WSMaxRetry int // 2
WSRetryDelay time.Duration // 2s
@ -130,7 +129,7 @@ func NewSessionCustom(ses MainSession, userID discord.UserID) *Session {
state: voicegateway.State{
UserID: userID,
},
DialUDP: udp.DialConnection,
udpManager: udp.NewManager(),
WSTimeout: WSTimeout,
WSMaxRetry: 2,
WSRetryDelay: 2 * time.Second,
@ -145,6 +144,12 @@ func NewSessionCustom(ses MainSession, userID discord.UserID) *Session {
return session
}
// SetUDPDialer sets the given dialer to be used for dialing UDP voice
// connections.
func (s *Session) SetUDPDialer(d *net.Dialer) {
s.udpManager.SetDialer(d)
}
func (s *Session) acquireUpdate(f func()) bool {
if s.joining.Get() {
return false
@ -154,11 +159,8 @@ func (s *Session) acquireUpdate(f func()) bool {
defer s.mut.Unlock()
// Ignore if we haven't connected yet or we're still joining.
select {
case <-s.disconnected:
if s.udpManager.IsClosed() {
return false
default:
// ok
}
f()
@ -200,6 +202,15 @@ func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
})
}
// JoinChannelAndSpeak is a convenient function that calls JoinChannel then
// Speaking.
func (s *Session) JoinChannelAndSpeak(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error {
if err := s.JoinChannel(ctx, chID, mute, deaf); err != nil {
return errors.Wrap(err, "cannot join channel")
}
return s.Speaking(ctx, voicegateway.Microphone)
}
type waitEventChs struct {
serverUpdate chan *gateway.VoiceServerUpdateEvent
stateUpdate chan *gateway.VoiceStateUpdateEvent
@ -222,11 +233,10 @@ func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute,
// Error out if we're already joining. JoinChannel shouldn't be called
// concurrently.
if s.joining.Get() {
if !s.joining.Acquire() {
return errors.New("JoinChannel working elsewhere")
}
s.joining.Set(true)
defer s.joining.Set(false)
// Set the state.
@ -303,7 +313,7 @@ func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute,
case <-timer.C:
continue
case <-ctx.Done():
return err
return errors.Wrap(err, "cannot ask Discord for events")
}
}
@ -375,6 +385,15 @@ func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error {
func (s *Session) reconnectCtx(ctx context.Context) error {
ws.WSDebug("Sending stop handle.")
if err := s.udpManager.Pause(ctx); err != nil {
return errors.Wrap(err, "cannot pause UDP manager")
}
defer func() {
if !s.udpManager.Continue() {
panic("UDP manager continued but invalid lock ownership")
}
}()
s.ensureClosed()
ws.WSDebug("Start gateway.")
@ -385,8 +404,10 @@ func (s *Session) reconnectCtx(ctx context.Context) error {
s.gwCancel = gwcancel
gwch := s.gateway.Connect(gwctx)
ws.WSDebug("Voice Gateway connected")
if err := s.spinGateway(ctx, gwch); err != nil {
ws.WSDebug("Voice spinGateway error:", err)
// Early cancel the gateway.
gwcancel()
// Nil this so future reconnects don't use the invalid gwDone.
@ -394,17 +415,20 @@ func (s *Session) reconnectCtx(ctx context.Context) error {
// Emit the error. It's fine to do this here since this is the only
// place that can error out.
s.Handler.Call(&ReconnectError{err})
return err
return errors.Wrap(err, "cannot wait for event sequence from voice gateway")
}
// Start dispatching.
s.gwDone = ophandler.Loop(gwch, s.Handler)
ws.WSDebug("Voice reconnectCtx finished with no error")
return nil
}
func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
var err error
var conn *udp.Connection
for {
select {
@ -412,7 +436,7 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
return ctx.Err()
case ev, ok := <-gwch:
if !ok {
return s.gateway.LastError()
return errors.Wrap(s.gateway.LastError(), "voice gateway error")
}
switch data := ev.Data.(type) {
@ -420,8 +444,10 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
return errors.Wrap(err, "voice gateway error")
case *voicegateway.ReadyEvent:
ws.WSDebug("Got ready from voice gateway, SSRC:", data.SSRC)
// Prepare the UDP voice connection.
s.voiceUDP, err = s.DialUDP(ctx, data.Addr(), data.SSRC)
conn, err = s.udpManager.Dial(ctx, data.Addr(), data.SSRC)
if err != nil {
return errors.Wrap(err, "failed to open voice UDP connection")
}
@ -429,8 +455,8 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
if err := s.gateway.Send(ctx, &voicegateway.SelectProtocolCommand{
Protocol: "udp",
Data: voicegateway.SelectProtocolData{
Address: s.voiceUDP.GatewayIP,
Port: s.voiceUDP.GatewayPort,
Address: conn.GatewayIP,
Port: conn.GatewayPort,
Mode: Protocol,
},
}); err != nil {
@ -438,8 +464,14 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
}
case *voicegateway.SessionDescriptionEvent:
if conn == nil {
return errors.New("server bug: SessionDescription before Ready")
}
ws.WSDebug("Received secret key from voice gateway")
// We're done.
s.voiceUDP.UseSecret(data.SecretKey)
conn.UseSecret(data.SecretKey)
return nil
}
@ -451,88 +483,35 @@ func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
// Speaking tells Discord we're speaking. This method should not be called
// concurrently.
//
// If only NotSpeaking (0) is given, then even if the gateway cannot be reached,
// a nil error will be returned. This is because sending Discord a not-speaking
// event is a destruction command that doesn't affect the outcome of anything
// done after whatsoever.
func (s *Session) Speaking(ctx context.Context, flag voicegateway.SpeakingFlag) error {
s.mut.Lock()
gateway := s.gateway
s.mut.Unlock()
return gateway.Speaking(ctx, flag)
}
func (s *Session) useUDP(f func(c *udp.Connection) error) (err error) {
const maxAttempts = 5
const retryDelay = 250 * time.Millisecond // adds up to about 1.25s
var lazyWait lazytime.Timer
// Hack: loop until we no longer get an error closed or until the connection
// is dead. This is a workaround for when the session is trying to reconnect
// itself in the background, which would drop the UDP connection.
for i := 0; i < maxAttempts; i++ {
s.mut.RLock()
voiceUDP := s.voiceUDP
disconnected := s.disconnected
s.mut.RUnlock()
select {
case <-disconnected:
return net.ErrClosed
default:
if voiceUDP == nil {
// Session is still connected, but our voice UDP connection is
// nil, so we're probably in the process of reconnecting
// already.
goto retry
}
}
if err = f(voiceUDP); err != nil && errors.Is(err, net.ErrClosed) {
// Session is still connected, but our UDP connection is somehow
// closed, so we're probably waiting for the server to ask us to
// reconnect with a new session.
goto retry
}
// Unknown error or none at all; exit.
if err := gateway.Speaking(ctx, flag); err != nil && flag != 0 {
return err
retry:
// Wait a slight bit. We can probably make the caller wait a couple
// milliseconds without a wait.
lazyWait.Reset(retryDelay)
select {
case <-lazyWait.C:
continue
case <-disconnected:
return net.ErrClosed
}
}
return
return nil
}
// Write writes into the UDP voice connection. This method is thread safe as far
// as calling other methods of Session goes; HOWEVER it is not thread safe to
// call Write itself concurrently.
func (s *Session) Write(b []byte) (int, error) {
var n int
err := s.useUDP(func(c *udp.Connection) (err error) {
n, err = c.Write(b)
return
})
return n, err
return s.udpManager.Write(b)
}
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
// thread safe, and must be used very carefully. The backing buffer is always
// reused.
func (s *Session) ReadPacket() (*udp.Packet, error) {
var p *udp.Packet
err := s.useUDP(func(c *udp.Connection) (err error) {
p, err = c.ReadPacket()
return
})
return p, err
return s.udpManager.ReadPacket()
}
// Leave disconnects the current voice session from the currently connected
@ -552,12 +531,12 @@ func (s *Session) Leave(ctx context.Context) error {
}
// If we're already closed.
if s.gateway == nil && s.voiceUDP == nil {
if s.gateway == nil && s.udpManager.IsClosed() {
return nil
}
// Notify Discord that we're leaving.
err := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{
sendErr := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{
GuildID: s.state.GuildID,
ChannelID: discord.ChannelID(discord.NullSnowflake),
SelfMute: true,
@ -570,8 +549,8 @@ func (s *Session) Leave(ctx context.Context) error {
return err
}
if err != nil {
return errors.Wrap(err, "failed to update voice state")
if sendErr != nil {
return errors.Wrap(sendErr, "failed to update voice state")
}
return nil
@ -592,18 +571,15 @@ func (s *Session) cancelGateway(ctx context.Context) error {
return nil
}
const (
permanentClose = true
temporaryClose = false
)
// close ensures everything is closed. It does not acquire the mutex.
func (s *Session) ensureClosed() {
// Disconnect the UDP connection.
if s.voiceUDP != nil {
s.voiceUDP.Close()
s.voiceUDP = nil
}
if !s.disconnectClosed {
close(s.disconnected)
s.disconnectClosed = true
}
// Disconnect the UDP connection. If not permanent, then pause.
s.udpManager.Close()
if s.gwCancel != nil {
s.gwCancel()

View File

@ -2,7 +2,6 @@ package voice_test
import (
"context"
"io"
"log"
"testing"
@ -10,6 +9,7 @@ import (
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/voice"
"github.com/diamondburned/arikawa/v3/voice/testdata"
)
var (
@ -25,9 +25,6 @@ func init() {
}
}
// pseudo function for example
func writeOpusInto(w io.Writer) {}
// make godoc not show the full file
func TestNoop(t *testing.T) {
t.Skip("noop")
@ -54,8 +51,7 @@ func ExampleSession() {
}
defer v.Leave(context.TODO())
// Start writing Opus frames.
for {
writeOpusInto(v)
if err := testdata.WriteOpus(v, "testdata/nico.dca"); err != nil {
log.Fatalln("failed to write opus:", err)
}
}

View File

@ -2,9 +2,8 @@ package voice
import (
"context"
"encoding/binary"
"io"
"log"
"math/rand"
"os"
"runtime"
"strconv"
@ -12,23 +11,37 @@ import (
"testing"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/ws"
"github.com/diamondburned/arikawa/v3/voice/testdata"
"github.com/diamondburned/arikawa/v3/voice/udp"
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
"github.com/pkg/errors"
)
func TestIntegration(t *testing.T) {
config := testenv.Must(t)
func TestMain(m *testing.M) {
ws.WSDebug = func(v ...interface{}) {
_, file, line, _ := runtime.Caller(1)
caller := file + ":" + strconv.Itoa(line)
log.Println(append([]interface{}{caller}, v...)...)
}
code := m.Run()
os.Exit(code)
}
type testState struct {
*state.State
channel *discord.Channel
}
func testOpen(t *testing.T) *testState {
config := testenv.Must(t)
s := state.New("Bot " + config.BotToken)
AddIntents(s)
@ -52,17 +65,22 @@ func TestIntegration(t *testing.T) {
t.Fatal("channel isn't a guild voice channel.")
}
log.Println("The voice channel's name is", c.Name)
t.Log("The voice channel's name is", c.Name)
testVoice(t, s, c)
// BUG: Discord doesn't want to send the second VoiceServerUpdateEvent. I
// have no idea why.
// testVoice(t, s, c)
return &testState{
State: s,
channel: c,
}
}
func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
func TestIntegration(t *testing.T) {
state := testOpen(t)
t.Run("1st", func(t *testing.T) { testIntegrationOnce(t, state) })
t.Run("2nd", func(t *testing.T) { testIntegrationOnce(t, state) })
}
func testIntegrationOnce(t *testing.T, s *testState) {
v, err := NewSession(s)
if err != nil {
t.Fatal("failed to create a new voice session:", err)
@ -79,59 +97,43 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
t.Cleanup(cancel)
if err := v.JoinChannel(ctx, c.ID, false, false); err != nil {
if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil {
t.Fatal("failed to join voice:", err)
}
t.Cleanup(func() {
log.Println("Leaving the voice channel concurrently.")
t.Log("Leaving the voice channel concurrently.")
raceMe(t, "failed to leave voice channel", func() (interface{}, error) {
return nil, v.Leave(ctx)
raceMe(t, "failed to leave voice channel", func() error {
return v.Leave(ctx)
})
})
finish("joining the voice channel")
// Trigger speaking.
if err := v.Speaking(ctx, voicegateway.Microphone); err != nil {
t.Fatal("failed to start speaking:", err)
}
t.Cleanup(func() {})
finish("sending the speaking command")
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("failed to open nico.dca:", err)
}
defer f.Close()
var lenbuf [4]byte
// Copy the audio?
for {
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
if err == io.EOF {
break
}
t.Fatal("failed to read:", err)
doneCh := make(chan struct{})
go func() {
if err := testdata.WriteOpus(v, testdata.Nico); err != nil {
t.Error(err)
}
doneCh <- struct{}{}
}()
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(v, f, framelen); err != nil && err != io.EOF {
t.Fatal("failed to write:", err)
}
select {
case <-ctx.Done():
t.Error("timed out waiting for voice to be done")
case <-doneCh:
finish("copying the audio")
}
finish("copying the audio")
}
// raceMe intentionally calls fn multiple times in goroutines to ensure it's not
// racy.
func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interface{} {
func raceMe(t *testing.T, wrapErr string, fn func() error) {
const n = 3 // run 3 times
t.Helper()
@ -139,21 +141,21 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf
var wgr sync.WaitGroup
var mut sync.Mutex
var val interface{}
var err error
for i := 0; i < n; i++ {
wgr.Add(1)
go func() {
v, e := fn()
e := fn()
mut.Lock()
val = v
err = e
if e != nil {
err = e
}
mut.Unlock()
if e != nil {
log.Println("Potential race test error:", e)
t.Log("Potential race test error:", e)
}
wgr.Done()
@ -163,10 +165,8 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf
wgr.Wait()
if err != nil {
t.Fatal("Race test failed:", errors.Wrap(err, wrapErr))
t.Fatal("race test failed:", errors.Wrap(err, wrapErr))
}
return val
}
// simple shitty benchmark thing
@ -179,3 +179,122 @@ func timer() func(finished string) {
then = now
}
}
func TestKickedOut(t *testing.T) {
err := testReconnect(t, func(s *testState) {
me, err := s.Me()
if err != nil {
t.Fatal("cannot get me")
}
if err := s.ModifyMember(s.channel.GuildID, me.ID, api.ModifyMemberData{
// Kick the bot out.
VoiceChannel: discord.NullChannelID,
}); err != nil {
t.Error("cannot kick the bot out:", err)
}
})
if !errors.Is(err, udp.ErrManagerClosed) {
t.Error("unexpected error while sending nico.dca:", err)
}
}
func TestRegionChange(t *testing.T) {
var state *testState
err := testReconnect(t, func(s *testState) {
state = s
t.Log("got voice region", s.channel.RTCRegionID)
regions, err := s.VoiceRegionsGuild(s.channel.GuildID)
if err != nil {
t.Error("cannot get voice region:", err)
return
}
rand.Shuffle(len(regions), func(i, j int) {
regions[i], regions[j] = regions[j], regions[i]
})
var anyRegion string
for _, region := range regions {
if region.ID != s.channel.RTCRegionID {
anyRegion = region.ID
break
}
}
t.Log("changing voice region to", anyRegion)
if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{
RTCRegionID: option.NewNullableString(anyRegion),
}); err != nil {
t.Error("cannot change voice region:", err)
}
})
if err != nil {
t.Error("unexpected error while sending nico.dca:", err)
}
s := state
// Change voice region back.
if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{
RTCRegionID: option.NewNullableString(s.channel.RTCRegionID),
}); err != nil {
t.Error("cannot change voice region back:", err)
}
t.Log("changed voice region back to", s.channel.RTCRegionID)
}
func testReconnect(t *testing.T, interrupt func(*testState)) error {
s := testOpen(t)
v, err := NewSession(s)
if err != nil {
t.Fatal("cannot")
}
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
t.Cleanup(cancel)
if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil {
t.Fatal("failed to join voice:", err)
}
t.Cleanup(func() {
if err := v.Speaking(ctx, voicegateway.NotSpeaking); err != nil {
t.Error("cannot stop speaking:", err)
}
if err := v.Leave(ctx); err != nil {
t.Error("cannot leave voice:", err)
}
})
// Ensure the channel is buffered so we can send into it. Write may not be
// called often enough to immediately receive a tick from the unbuffered
// timer.
oneSec := make(chan struct{}, 1)
go func() {
<-time.After(450 * time.Millisecond)
oneSec <- struct{}{}
}()
// Use a WriterFunc so we can interrupt the writing.
// Give 1s for the function to write before interrupting it; we already know
// that the saved dca file is longer than 1s, so we're fine doing this.
interruptWriter := testdata.WriterFunc(func(b []byte) (int, error) {
select {
case <-oneSec:
interrupt(s)
default:
// ok
}
return v.Write(b)
})
return testdata.WriteOpus(interruptWriter, testdata.Nico)
}

48
voice/testdata/testdata.go vendored Normal file
View File

@ -0,0 +1,48 @@
package testdata
import (
"encoding/binary"
"io"
"os"
"github.com/pkg/errors"
)
const Nico = "testdata/nico.dca"
// WriteOpus reads the given file containing the Opus frames into the give
// io.Writer.
func WriteOpus(w io.Writer, file string) error {
f, err := os.Open(file)
if err != nil {
return errors.Wrap(err, "failed to open "+file)
}
defer f.Close()
var lenbuf [4]byte
for {
_, err := io.ReadFull(f, lenbuf[:])
if err != nil {
if err == io.EOF {
return nil
}
return errors.Wrap(err, "failed to read "+file)
}
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
_, err = io.CopyN(w, f, framelen)
if err != nil && err != io.EOF {
return errors.Wrap(err, "failed to write")
}
}
}
// WriterFunc wraps f to be an io.Writer.
type WriterFunc func([]byte) (int, error)
func (w WriterFunc) Write(b []byte) (int, error) {
return w(b)
}

View File

@ -6,33 +6,20 @@ import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/pkg/errors"
"golang.org/x/crypto/nacl/secretbox"
)
const (
packetHeaderSize = 12
)
// ErrDecryptionFailed is returned from ReadPacket if the received packet fails
// to decrypt.
var ErrDecryptionFailed = errors.New("decryption failed")
// Dialer is the default dialer that this package uses for all its dialing.
var (
ErrDecryptionFailed = errors.New("decryption failed")
Dialer = net.Dialer{
Timeout: 10 * time.Second,
}
)
// Packet represents a voice packet. It is not thread-safe.
type Packet struct {
VersionFlags byte
Type byte
SSRC uint32
Sequence uint16
Timestamp uint32
Opus []byte
var Dialer = net.Dialer{
Timeout: 10 * time.Second,
}
// Connection represents a voice connection. It is not thread-safe.
@ -61,13 +48,21 @@ type Connection struct {
recvOpus []byte // len 1400
recvPacket *Packet // uses recvOpus' backing array
closed moreatomic.Bool
closed sync.Once
}
// DialConnection dials a UDP connection.
// DialConnection dials the UDP connection using the given address and SSRC
// number.
func DialConnection(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
return DialConnectionCustom(ctx, &Dialer, addr, ssrc)
}
// DialConnectionCustom dials the UDP connection with a custom dialer.
func DialConnectionCustom(
ctx context.Context, dialer *net.Dialer, addr string, ssrc uint32) (*Connection, error) {
// Create a new UDP connection.
conn, err := Dialer.DialContext(ctx, "udp", addr)
conn, err := dialer.DialContext(ctx, "udp", addr)
if err != nil {
return nil, errors.Wrap(err, "failed to dial host")
}
@ -173,18 +168,20 @@ func (c *Connection) SetReadDeadline(deadline time.Time) {
c.conn.SetReadDeadline(deadline)
}
// Close closes the connection.
func (c *Connection) Close() error {
if c.closed.Acquire() {
c.closed.Do(func() {
// Be sure to only run this ONCE.
c.frequency.Stop()
close(c.stopFreq)
}
})
return c.conn.Close()
}
// Write sends a packet of audio into the voice UDP connection using the preset
// context.
// Write sends a packet of audio into the voice UDP connection. It is made to be
// stream-compatible: the internal frequency clock will slow Write down to match
// the real playback time.
func (c *Connection) Write(b []byte) (int, error) {
// Write a new sequence.
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
@ -193,53 +190,86 @@ func (c *Connection) Write(b []byte) (int, error) {
binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp)
c.timestamp += c.timeIncr
copy(c.nonce[:], c.packet[:])
// Copy the first 12 bytes from the packet into the nonce.
copy(c.nonce[:12], c.packet[:])
toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret)
// Seal the message, but reuse the packet buffer. We pass in the first 12
// bytes of the packet, but allow it to reuse the whole packet buffer
toSend := secretbox.Seal(c.packet[:12], b, &c.nonce, &c.secret)
select {
case <-c.frequency.C:
// ok
case <-c.stopFreq:
return 0, net.ErrClosed
return 0, errors.Wrap(net.ErrClosed, "frequency ticker stopped")
}
n, err := c.conn.Write(toSend)
_, err := c.conn.Write(toSend)
if err != nil {
return n, errors.Wrap(err, "failed to write to UDP connection")
return 0, err
}
// We're not really returning everything, since we're "sealing" the bytes.
return len(b), nil
}
// ReadPacket reads the UDP connection and returns a packet if successful. This
// packet is not thread-safe to use, as it shares recvBuf's buffer. Byte slices
// inside it must be copied or used before the next call to ReadPacket happens.
func (c *Connection) ReadPacket() (*Packet, error) {
for {
rlen, err := c.conn.Read(c.recvBuf)
// Packet represents a voice packet.
type Packet struct {
header []byte
Opus []byte
}
// VersionFlags returns the version flags of the current packet.
func (p *Packet) VersionFlags() byte { return p.header[0] }
// Type returns the packet type.
func (p *Packet) Type() byte { return p.header[1] }
// Sequence returns the packet sequence.
func (p *Packet) Sequence() uint16 { return binary.BigEndian.Uint16(p.header[2:4]) }
// Timestamp returns the packet's timestamp.
func (p *Packet) Timestamp() uint32 { return binary.BigEndian.Uint32(p.header[4:8]) }
// SSRC returns the packet's SSRC number.
func (p *Packet) SSRC() uint32 { return binary.BigEndian.Uint32(p.header[8:12]) }
// Copy copies the current packet into the given packet.
func (p *Packet) Copy(dst *Packet) {
dst.header = append(dst.header[:0], p.header...)
dst.Opus = append(dst.Opus[:0], p.Opus...)
}
const packetHeaderSize = 12
// ReadPacket reads the UDP connection and returns a packet if successful. The
// returned packet is invalidated once ReadPacket is called again. To avoid
// this, manually Copy the packet.
func (c *Connection) ReadPacket() (*Packet, error) {
if c.recvPacket.header == nil {
// Initialize the recvPacket's header.
c.recvPacket.header = c.recvBuf[:12]
}
for {
i, err := c.conn.Read(c.recvBuf)
if err != nil {
return nil, err
}
if rlen < packetHeaderSize || (c.recvBuf[0] != 0x80 && c.recvBuf[0] != 0x90) {
if i < packetHeaderSize || (c.recvBuf[0] != 0x80 && c.recvBuf[0] != 0x90) {
continue
}
c.recvPacket.VersionFlags = c.recvBuf[0]
c.recvPacket.Type = c.recvBuf[1]
c.recvPacket.Sequence = binary.BigEndian.Uint16(c.recvBuf[2:4])
c.recvPacket.Timestamp = binary.BigEndian.Uint32(c.recvBuf[4:8])
c.recvPacket.SSRC = binary.BigEndian.Uint32(c.recvBuf[8:12])
// Copy the nonce to be read.
// TODO: once Go 1.17 is released, we can remove recvNonce and directly
// cast it as (*[packetHeaderSize]byte)(c.recvBuf).
copy(c.recvNonce[:], c.recvBuf[0:packetHeaderSize])
var ok bool
// Open (decrypt) the rest of the received bytes.
c.recvPacket.Opus, ok = secretbox.Open(
c.recvOpus[:0], c.recvBuf[packetHeaderSize:rlen], &c.recvNonce, &c.secret)
c.recvOpus[:0], c.recvBuf[packetHeaderSize:i], &c.recvNonce, &c.secret)
if !ok {
return nil, ErrDecryptionFailed
}
@ -267,7 +297,7 @@ func (c *Connection) ReadPacket() (*Packet, error) {
// exactly one header extension, with a format defined in Section
// 5.3.1.
//
isExtension := c.recvPacket.VersionFlags&0x10 == 0x10
isExtension := c.recvPacket.VersionFlags()&0x10 == 0x10
// We then check for whether or not the marker bit (9th bit) is set. The
// 9th bit is carried over to the second byte (Type), so we check its
@ -291,7 +321,7 @@ func (c *Connection) ReadPacket() (*Packet, error) {
// This implies that, when the marker bit is 1, the received packet is
// an RTCP packet and NOT an RTP packet; therefore, we must ignore the
// unknown sections, so we do a (NOT isMarker) check below.
isMarker := c.recvPacket.Type&0x80 != 0x0
isMarker := c.recvPacket.Type()&0x80 != 0x0
if isExtension && !isMarker {
extLen := binary.BigEndian.Uint16(c.recvPacket.Opus[2:4])

221
voice/udp/manager.go Normal file
View File

@ -0,0 +1,221 @@
package udp
import (
"context"
"net"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/utils/ws"
"github.com/pkg/errors"
)
// ErrManagerClosed is returned when a Manager that is already closed is dialed,
// written to or read from.
var ErrManagerClosed = errors.New("UDP connection manager is closed")
// ErrDialWhileUnpaused is returned if Dial is called on the Manager without
// pausing it first
var ErrDialWhileUnpaused = errors.New("dial is called while manager is not paused")
// Manager manages a UDP connection. It allows reconnecting. A Manager instance
// is thread-safe, meaning it can be used concurrently.
type Manager struct {
dialer *net.Dialer
stopMu sync.Mutex
stopConn chan struct{}
stopDial context.CancelFunc
// conn state
conn *Connection
connLock chan struct{}
frequency time.Duration
timeIncr uint32
}
// NewManager creates a new UDP connection manager with the defalt dialer.
func NewManager() *Manager {
return &Manager{
dialer: &Dialer,
stopConn: make(chan struct{}),
connLock: make(chan struct{}, 1),
}
}
// SetDialer sets the manager's dialer. Calling this function while the Manager
// is working will cause a panic. Only call this method directly after
// construction.
func (m *Manager) SetDialer(d *net.Dialer) {
select {
case m.connLock <- struct{}{}:
m.dialer = d
<-m.connLock
default:
panic("SetDialer called while Manager is working")
}
}
// Pause explicitly pauses the manager. It blocks until the Manager is paused or
// the context expires.
func (m *Manager) Pause(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case m.connLock <- struct{}{}:
return nil
}
}
// Close closes the current connection. If the connection is already closed,
// then nothing is done and ErrManagerClosed is returned. Close does not pause
// the connection; calls to Close while the user is using the connection will
// result in the user getting ErrManagerClosed.
func (m *Manager) Close() (err error) {
// Acquire the mutex first.
m.stopMu.Lock()
defer m.stopMu.Unlock()
// Cancel the dialing.
if m.stopDial != nil {
m.stopDial()
m.stopDial = nil
}
// Stop existing Manager users.
select {
case <-m.stopConn:
// m.stopConn already closed
ws.WSDebug("UDP manager already closed")
return ErrManagerClosed
default:
close(m.stopConn)
ws.WSDebug("UDP manager closed")
}
return nil
}
// IsClosed returns true if the connection is closed.
func (m *Manager) IsClosed() bool {
return m.acquireConn() == nil
}
// Continue unpauses and resumes the active user. If the manager has been
// successfully resumed, then true is returned, otherwise if it's already
// continued, then false is returned.
func (m *Manager) Continue() bool {
ws.WSDebug("UDP continued")
select {
case <-m.connLock:
return true
default:
return false
}
}
// Dial dials the internal connection to the given address and SSRC number. If
// the Manager is not Paused, then an error is returned. The caller must call
// Dial after Pause and before Unpause. If the Manager is already being dialed
// elsewhere, then ErrAlreadyDialing is returned.
func (m *Manager) Dial(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
select {
case m.connLock <- struct{}{}:
return nil, errors.New("Dial called on unpaused Manager")
default:
// ok
}
m.stopMu.Lock()
// Allow cancelling from another goroutine with this context.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
m.stopDial = cancel
m.stopMu.Unlock()
conn, err := DialConnectionCustom(ctx, m.dialer, addr, ssrc)
if err != nil {
// Unlock if we failed.
<-m.connLock
return nil, errors.Wrap(err, "failed to dial")
}
if m.frequency > 0 && m.timeIncr > 0 {
conn.ResetFrequency(m.frequency, m.timeIncr)
}
m.stopMu.Lock()
ws.WSDebug("setting UDP conn to one w/ gateway address", conn.GatewayIP)
m.conn = conn
m.stopDial = nil
m.stopConn = make(chan struct{})
m.stopMu.Unlock()
return conn, nil
}
// ResetFrequency sets the current connection and future connections' write
// frequency. Note that calling this method while Connection is being used in a
// different goroutine is not thread-safe.
func (m *Manager) ResetFrequency(frameDuration time.Duration, timeIncr uint32) {
m.connLock <- struct{}{}
defer func() { <-m.connLock }()
m.frequency = frameDuration
m.timeIncr = timeIncr
if m.conn != nil {
m.conn.ResetFrequency(frameDuration, timeIncr)
}
}
// ReadPacket reads the current packet. It blocks until a packet arrives or
// the Manager is closed.
func (m *Manager) ReadPacket() (p *Packet, err error) {
conn := m.acquireConn()
if conn == nil {
return nil, ErrManagerClosed
}
return conn.ReadPacket()
}
// Write writes to the current connection in the manager. It blocks if the
// connection is being re-established.
func (m *Manager) Write(b []byte) (n int, err error) {
conn := m.acquireConn()
if conn == nil {
return 0, ErrManagerClosed
}
return conn.Write(b)
}
// acquireConn acquires the current connection and releases the lock, returning
// the connection at that point in time. Nil is returned if Manager is closed.
func (m *Manager) acquireConn() *Connection {
// Acquire the pause lock first. We must only rely on the stopConn being
// closed once we have this.
m.connLock <- struct{}{}
defer func() { <-m.connLock }()
m.stopMu.Lock()
defer m.stopMu.Unlock()
select {
case <-m.stopConn:
ws.WSDebug("UDP acquisition got stopped conn")
return nil
default:
// ok
}
if m.conn == nil {
ws.WSDebug("UDP acquisition got nil conn")
}
return m.conn
}

View File

@ -6,9 +6,11 @@ package voice
import "github.com/diamondburned/arikawa/v3/gateway"
// Intents are the intents needed for voice to work properly.
const Intents = gateway.IntentGuilds | gateway.IntentGuildVoiceStates
// AddIntents adds the needed voice intents into gw. Bots should always call
// this before Open if voice is required.
func AddIntents(gw interface{ AddIntents(gateway.Intents) }) {
gw.AddIntents(gateway.IntentGuilds)
gw.AddIntents(gateway.IntentGuildVoiceStates)
gw.AddIntents(Intents)
}

View File

@ -21,18 +21,14 @@ import (
"github.com/diamondburned/arikawa/v3/utils/ws"
)
const (
// Version represents the current version of the Discord Gateway Gateway this package uses.
Version = "4"
)
// Version represents the current version of the Discord Gateway Gateway this package uses.
const Version = "4"
var (
ErrNoSessionID = errors.New("no sessionID was received")
ErrNoEndpoint = errors.New("no endpoint was received")
)
type Event = interface{}
// State contains state information of a voice gateway.
type State struct {
UserID discord.UserID // constant
@ -81,7 +77,7 @@ var DefaultGatewayOpts = ws.GatewayOpts{
// New creates a new voice gateway.
func New(state State) *Gateway {
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
var endpoint = "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version
endpoint := "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version
gw := ws.NewGateway(
ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), endpoint),
@ -117,13 +113,17 @@ func (g *Gateway) Send(ctx context.Context, data ws.Event) error {
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (g *Gateway) Speaking(ctx context.Context, flag SpeakingFlag) error {
g.mutex.RLock()
ssrc := g.ready.SSRC
ready := g.ready
g.mutex.RUnlock()
if ready == nil {
return errors.New("Speaking called before gateway was ready")
}
return g.gateway.Send(ctx, &SpeakingEvent{
Speaking: flag,
Delay: 0,
SSRC: ssrc,
SSRC: ready.SSRC,
})
}