Voice: Remove state-keeping of sessions

This commit gets rid of all the code that previously managed different
voice sessions in different guilds. This is because there is rarely ever
a need for this, and most bots that need this could do their own
keeping.

This change, although removes some features off of the package, adds a
lot of clarity on what to do exactly when it comes to connecting to a
voice channel.

In order to make the migration process a bit easier, an example has been
added which guides through using the voice.Session API.
This commit is contained in:
diamondburned 2020-11-30 17:49:18 -08:00
parent 6727f0e728
commit b8994ed0da
9 changed files with 248 additions and 355 deletions

View File

@ -31,7 +31,7 @@
"variables": [ "$BOT_TOKEN" ]
},
"script": [
"go test -coverprofile $COV -race ./...",
"go test -coverprofile $COV -tags unitonly -race ./...",
"go tool cover -func $COV"
]
},
@ -47,7 +47,7 @@
"go get ./...",
# Test this package along with dismock.
"go get $dismock@$dismock_v",
"go test -coverpkg $tested -coverprofile $COV -tags integration -race ./... $dismock",
"go test -coverpkg $tested -coverprofile $COV -race ./... $dismock",
"go mod tidy",
"go tool cover -func $COV"
]

View File

@ -1,46 +1,18 @@
// +build integration
// +build !unitonly
package api
import (
"fmt"
"log"
"os"
"testing"
"time"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/internal/testenv"
)
type testConfig struct {
BotToken string
ChannelID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var cid = os.Getenv("CHANNEL_ID")
if cid == "" {
t.Fatal("Missing $CHANNEL_ID")
}
id, err := discord.ParseSnowflake(cid)
if err != nil {
t.Fatal("Invalid $CHANNEL_ID:", err)
}
return testConfig{
BotToken: token,
ChannelID: discord.ChannelID(id),
}
}
func TestIntegration(t *testing.T) {
cfg := mustConfig(t)
cfg := testenv.Must(t)
client := NewClient("Bot " + cfg.BotToken)
@ -81,7 +53,7 @@ var emojisToSend = [...]string{
}
func TestReactions(t *testing.T) {
cfg := mustConfig(t)
cfg := testenv.Must(t)
client := NewClient("Bot " + cfg.BotToken)

View File

@ -1,16 +1,16 @@
// +build integration
// +build !unitonly
package gateway
import (
"context"
"log"
"os"
"strings"
"testing"
"time"
"github.com/diamondburned/arikawa/v2/internal/heart"
"github.com/diamondburned/arikawa/v2/internal/testenv"
"github.com/diamondburned/arikawa/v2/utils/wsutil"
)
@ -43,10 +43,7 @@ func TestInvalidToken(t *testing.T) {
}
func TestIntegration(t *testing.T) {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
config := testenv.Must(t)
wsutil.WSError = func(err error) {
t.Error(err)
@ -55,7 +52,7 @@ func TestIntegration(t *testing.T) {
var gateway *Gateway
// NewGateway should call Start for us.
g, err := NewGateway("Bot " + token)
g, err := NewGateway("Bot " + config.BotToken)
if err != nil {
t.Fatal("Failed to make a Gateway:", err)
}

View File

@ -0,0 +1,75 @@
// +build !uintonly
package testenv
import (
"os"
"sync"
"testing"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/pkg/errors"
)
type Env struct {
BotToken string
ChannelID discord.ChannelID
VoiceChID discord.ChannelID
}
var (
env Env
err error
once sync.Once
)
func Must(t *testing.T) Env {
e, err := GetEnv()
if err != nil {
t.Fatal(err)
}
return e
}
func GetEnv() (Env, error) {
once.Do(getEnv)
return env, err
}
func getEnv() {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
err = errors.New("missing $BOT_TOKEN")
return
}
var cid = os.Getenv("CHANNEL_ID")
if cid == "" {
err = errors.New("missing $CHANNEL_ID")
return
}
id, err := discord.ParseSnowflake(cid)
if err != nil {
err = errors.Wrap(err, "invalid $CHANNEL_ID")
return
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
err = errors.New("missing $VOICE_ID")
return
}
vid, err := discord.ParseSnowflake(sid)
if err != nil {
err = errors.Wrap(err, "invalid $VOICE_ID")
return
}
env = Env{
BotToken: token,
ChannelID: discord.ChannelID(id),
VoiceChID: discord.ChannelID(vid),
}
}

View File

@ -5,6 +5,7 @@ import (
"sync"
"time"
"github.com/diamondburned/arikawa/v2/state"
"github.com/diamondburned/arikawa/v2/utils/handler"
"github.com/pkg/errors"
@ -19,21 +20,28 @@ import (
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
)
// Protocol is the encryption protocol that this library uses.
const Protocol = "xsalsa20_poly1305"
// ErrAlreadyConnecting is returned when the session is already connecting.
var ErrAlreadyConnecting = errors.New("already connecting")
// ErrCannotSend is an error when audio is sent to a closed channel.
var ErrCannotSend = errors.New("cannot send audio to closed channel")
// WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't
// take in a context already.
var WSTimeout = 10 * time.Second
// Session is a single voice session that wraps around the voice gateway and UDP
// connection.
type Session struct {
*handler.Handler
ErrorLog func(err error)
session *session.Session
cancels []func()
looper *handleloop.Loop
// joining determines the behavior of incoming event callbacks (Update).
@ -51,13 +59,24 @@ type Session struct {
voiceUDP *udp.Connection
}
func NewSession(ses *session.Session, userID discord.UserID) *Session {
handler := handler.New()
looper := handleloop.NewLoop(handler)
// NewSession creates a new voice session for the current user.
func NewSession(state *state.State) (*Session, error) {
u, err := state.Me()
if err != nil {
return nil, errors.Wrap(err, "failed to get me")
}
return &Session{
return NewSessionCustom(state.Session, u.ID), nil
}
// NewSessionCustom creates a new voice session from the given session and user
// ID.
func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
handler := handler.New()
hlooper := handleloop.NewLoop(handler)
session := &Session{
Handler: handler,
looper: looper,
looper: hlooper,
session: ses,
state: voicegateway.State{
UserID: userID,
@ -65,9 +84,15 @@ func NewSession(ses *session.Session, userID discord.UserID) *Session {
ErrorLog: func(err error) {},
incoming: make(chan struct{}, 2),
}
session.cancels = []func(){
ses.AddHandler(session.updateServer),
ses.AddHandler(session.updateState),
}
return session
}
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
// If this is true, then mutex is acquired already.
if s.joining.Get() {
if s.state.GuildID != ev.GuildID {
@ -101,7 +126,7 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
}
}
func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
if s.state.UserID != ev.UserID { // constant so no mutex
// Not our state.
return
@ -109,6 +134,10 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
// If this is true, then mutex is acquired already.
if s.joining.Get() {
if s.state.GuildID != ev.GuildID {
return
}
s.state.SessionID = ev.SessionID
s.state.ChannelID = ev.ChannelID
@ -117,20 +146,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
}
}
func (s *Session) JoinChannel(
gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
}
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
// directly, but rather Voice's. If this method is called concurrently, an error
// will be returned.
// directly, but rather Voice's. This method shouldn't ever be called
// concurrently.
func (s *Session) JoinChannelCtx(
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
if s.joining.Get() {
return ErrAlreadyConnecting
@ -162,8 +189,8 @@ func (s *Session) JoinChannelCtx(
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
GuildID: gID,
ChannelID: channelID,
SelfMute: muted,
SelfDeaf: deafened,
SelfMute: mute,
SelfDeaf: deaf,
})
if err != nil {
return errors.Wrap(err, "failed to send Voice State Update event")
@ -247,7 +274,7 @@ func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
return gateway.Speaking(flag)
}
// UseContext tells the UDP voice connection to write with the given mutex.
// UseContext tells the UDP voice connection to write with the given context.
func (s *Session) UseContext(ctx context.Context) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -287,14 +314,17 @@ func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
return voiceUDP.WriteCtx(ctx, b)
}
func (s *Session) Disconnect() error {
// Leave disconnects the current voice session from the currently connected
// channel.
func (s *Session) Leave() error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.DisconnectCtx(ctx)
return s.LeaveCtx(ctx)
}
func (s *Session) DisconnectCtx(ctx context.Context) error {
// LeaveCtx disconencts with a context. Refer to Leave for more information.
func (s *Session) LeaveCtx(ctx context.Context) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -340,9 +370,9 @@ func (s *Session) ensureClosed() {
}
}
// ReadPacket reads a single packet from the UDP connection.
// This is NOT at all thread safe, and must be used very carefully.
// The backing buffer is always reused.
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
// thread safe, and must be used very carefully. The backing buffer is always
// reused.
func (s *Session) ReadPacket() (*udp.Packet, error) {
return s.voiceUDP.ReadPacket()
}

View File

@ -0,0 +1,70 @@
// +build !unitonly
package voice_test
import (
"io"
"log"
"testing"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/internal/testenv"
"github.com/diamondburned/arikawa/v2/state"
"github.com/diamondburned/arikawa/v2/voice"
)
var (
token string
channelID discord.ChannelID
)
func init() {
e, err := testenv.GetEnv()
if err == nil {
token = e.BotToken
channelID = e.VoiceChID
}
}
// pseudo function for example
func writeOpusInto(w io.Writer) {}
// make godoc not show the full file
func TestNoop(t *testing.T) {
t.Skip("noop")
}
func ExampleSession() {
s, err := state.New("Bot " + token)
if err != nil {
log.Fatalln("failed to make state:", err)
}
// This is required for bots.
voice.AddIntents(s.Gateway)
if err := s.Open(); err != nil {
log.Fatalln("failed to open gateway:", err)
}
defer s.Close()
c, err := s.Channel(channelID)
if err != nil {
log.Fatalln("failed to get channel:", err)
}
v, err := voice.NewSession(s)
if err != nil {
log.Fatalln("failed to create voice session:", err)
}
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
log.Fatalln("failed to join voice channel:", err)
}
defer v.Leave()
// Start writing Opus frames.
for {
writeOpusInto(v)
}
}

View File

@ -1,4 +1,4 @@
// +build integration
// +build !unitonly
package voice
@ -15,14 +15,15 @@ import (
"time"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/gateway"
"github.com/diamondburned/arikawa/v2/internal/testenv"
"github.com/diamondburned/arikawa/v2/state"
"github.com/diamondburned/arikawa/v2/utils/wsutil"
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
"github.com/pkg/errors"
)
func TestIntegration(t *testing.T) {
config := mustConfig(t)
config := testenv.Must(t)
wsutil.WSDebug = func(v ...interface{}) {
_, file, line, _ := runtime.Caller(1)
@ -30,23 +31,19 @@ func TestIntegration(t *testing.T) {
log.Println(append([]interface{}{caller}, v...)...)
}
v, err := NewFromToken("Bot " + config.BotToken)
s, err := state.New("Bot " + config.BotToken)
if err != nil {
t.Fatal("Failed to create a new voice session:", err)
t.Fatal("Failed to create a new state:", err)
}
v.Gateway.AddIntents(gateway.IntentGuildVoiceStates)
AddIntents(s.Gateway)
v.ErrorLog = func(err error) {
t.Error(err)
}
if err := v.Open(); err != nil {
if err := s.Open(); err != nil {
t.Fatal("Failed to connect:", err)
}
t.Cleanup(func() { v.Close() })
t.Cleanup(func() { s.Close() })
// Validate the given voice channel.
c, err := v.Channel(config.VoiceChID)
c, err := s.Channel(config.VoiceChID)
if err != nil {
t.Fatal("Failed to get channel:", err)
}
@ -56,43 +53,48 @@ func TestIntegration(t *testing.T) {
log.Println("The voice channel's name is", c.Name)
v, err := NewSession(s)
if err != nil {
t.Fatal("Failed to create a new voice session:", err)
}
v.ErrorLog = func(err error) { t.Error(err) }
// Grab a timer to benchmark things.
finish := timer()
// Join the voice channel concurrently.
raceValue := raceMe(t, "failed to join voice channel", func() (interface{}, error) {
return v.JoinChannel(c.ID, false, false)
// Add handler to receive speaking update beforehand.
v.AddHandler(func(e *voicegateway.SpeakingEvent) {
finish("receiving voice speaking event")
})
// Join the voice channel concurrently.
raceMe(t, "failed to join voice channel", func() (interface{}, error) {
return nil, v.JoinChannel(c.GuildID, c.ID, false, false)
})
vs := raceValue.(*Session)
t.Cleanup(func() {
log.Println("Disconnecting from the voice channel concurrently.")
log.Println("Leaving the voice channel concurrently.")
raceMe(t, "failed to disconnect", func() (interface{}, error) {
return nil, vs.Disconnect()
raceMe(t, "failed to leave voice channel", func() (interface{}, error) {
return nil, v.Leave()
})
})
finish("joining the voice channel")
// Add handler to receive speaking update
vs.AddHandler(func(e *voicegateway.SpeakingEvent) {
finish("received voice speaking event")
})
// Create a context and only cancel it AFTER we're done sending silence
// frames.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
// Trigger speaking.
if err := vs.Speaking(voicegateway.Microphone); err != nil {
if err := v.Speaking(voicegateway.Microphone); err != nil {
t.Fatal("failed to start speaking:", err)
}
finish("sending the speaking command")
if err := vs.UseContext(ctx); err != nil {
if err := v.UseContext(ctx); err != nil {
t.Fatal("failed to set ctx into vs:", err)
}
@ -117,7 +119,7 @@ func TestIntegration(t *testing.T) {
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(vs, f, framelen); err != nil && err != io.EOF {
if _, err := io.CopyN(v, f, framelen); err != nil && err != io.EOF {
t.Fatal("failed to write:", err)
}
}
@ -165,39 +167,6 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf
return val
}
type testConfig struct {
BotToken string
VoiceChID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
t.Fatal("Missing $VOICE_ID")
}
id, err := discord.ParseSnowflake(sid)
if err != nil {
t.Fatal("Invalid $VOICE_ID:", err)
}
return testConfig{
BotToken: token,
VoiceChID: discord.ChannelID(id),
}
}
// file is only a few bytes lolmao
func nicoReadTo(t *testing.T, dst io.Writer) {
t.Helper()
}
// simple shitty benchmark thing
func timer() func(finished string) {
var then = time.Now()

View File

@ -1,235 +1,14 @@
// Package voice handles the Discord voice gateway and UDP connections, as well
// as managing and keeping track of multiple voice sessions.
// Package voice handles the Discord voice gateway and UDP connections. It does
// not handle book-keeping of those sessions.
//
// This package abstracts the subpackage voice/voicesession and voice/udp.
package voice
import (
"context"
"log"
"strconv"
"sync"
"time"
import "github.com/diamondburned/arikawa/v2/gateway"
"github.com/diamondburned/arikawa/v2/discord"
"github.com/diamondburned/arikawa/v2/gateway"
"github.com/diamondburned/arikawa/v2/state"
"github.com/pkg/errors"
)
var (
// defaultErrorHandler is the default error handler
defaultErrorHandler = func(err error) { log.Println("voice gateway error:", err) }
// ErrCannotSend is an error when audio is sent to a closed channel.
ErrCannotSend = errors.New("cannot send audio to closed channel")
)
// Voice represents a Voice Repository used for managing voice sessions.
type Voice struct {
*state.State
// Session holds all of the active voice sessions.
mapmutex sync.Mutex
sessions map[discord.GuildID]*Session
// Callbacks to remove the handlers.
closers []func()
// ErrorLog will be called when an error occurs (defaults to log.Println)
ErrorLog func(err error)
}
// NewFromToken creates a new voice session from the given token.
func NewFromToken(token string) (*Voice, error) {
s, err := state.New(token)
if err != nil {
return nil, errors.Wrap(err, "failed to create a new session")
}
return New(s), nil
}
// New creates a new Voice repository wrapped around a state. The function will
// also automatically add the GuildVoiceStates intent, as that is required.
//
// This function will add the Guilds and GuildVoiceStates intents into the state
// in order to receive the needed events.
func New(s *state.State) *Voice {
// Register the voice intents.
s.Gateway.AddIntents(gateway.IntentGuilds)
s.Gateway.AddIntents(gateway.IntentGuildVoiceStates)
return NewWithoutIntents(s)
}
// NewWithoutIntents creates a new Voice repository wrapped around a state
// without modifying the given Gateway to add intents.
func NewWithoutIntents(s *state.State) *Voice {
v := &Voice{
State: s,
sessions: make(map[discord.GuildID]*Session),
ErrorLog: defaultErrorHandler,
}
// Add the required event handlers to the session.
v.closers = []func(){
s.AddHandler(v.onVoiceStateUpdate),
s.AddHandler(v.onVoiceServerUpdate),
}
return v
}
// onVoiceStateUpdate receives VoiceStateUpdateEvents from the gateway
// to keep track of the current user's voice state.
func (v *Voice) onVoiceStateUpdate(e *gateway.VoiceStateUpdateEvent) {
// Get the current user.
me, err := v.Me()
if err != nil {
v.ErrorLog(err)
return
}
// Ignore the event if it is an update from another user.
if me.ID != e.UserID {
return
}
// Get the stored voice session for the given guild.
vs, ok := v.GetSession(e.GuildID)
if !ok {
return
}
// Do what we must.
vs.UpdateState(e)
// Remove the connection if the current user has disconnected.
if e.ChannelID == 0 {
v.RemoveSession(e.GuildID)
}
}
// onVoiceServerUpdate receives VoiceServerUpdateEvents from the gateway
// to manage the current user's voice connections.
func (v *Voice) onVoiceServerUpdate(e *gateway.VoiceServerUpdateEvent) {
// Get the stored voice session for the given guild.
vs, ok := v.GetSession(e.GuildID)
if !ok {
return
}
// Do what we must.
vs.UpdateServer(e)
}
// GetSession gets a session for a guild with a read lock.
func (v *Voice) GetSession(guildID discord.GuildID) (*Session, bool) {
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
// For some reason you cannot just put `return v.sessions[]` and return a bool D:
conn, ok := v.sessions[guildID]
return conn, ok
}
// RemoveSession removes a session.
func (v *Voice) RemoveSession(guildID discord.GuildID) {
v.mapmutex.Lock()
ses, ok := v.sessions[guildID]
if !ok {
v.mapmutex.Unlock()
return
}
delete(v.sessions, guildID)
v.mapmutex.Unlock()
// Ensure that the session is disconnected.
ses.Disconnect()
}
// JoinChannel joins the specified channel in the specified guild.
func (v *Voice) JoinChannel(cID discord.ChannelID, muted, deafened bool) (*Session, error) {
c, err := v.Cabinet.Channel(cID)
if err != nil {
return nil, errors.Wrap(err, "failed to get channel from state")
}
// Get the stored voice session for the given guild.
conn, ok := v.GetSession(c.GuildID)
// Create a new voice session if one does not exist.
if !ok {
u, err := v.Me()
if err != nil {
return nil, errors.Wrap(err, "failed to get self")
}
conn = NewSession(v.Session, u.ID)
conn.ErrorLog = v.ErrorLog
v.mapmutex.Lock()
v.sessions[c.GuildID] = conn
v.mapmutex.Unlock()
}
// Connect.
return conn, conn.JoinChannel(c.GuildID, cID, muted, deafened)
}
func (v *Voice) Close() error {
err := &CloseError{
SessionErrors: make(map[discord.GuildID]error),
}
v.mapmutex.Lock()
defer v.mapmutex.Unlock()
// Remove all callback handlers.
for _, fn := range v.closers {
fn()
}
for gID, s := range v.sessions {
log.Println("closing", gID)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
if dErr := s.DisconnectCtx(ctx); dErr != nil {
err.SessionErrors[gID] = dErr
}
cancel()
log.Println("closed", gID)
}
err.StateErr = v.State.Close()
if err.HasError() {
return err
}
return nil
}
type CloseError struct {
SessionErrors map[discord.GuildID]error
StateErr error
}
func (e *CloseError) HasError() bool {
if e.StateErr != nil {
return true
}
return len(e.SessionErrors) > 0
}
func (e *CloseError) Error() string {
if e.StateErr != nil {
return e.StateErr.Error()
}
if len(e.SessionErrors) < 1 {
return ""
}
return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect"
// AddIntents adds the needed voice intents into gw. Bots should always call
// this before Open if voice is required.
func AddIntents(gw *gateway.Gateway) {
gw.AddIntents(gateway.IntentGuilds)
gw.AddIntents(gateway.IntentGuildVoiceStates)
}

View File

@ -99,7 +99,8 @@ func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
type SpeakingFlag uint64
const (
Microphone SpeakingFlag = 1 << iota
NotSpeaking SpeakingFlag = 0
Microphone SpeakingFlag = 1 << iota
Soundshare
Priority
)