1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-07 12:38:05 +00:00

Voice: Remove state-keeping of sessions

This commit gets rid of all the code that previously managed different
voice sessions in different guilds. This is because there is rarely ever
a need for this, and most bots that need this could do their own
keeping.

This change, although removes some features off of the package, adds a
lot of clarity on what to do exactly when it comes to connecting to a
voice channel.

In order to make the migration process a bit easier, an example has been
added which guides through using the voice.Session API.
This commit is contained in:
diamondburned 2020-11-30 17:49:18 -08:00
parent 6727f0e728
commit b8994ed0da
9 changed files with 248 additions and 355 deletions

View file

@ -31,7 +31,7 @@
"variables": [ "$BOT_TOKEN" ] "variables": [ "$BOT_TOKEN" ]
}, },
"script": [ "script": [
"go test -coverprofile $COV -race ./...", "go test -coverprofile $COV -tags unitonly -race ./...",
"go tool cover -func $COV" "go tool cover -func $COV"
] ]
}, },
@ -47,7 +47,7 @@
"go get ./...", "go get ./...",
# Test this package along with dismock. # Test this package along with dismock.
"go get $dismock@$dismock_v", "go get $dismock@$dismock_v",
"go test -coverpkg $tested -coverprofile $COV -tags integration -race ./... $dismock", "go test -coverpkg $tested -coverprofile $COV -race ./... $dismock",
"go mod tidy", "go mod tidy",
"go tool cover -func $COV" "go tool cover -func $COV"
] ]

View file

@ -1,46 +1,18 @@
// +build integration // +build !unitonly
package api package api
import ( import (
"fmt" "fmt"
"log" "log"
"os"
"testing" "testing"
"time" "time"
"github.com/diamondburned/arikawa/v2/discord" "github.com/diamondburned/arikawa/v2/internal/testenv"
) )
type testConfig struct {
BotToken string
ChannelID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var cid = os.Getenv("CHANNEL_ID")
if cid == "" {
t.Fatal("Missing $CHANNEL_ID")
}
id, err := discord.ParseSnowflake(cid)
if err != nil {
t.Fatal("Invalid $CHANNEL_ID:", err)
}
return testConfig{
BotToken: token,
ChannelID: discord.ChannelID(id),
}
}
func TestIntegration(t *testing.T) { func TestIntegration(t *testing.T) {
cfg := mustConfig(t) cfg := testenv.Must(t)
client := NewClient("Bot " + cfg.BotToken) client := NewClient("Bot " + cfg.BotToken)
@ -81,7 +53,7 @@ var emojisToSend = [...]string{
} }
func TestReactions(t *testing.T) { func TestReactions(t *testing.T) {
cfg := mustConfig(t) cfg := testenv.Must(t)
client := NewClient("Bot " + cfg.BotToken) client := NewClient("Bot " + cfg.BotToken)

View file

@ -1,16 +1,16 @@
// +build integration // +build !unitonly
package gateway package gateway
import ( import (
"context" "context"
"log" "log"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/diamondburned/arikawa/v2/internal/heart" "github.com/diamondburned/arikawa/v2/internal/heart"
"github.com/diamondburned/arikawa/v2/internal/testenv"
"github.com/diamondburned/arikawa/v2/utils/wsutil" "github.com/diamondburned/arikawa/v2/utils/wsutil"
) )
@ -43,10 +43,7 @@ func TestInvalidToken(t *testing.T) {
} }
func TestIntegration(t *testing.T) { func TestIntegration(t *testing.T) {
var token = os.Getenv("BOT_TOKEN") config := testenv.Must(t)
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
wsutil.WSError = func(err error) { wsutil.WSError = func(err error) {
t.Error(err) t.Error(err)
@ -55,7 +52,7 @@ func TestIntegration(t *testing.T) {
var gateway *Gateway var gateway *Gateway
// NewGateway should call Start for us. // NewGateway should call Start for us.
g, err := NewGateway("Bot " + token) g, err := NewGateway("Bot " + config.BotToken)
if err != nil { if err != nil {
t.Fatal("Failed to make a Gateway:", err) t.Fatal("Failed to make a Gateway:", err)
} }

View file

@ -0,0 +1,75 @@
// +build !uintonly
package testenv
import (
"os"
"sync"
"testing"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/pkg/errors"
)
type Env struct {
BotToken string
ChannelID discord.ChannelID
VoiceChID discord.ChannelID
}
var (
env Env
err error
once sync.Once
)
func Must(t *testing.T) Env {
e, err := GetEnv()
if err != nil {
t.Fatal(err)
}
return e
}
func GetEnv() (Env, error) {
once.Do(getEnv)
return env, err
}
func getEnv() {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
err = errors.New("missing $BOT_TOKEN")
return
}
var cid = os.Getenv("CHANNEL_ID")
if cid == "" {
err = errors.New("missing $CHANNEL_ID")
return
}
id, err := discord.ParseSnowflake(cid)
if err != nil {
err = errors.Wrap(err, "invalid $CHANNEL_ID")
return
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
err = errors.New("missing $VOICE_ID")
return
}
vid, err := discord.ParseSnowflake(sid)
if err != nil {
err = errors.Wrap(err, "invalid $VOICE_ID")
return
}
env = Env{
BotToken: token,
ChannelID: discord.ChannelID(id),
VoiceChID: discord.ChannelID(vid),
}
}

View file

@ -5,6 +5,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/diamondburned/arikawa/v2/state"
"github.com/diamondburned/arikawa/v2/utils/handler" "github.com/diamondburned/arikawa/v2/utils/handler"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -19,21 +20,28 @@ import (
"github.com/diamondburned/arikawa/v2/voice/voicegateway" "github.com/diamondburned/arikawa/v2/voice/voicegateway"
) )
// Protocol is the encryption protocol that this library uses.
const Protocol = "xsalsa20_poly1305" const Protocol = "xsalsa20_poly1305"
// ErrAlreadyConnecting is returned when the session is already connecting. // ErrAlreadyConnecting is returned when the session is already connecting.
var ErrAlreadyConnecting = errors.New("already connecting") var ErrAlreadyConnecting = errors.New("already connecting")
// ErrCannotSend is an error when audio is sent to a closed channel.
var ErrCannotSend = errors.New("cannot send audio to closed channel")
// WSTimeout is the duration to wait for a gateway operation including Session // WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't // to complete before erroring out. This only applies to functions that don't
// take in a context already. // take in a context already.
var WSTimeout = 10 * time.Second var WSTimeout = 10 * time.Second
// Session is a single voice session that wraps around the voice gateway and UDP
// connection.
type Session struct { type Session struct {
*handler.Handler *handler.Handler
ErrorLog func(err error) ErrorLog func(err error)
session *session.Session session *session.Session
cancels []func()
looper *handleloop.Loop looper *handleloop.Loop
// joining determines the behavior of incoming event callbacks (Update). // joining determines the behavior of incoming event callbacks (Update).
@ -51,13 +59,24 @@ type Session struct {
voiceUDP *udp.Connection voiceUDP *udp.Connection
} }
func NewSession(ses *session.Session, userID discord.UserID) *Session { // NewSession creates a new voice session for the current user.
handler := handler.New() func NewSession(state *state.State) (*Session, error) {
looper := handleloop.NewLoop(handler) u, err := state.Me()
if err != nil {
return nil, errors.Wrap(err, "failed to get me")
}
return &Session{ return NewSessionCustom(state.Session, u.ID), nil
}
// NewSessionCustom creates a new voice session from the given session and user
// ID.
func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
handler := handler.New()
hlooper := handleloop.NewLoop(handler)
session := &Session{
Handler: handler, Handler: handler,
looper: looper, looper: hlooper,
session: ses, session: ses,
state: voicegateway.State{ state: voicegateway.State{
UserID: userID, UserID: userID,
@ -65,9 +84,15 @@ func NewSession(ses *session.Session, userID discord.UserID) *Session {
ErrorLog: func(err error) {}, ErrorLog: func(err error) {},
incoming: make(chan struct{}, 2), incoming: make(chan struct{}, 2),
} }
session.cancels = []func(){
ses.AddHandler(session.updateServer),
ses.AddHandler(session.updateState),
}
return session
} }
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) { func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
// If this is true, then mutex is acquired already. // If this is true, then mutex is acquired already.
if s.joining.Get() { if s.joining.Get() {
if s.state.GuildID != ev.GuildID { if s.state.GuildID != ev.GuildID {
@ -101,7 +126,7 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
} }
} }
func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) { func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
if s.state.UserID != ev.UserID { // constant so no mutex if s.state.UserID != ev.UserID { // constant so no mutex
// Not our state. // Not our state.
return return
@ -109,6 +134,10 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
// If this is true, then mutex is acquired already. // If this is true, then mutex is acquired already.
if s.joining.Get() { if s.joining.Get() {
if s.state.GuildID != ev.GuildID {
return
}
s.state.SessionID = ev.SessionID s.state.SessionID = ev.SessionID
s.state.ChannelID = ev.ChannelID s.state.ChannelID = ev.ChannelID
@ -117,20 +146,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
} }
} }
func (s *Session) JoinChannel( func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel() defer cancel()
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened) return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
} }
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method // JoinChannelCtx joins a voice channel. Callers shouldn't use this method
// directly, but rather Voice's. If this method is called concurrently, an error // directly, but rather Voice's. This method shouldn't ever be called
// will be returned. // concurrently.
func (s *Session) JoinChannelCtx( func (s *Session) JoinChannelCtx(
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error { ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
if s.joining.Get() { if s.joining.Get() {
return ErrAlreadyConnecting return ErrAlreadyConnecting
@ -162,8 +189,8 @@ func (s *Session) JoinChannelCtx(
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{ err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
GuildID: gID, GuildID: gID,
ChannelID: channelID, ChannelID: channelID,
SelfMute: muted, SelfMute: mute,
SelfDeaf: deafened, SelfDeaf: deaf,
}) })
if err != nil { if err != nil {
return errors.Wrap(err, "failed to send Voice State Update event") return errors.Wrap(err, "failed to send Voice State Update event")
@ -247,7 +274,7 @@ func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
return gateway.Speaking(flag) return gateway.Speaking(flag)
} }
// UseContext tells the UDP voice connection to write with the given mutex. // UseContext tells the UDP voice connection to write with the given context.
func (s *Session) UseContext(ctx context.Context) error { func (s *Session) UseContext(ctx context.Context) error {
s.mut.Lock() s.mut.Lock()
defer s.mut.Unlock() defer s.mut.Unlock()
@ -287,14 +314,17 @@ func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
return voiceUDP.WriteCtx(ctx, b) return voiceUDP.WriteCtx(ctx, b)
} }
func (s *Session) Disconnect() error { // Leave disconnects the current voice session from the currently connected
// channel.
func (s *Session) Leave() error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel() defer cancel()
return s.DisconnectCtx(ctx) return s.LeaveCtx(ctx)
} }
func (s *Session) DisconnectCtx(ctx context.Context) error { // LeaveCtx disconencts with a context. Refer to Leave for more information.
func (s *Session) LeaveCtx(ctx context.Context) error {
s.mut.Lock() s.mut.Lock()
defer s.mut.Unlock() defer s.mut.Unlock()
@ -340,9 +370,9 @@ func (s *Session) ensureClosed() {
} }
} }
// ReadPacket reads a single packet from the UDP connection. // ReadPacket reads a single packet from the UDP connection. This is NOT at all
// This is NOT at all thread safe, and must be used very carefully. // thread safe, and must be used very carefully. The backing buffer is always
// The backing buffer is always reused. // reused.
func (s *Session) ReadPacket() (*udp.Packet, error) { func (s *Session) ReadPacket() (*udp.Packet, error) {
return s.voiceUDP.ReadPacket() return s.voiceUDP.ReadPacket()
} }

View file

@ -0,0 +1,70 @@
// +build !unitonly
package voice_test
import (
"io"
"log"
"testing"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/internal/testenv"
"github.com/diamondburned/arikawa/v2/state"
"github.com/diamondburned/arikawa/v2/voice"
)
var (
token string
channelID discord.ChannelID
)
func init() {
e, err := testenv.GetEnv()
if err == nil {
token = e.BotToken
channelID = e.VoiceChID
}
}
// pseudo function for example
func writeOpusInto(w io.Writer) {}
// make godoc not show the full file
func TestNoop(t *testing.T) {
t.Skip("noop")
}
func ExampleSession() {
s, err := state.New("Bot " + token)
if err != nil {
log.Fatalln("failed to make state:", err)
}
// This is required for bots.
voice.AddIntents(s.Gateway)
if err := s.Open(); err != nil {
log.Fatalln("failed to open gateway:", err)
}
defer s.Close()
c, err := s.Channel(channelID)
if err != nil {
log.Fatalln("failed to get channel:", err)
}
v, err := voice.NewSession(s)
if err != nil {
log.Fatalln("failed to create voice session:", err)
}
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
log.Fatalln("failed to join voice channel:", err)
}
defer v.Leave()
// Start writing Opus frames.
for {
writeOpusInto(v)
}
}

View file

@ -1,4 +1,4 @@
// +build integration // +build !unitonly
package voice package voice
@ -15,14 +15,15 @@ import (
"time" "time"
"github.com/diamondburned/arikawa/v2/discord" "github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/gateway" "github.com/diamondburned/arikawa/v2/internal/testenv"
"github.com/diamondburned/arikawa/v2/state"
"github.com/diamondburned/arikawa/v2/utils/wsutil" "github.com/diamondburned/arikawa/v2/utils/wsutil"
"github.com/diamondburned/arikawa/v2/voice/voicegateway" "github.com/diamondburned/arikawa/v2/voice/voicegateway"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func TestIntegration(t *testing.T) { func TestIntegration(t *testing.T) {
config := mustConfig(t) config := testenv.Must(t)
wsutil.WSDebug = func(v ...interface{}) { wsutil.WSDebug = func(v ...interface{}) {
_, file, line, _ := runtime.Caller(1) _, file, line, _ := runtime.Caller(1)
@ -30,23 +31,19 @@ func TestIntegration(t *testing.T) {
log.Println(append([]interface{}{caller}, v...)...) log.Println(append([]interface{}{caller}, v...)...)
} }
v, err := NewFromToken("Bot " + config.BotToken) s, err := state.New("Bot " + config.BotToken)
if err != nil { if err != nil {
t.Fatal("Failed to create a new voice session:", err) t.Fatal("Failed to create a new state:", err)
} }
v.Gateway.AddIntents(gateway.IntentGuildVoiceStates) AddIntents(s.Gateway)
v.ErrorLog = func(err error) { if err := s.Open(); err != nil {
t.Error(err)
}
if err := v.Open(); err != nil {
t.Fatal("Failed to connect:", err) t.Fatal("Failed to connect:", err)
} }
t.Cleanup(func() { v.Close() }) t.Cleanup(func() { s.Close() })
// Validate the given voice channel. // Validate the given voice channel.
c, err := v.Channel(config.VoiceChID) c, err := s.Channel(config.VoiceChID)
if err != nil { if err != nil {
t.Fatal("Failed to get channel:", err) t.Fatal("Failed to get channel:", err)
} }
@ -56,43 +53,48 @@ func TestIntegration(t *testing.T) {
log.Println("The voice channel's name is", c.Name) log.Println("The voice channel's name is", c.Name)
v, err := NewSession(s)
if err != nil {
t.Fatal("Failed to create a new voice session:", err)
}
v.ErrorLog = func(err error) { t.Error(err) }
// Grab a timer to benchmark things. // Grab a timer to benchmark things.
finish := timer() finish := timer()
// Join the voice channel concurrently. // Add handler to receive speaking update beforehand.
raceValue := raceMe(t, "failed to join voice channel", func() (interface{}, error) { v.AddHandler(func(e *voicegateway.SpeakingEvent) {
return v.JoinChannel(c.ID, false, false) finish("receiving voice speaking event")
})
// Join the voice channel concurrently.
raceMe(t, "failed to join voice channel", func() (interface{}, error) {
return nil, v.JoinChannel(c.GuildID, c.ID, false, false)
}) })
vs := raceValue.(*Session)
t.Cleanup(func() { t.Cleanup(func() {
log.Println("Disconnecting from the voice channel concurrently.") log.Println("Leaving the voice channel concurrently.")
raceMe(t, "failed to disconnect", func() (interface{}, error) { raceMe(t, "failed to leave voice channel", func() (interface{}, error) {
return nil, vs.Disconnect() return nil, v.Leave()
}) })
}) })
finish("joining the voice channel") finish("joining the voice channel")
// Add handler to receive speaking update
vs.AddHandler(func(e *voicegateway.SpeakingEvent) {
finish("received voice speaking event")
})
// Create a context and only cancel it AFTER we're done sending silence // Create a context and only cancel it AFTER we're done sending silence
// frames. // frames.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel) t.Cleanup(cancel)
// Trigger speaking. // Trigger speaking.
if err := vs.Speaking(voicegateway.Microphone); err != nil { if err := v.Speaking(voicegateway.Microphone); err != nil {
t.Fatal("failed to start speaking:", err) t.Fatal("failed to start speaking:", err)
} }
finish("sending the speaking command") finish("sending the speaking command")
if err := vs.UseContext(ctx); err != nil { if err := v.UseContext(ctx); err != nil {
t.Fatal("failed to set ctx into vs:", err) t.Fatal("failed to set ctx into vs:", err)
} }
@ -117,7 +119,7 @@ func TestIntegration(t *testing.T) {
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:])) framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame. // Copy the frame.
if _, err := io.CopyN(vs, f, framelen); err != nil && err != io.EOF { if _, err := io.CopyN(v, f, framelen); err != nil && err != io.EOF {
t.Fatal("failed to write:", err) t.Fatal("failed to write:", err)
} }
} }
@ -165,39 +167,6 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf
return val return val
} }
type testConfig struct {
BotToken string
VoiceChID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
t.Fatal("Missing $VOICE_ID")
}
id, err := discord.ParseSnowflake(sid)
if err != nil {
t.Fatal("Invalid $VOICE_ID:", err)
}
return testConfig{
BotToken: token,
VoiceChID: discord.ChannelID(id),
}
}
// file is only a few bytes lolmao
func nicoReadTo(t *testing.T, dst io.Writer) {
t.Helper()
}
// simple shitty benchmark thing // simple shitty benchmark thing
func timer() func(finished string) { func timer() func(finished string) {
var then = time.Now() var then = time.Now()

View file

@ -1,235 +1,14 @@
// Package voice handles the Discord voice gateway and UDP connections, as well // Package voice handles the Discord voice gateway and UDP connections. It does
// as managing and keeping track of multiple voice sessions. // not handle book-keeping of those sessions.
// //
// This package abstracts the subpackage voice/voicesession and voice/udp. // This package abstracts the subpackage voice/voicesession and voice/udp.
package voice package voice
import ( import "github.com/diamondburned/arikawa/v2/gateway"
"context"
"log"
"strconv"
"sync"
"time"
"github.com/diamondburned/arikawa/v2/discord" // AddIntents adds the needed voice intents into gw. Bots should always call
"github.com/diamondburned/arikawa/v2/gateway" // this before Open if voice is required.
"github.com/diamondburned/arikawa/v2/state" func AddIntents(gw *gateway.Gateway) {
"github.com/pkg/errors" gw.AddIntents(gateway.IntentGuilds)
) gw.AddIntents(gateway.IntentGuildVoiceStates)
var (
// defaultErrorHandler is the default error handler
defaultErrorHandler = func(err error) { log.Println("voice gateway error:", err) }
// ErrCannotSend is an error when audio is sent to a closed channel.
ErrCannotSend = errors.New("cannot send audio to closed channel")
)
// Voice represents a Voice Repository used for managing voice sessions.
type Voice struct {
*state.State
// Session holds all of the active voice sessions.
mapmutex sync.Mutex
sessions map[discord.GuildID]*Session
// Callbacks to remove the handlers.
closers []func()
// ErrorLog will be called when an error occurs (defaults to log.Println)
ErrorLog func(err error)
}
// NewFromToken creates a new voice session from the given token.
func NewFromToken(token string) (*Voice, error) {
s, err := state.New(token)
if err != nil {
return nil, errors.Wrap(err, "failed to create a new session")
}
return New(s), nil
}
// New creates a new Voice repository wrapped around a state. The function will
// also automatically add the GuildVoiceStates intent, as that is required.
//
// This function will add the Guilds and GuildVoiceStates intents into the state
// in order to receive the needed events.
func New(s *state.State) *Voice {
// Register the voice intents.
s.Gateway.AddIntents(gateway.IntentGuilds)
s.Gateway.AddIntents(gateway.IntentGuildVoiceStates)
return NewWithoutIntents(s)
}
// NewWithoutIntents creates a new Voice repository wrapped around a state
// without modifying the given Gateway to add intents.
func NewWithoutIntents(s *state.State) *Voice {
v := &Voice{
State: s,
sessions: make(map[discord.GuildID]*Session),
ErrorLog: defaultErrorHandler,
}
// Add the required event handlers to the session.
v.closers = []func(){
s.AddHandler(v.onVoiceStateUpdate),
s.AddHandler(v.onVoiceServerUpdate),
}
return v
}
// onVoiceStateUpdate receives VoiceStateUpdateEvents from the gateway
// to keep track of the current user's voice state.
func (v *Voice) onVoiceStateUpdate(e *gateway.VoiceStateUpdateEvent) {
// Get the current user.
me, err := v.Me()
if err != nil {
v.ErrorLog(err)
return
}
// Ignore the event if it is an update from another user.
if me.ID != e.UserID {
return
}
// Get the stored voice session for the given guild.
vs, ok := v.GetSession(e.GuildID)
if !ok {
return
}
// Do what we must.
vs.UpdateState(e)
// Remove the connection if the current user has disconnected.
if e.ChannelID == 0 {
v.RemoveSession(e.GuildID)
}
}
// onVoiceServerUpdate receives VoiceServerUpdateEvents from the gateway
// to manage the current user's voice connections.
func (v *Voice) onVoiceServerUpdate(e *gateway.VoiceServerUpdateEvent) {
// Get the stored voice session for the given guild.
vs, ok := v.GetSession(e.GuildID)
if !ok {
return
}
// Do what we must.
vs.UpdateServer(e)
}
// GetSession gets a session for a guild with a read lock.
func (v *Voice) GetSession(guildID discord.GuildID) (*Session, bool) {
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
// For some reason you cannot just put `return v.sessions[]` and return a bool D:
conn, ok := v.sessions[guildID]
return conn, ok
}
// RemoveSession removes a session.
func (v *Voice) RemoveSession(guildID discord.GuildID) {
v.mapmutex.Lock()
ses, ok := v.sessions[guildID]
if !ok {
v.mapmutex.Unlock()
return
}
delete(v.sessions, guildID)
v.mapmutex.Unlock()
// Ensure that the session is disconnected.
ses.Disconnect()
}
// JoinChannel joins the specified channel in the specified guild.
func (v *Voice) JoinChannel(cID discord.ChannelID, muted, deafened bool) (*Session, error) {
c, err := v.Cabinet.Channel(cID)
if err != nil {
return nil, errors.Wrap(err, "failed to get channel from state")
}
// Get the stored voice session for the given guild.
conn, ok := v.GetSession(c.GuildID)
// Create a new voice session if one does not exist.
if !ok {
u, err := v.Me()
if err != nil {
return nil, errors.Wrap(err, "failed to get self")
}
conn = NewSession(v.Session, u.ID)
conn.ErrorLog = v.ErrorLog
v.mapmutex.Lock()
v.sessions[c.GuildID] = conn
v.mapmutex.Unlock()
}
// Connect.
return conn, conn.JoinChannel(c.GuildID, cID, muted, deafened)
}
func (v *Voice) Close() error {
err := &CloseError{
SessionErrors: make(map[discord.GuildID]error),
}
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
// Remove all callback handlers.
for _, fn := range v.closers {
fn()
}
for gID, s := range v.sessions {
log.Println("closing", gID)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
if dErr := s.DisconnectCtx(ctx); dErr != nil {
err.SessionErrors[gID] = dErr
}
cancel()
log.Println("closed", gID)
}
err.StateErr = v.State.Close()
if err.HasError() {
return err
}
return nil
}
type CloseError struct {
SessionErrors map[discord.GuildID]error
StateErr error
}
func (e *CloseError) HasError() bool {
if e.StateErr != nil {
return true
}
return len(e.SessionErrors) > 0
}
func (e *CloseError) Error() string {
if e.StateErr != nil {
return e.StateErr.Error()
}
if len(e.SessionErrors) < 1 {
return ""
}
return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect"
} }

View file

@ -99,7 +99,8 @@ func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
type SpeakingFlag uint64 type SpeakingFlag uint64
const ( const (
Microphone SpeakingFlag = 1 << iota NotSpeaking SpeakingFlag = 0
Microphone SpeakingFlag = 1 << iota
Soundshare Soundshare
Priority Priority
) )