diff --git a/voice/session.go b/voice/session.go index 4790427..26c0dcc 100644 --- a/voice/session.go +++ b/voice/session.go @@ -41,22 +41,26 @@ type Session struct { ErrorLog func(err error) session *session.Session - cancels []func() looper *handleloop.Loop + detach func() + + mut sync.RWMutex + state voicegateway.State // guarded except UserID + // TODO: expose getters mutex-guarded. + gateway *voicegateway.Gateway + voiceUDP *udp.Connection + // end of mutex + + WSTimeout time.Duration // global WSTimeout + WSMaxRetry int // 2 + WSRetryDelay time.Duration // 2s + WSWaitDuration time.Duration // 5s // joining determines the behavior of incoming event callbacks (Update). // If this is true, incoming events will just send into Updated channels. If // false, events will trigger a reconnection. - joining moreatomic.Bool - incoming chan struct{} // used only when joining == true - - mut sync.RWMutex - - state voicegateway.State // guarded except UserID - - // TODO: expose getters mutex-guarded. - gateway *voicegateway.Gateway - voiceUDP *udp.Connection + joining moreatomic.Bool + connected bool } // NewSession creates a new voice session for the current user. @@ -81,35 +85,27 @@ func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session { state: voicegateway.State{ UserID: userID, }, - ErrorLog: func(err error) {}, - incoming: make(chan struct{}, 2), - } - session.cancels = []func(){ - ses.AddHandler(session.updateServer), - ses.AddHandler(session.updateState), + ErrorLog: func(err error) {}, + WSTimeout: WSTimeout, + WSMaxRetry: 2, + WSRetryDelay: 2 * time.Second, + WSWaitDuration: 5 * time.Second, } return session } +// updateServer is specifically used to monitor for reconnects. func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) { - // If this is true, then mutex is acquired already. if s.joining.Get() { - if s.state.GuildID != ev.GuildID { - return - } - - s.state.Endpoint = ev.Endpoint - s.state.Token = ev.Token - - s.incoming <- struct{}{} return } s.mut.Lock() defer s.mut.Unlock() - if s.state.GuildID != ev.GuildID { + // Ignore if we haven't connected yet or we're still joining. + if !s.connected || s.state.GuildID != ev.GuildID { return } @@ -126,26 +122,8 @@ func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) { } } -func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) { - if s.state.UserID != ev.UserID { // constant so no mutex - // Not our state. - return - } - - // If this is true, then mutex is acquired already. - if s.joining.Get() { - if s.state.GuildID != ev.GuildID { - return - } - - s.state.SessionID = ev.SessionID - s.state.ChannelID = ev.ChannelID - - s.incoming <- struct{}{} - return - } -} - +// JoinChannel joins a voice channel with the default WS timeout. See +// JoinChannelCtx for more information. func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error { ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) defer cancel() @@ -153,30 +131,33 @@ func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, return s.JoinChannelCtx(ctx, gID, cID, mute, deaf) } +type waitEventChs struct { + serverUpdate chan *gateway.VoiceServerUpdateEvent + stateUpdate chan *gateway.VoiceStateUpdateEvent +} + // JoinChannelCtx joins a voice channel. Callers shouldn't use this method // directly, but rather Voice's. This method shouldn't ever be called // concurrently. func (s *Session) JoinChannelCtx( ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error { - if s.joining.Get() { - return ErrAlreadyConnecting - } - - // Acquire the mutex during join, locking during IO as well. s.mut.Lock() defer s.mut.Unlock() - // Set that we're joining. - s.joining.Set(true) - defer s.joining.Set(false) // reset when done + // Error out if we're already joining. JoinChannel shouldn't be called + // concurrently. + if s.joining.Get() { + return errors.New("JoinChannel working elsewhere") + } - // Ensure gateway and voiceUDP are already closed. - s.ensureClosed() + s.joining.Set(true) + defer s.joining.Set(false) // Set the state. s.state.ChannelID = cID s.state.GuildID = gID + s.detach = s.session.AddHandler(s.updateServer) // Ensure that if `cID` is zero that it passes null to the update event. channelID := discord.NullChannelID @@ -184,34 +165,121 @@ func (s *Session) JoinChannelCtx( channelID = cID } - // https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information - // Send a Voice State Update event to the gateway. - err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{ + chs := waitEventChs{ + serverUpdate: make(chan *gateway.VoiceServerUpdateEvent), + stateUpdate: make(chan *gateway.VoiceStateUpdateEvent), + } + + // Bind the handlers. + cancels := []func(){ + s.session.AddHandler(chs.serverUpdate), + s.session.AddHandler(chs.stateUpdate), + } + // Disconnects the handlers once the function exits. + defer func() { + for _, cancel := range cancels { + cancel() + } + }() + + // Ensure gateway and voiceUDP are already closed. + s.ensureClosed() + + data := gateway.UpdateVoiceStateData{ GuildID: gID, ChannelID: channelID, SelfMute: mute, SelfDeaf: deaf, - }) - if err != nil { - return errors.Wrap(err, "failed to send Voice State Update event") } - // Wait for 2 replies. The above command should reply with these 2 events. - if err := s.waitForIncoming(ctx, 2); err != nil { - return errors.Wrap(err, "failed to wait for needed gateway events") + var err error + + var timer *time.Timer + + // Retry 3 times maximum. + for i := 0; i < s.WSMaxRetry; i++ { + if err = s.askDiscord(ctx, data, chs); err == nil { + break + } + + // If this is the first attempt and the context timed out, it's + // probably the context that's waiting for gateway events. Retry the + // loop. + if i == 0 && errors.Is(err, ctx.Err()) { + continue + } + + if timer == nil { + // Set up a timer. + timer = time.NewTimer(s.WSRetryDelay) + defer timer.Stop() + } else { + timer.Reset(s.WSRetryDelay) + } + + select { + case <-timer.C: + continue + case <-ctx.Done(): + return err + } } // These 2 methods should've updated s.state before sending into these // channels. Since s.state is already filled, we can go ahead and connect. + // Mark the session as connected and move on. This allows one of the + // connected handlers to reconnect on its own. + s.connected = true + return s.reconnectCtx(ctx) } -func (s *Session) waitForIncoming(ctx context.Context, n int) error { - for i := 0; i < n; i++ { +func (s *Session) askDiscord( + ctx context.Context, data gateway.UpdateVoiceStateData, chs waitEventChs) error { + + // https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information + // Send a Voice State Update event to the gateway. + if err := s.session.Gateway.UpdateVoiceStateCtx(ctx, data); err != nil { + return errors.Wrap(err, "failed to send Voice State Update event") + } + + // Wait for 2 replies. The above command should reply with these 2 events. + if err := s.waitForIncoming(ctx, chs); err != nil { + return errors.Wrap(err, "failed to wait for needed gateway events") + } + + return nil +} + +func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error { + ctx, cancel := context.WithTimeout(ctx, s.WSWaitDuration) + defer cancel() + + state := false + // server is true when we already have the token and endpoint, meaning that + // we don't have to wait for another such event. + server := s.state.Token != "" && s.state.Endpoint != "" + + // Loop until timeout or until we have all the information that we need. + for !(server && state) { select { - case <-s.incoming: - continue + case ev := <-chs.serverUpdate: + if s.state.GuildID != ev.GuildID { + continue + } + s.state.Endpoint = ev.Endpoint + s.state.Token = ev.Token + server = true + + case ev := <-chs.stateUpdate: + if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID { + continue + } + s.state.SessionID = ev.SessionID + s.state.ChannelID = ev.ChannelID + state = true + case <-ctx.Done(): return ctx.Err() } @@ -328,6 +396,14 @@ func (s *Session) LeaveCtx(ctx context.Context) error { s.mut.Lock() defer s.mut.Unlock() + s.connected = false + + // Unbind the handlers. + if s.detach != nil { + s.detach() + s.detach = nil + } + // If we're already closed. if s.gateway == nil && s.voiceUDP == nil { return nil diff --git a/voice/session_test.go b/voice/session_test.go index 7ebfbda..a3ddc94 100644 --- a/voice/session_test.go +++ b/voice/session_test.go @@ -31,7 +31,7 @@ func TestIntegration(t *testing.T) { s, err := state.New("Bot " + config.BotToken) if err != nil { - t.Fatal("Failed to create a new state:", err) + t.Fatal("failed to create a new state:", err) } AddIntents(s.Gateway) @@ -40,7 +40,7 @@ func TestIntegration(t *testing.T) { defer cancel() if err := s.Open(ctx); err != nil { - t.Fatal("Failed to connect:", err) + t.Fatal("failed to connect:", err) } }() @@ -49,17 +49,26 @@ func TestIntegration(t *testing.T) { // Validate the given voice channel. c, err := s.Channel(config.VoiceChID) if err != nil { - t.Fatal("Failed to get channel:", err) + t.Fatal("failed to get channel:", err) } if c.Type != discord.GuildVoice { - t.Fatal("Channel isn't a guild voice channel.") + t.Fatal("channel isn't a guild voice channel.") } log.Println("The voice channel's name is", c.Name) + testVoice(t, s, c) + + // BUG: Discord doesn't want to send the second VoiceServerUpdateEvent. I + // have no idea why. + + // testVoice(t, s, c) +} + +func testVoice(t *testing.T, s *state.State, c *discord.Channel) { v, err := NewSession(s) if err != nil { - t.Fatal("Failed to create a new voice session:", err) + t.Fatal("failed to create a new voice session:", err) } v.ErrorLog = func(err error) { t.Error(err) } @@ -71,10 +80,9 @@ func TestIntegration(t *testing.T) { finish("receiving voice speaking event") }) - // Join the voice channel concurrently. - raceMe(t, "failed to join voice channel", func() (interface{}, error) { - return nil, v.JoinChannel(c.GuildID, c.ID, false, false) - }) + if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil { + t.Fatal("failed to join voice:", err) + } t.Cleanup(func() { log.Println("Leaving the voice channel concurrently.") @@ -104,7 +112,7 @@ func TestIntegration(t *testing.T) { f, err := os.Open("testdata/nico.dca") if err != nil { - t.Fatal("Failed to open nico.dca:", err) + t.Fatal("failed to open nico.dca:", err) } defer f.Close()