1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-11-23 13:14:16 +00:00

Voice: Minor concurrency improvements

This commit adds multiple thread safe guards to do better the
concurrency promises. However, it also omits completely those guarantees
in certain places that don't make sense to call concurrently. This is
mostly documented.

This commit also adds a small piece of code to concurrently run things
with the race detector.
This commit is contained in:
diamondburned 2020-11-17 12:09:15 -08:00 committed by diamondburned
parent 1b8af1513e
commit f4750292eb
3 changed files with 169 additions and 112 deletions

View file

@ -10,6 +10,7 @@ import (
"os"
"runtime"
"strconv"
"sync"
"testing"
"time"
@ -17,6 +18,7 @@ import (
"github.com/diamondburned/arikawa/v2/gateway"
"github.com/diamondburned/arikawa/v2/utils/wsutil"
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
"github.com/pkg/errors"
)
func TestIntegration(t *testing.T) {
@ -41,7 +43,7 @@ func TestIntegration(t *testing.T) {
if err := v.Open(); err != nil {
t.Fatal("Failed to connect:", err)
}
defer v.Close()
t.Cleanup(func() { v.Close() })
// Validate the given voice channel.
c, err := v.Channel(config.VoiceChID)
@ -57,17 +59,19 @@ func TestIntegration(t *testing.T) {
// Grab a timer to benchmark things.
finish := timer()
// Join the voice channel.
vs, err := v.JoinChannel(c.GuildID, c.ID, false, false)
if err != nil {
t.Fatal("Failed to join channel:", err)
}
defer func() {
log.Println("Disconnecting from the voice channel.")
if err := vs.Disconnect(); err != nil {
t.Fatal("Failed to disconnect:", err)
}
}()
// Join the voice channel concurrently.
raceValue := raceMe(t, "failed to join voice channel", func() (interface{}, error) {
return v.JoinChannel(c.GuildID, c.ID, false, false)
})
vs := raceValue.(*Session)
t.Cleanup(func() {
log.Println("Disconnecting from the voice channel concurrently.")
raceMe(t, "failed to disconnect", func() (interface{}, error) {
return nil, vs.Disconnect()
})
})
finish("joining the voice channel")
@ -76,32 +80,97 @@ func TestIntegration(t *testing.T) {
finish("received voice speaking event")
})
// Create a context and only cancel it AFTER we're done sending silence
// frames.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
// Trigger speaking.
if err := vs.Speaking(voicegateway.Microphone); err != nil {
t.Fatal("Failed to start speaking:", err)
t.Fatal("failed to start speaking:", err)
}
defer func() {
log.Println("Stopping speaking.") // sounds grammatically wrong
t.Cleanup(func() {
log.Println("Stopping speaking.")
if err := vs.StopSpeaking(); err != nil {
t.Fatal("Failed to stop speaking:", err)
t.Fatal("failed to stop speaking:", err)
}
}()
})
finish("sending the speaking command")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := vs.UseContext(ctx); err != nil {
t.Fatal("failed to set ctx into vs:", err)
}
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("Failed to open nico.dca:", err)
}
defer f.Close()
var lenbuf [4]byte
// Copy the audio?
nicoReadTo(t, vs)
for {
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
if err == io.EOF {
break
}
t.Fatal("failed to read:", err)
}
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(vs, f, framelen); err != nil && err != io.EOF {
t.Fatal("failed to write:", err)
}
}
finish("copying the audio")
}
// raceMe intentionally calls fn multiple times in goroutines to ensure it's not
// racy.
func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interface{} {
const n = 3 // run 3 times
t.Helper()
// It is very ironic how this method itself is racy.
var wgr sync.WaitGroup
var mut sync.Mutex
var val interface{}
var err error
for i := 0; i < n; i++ {
wgr.Add(1)
go func() {
v, e := fn()
mut.Lock()
val = v
err = e
mut.Unlock()
if e != nil {
log.Println("Potential race test error:", e)
}
wgr.Done()
}()
}
wgr.Wait()
if err != nil {
t.Fatal("Race test failed:", errors.Wrap(err, wrapErr))
}
return val
}
type testConfig struct {
BotToken string
VoiceChID discord.ChannelID
@ -133,30 +202,6 @@ func mustConfig(t *testing.T) testConfig {
func nicoReadTo(t *testing.T, dst io.Writer) {
t.Helper()
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("Failed to open nico.dca:", err)
}
defer f.Close()
var lenbuf [4]byte
for {
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
if err == io.EOF {
break
}
t.Fatal("failed to read:", err)
}
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(dst, f, framelen); err != nil && err != io.EOF {
t.Fatal("failed to write:", err)
}
}
}
// simple shitty benchmark thing

View file

@ -2,14 +2,16 @@ package voice
import (
"context"
"github.com/diamondburned/arikawa/v2/utils/handler"
"sync"
"time"
"github.com/diamondburned/arikawa/v2/utils/handler"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/gateway"
"github.com/diamondburned/arikawa/v2/internal/handleloop"
"github.com/diamondburned/arikawa/v2/internal/moreatomic"
"github.com/diamondburned/arikawa/v2/session"
"github.com/diamondburned/arikawa/v2/utils/wsutil"
@ -21,6 +23,9 @@ const Protocol = "xsalsa20_poly1305"
var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE}
// ErrAlreadyConnecting is returned when the session is already connecting.
var ErrAlreadyConnecting = errors.New("already connecting")
// WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't
// take in a context already.
@ -28,56 +33,49 @@ var WSTimeout = 10 * time.Second
type Session struct {
*handler.Handler
session *session.Session
state voicegateway.State
ErrorLog func(err error)
// Filled by events.
// sessionID string
// token string
// endpoint string
session *session.Session
looper *handleloop.Loop
// joining determines the behavior of incoming event callbacks (Update).
// If this is true, incoming events will just send into Updated channels. If
// false, events will trigger a reconnection.
joining moreatomic.Bool
incoming chan struct{} // used only when joining == true
hstop chan struct{}
mut sync.RWMutex
state voicegateway.State // guarded except UserID
// TODO: expose getters mutex-guarded.
gateway *voicegateway.Gateway
voiceUDP *udp.Connection
muted bool
deafened bool
speaking bool
}
func NewSession(ses *session.Session, userID discord.UserID) *Session {
handler := handler.New()
looper := handleloop.NewLoop(handler)
return &Session{
Handler: handler.New(),
Handler: handler,
looper: looper,
session: ses,
state: voicegateway.State{
UserID: userID,
},
ErrorLog: func(err error) {},
incoming: make(chan struct{}, 2),
hstop: make(chan struct{}),
}
}
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
if s.state.GuildID != ev.GuildID {
// Not our state.
return
}
// If this is true, then mutex is acquired already.
if s.joining.Get() {
if s.state.GuildID != ev.GuildID {
return
}
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
@ -85,10 +83,15 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
return
}
// Reconnect.
s.mut.Lock()
defer s.mut.Unlock()
if s.state.GuildID != ev.GuildID {
return
}
// Reconnect.
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
@ -101,7 +104,7 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
}
func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
if s.state.UserID != ev.UserID {
if s.state.UserID != ev.UserID { // constant so no mutex
// Not our state.
return
}
@ -125,9 +128,16 @@ func (s *Session) JoinChannel(
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
}
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
// directly, but rather Voice's. If this method is called concurrently, an error
// will be returned.
func (s *Session) JoinChannelCtx(
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
if s.joining.Get() {
return ErrAlreadyConnecting
}
// Acquire the mutex during join, locking during IO as well.
s.mut.Lock()
defer s.mut.Unlock()
@ -143,10 +153,6 @@ func (s *Session) JoinChannelCtx(
s.state.ChannelID = cID
s.state.GuildID = gID
s.muted = muted
s.deafened = deafened
s.speaking = false
// Ensure that if `cID` is zero that it passes null to the update event.
channelID := discord.NullChannelID
if cID.IsValid() {
@ -192,13 +198,10 @@ func (s *Session) waitForIncoming(ctx context.Context, n int) error {
// reconnect uses the current state to reconnect to a new gateway and UDP
// connection.
func (s *Session) reconnectCtx(ctx context.Context) (err error) {
wsutil.WSDebug("Sending stop handle")
// Stop the existing handler
close(s.hstop)
wsutil.WSDebug("Sending stop handle.")
s.looper.Stop()
s.hstop = make(chan struct{})
wsutil.WSDebug("Start gateway")
wsutil.WSDebug("Start gateway.")
s.gateway = voicegateway.New(s.state)
// Open the voice gateway. The function will block until Ready is received.
@ -207,7 +210,7 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
}
// Start the handler dispatching
go s.startHandler()
s.looper.Start(s.gateway.Events)
// Get the Ready event.
voiceReady := s.gateway.Ready()
@ -237,29 +240,36 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
}
// Speaking tells Discord we're speaking. This calls
// (*voicegateway.Gateway).Speaking().
// (*voicegateway.Gateway).Speaking(). This method should not be called
// concurrently.
func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
// TODO: maybe we don't need to mutex protect IO.
s.mut.RLock()
defer s.mut.RUnlock()
gateway := s.gateway
s.mut.RUnlock()
return s.gateway.Speaking(flag)
return gateway.Speaking(flag)
}
// StopSpeaking sends 5 frames of silence over the UDP connection. Since the UDP
// connection itself is not concurrently safe, this method should not be called
// as such.
func (s *Session) StopSpeaking() error {
udp := s.VoiceUDPConn()
// Send 5 frames of silence.
for i := 0; i < 5; i++ {
if _, err := s.Write(OpusSilence[:]); err != nil {
if _, err := udp.Write(OpusSilence[:]); err != nil {
return errors.Wrapf(err, "failed to send frame %d", i)
}
}
return nil
}
// UseContext tells the UDP voice connection to write with the given mutex.
func (s *Session) UseContext(ctx context.Context) error {
s.mut.RLock()
defer s.mut.RUnlock()
s.mut.Lock()
defer s.mut.Unlock()
if s.voiceUDP == nil {
return ErrCannotSend
@ -268,21 +278,32 @@ func (s *Session) UseContext(ctx context.Context) error {
return s.voiceUDP.UseContext(ctx)
}
// Write writes into the UDP voice connection WITHOUT a timeout.
// VoiceUDPConn gets a voice UDP connection. The caller could use this method to
// circumvent the rapid mutex-read-lock acquire inside Write.
func (s *Session) VoiceUDPConn() *udp.Connection {
s.mut.RLock()
defer s.mut.RUnlock()
return s.voiceUDP
}
// Write writes into the UDP voice connection WITHOUT a timeout. Refer to
// WriteCtx for more information.
func (s *Session) Write(b []byte) (int, error) {
return s.WriteCtx(context.Background(), b)
}
// WriteCtx writes into the UDP voice connection with a context for timeout.
// This method is thread safe as far as calling other methods of Session goes;
// HOWEVER it is not thread safe to call Write itself concurrently.
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
s.mut.RLock()
defer s.mut.RUnlock()
voiceUDP := s.VoiceUDPConn()
if s.voiceUDP == nil {
if voiceUDP == nil {
return 0, ErrCannotSend
}
return s.voiceUDP.WriteCtx(ctx, b)
return voiceUDP.WriteCtx(ctx, b)
}
func (s *Session) Disconnect() error {
@ -301,6 +322,8 @@ func (s *Session) DisconnectCtx(ctx context.Context) error {
return nil
}
s.looper.Stop()
// Notify Discord that we're leaving. This will send a
// VoiceStateUpdateEvent, in which our handler will promptly remove the
// session from the map.
@ -319,10 +342,7 @@ func (s *Session) DisconnectCtx(ctx context.Context) error {
// close ensures everything is closed. It does not acquire the mutex.
func (s *Session) ensureClosed() {
// If we're already closed.
if s.gateway == nil && s.voiceUDP == nil {
return
}
s.looper.Stop()
// Disconnect the UDP connection.
if s.voiceUDP != nil {
@ -338,15 +358,3 @@ func (s *Session) ensureClosed() {
s.gateway = nil
}
}
// startHandler processes events from the gateway into event handlers.
func (s *Session) startHandler() {
for {
select {
case <-s.hstop:
return
case ev := <-s.gateway.Events:
s.Call(ev)
}
}
}

View file

@ -126,21 +126,25 @@ func (c *Connection) Close() error {
return c.conn.Close()
}
// Write sends bytes into the voice UDP connection.
// Write sends bytes into the voice UDP connection using the preset context.
func (c *Connection) Write(b []byte) (int, error) {
return c.write(b)
}
// WriteCtx sends bytes into the voice UDP connection with a timeout.
// WriteCtx sends bytes into the voice UDP connection with a timeout using the
// given context. It ignores the context inside the connection, but will restore
// the deadline after this call is done.
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
if err := c.useContext(ctx); err != nil {
return 0, errors.Wrap(err, "failed to use context")
if deadline, ok := ctx.Deadline(); ok {
ctx := c.context
defer c.useContext(ctx) // restore after we're done
c.conn.SetWriteDeadline(deadline)
}
return c.write(b)
}
// write is thread-unsafe.
func (c *Connection) write(b []byte) (int, error) {
// Write a new sequence.
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)