1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-10 22:16:59 +00:00
arikawa/voice/session.go
Cléo Rebert 8f548d2607 go fmt + gci
Signed-off-by: Cléo Rebert <cleo.rebert@gmail.com>
2023-09-19 11:37:24 -07:00

585 lines
16 KiB
Go

package voice
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/utils/handler"
"github.com/diamondburned/arikawa/v3/utils/ws"
"github.com/diamondburned/arikawa/v3/utils/ws/ophandler"
"github.com/diamondburned/arikawa/v3/voice/udp"
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
)
// Protocol is the encryption protocol that this library uses.
const Protocol = "xsalsa20_poly1305"
// ErrAlreadyConnecting is returned when the session is already connecting.
var ErrAlreadyConnecting = errors.New("already connecting")
// ErrCannotSend is an error when audio is sent to a closed channel.
var ErrCannotSend = errors.New("cannot send audio to closed channel")
// WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't
// take in a context already.
const WSTimeout = 25 * time.Second
// ReconnectError is emitted into Session.Handler everytime the voice gateway
// fails to be reconnected. It implements the error interface.
type ReconnectError struct {
Err error
}
// Error implements error.
func (e ReconnectError) Error() string {
return "voice reconnect error: " + e.Err.Error()
}
// Unwrap returns e.Err.
func (e ReconnectError) Unwrap() error { return e.Err }
// MainSession abstracts both session.Session and state.State.
type MainSession interface {
// AddHandler describes the method in handler.Handler.
AddHandler(handler interface{}) (rm func())
// Me returns the current user.
Me() (*discord.User, error)
// Channel queries for the channel with the given ID.
Channel(discord.ChannelID) (*discord.Channel, error)
// SendGateway is a helper to send messages over the gateway.
SendGateway(ctx context.Context, m ws.Event) error
}
var (
_ MainSession = (*session.Session)(nil)
_ MainSession = (*state.State)(nil)
)
// Session is a single voice session that wraps around the voice gateway and UDP
// connection.
type Session struct {
*handler.Handler
session MainSession
mut sync.RWMutex
// connected is a non-nil blocking channel after Join is called and is
// closed once Leave is called.
disconnected chan struct{}
state voicegateway.State // guarded except UserID
detachReconnect []func()
// udpManager is the manager for a UDP connection. The user can use this to
// plug in a custom UDP dialer.
udpManager *udp.Manager
gateway *voicegateway.Gateway
gwCancel context.CancelFunc
gwDone <-chan struct{}
WSTimeout time.Duration // global WSTimeout
WSMaxRetry int // 2
WSRetryDelay time.Duration // 2s
WSWaitDuration time.Duration // 5s
// joining determines the behavior of incoming event callbacks (Update).
// If this is true, incoming events will just send into Updated channels. If
// false, events will trigger a reconnection.
joining moreatomic.Bool
// disconnectClosed is true if connected is already closed. It is only used
// to keep track of closing connected.
disconnectClosed bool
}
// NewSession creates a new voice session for the current user.
func NewSession(state MainSession) (*Session, error) {
u, err := state.Me()
if err != nil {
return nil, fmt.Errorf("failed to get me: %w", err)
}
return NewSessionCustom(state, u.ID), nil
}
// NewSessionCustom creates a new voice session from the given session and user
// ID.
func NewSessionCustom(ses MainSession, userID discord.UserID) *Session {
closed := make(chan struct{})
close(closed)
session := &Session{
Handler: handler.New(),
session: ses,
state: voicegateway.State{
UserID: userID,
},
udpManager: udp.NewManager(),
WSTimeout: WSTimeout,
WSMaxRetry: 2,
WSRetryDelay: 2 * time.Second,
WSWaitDuration: 5 * time.Second,
// Set this pair of value in so we never have to nil-check the channel.
// We can just assume that it's either closed or connected.
disconnected: closed,
disconnectClosed: true,
}
return session
}
// SetUDPDialer sets the given dialer to be used for dialing UDP voice
// connections.
func (s *Session) SetUDPDialer(d udp.DialFunc) {
s.udpManager.SetDialer(d)
}
func (s *Session) acquireUpdate(f func()) bool {
if s.joining.Get() {
return false
}
s.mut.Lock()
defer s.mut.Unlock()
// Ignore if we haven't connected yet or we're still joining.
if s.udpManager.IsClosed() {
return false
}
f()
return true
}
// updateServer is specifically used to monitor for reconnects.
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
s.acquireUpdate(func() {
if s.state.GuildID != ev.GuildID {
return
}
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
s.reconnectCtx(ctx)
})
}
// updateState is specifically used after connecting to monitor when the bot is
// forced across channels.
func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
s.acquireUpdate(func() {
if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID {
return
}
s.state.ChannelID = ev.ChannelID
s.state.SessionID = ev.SessionID
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
s.reconnectCtx(ctx)
})
}
// JoinChannelAndSpeak is a convenient function that calls JoinChannel then
// Speaking.
func (s *Session) JoinChannelAndSpeak(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error {
if err := s.JoinChannel(ctx, chID, mute, deaf); err != nil {
return fmt.Errorf("cannot join channel: %w", err)
}
return s.Speaking(ctx, voicegateway.Microphone)
}
type waitEventChs struct {
serverUpdate chan *gateway.VoiceServerUpdateEvent
stateUpdate chan *gateway.VoiceStateUpdateEvent
}
// JoinChannel joins the given voice channel with the default timeout.
func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error {
var ch *discord.Channel
if chID.IsValid() {
var err error
ch, err = s.session.Channel(chID)
if err != nil {
return fmt.Errorf("invalid channel ID: %w", err)
}
}
s.mut.Lock()
defer s.mut.Unlock()
// Error out if we're already joining. JoinChannel shouldn't be called
// concurrently.
if !s.joining.Acquire() {
return errors.New("JoinChannel working elsewhere")
}
defer s.joining.Set(false)
// Set the state.
if ch != nil {
s.state.ChannelID = ch.ID
s.state.GuildID = ch.GuildID
} else {
s.state.GuildID = 0
// Ensure that if `cID` is zero that it passes null to the update event.
s.state.ChannelID = discord.NullChannelID
}
if s.detachReconnect == nil {
s.detachReconnect = []func(){
s.session.AddHandler(s.updateServer),
s.session.AddHandler(s.updateState),
}
}
chs := waitEventChs{
serverUpdate: make(chan *gateway.VoiceServerUpdateEvent),
stateUpdate: make(chan *gateway.VoiceStateUpdateEvent),
}
// Bind the handlers.
cancels := []func(){
s.session.AddHandler(chs.serverUpdate),
s.session.AddHandler(chs.stateUpdate),
}
// Disconnects the handlers once the function exits.
defer func() {
for _, cancel := range cancels {
cancel()
}
}()
// Ensure gateway and voiceUDP are already closed.
s.ensureClosed()
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
// Send a Voice State Update event to the gateway.
data := &gateway.UpdateVoiceStateCommand{
GuildID: s.state.GuildID,
ChannelID: s.state.ChannelID,
SelfMute: mute,
SelfDeaf: deaf,
}
var err error
var timer *time.Timer
// Retry 3 times maximum.
for i := 0; i < s.WSMaxRetry; i++ {
if err = s.askDiscord(ctx, data, chs); err == nil {
break
}
// If this is the first attempt and the context timed out, it's
// probably the context that's waiting for gateway events. Retry the
// loop.
if i == 0 && errors.Is(err, ctx.Err()) {
continue
}
if timer == nil {
// Set up a timer.
timer = time.NewTimer(s.WSRetryDelay)
defer timer.Stop()
} else {
timer.Reset(s.WSRetryDelay)
}
select {
case <-timer.C:
continue
case <-ctx.Done():
return fmt.Errorf("cannot ask Discord for events: %w", err)
}
}
// These 2 methods should've updated s.state before sending into these
// channels. Since s.state is already filled, we can go ahead and connect.
// Mark the session as connected and move on. This allows one of the
// connected handlers to reconnect on its own.
s.disconnected = make(chan struct{})
return s.reconnectCtx(ctx)
}
func (s *Session) askDiscord(
ctx context.Context, data *gateway.UpdateVoiceStateCommand, chs waitEventChs) error {
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
// Send a Voice State Update event to the gateway.
if err := s.session.SendGateway(ctx, data); err != nil {
return fmt.Errorf("failed to send Voice State Update event: %w", err)
}
// Wait for 2 replies. The above command should reply with these 2 events.
if err := s.waitForIncoming(ctx, chs); err != nil {
return fmt.Errorf("failed to wait for needed gateway events: %w", err)
}
return nil
}
func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error {
ctx, cancel := context.WithTimeout(ctx, s.WSWaitDuration)
defer cancel()
state := false
// server is true when we already have the token and endpoint, meaning that
// we don't have to wait for another such event.
server := s.state.Token != "" && s.state.Endpoint != ""
// Loop until timeout or until we have all the information that we need.
for !(server && state) {
select {
case ev := <-chs.serverUpdate:
if s.state.GuildID != ev.GuildID {
continue
}
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
server = true
case ev := <-chs.stateUpdate:
if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID {
continue
}
s.state.SessionID = ev.SessionID
s.state.ChannelID = ev.ChannelID
state = true
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// reconnect uses the current state to reconnect to a new gateway and UDP
// connection.
func (s *Session) reconnectCtx(ctx context.Context) error {
ws.WSDebug("Sending stop handle.")
if err := s.udpManager.Pause(ctx); err != nil {
return fmt.Errorf("cannot pause UDP manager: %w", err)
}
defer func() {
if !s.udpManager.Continue() {
panic("UDP manager continued but invalid lock ownership")
}
}()
s.ensureClosed()
ws.WSDebug("Start gateway.")
s.gateway = voicegateway.New(s.state)
// Open the voice gateway. The function will block until Ready is received.
gwctx, gwcancel := context.WithCancel(context.Background())
s.gwCancel = gwcancel
gwch := s.gateway.Connect(gwctx)
ws.WSDebug("Voice Gateway connected")
if err := s.spinGateway(ctx, gwch); err != nil {
ws.WSDebug("Voice spinGateway error:", err)
// Early cancel the gateway.
gwcancel()
// Nil this so future reconnects don't use the invalid gwDone.
s.gwCancel = nil
// Emit the error. It's fine to do this here since this is the only
// place that can error out.
s.Handler.Call(&ReconnectError{err})
return fmt.Errorf("cannot wait for event sequence from voice gateway: %w", err)
}
// Start dispatching.
s.gwDone = ophandler.Loop(gwch, s.Handler)
ws.WSDebug("Voice reconnectCtx finished with no error")
return nil
}
func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
var err error
var conn *udp.Connection
for {
select {
case <-ctx.Done():
return ctx.Err()
case ev, ok := <-gwch:
if !ok {
return fmt.Errorf("voice gateway error: %w", s.gateway.LastError())
}
switch data := ev.Data.(type) {
case *ws.CloseEvent:
return fmt.Errorf("voice gateway error: %w", err)
case *voicegateway.ReadyEvent:
ws.WSDebug("Got ready from voice gateway, SSRC:", data.SSRC)
// Prepare the UDP voice connection.
conn, err = s.udpManager.Dial(ctx, data.Addr(), data.SSRC)
if err != nil {
return fmt.Errorf("failed to open voice UDP connection: %w", err)
}
if err := s.gateway.Send(ctx, &voicegateway.SelectProtocolCommand{
Protocol: "udp",
Data: voicegateway.SelectProtocolData{
Address: conn.GatewayIP,
Port: conn.GatewayPort,
Mode: Protocol,
},
}); err != nil {
return fmt.Errorf("failed to send SelectProtocolCommand: %w", err)
}
case *voicegateway.SessionDescriptionEvent:
if conn == nil {
return errors.New("server bug: SessionDescription before Ready")
}
ws.WSDebug("Received secret key from voice gateway")
// We're done.
conn.UseSecret(data.SecretKey)
return nil
}
// Dispatch this event to the handler.
s.Handler.Call(ev.Data)
}
}
}
// Speaking tells Discord we're speaking. This method should not be called
// concurrently.
//
// If only NotSpeaking (0) is given, then even if the gateway cannot be reached,
// a nil error will be returned. This is because sending Discord a not-speaking
// event is a destruction command that doesn't affect the outcome of anything
// done after whatsoever.
func (s *Session) Speaking(ctx context.Context, flag voicegateway.SpeakingFlag) error {
s.mut.Lock()
gateway := s.gateway
s.mut.Unlock()
if err := gateway.Speaking(ctx, flag); err != nil && flag != 0 {
return err
}
return nil
}
// Write writes into the UDP voice connection. This method is thread safe as far
// as calling other methods of Session goes; HOWEVER it is not thread safe to
// call Write itself concurrently.
func (s *Session) Write(b []byte) (int, error) {
return s.udpManager.Write(b)
}
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
// thread safe, and must be used very carefully. The backing buffer is always
// reused.
func (s *Session) ReadPacket() (*udp.Packet, error) {
return s.udpManager.ReadPacket()
}
// Leave disconnects the current voice session from the currently connected
// channel.
func (s *Session) Leave(ctx context.Context) error {
s.mut.Lock()
defer s.mut.Unlock()
s.ensureClosed()
// Unbind the handlers.
if s.detachReconnect != nil {
for _, detach := range s.detachReconnect {
detach()
}
s.detachReconnect = nil
}
// If we're already closed.
if s.gateway == nil && s.udpManager.IsClosed() {
return nil
}
// Notify Discord that we're leaving.
sendErr := s.session.SendGateway(ctx, &gateway.UpdateVoiceStateCommand{
GuildID: s.state.GuildID,
ChannelID: discord.ChannelID(discord.NullSnowflake),
SelfMute: true,
SelfDeaf: true,
})
// Wait for the gateway to exit first before we tell the user of the gateway
// send error.
if err := s.cancelGateway(ctx); err != nil {
return err
}
if sendErr != nil {
return fmt.Errorf("failed to update voice state: %w", sendErr)
}
return nil
}
func (s *Session) cancelGateway(ctx context.Context) error {
if s.gwCancel != nil {
s.gwCancel()
s.gwCancel = nil
// Wait for the previous gateway to finish closing up, but make sure to
// bail if the context expires.
if err := ophandler.WaitForDone(ctx, s.gwDone); err != nil {
return fmt.Errorf("cannot wait for gateway to close: %w", err)
}
}
return nil
}
const (
permanentClose = true
temporaryClose = false
)
// close ensures everything is closed. It does not acquire the mutex.
func (s *Session) ensureClosed() {
// Disconnect the UDP connection. If not permanent, then pause.
s.udpManager.Close()
if s.gwCancel != nil {
s.gwCancel()
// Don't actually clear this field, because we still want the caller to
// be able to wait for the gateway to completely exit using
// cancelGateway.
}
}