mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-11-23 13:14:16 +00:00
Voice: Minor concurrency improvements
This commit adds multiple thread safe guards to do better the concurrency promises. However, it also omits completely those guarantees in certain places that don't make sense to call concurrently. This is mostly documented. This commit also adds a small piece of code to concurrently run things with the race detector.
This commit is contained in:
parent
1b8af1513e
commit
f4750292eb
|
|
@ -10,6 +10,7 @@ import (
|
|||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -17,6 +18,7 @@ import (
|
|||
"github.com/diamondburned/arikawa/v2/gateway"
|
||||
"github.com/diamondburned/arikawa/v2/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
|
|
@ -41,7 +43,7 @@ func TestIntegration(t *testing.T) {
|
|||
if err := v.Open(); err != nil {
|
||||
t.Fatal("Failed to connect:", err)
|
||||
}
|
||||
defer v.Close()
|
||||
t.Cleanup(func() { v.Close() })
|
||||
|
||||
// Validate the given voice channel.
|
||||
c, err := v.Channel(config.VoiceChID)
|
||||
|
|
@ -57,17 +59,19 @@ func TestIntegration(t *testing.T) {
|
|||
// Grab a timer to benchmark things.
|
||||
finish := timer()
|
||||
|
||||
// Join the voice channel.
|
||||
vs, err := v.JoinChannel(c.GuildID, c.ID, false, false)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to join channel:", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Println("Disconnecting from the voice channel.")
|
||||
if err := vs.Disconnect(); err != nil {
|
||||
t.Fatal("Failed to disconnect:", err)
|
||||
}
|
||||
}()
|
||||
// Join the voice channel concurrently.
|
||||
raceValue := raceMe(t, "failed to join voice channel", func() (interface{}, error) {
|
||||
return v.JoinChannel(c.GuildID, c.ID, false, false)
|
||||
})
|
||||
vs := raceValue.(*Session)
|
||||
|
||||
t.Cleanup(func() {
|
||||
log.Println("Disconnecting from the voice channel concurrently.")
|
||||
|
||||
raceMe(t, "failed to disconnect", func() (interface{}, error) {
|
||||
return nil, vs.Disconnect()
|
||||
})
|
||||
})
|
||||
|
||||
finish("joining the voice channel")
|
||||
|
||||
|
|
@ -76,32 +80,97 @@ func TestIntegration(t *testing.T) {
|
|||
finish("received voice speaking event")
|
||||
})
|
||||
|
||||
// Create a context and only cancel it AFTER we're done sending silence
|
||||
// frames.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Trigger speaking.
|
||||
if err := vs.Speaking(voicegateway.Microphone); err != nil {
|
||||
t.Fatal("Failed to start speaking:", err)
|
||||
t.Fatal("failed to start speaking:", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Println("Stopping speaking.") // sounds grammatically wrong
|
||||
t.Cleanup(func() {
|
||||
log.Println("Stopping speaking.")
|
||||
if err := vs.StopSpeaking(); err != nil {
|
||||
t.Fatal("Failed to stop speaking:", err)
|
||||
t.Fatal("failed to stop speaking:", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
finish("sending the speaking command")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := vs.UseContext(ctx); err != nil {
|
||||
t.Fatal("failed to set ctx into vs:", err)
|
||||
}
|
||||
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open nico.dca:", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var lenbuf [4]byte
|
||||
|
||||
// Copy the audio?
|
||||
nicoReadTo(t, vs)
|
||||
for {
|
||||
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
t.Fatal("failed to read:", err)
|
||||
}
|
||||
|
||||
// Read the integer
|
||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||
|
||||
// Copy the frame.
|
||||
if _, err := io.CopyN(vs, f, framelen); err != nil && err != io.EOF {
|
||||
t.Fatal("failed to write:", err)
|
||||
}
|
||||
}
|
||||
|
||||
finish("copying the audio")
|
||||
}
|
||||
|
||||
// raceMe intentionally calls fn multiple times in goroutines to ensure it's not
|
||||
// racy.
|
||||
func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interface{} {
|
||||
const n = 3 // run 3 times
|
||||
t.Helper()
|
||||
|
||||
// It is very ironic how this method itself is racy.
|
||||
|
||||
var wgr sync.WaitGroup
|
||||
var mut sync.Mutex
|
||||
var val interface{}
|
||||
var err error
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
wgr.Add(1)
|
||||
go func() {
|
||||
v, e := fn()
|
||||
|
||||
mut.Lock()
|
||||
val = v
|
||||
err = e
|
||||
mut.Unlock()
|
||||
|
||||
if e != nil {
|
||||
log.Println("Potential race test error:", e)
|
||||
}
|
||||
|
||||
wgr.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
wgr.Wait()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("Race test failed:", errors.Wrap(err, wrapErr))
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
type testConfig struct {
|
||||
BotToken string
|
||||
VoiceChID discord.ChannelID
|
||||
|
|
@ -133,30 +202,6 @@ func mustConfig(t *testing.T) testConfig {
|
|||
func nicoReadTo(t *testing.T, dst io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open nico.dca:", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var lenbuf [4]byte
|
||||
|
||||
for {
|
||||
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
t.Fatal("failed to read:", err)
|
||||
}
|
||||
|
||||
// Read the integer
|
||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||
|
||||
// Copy the frame.
|
||||
if _, err := io.CopyN(dst, f, framelen); err != nil && err != io.EOF {
|
||||
t.Fatal("failed to write:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// simple shitty benchmark thing
|
||||
|
|
|
|||
132
voice/session.go
132
voice/session.go
|
|
@ -2,14 +2,16 @@ package voice
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/diamondburned/arikawa/v2/utils/handler"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v2/utils/handler"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v2/discord"
|
||||
"github.com/diamondburned/arikawa/v2/gateway"
|
||||
"github.com/diamondburned/arikawa/v2/internal/handleloop"
|
||||
"github.com/diamondburned/arikawa/v2/internal/moreatomic"
|
||||
"github.com/diamondburned/arikawa/v2/session"
|
||||
"github.com/diamondburned/arikawa/v2/utils/wsutil"
|
||||
|
|
@ -21,6 +23,9 @@ const Protocol = "xsalsa20_poly1305"
|
|||
|
||||
var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE}
|
||||
|
||||
// ErrAlreadyConnecting is returned when the session is already connecting.
|
||||
var ErrAlreadyConnecting = errors.New("already connecting")
|
||||
|
||||
// WSTimeout is the duration to wait for a gateway operation including Session
|
||||
// to complete before erroring out. This only applies to functions that don't
|
||||
// take in a context already.
|
||||
|
|
@ -28,56 +33,49 @@ var WSTimeout = 10 * time.Second
|
|||
|
||||
type Session struct {
|
||||
*handler.Handler
|
||||
|
||||
session *session.Session
|
||||
state voicegateway.State
|
||||
|
||||
ErrorLog func(err error)
|
||||
|
||||
// Filled by events.
|
||||
// sessionID string
|
||||
// token string
|
||||
// endpoint string
|
||||
session *session.Session
|
||||
looper *handleloop.Loop
|
||||
|
||||
// joining determines the behavior of incoming event callbacks (Update).
|
||||
// If this is true, incoming events will just send into Updated channels. If
|
||||
// false, events will trigger a reconnection.
|
||||
joining moreatomic.Bool
|
||||
incoming chan struct{} // used only when joining == true
|
||||
hstop chan struct{}
|
||||
|
||||
mut sync.RWMutex
|
||||
|
||||
state voicegateway.State // guarded except UserID
|
||||
|
||||
// TODO: expose getters mutex-guarded.
|
||||
gateway *voicegateway.Gateway
|
||||
voiceUDP *udp.Connection
|
||||
|
||||
muted bool
|
||||
deafened bool
|
||||
speaking bool
|
||||
}
|
||||
|
||||
func NewSession(ses *session.Session, userID discord.UserID) *Session {
|
||||
handler := handler.New()
|
||||
looper := handleloop.NewLoop(handler)
|
||||
|
||||
return &Session{
|
||||
Handler: handler.New(),
|
||||
Handler: handler,
|
||||
looper: looper,
|
||||
session: ses,
|
||||
state: voicegateway.State{
|
||||
UserID: userID,
|
||||
},
|
||||
ErrorLog: func(err error) {},
|
||||
incoming: make(chan struct{}, 2),
|
||||
hstop: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||
if s.state.GuildID != ev.GuildID {
|
||||
// Not our state.
|
||||
return
|
||||
}
|
||||
|
||||
// If this is true, then mutex is acquired already.
|
||||
if s.joining.Get() {
|
||||
if s.state.GuildID != ev.GuildID {
|
||||
return
|
||||
}
|
||||
|
||||
s.state.Endpoint = ev.Endpoint
|
||||
s.state.Token = ev.Token
|
||||
|
||||
|
|
@ -85,10 +83,15 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
|||
return
|
||||
}
|
||||
|
||||
// Reconnect.
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
if s.state.GuildID != ev.GuildID {
|
||||
return
|
||||
}
|
||||
|
||||
// Reconnect.
|
||||
|
||||
s.state.Endpoint = ev.Endpoint
|
||||
s.state.Token = ev.Token
|
||||
|
||||
|
|
@ -101,7 +104,7 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
|||
}
|
||||
|
||||
func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
||||
if s.state.UserID != ev.UserID {
|
||||
if s.state.UserID != ev.UserID { // constant so no mutex
|
||||
// Not our state.
|
||||
return
|
||||
}
|
||||
|
|
@ -125,9 +128,16 @@ func (s *Session) JoinChannel(
|
|||
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
|
||||
}
|
||||
|
||||
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
|
||||
// directly, but rather Voice's. If this method is called concurrently, an error
|
||||
// will be returned.
|
||||
func (s *Session) JoinChannelCtx(
|
||||
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
|
||||
if s.joining.Get() {
|
||||
return ErrAlreadyConnecting
|
||||
}
|
||||
|
||||
// Acquire the mutex during join, locking during IO as well.
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
|
@ -143,10 +153,6 @@ func (s *Session) JoinChannelCtx(
|
|||
s.state.ChannelID = cID
|
||||
s.state.GuildID = gID
|
||||
|
||||
s.muted = muted
|
||||
s.deafened = deafened
|
||||
s.speaking = false
|
||||
|
||||
// Ensure that if `cID` is zero that it passes null to the update event.
|
||||
channelID := discord.NullChannelID
|
||||
if cID.IsValid() {
|
||||
|
|
@ -192,13 +198,10 @@ func (s *Session) waitForIncoming(ctx context.Context, n int) error {
|
|||
// reconnect uses the current state to reconnect to a new gateway and UDP
|
||||
// connection.
|
||||
func (s *Session) reconnectCtx(ctx context.Context) (err error) {
|
||||
wsutil.WSDebug("Sending stop handle")
|
||||
// Stop the existing handler
|
||||
close(s.hstop)
|
||||
wsutil.WSDebug("Sending stop handle.")
|
||||
s.looper.Stop()
|
||||
|
||||
s.hstop = make(chan struct{})
|
||||
|
||||
wsutil.WSDebug("Start gateway")
|
||||
wsutil.WSDebug("Start gateway.")
|
||||
s.gateway = voicegateway.New(s.state)
|
||||
|
||||
// Open the voice gateway. The function will block until Ready is received.
|
||||
|
|
@ -207,7 +210,7 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// Start the handler dispatching
|
||||
go s.startHandler()
|
||||
s.looper.Start(s.gateway.Events)
|
||||
|
||||
// Get the Ready event.
|
||||
voiceReady := s.gateway.Ready()
|
||||
|
|
@ -237,29 +240,36 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// Speaking tells Discord we're speaking. This calls
|
||||
// (*voicegateway.Gateway).Speaking().
|
||||
// (*voicegateway.Gateway).Speaking(). This method should not be called
|
||||
// concurrently.
|
||||
func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
|
||||
// TODO: maybe we don't need to mutex protect IO.
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
gateway := s.gateway
|
||||
s.mut.RUnlock()
|
||||
|
||||
return s.gateway.Speaking(flag)
|
||||
return gateway.Speaking(flag)
|
||||
}
|
||||
|
||||
// StopSpeaking sends 5 frames of silence over the UDP connection. Since the UDP
|
||||
// connection itself is not concurrently safe, this method should not be called
|
||||
// as such.
|
||||
func (s *Session) StopSpeaking() error {
|
||||
udp := s.VoiceUDPConn()
|
||||
|
||||
// Send 5 frames of silence.
|
||||
for i := 0; i < 5; i++ {
|
||||
if _, err := s.Write(OpusSilence[:]); err != nil {
|
||||
if _, err := udp.Write(OpusSilence[:]); err != nil {
|
||||
return errors.Wrapf(err, "failed to send frame %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseContext tells the UDP voice connection to write with the given mutex.
|
||||
func (s *Session) UseContext(ctx context.Context) error {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
if s.voiceUDP == nil {
|
||||
return ErrCannotSend
|
||||
|
|
@ -268,21 +278,32 @@ func (s *Session) UseContext(ctx context.Context) error {
|
|||
return s.voiceUDP.UseContext(ctx)
|
||||
}
|
||||
|
||||
// Write writes into the UDP voice connection WITHOUT a timeout.
|
||||
// VoiceUDPConn gets a voice UDP connection. The caller could use this method to
|
||||
// circumvent the rapid mutex-read-lock acquire inside Write.
|
||||
func (s *Session) VoiceUDPConn() *udp.Connection {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
return s.voiceUDP
|
||||
}
|
||||
|
||||
// Write writes into the UDP voice connection WITHOUT a timeout. Refer to
|
||||
// WriteCtx for more information.
|
||||
func (s *Session) Write(b []byte) (int, error) {
|
||||
return s.WriteCtx(context.Background(), b)
|
||||
}
|
||||
|
||||
// WriteCtx writes into the UDP voice connection with a context for timeout.
|
||||
// This method is thread safe as far as calling other methods of Session goes;
|
||||
// HOWEVER it is not thread safe to call Write itself concurrently.
|
||||
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
voiceUDP := s.VoiceUDPConn()
|
||||
|
||||
if s.voiceUDP == nil {
|
||||
if voiceUDP == nil {
|
||||
return 0, ErrCannotSend
|
||||
}
|
||||
|
||||
return s.voiceUDP.WriteCtx(ctx, b)
|
||||
return voiceUDP.WriteCtx(ctx, b)
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect() error {
|
||||
|
|
@ -301,6 +322,8 @@ func (s *Session) DisconnectCtx(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
s.looper.Stop()
|
||||
|
||||
// Notify Discord that we're leaving. This will send a
|
||||
// VoiceStateUpdateEvent, in which our handler will promptly remove the
|
||||
// session from the map.
|
||||
|
|
@ -319,10 +342,7 @@ func (s *Session) DisconnectCtx(ctx context.Context) error {
|
|||
|
||||
// close ensures everything is closed. It does not acquire the mutex.
|
||||
func (s *Session) ensureClosed() {
|
||||
// If we're already closed.
|
||||
if s.gateway == nil && s.voiceUDP == nil {
|
||||
return
|
||||
}
|
||||
s.looper.Stop()
|
||||
|
||||
// Disconnect the UDP connection.
|
||||
if s.voiceUDP != nil {
|
||||
|
|
@ -338,15 +358,3 @@ func (s *Session) ensureClosed() {
|
|||
s.gateway = nil
|
||||
}
|
||||
}
|
||||
|
||||
// startHandler processes events from the gateway into event handlers.
|
||||
func (s *Session) startHandler() {
|
||||
for {
|
||||
select {
|
||||
case <-s.hstop:
|
||||
return
|
||||
case ev := <-s.gateway.Events:
|
||||
s.Call(ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -126,21 +126,25 @@ func (c *Connection) Close() error {
|
|||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Write sends bytes into the voice UDP connection.
|
||||
// Write sends bytes into the voice UDP connection using the preset context.
|
||||
func (c *Connection) Write(b []byte) (int, error) {
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
// WriteCtx sends bytes into the voice UDP connection with a timeout.
|
||||
// WriteCtx sends bytes into the voice UDP connection with a timeout using the
|
||||
// given context. It ignores the context inside the connection, but will restore
|
||||
// the deadline after this call is done.
|
||||
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
if err := c.useContext(ctx); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to use context")
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
ctx := c.context
|
||||
defer c.useContext(ctx) // restore after we're done
|
||||
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
// write is thread-unsafe.
|
||||
func (c *Connection) write(b []byte) (int, error) {
|
||||
// Write a new sequence.
|
||||
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
|
||||
|
|
|
|||
Loading…
Reference in a new issue