arikawa/voice/session_test.go

301 lines
6.6 KiB
Go

package voice
import (
"context"
"log"
"math/rand"
"os"
"runtime"
"strconv"
"sync"
"testing"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/ws"
"github.com/diamondburned/arikawa/v3/voice/testdata"
"github.com/diamondburned/arikawa/v3/voice/udp"
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
"github.com/pkg/errors"
)
func TestMain(m *testing.M) {
ws.WSDebug = func(v ...interface{}) {
_, file, line, _ := runtime.Caller(1)
caller := file + ":" + strconv.Itoa(line)
log.Println(append([]interface{}{caller}, v...)...)
}
code := m.Run()
os.Exit(code)
}
type testState struct {
*state.State
channel *discord.Channel
}
func testOpen(t *testing.T) *testState {
config := testenv.Must(t)
s := state.New("Bot " + config.BotToken)
AddIntents(s)
func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if err := s.Open(ctx); err != nil {
t.Fatal("failed to connect:", err)
}
}()
t.Cleanup(func() { s.Close() })
// Validate the given voice channel.
c, err := s.Channel(config.VoiceChID)
if err != nil {
t.Fatal("failed to get channel:", err)
}
if c.Type != discord.GuildVoice {
t.Fatal("channel isn't a guild voice channel.")
}
t.Log("The voice channel's name is", c.Name)
return &testState{
State: s,
channel: c,
}
}
func TestIntegration(t *testing.T) {
state := testOpen(t)
t.Run("1st", func(t *testing.T) { testIntegrationOnce(t, state) })
t.Run("2nd", func(t *testing.T) { testIntegrationOnce(t, state) })
}
func testIntegrationOnce(t *testing.T, s *testState) {
v, err := NewSession(s)
if err != nil {
t.Fatal("failed to create a new voice session:", err)
}
// Grab a timer to benchmark things.
finish := timer()
// Add handler to receive speaking update beforehand.
v.AddHandler(func(e *voicegateway.SpeakingEvent) {
finish("receiving voice speaking event")
})
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
t.Cleanup(cancel)
if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil {
t.Fatal("failed to join voice:", err)
}
t.Cleanup(func() {
t.Log("Leaving the voice channel concurrently.")
raceMe(t, "failed to leave voice channel", func() error {
return v.Leave(ctx)
})
})
finish("joining the voice channel")
t.Cleanup(func() {})
finish("sending the speaking command")
doneCh := make(chan struct{})
go func() {
if err := testdata.WriteOpus(v, testdata.Nico); err != nil {
t.Error(err)
}
doneCh <- struct{}{}
}()
select {
case <-ctx.Done():
t.Error("timed out waiting for voice to be done")
case <-doneCh:
finish("copying the audio")
}
}
// raceMe intentionally calls fn multiple times in goroutines to ensure it's not
// racy.
func raceMe(t *testing.T, wrapErr string, fn func() error) {
const n = 3 // run 3 times
t.Helper()
// It is very ironic how this method itself is racy.
var wgr sync.WaitGroup
var mut sync.Mutex
var err error
for i := 0; i < n; i++ {
wgr.Add(1)
go func() {
e := fn()
mut.Lock()
if e != nil {
err = e
}
mut.Unlock()
if e != nil {
t.Log("Potential race test error:", e)
}
wgr.Done()
}()
}
wgr.Wait()
if err != nil {
t.Fatal("race test failed:", errors.Wrap(err, wrapErr))
}
}
// simple shitty benchmark thing
func timer() func(finished string) {
var then = time.Now()
return func(finished string) {
now := time.Now()
log.Println("Finished", finished+", took", now.Sub(then))
then = now
}
}
func TestKickedOut(t *testing.T) {
err := testReconnect(t, func(s *testState) {
me, err := s.Me()
if err != nil {
t.Fatal("cannot get me")
}
if err := s.ModifyMember(s.channel.GuildID, me.ID, api.ModifyMemberData{
// Kick the bot out.
VoiceChannel: discord.NullChannelID,
}); err != nil {
t.Error("cannot kick the bot out:", err)
}
})
if !errors.Is(err, udp.ErrManagerClosed) {
t.Error("unexpected error while sending nico.dca:", err)
}
}
func TestRegionChange(t *testing.T) {
var state *testState
err := testReconnect(t, func(s *testState) {
state = s
t.Log("got voice region", s.channel.RTCRegionID)
regions, err := s.VoiceRegionsGuild(s.channel.GuildID)
if err != nil {
t.Error("cannot get voice region:", err)
return
}
rand.Shuffle(len(regions), func(i, j int) {
regions[i], regions[j] = regions[j], regions[i]
})
var anyRegion string
for _, region := range regions {
if region.ID != s.channel.RTCRegionID {
anyRegion = region.ID
break
}
}
t.Log("changing voice region to", anyRegion)
if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{
RTCRegionID: option.NewNullableString(anyRegion),
}); err != nil {
t.Error("cannot change voice region:", err)
}
})
if err != nil {
t.Error("unexpected error while sending nico.dca:", err)
}
s := state
// Change voice region back.
if err := s.ModifyChannel(s.channel.ID, api.ModifyChannelData{
RTCRegionID: option.NewNullableString(s.channel.RTCRegionID),
}); err != nil {
t.Error("cannot change voice region back:", err)
}
t.Log("changed voice region back to", s.channel.RTCRegionID)
}
func testReconnect(t *testing.T, interrupt func(*testState)) error {
s := testOpen(t)
v, err := NewSession(s)
if err != nil {
t.Fatal("cannot")
}
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
t.Cleanup(cancel)
if err := v.JoinChannelAndSpeak(ctx, s.channel.ID, false, false); err != nil {
t.Fatal("failed to join voice:", err)
}
t.Cleanup(func() {
if err := v.Speaking(ctx, voicegateway.NotSpeaking); err != nil {
t.Error("cannot stop speaking:", err)
}
if err := v.Leave(ctx); err != nil {
t.Error("cannot leave voice:", err)
}
})
// Ensure the channel is buffered so we can send into it. Write may not be
// called often enough to immediately receive a tick from the unbuffered
// timer.
oneSec := make(chan struct{}, 1)
go func() {
<-time.After(450 * time.Millisecond)
oneSec <- struct{}{}
}()
// Use a WriterFunc so we can interrupt the writing.
// Give 1s for the function to write before interrupting it; we already know
// that the saved dca file is longer than 1s, so we're fine doing this.
interruptWriter := testdata.WriterFunc(func(b []byte) (int, error) {
select {
case <-oneSec:
interrupt(s)
default:
// ok
}
return v.Write(b)
})
return testdata.WriteOpus(interruptWriter, testdata.Nico)
}