arikawa/voice/session.go

455 lines
12 KiB
Go

package voice
import (
"context"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/utils/handler"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/internal/handleloop"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/diamondburned/arikawa/v3/voice/udp"
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
)
// Protocol is the encryption protocol that this library uses.
const Protocol = "xsalsa20_poly1305"
// ErrAlreadyConnecting is returned when the session is already connecting.
var ErrAlreadyConnecting = errors.New("already connecting")
// ErrCannotSend is an error when audio is sent to a closed channel.
var ErrCannotSend = errors.New("cannot send audio to closed channel")
// WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't
// take in a context already.
var WSTimeout = 10 * time.Second
// Session is a single voice session that wraps around the voice gateway and UDP
// connection.
type Session struct {
*handler.Handler
ErrorLog func(err error)
session *session.Session
looper *handleloop.Loop
detach func()
mut sync.RWMutex
state voicegateway.State // guarded except UserID
// TODO: expose getters mutex-guarded.
gateway *voicegateway.Gateway
voiceUDP *udp.Connection
// end of mutex
WSTimeout time.Duration // global WSTimeout
WSMaxRetry int // 2
WSRetryDelay time.Duration // 2s
WSWaitDuration time.Duration // 5s
// joining determines the behavior of incoming event callbacks (Update).
// If this is true, incoming events will just send into Updated channels. If
// false, events will trigger a reconnection.
joining moreatomic.Bool
connected bool
}
// NewSession creates a new voice session for the current user.
func NewSession(state *state.State) (*Session, error) {
u, err := state.Me()
if err != nil {
return nil, errors.Wrap(err, "failed to get me")
}
return NewSessionCustom(state.Session, u.ID), nil
}
// NewSessionCustom creates a new voice session from the given session and user
// ID.
func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
handler := handler.New()
hlooper := handleloop.NewLoop(handler)
session := &Session{
Handler: handler,
looper: hlooper,
session: ses,
state: voicegateway.State{
UserID: userID,
},
ErrorLog: func(err error) {},
WSTimeout: WSTimeout,
WSMaxRetry: 2,
WSRetryDelay: 2 * time.Second,
WSWaitDuration: 5 * time.Second,
}
return session
}
// updateServer is specifically used to monitor for reconnects.
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
if s.joining.Get() {
return
}
s.mut.Lock()
defer s.mut.Unlock()
// Ignore if we haven't connected yet or we're still joining.
if !s.connected || s.state.GuildID != ev.GuildID {
return
}
// Reconnect.
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
if err := s.reconnectCtx(ctx); err != nil {
s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update"))
}
}
// JoinChannel joins a voice channel with the default WS timeout. See
// JoinChannelCtx for more information.
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
}
type waitEventChs struct {
serverUpdate chan *gateway.VoiceServerUpdateEvent
stateUpdate chan *gateway.VoiceStateUpdateEvent
}
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
// directly, but rather Voice's. This method shouldn't ever be called
// concurrently.
func (s *Session) JoinChannelCtx(
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
s.mut.Lock()
defer s.mut.Unlock()
// Error out if we're already joining. JoinChannel shouldn't be called
// concurrently.
if s.joining.Get() {
return errors.New("JoinChannel working elsewhere")
}
s.joining.Set(true)
defer s.joining.Set(false)
// Set the state.
s.state.ChannelID = cID
s.state.GuildID = gID
s.detach = s.session.AddHandler(s.updateServer)
// Ensure that if `cID` is zero that it passes null to the update event.
channelID := discord.NullChannelID
if cID.IsValid() {
channelID = cID
}
chs := waitEventChs{
serverUpdate: make(chan *gateway.VoiceServerUpdateEvent),
stateUpdate: make(chan *gateway.VoiceStateUpdateEvent),
}
// Bind the handlers.
cancels := []func(){
s.session.AddHandler(chs.serverUpdate),
s.session.AddHandler(chs.stateUpdate),
}
// Disconnects the handlers once the function exits.
defer func() {
for _, cancel := range cancels {
cancel()
}
}()
// Ensure gateway and voiceUDP are already closed.
s.ensureClosed()
data := gateway.UpdateVoiceStateData{
GuildID: gID,
ChannelID: channelID,
SelfMute: mute,
SelfDeaf: deaf,
}
var err error
var timer *time.Timer
// Retry 3 times maximum.
for i := 0; i < s.WSMaxRetry; i++ {
if err = s.askDiscord(ctx, data, chs); err == nil {
break
}
// If this is the first attempt and the context timed out, it's
// probably the context that's waiting for gateway events. Retry the
// loop.
if i == 0 && errors.Is(err, ctx.Err()) {
continue
}
if timer == nil {
// Set up a timer.
timer = time.NewTimer(s.WSRetryDelay)
defer timer.Stop()
} else {
timer.Reset(s.WSRetryDelay)
}
select {
case <-timer.C:
continue
case <-ctx.Done():
return err
}
}
// These 2 methods should've updated s.state before sending into these
// channels. Since s.state is already filled, we can go ahead and connect.
// Mark the session as connected and move on. This allows one of the
// connected handlers to reconnect on its own.
s.connected = true
return s.reconnectCtx(ctx)
}
func (s *Session) askDiscord(
ctx context.Context, data gateway.UpdateVoiceStateData, chs waitEventChs) error {
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
// Send a Voice State Update event to the gateway.
if err := s.session.Gateway.UpdateVoiceStateCtx(ctx, data); err != nil {
return errors.Wrap(err, "failed to send Voice State Update event")
}
// Wait for 2 replies. The above command should reply with these 2 events.
if err := s.waitForIncoming(ctx, chs); err != nil {
return errors.Wrap(err, "failed to wait for needed gateway events")
}
return nil
}
func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error {
ctx, cancel := context.WithTimeout(ctx, s.WSWaitDuration)
defer cancel()
state := false
// server is true when we already have the token and endpoint, meaning that
// we don't have to wait for another such event.
server := s.state.Token != "" && s.state.Endpoint != ""
// Loop until timeout or until we have all the information that we need.
for !(server && state) {
select {
case ev := <-chs.serverUpdate:
if s.state.GuildID != ev.GuildID {
continue
}
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
server = true
case ev := <-chs.stateUpdate:
if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID {
continue
}
s.state.SessionID = ev.SessionID
s.state.ChannelID = ev.ChannelID
state = true
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// reconnect uses the current state to reconnect to a new gateway and UDP
// connection.
func (s *Session) reconnectCtx(ctx context.Context) (err error) {
wsutil.WSDebug("Sending stop handle.")
s.looper.Stop()
wsutil.WSDebug("Start gateway.")
s.gateway = voicegateway.New(s.state)
// Open the voice gateway. The function will block until Ready is received.
if err := s.gateway.OpenCtx(ctx); err != nil {
return errors.Wrap(err, "failed to open voice gateway")
}
// Start the handler dispatching
s.looper.Start(s.gateway.Events)
// Get the Ready event.
voiceReady := s.gateway.Ready()
// Prepare the UDP voice connection.
s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC)
if err != nil {
return errors.Wrap(err, "failed to open voice UDP connection")
}
// Get the session description from the voice gateway.
d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{
Protocol: "udp",
Data: voicegateway.SelectProtocolData{
Address: s.voiceUDP.GatewayIP,
Port: s.voiceUDP.GatewayPort,
Mode: Protocol,
},
})
if err != nil {
return errors.Wrap(err, "failed to select protocol")
}
s.voiceUDP.UseSecret(d.SecretKey)
return nil
}
// Speaking tells Discord we're speaking. This method should not be called
// concurrently.
func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
s.mut.RLock()
gateway := s.gateway
s.mut.RUnlock()
return gateway.Speaking(flag)
}
// UseContext tells the UDP voice connection to write with the given context.
func (s *Session) UseContext(ctx context.Context) error {
s.mut.Lock()
defer s.mut.Unlock()
if s.voiceUDP == nil {
return ErrCannotSend
}
return s.voiceUDP.UseContext(ctx)
}
// VoiceUDPConn gets a voice UDP connection. The caller could use this method to
// circumvent the rapid mutex-read-lock acquire inside Write.
func (s *Session) VoiceUDPConn() *udp.Connection {
s.mut.RLock()
defer s.mut.RUnlock()
return s.voiceUDP
}
// Write writes into the UDP voice connection WITHOUT a timeout. Refer to
// WriteCtx for more information.
func (s *Session) Write(b []byte) (int, error) {
return s.WriteCtx(context.Background(), b)
}
// WriteCtx writes into the UDP voice connection with a context for timeout.
// This method is thread safe as far as calling other methods of Session goes;
// HOWEVER it is not thread safe to call Write itself concurrently.
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
voiceUDP := s.VoiceUDPConn()
if voiceUDP == nil {
return 0, ErrCannotSend
}
return voiceUDP.WriteCtx(ctx, b)
}
// Leave disconnects the current voice session from the currently connected
// channel.
func (s *Session) Leave() error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.LeaveCtx(ctx)
}
// LeaveCtx disconencts with a context. Refer to Leave for more information.
func (s *Session) LeaveCtx(ctx context.Context) error {
s.mut.Lock()
defer s.mut.Unlock()
s.connected = false
// Unbind the handlers.
if s.detach != nil {
s.detach()
s.detach = nil
}
// If we're already closed.
if s.gateway == nil && s.voiceUDP == nil {
return nil
}
s.looper.Stop()
// Notify Discord that we're leaving. This will send a
// VoiceStateUpdateEvent, in which our handler will promptly remove the
// session from the map.
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
GuildID: s.state.GuildID,
ChannelID: discord.ChannelID(discord.NullSnowflake),
SelfMute: true,
SelfDeaf: true,
})
s.ensureClosed()
// wrap returns nil if err is nil
return errors.Wrap(err, "failed to update voice state")
}
// close ensures everything is closed. It does not acquire the mutex.
func (s *Session) ensureClosed() {
s.looper.Stop()
// Disconnect the UDP connection.
if s.voiceUDP != nil {
s.voiceUDP.Close()
s.voiceUDP = nil
}
// Disconnect the voice gateway, ignoring the error.
if s.gateway != nil {
if err := s.gateway.Close(); err != nil {
wsutil.WSDebug("Uncaught voice gateway close error:", err)
}
s.gateway = nil
}
}
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
// thread safe, and must be used very carefully. The backing buffer is always
// reused.
func (s *Session) ReadPacket() (*udp.Packet, error) {
return s.VoiceUDPConn().ReadPacket()
}