voice: Fix init inconsistencies
This commit fixes a few subtle bugs in the voice package. It slightly refactors the connecting and reconnecting of voice sessions.
This commit is contained in:
parent
accb2fc52b
commit
123f8bc41f
216
voice/session.go
216
voice/session.go
|
@ -41,22 +41,26 @@ type Session struct {
|
||||||
ErrorLog func(err error)
|
ErrorLog func(err error)
|
||||||
|
|
||||||
session *session.Session
|
session *session.Session
|
||||||
cancels []func()
|
|
||||||
looper *handleloop.Loop
|
looper *handleloop.Loop
|
||||||
|
detach func()
|
||||||
|
|
||||||
|
mut sync.RWMutex
|
||||||
|
state voicegateway.State // guarded except UserID
|
||||||
|
// TODO: expose getters mutex-guarded.
|
||||||
|
gateway *voicegateway.Gateway
|
||||||
|
voiceUDP *udp.Connection
|
||||||
|
// end of mutex
|
||||||
|
|
||||||
|
WSTimeout time.Duration // global WSTimeout
|
||||||
|
WSMaxRetry int // 2
|
||||||
|
WSRetryDelay time.Duration // 2s
|
||||||
|
WSWaitDuration time.Duration // 5s
|
||||||
|
|
||||||
// joining determines the behavior of incoming event callbacks (Update).
|
// joining determines the behavior of incoming event callbacks (Update).
|
||||||
// If this is true, incoming events will just send into Updated channels. If
|
// If this is true, incoming events will just send into Updated channels. If
|
||||||
// false, events will trigger a reconnection.
|
// false, events will trigger a reconnection.
|
||||||
joining moreatomic.Bool
|
joining moreatomic.Bool
|
||||||
incoming chan struct{} // used only when joining == true
|
connected bool
|
||||||
|
|
||||||
mut sync.RWMutex
|
|
||||||
|
|
||||||
state voicegateway.State // guarded except UserID
|
|
||||||
|
|
||||||
// TODO: expose getters mutex-guarded.
|
|
||||||
gateway *voicegateway.Gateway
|
|
||||||
voiceUDP *udp.Connection
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSession creates a new voice session for the current user.
|
// NewSession creates a new voice session for the current user.
|
||||||
|
@ -81,35 +85,27 @@ func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
|
||||||
state: voicegateway.State{
|
state: voicegateway.State{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
},
|
},
|
||||||
ErrorLog: func(err error) {},
|
ErrorLog: func(err error) {},
|
||||||
incoming: make(chan struct{}, 2),
|
WSTimeout: WSTimeout,
|
||||||
}
|
WSMaxRetry: 2,
|
||||||
session.cancels = []func(){
|
WSRetryDelay: 2 * time.Second,
|
||||||
ses.AddHandler(session.updateServer),
|
WSWaitDuration: 5 * time.Second,
|
||||||
ses.AddHandler(session.updateState),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateServer is specifically used to monitor for reconnects.
|
||||||
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
|
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||||
// If this is true, then mutex is acquired already.
|
|
||||||
if s.joining.Get() {
|
if s.joining.Get() {
|
||||||
if s.state.GuildID != ev.GuildID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.state.Endpoint = ev.Endpoint
|
|
||||||
s.state.Token = ev.Token
|
|
||||||
|
|
||||||
s.incoming <- struct{}{}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mut.Lock()
|
s.mut.Lock()
|
||||||
defer s.mut.Unlock()
|
defer s.mut.Unlock()
|
||||||
|
|
||||||
if s.state.GuildID != ev.GuildID {
|
// Ignore if we haven't connected yet or we're still joining.
|
||||||
|
if !s.connected || s.state.GuildID != ev.GuildID {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,26 +122,8 @@ func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
|
// JoinChannel joins a voice channel with the default WS timeout. See
|
||||||
if s.state.UserID != ev.UserID { // constant so no mutex
|
// JoinChannelCtx for more information.
|
||||||
// Not our state.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this is true, then mutex is acquired already.
|
|
||||||
if s.joining.Get() {
|
|
||||||
if s.state.GuildID != ev.GuildID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.state.SessionID = ev.SessionID
|
|
||||||
s.state.ChannelID = ev.ChannelID
|
|
||||||
|
|
||||||
s.incoming <- struct{}{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -153,30 +131,33 @@ func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute,
|
||||||
return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
|
return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type waitEventChs struct {
|
||||||
|
serverUpdate chan *gateway.VoiceServerUpdateEvent
|
||||||
|
stateUpdate chan *gateway.VoiceStateUpdateEvent
|
||||||
|
}
|
||||||
|
|
||||||
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
|
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
|
||||||
// directly, but rather Voice's. This method shouldn't ever be called
|
// directly, but rather Voice's. This method shouldn't ever be called
|
||||||
// concurrently.
|
// concurrently.
|
||||||
func (s *Session) JoinChannelCtx(
|
func (s *Session) JoinChannelCtx(
|
||||||
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
||||||
|
|
||||||
if s.joining.Get() {
|
|
||||||
return ErrAlreadyConnecting
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire the mutex during join, locking during IO as well.
|
|
||||||
s.mut.Lock()
|
s.mut.Lock()
|
||||||
defer s.mut.Unlock()
|
defer s.mut.Unlock()
|
||||||
|
|
||||||
// Set that we're joining.
|
// Error out if we're already joining. JoinChannel shouldn't be called
|
||||||
s.joining.Set(true)
|
// concurrently.
|
||||||
defer s.joining.Set(false) // reset when done
|
if s.joining.Get() {
|
||||||
|
return errors.New("JoinChannel working elsewhere")
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure gateway and voiceUDP are already closed.
|
s.joining.Set(true)
|
||||||
s.ensureClosed()
|
defer s.joining.Set(false)
|
||||||
|
|
||||||
// Set the state.
|
// Set the state.
|
||||||
s.state.ChannelID = cID
|
s.state.ChannelID = cID
|
||||||
s.state.GuildID = gID
|
s.state.GuildID = gID
|
||||||
|
s.detach = s.session.AddHandler(s.updateServer)
|
||||||
|
|
||||||
// Ensure that if `cID` is zero that it passes null to the update event.
|
// Ensure that if `cID` is zero that it passes null to the update event.
|
||||||
channelID := discord.NullChannelID
|
channelID := discord.NullChannelID
|
||||||
|
@ -184,34 +165,121 @@ func (s *Session) JoinChannelCtx(
|
||||||
channelID = cID
|
channelID = cID
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
|
chs := waitEventChs{
|
||||||
// Send a Voice State Update event to the gateway.
|
serverUpdate: make(chan *gateway.VoiceServerUpdateEvent),
|
||||||
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
|
stateUpdate: make(chan *gateway.VoiceStateUpdateEvent),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the handlers.
|
||||||
|
cancels := []func(){
|
||||||
|
s.session.AddHandler(chs.serverUpdate),
|
||||||
|
s.session.AddHandler(chs.stateUpdate),
|
||||||
|
}
|
||||||
|
// Disconnects the handlers once the function exits.
|
||||||
|
defer func() {
|
||||||
|
for _, cancel := range cancels {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Ensure gateway and voiceUDP are already closed.
|
||||||
|
s.ensureClosed()
|
||||||
|
|
||||||
|
data := gateway.UpdateVoiceStateData{
|
||||||
GuildID: gID,
|
GuildID: gID,
|
||||||
ChannelID: channelID,
|
ChannelID: channelID,
|
||||||
SelfMute: mute,
|
SelfMute: mute,
|
||||||
SelfDeaf: deaf,
|
SelfDeaf: deaf,
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "failed to send Voice State Update event")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for 2 replies. The above command should reply with these 2 events.
|
var err error
|
||||||
if err := s.waitForIncoming(ctx, 2); err != nil {
|
|
||||||
return errors.Wrap(err, "failed to wait for needed gateway events")
|
var timer *time.Timer
|
||||||
|
|
||||||
|
// Retry 3 times maximum.
|
||||||
|
for i := 0; i < s.WSMaxRetry; i++ {
|
||||||
|
if err = s.askDiscord(ctx, data, chs); err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is the first attempt and the context timed out, it's
|
||||||
|
// probably the context that's waiting for gateway events. Retry the
|
||||||
|
// loop.
|
||||||
|
if i == 0 && errors.Is(err, ctx.Err()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if timer == nil {
|
||||||
|
// Set up a timer.
|
||||||
|
timer = time.NewTimer(s.WSRetryDelay)
|
||||||
|
defer timer.Stop()
|
||||||
|
} else {
|
||||||
|
timer.Reset(s.WSRetryDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// These 2 methods should've updated s.state before sending into these
|
// These 2 methods should've updated s.state before sending into these
|
||||||
// channels. Since s.state is already filled, we can go ahead and connect.
|
// channels. Since s.state is already filled, we can go ahead and connect.
|
||||||
|
|
||||||
|
// Mark the session as connected and move on. This allows one of the
|
||||||
|
// connected handlers to reconnect on its own.
|
||||||
|
s.connected = true
|
||||||
|
|
||||||
return s.reconnectCtx(ctx)
|
return s.reconnectCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) waitForIncoming(ctx context.Context, n int) error {
|
func (s *Session) askDiscord(
|
||||||
for i := 0; i < n; i++ {
|
ctx context.Context, data gateway.UpdateVoiceStateData, chs waitEventChs) error {
|
||||||
|
|
||||||
|
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
|
||||||
|
// Send a Voice State Update event to the gateway.
|
||||||
|
if err := s.session.Gateway.UpdateVoiceStateCtx(ctx, data); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to send Voice State Update event")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for 2 replies. The above command should reply with these 2 events.
|
||||||
|
if err := s.waitForIncoming(ctx, chs); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to wait for needed gateway events")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, s.WSWaitDuration)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
state := false
|
||||||
|
// server is true when we already have the token and endpoint, meaning that
|
||||||
|
// we don't have to wait for another such event.
|
||||||
|
server := s.state.Token != "" && s.state.Endpoint != ""
|
||||||
|
|
||||||
|
// Loop until timeout or until we have all the information that we need.
|
||||||
|
for !(server && state) {
|
||||||
select {
|
select {
|
||||||
case <-s.incoming:
|
case ev := <-chs.serverUpdate:
|
||||||
continue
|
if s.state.GuildID != ev.GuildID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.state.Endpoint = ev.Endpoint
|
||||||
|
s.state.Token = ev.Token
|
||||||
|
server = true
|
||||||
|
|
||||||
|
case ev := <-chs.stateUpdate:
|
||||||
|
if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.state.SessionID = ev.SessionID
|
||||||
|
s.state.ChannelID = ev.ChannelID
|
||||||
|
state = true
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
@ -328,6 +396,14 @@ func (s *Session) LeaveCtx(ctx context.Context) error {
|
||||||
s.mut.Lock()
|
s.mut.Lock()
|
||||||
defer s.mut.Unlock()
|
defer s.mut.Unlock()
|
||||||
|
|
||||||
|
s.connected = false
|
||||||
|
|
||||||
|
// Unbind the handlers.
|
||||||
|
if s.detach != nil {
|
||||||
|
s.detach()
|
||||||
|
s.detach = nil
|
||||||
|
}
|
||||||
|
|
||||||
// If we're already closed.
|
// If we're already closed.
|
||||||
if s.gateway == nil && s.voiceUDP == nil {
|
if s.gateway == nil && s.voiceUDP == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -31,7 +31,7 @@ func TestIntegration(t *testing.T) {
|
||||||
|
|
||||||
s, err := state.New("Bot " + config.BotToken)
|
s, err := state.New("Bot " + config.BotToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to create a new state:", err)
|
t.Fatal("failed to create a new state:", err)
|
||||||
}
|
}
|
||||||
AddIntents(s.Gateway)
|
AddIntents(s.Gateway)
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ func TestIntegration(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := s.Open(ctx); err != nil {
|
if err := s.Open(ctx); err != nil {
|
||||||
t.Fatal("Failed to connect:", err)
|
t.Fatal("failed to connect:", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -49,17 +49,26 @@ func TestIntegration(t *testing.T) {
|
||||||
// Validate the given voice channel.
|
// Validate the given voice channel.
|
||||||
c, err := s.Channel(config.VoiceChID)
|
c, err := s.Channel(config.VoiceChID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to get channel:", err)
|
t.Fatal("failed to get channel:", err)
|
||||||
}
|
}
|
||||||
if c.Type != discord.GuildVoice {
|
if c.Type != discord.GuildVoice {
|
||||||
t.Fatal("Channel isn't a guild voice channel.")
|
t.Fatal("channel isn't a guild voice channel.")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("The voice channel's name is", c.Name)
|
log.Println("The voice channel's name is", c.Name)
|
||||||
|
|
||||||
|
testVoice(t, s, c)
|
||||||
|
|
||||||
|
// BUG: Discord doesn't want to send the second VoiceServerUpdateEvent. I
|
||||||
|
// have no idea why.
|
||||||
|
|
||||||
|
// testVoice(t, s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
|
||||||
v, err := NewSession(s)
|
v, err := NewSession(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to create a new voice session:", err)
|
t.Fatal("failed to create a new voice session:", err)
|
||||||
}
|
}
|
||||||
v.ErrorLog = func(err error) { t.Error(err) }
|
v.ErrorLog = func(err error) { t.Error(err) }
|
||||||
|
|
||||||
|
@ -71,10 +80,9 @@ func TestIntegration(t *testing.T) {
|
||||||
finish("receiving voice speaking event")
|
finish("receiving voice speaking event")
|
||||||
})
|
})
|
||||||
|
|
||||||
// Join the voice channel concurrently.
|
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
|
||||||
raceMe(t, "failed to join voice channel", func() (interface{}, error) {
|
t.Fatal("failed to join voice:", err)
|
||||||
return nil, v.JoinChannel(c.GuildID, c.ID, false, false)
|
}
|
||||||
})
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
log.Println("Leaving the voice channel concurrently.")
|
log.Println("Leaving the voice channel concurrently.")
|
||||||
|
@ -104,7 +112,7 @@ func TestIntegration(t *testing.T) {
|
||||||
|
|
||||||
f, err := os.Open("testdata/nico.dca")
|
f, err := os.Open("testdata/nico.dca")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to open nico.dca:", err)
|
t.Fatal("failed to open nico.dca:", err)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue