mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-07 12:38:05 +00:00
Voice: Remove state-keeping of sessions
This commit gets rid of all the code that previously managed different voice sessions in different guilds. This is because there is rarely ever a need for this, and most bots that need this could do their own keeping. This change, although removes some features off of the package, adds a lot of clarity on what to do exactly when it comes to connecting to a voice channel. In order to make the migration process a bit easier, an example has been added which guides through using the voice.Session API.
This commit is contained in:
parent
6727f0e728
commit
b8994ed0da
|
@ -31,7 +31,7 @@
|
||||||
"variables": [ "$BOT_TOKEN" ]
|
"variables": [ "$BOT_TOKEN" ]
|
||||||
},
|
},
|
||||||
"script": [
|
"script": [
|
||||||
"go test -coverprofile $COV -race ./...",
|
"go test -coverprofile $COV -tags unitonly -race ./...",
|
||||||
"go tool cover -func $COV"
|
"go tool cover -func $COV"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -47,7 +47,7 @@
|
||||||
"go get ./...",
|
"go get ./...",
|
||||||
# Test this package along with dismock.
|
# Test this package along with dismock.
|
||||||
"go get $dismock@$dismock_v",
|
"go get $dismock@$dismock_v",
|
||||||
"go test -coverpkg $tested -coverprofile $COV -tags integration -race ./... $dismock",
|
"go test -coverpkg $tested -coverprofile $COV -race ./... $dismock",
|
||||||
"go mod tidy",
|
"go mod tidy",
|
||||||
"go tool cover -func $COV"
|
"go tool cover -func $COV"
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,46 +1,18 @@
|
||||||
// +build integration
|
// +build !unitonly
|
||||||
|
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/v2/discord"
|
"github.com/diamondburned/arikawa/v2/internal/testenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testConfig struct {
|
|
||||||
BotToken string
|
|
||||||
ChannelID discord.ChannelID
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustConfig(t *testing.T) testConfig {
|
|
||||||
var token = os.Getenv("BOT_TOKEN")
|
|
||||||
if token == "" {
|
|
||||||
t.Fatal("Missing $BOT_TOKEN")
|
|
||||||
}
|
|
||||||
|
|
||||||
var cid = os.Getenv("CHANNEL_ID")
|
|
||||||
if cid == "" {
|
|
||||||
t.Fatal("Missing $CHANNEL_ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := discord.ParseSnowflake(cid)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Invalid $CHANNEL_ID:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return testConfig{
|
|
||||||
BotToken: token,
|
|
||||||
ChannelID: discord.ChannelID(id),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration(t *testing.T) {
|
func TestIntegration(t *testing.T) {
|
||||||
cfg := mustConfig(t)
|
cfg := testenv.Must(t)
|
||||||
|
|
||||||
client := NewClient("Bot " + cfg.BotToken)
|
client := NewClient("Bot " + cfg.BotToken)
|
||||||
|
|
||||||
|
@ -81,7 +53,7 @@ var emojisToSend = [...]string{
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReactions(t *testing.T) {
|
func TestReactions(t *testing.T) {
|
||||||
cfg := mustConfig(t)
|
cfg := testenv.Must(t)
|
||||||
|
|
||||||
client := NewClient("Bot " + cfg.BotToken)
|
client := NewClient("Bot " + cfg.BotToken)
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
// +build integration
|
// +build !unitonly
|
||||||
|
|
||||||
package gateway
|
package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/v2/internal/heart"
|
"github.com/diamondburned/arikawa/v2/internal/heart"
|
||||||
|
"github.com/diamondburned/arikawa/v2/internal/testenv"
|
||||||
"github.com/diamondburned/arikawa/v2/utils/wsutil"
|
"github.com/diamondburned/arikawa/v2/utils/wsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,10 +43,7 @@ func TestInvalidToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegration(t *testing.T) {
|
func TestIntegration(t *testing.T) {
|
||||||
var token = os.Getenv("BOT_TOKEN")
|
config := testenv.Must(t)
|
||||||
if token == "" {
|
|
||||||
t.Fatal("Missing $BOT_TOKEN")
|
|
||||||
}
|
|
||||||
|
|
||||||
wsutil.WSError = func(err error) {
|
wsutil.WSError = func(err error) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
|
@ -55,7 +52,7 @@ func TestIntegration(t *testing.T) {
|
||||||
var gateway *Gateway
|
var gateway *Gateway
|
||||||
|
|
||||||
// NewGateway should call Start for us.
|
// NewGateway should call Start for us.
|
||||||
g, err := NewGateway("Bot " + token)
|
g, err := NewGateway("Bot " + config.BotToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to make a Gateway:", err)
|
t.Fatal("Failed to make a Gateway:", err)
|
||||||
}
|
}
|
||||||
|
|
75
internal/testenv/testenv.go
Normal file
75
internal/testenv/testenv.go
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
// +build !uintonly
|
||||||
|
|
||||||
|
package testenv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/diamondburned/arikawa/v2/discord"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Env struct {
|
||||||
|
BotToken string
|
||||||
|
ChannelID discord.ChannelID
|
||||||
|
VoiceChID discord.ChannelID
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
env Env
|
||||||
|
err error
|
||||||
|
once sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
func Must(t *testing.T) Env {
|
||||||
|
e, err := GetEnv()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEnv() (Env, error) {
|
||||||
|
once.Do(getEnv)
|
||||||
|
return env, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnv() {
|
||||||
|
var token = os.Getenv("BOT_TOKEN")
|
||||||
|
if token == "" {
|
||||||
|
err = errors.New("missing $BOT_TOKEN")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cid = os.Getenv("CHANNEL_ID")
|
||||||
|
if cid == "" {
|
||||||
|
err = errors.New("missing $CHANNEL_ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := discord.ParseSnowflake(cid)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.Wrap(err, "invalid $CHANNEL_ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var sid = os.Getenv("VOICE_ID")
|
||||||
|
if sid == "" {
|
||||||
|
err = errors.New("missing $VOICE_ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vid, err := discord.ParseSnowflake(sid)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.Wrap(err, "invalid $VOICE_ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
env = Env{
|
||||||
|
BotToken: token,
|
||||||
|
ChannelID: discord.ChannelID(id),
|
||||||
|
VoiceChID: discord.ChannelID(vid),
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/diamondburned/arikawa/v2/state"
|
||||||
"github.com/diamondburned/arikawa/v2/utils/handler"
|
"github.com/diamondburned/arikawa/v2/utils/handler"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -19,21 +20,28 @@ import (
|
||||||
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
|
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Protocol is the encryption protocol that this library uses.
|
||||||
const Protocol = "xsalsa20_poly1305"
|
const Protocol = "xsalsa20_poly1305"
|
||||||
|
|
||||||
// ErrAlreadyConnecting is returned when the session is already connecting.
|
// ErrAlreadyConnecting is returned when the session is already connecting.
|
||||||
var ErrAlreadyConnecting = errors.New("already connecting")
|
var ErrAlreadyConnecting = errors.New("already connecting")
|
||||||
|
|
||||||
|
// ErrCannotSend is an error when audio is sent to a closed channel.
|
||||||
|
var ErrCannotSend = errors.New("cannot send audio to closed channel")
|
||||||
|
|
||||||
// WSTimeout is the duration to wait for a gateway operation including Session
|
// WSTimeout is the duration to wait for a gateway operation including Session
|
||||||
// to complete before erroring out. This only applies to functions that don't
|
// to complete before erroring out. This only applies to functions that don't
|
||||||
// take in a context already.
|
// take in a context already.
|
||||||
var WSTimeout = 10 * time.Second
|
var WSTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// Session is a single voice session that wraps around the voice gateway and UDP
|
||||||
|
// connection.
|
||||||
type Session struct {
|
type Session struct {
|
||||||
*handler.Handler
|
*handler.Handler
|
||||||
ErrorLog func(err error)
|
ErrorLog func(err error)
|
||||||
|
|
||||||
session *session.Session
|
session *session.Session
|
||||||
|
cancels []func()
|
||||||
looper *handleloop.Loop
|
looper *handleloop.Loop
|
||||||
|
|
||||||
// joining determines the behavior of incoming event callbacks (Update).
|
// joining determines the behavior of incoming event callbacks (Update).
|
||||||
|
@ -51,13 +59,24 @@ type Session struct {
|
||||||
voiceUDP *udp.Connection
|
voiceUDP *udp.Connection
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(ses *session.Session, userID discord.UserID) *Session {
|
// NewSession creates a new voice session for the current user.
|
||||||
handler := handler.New()
|
func NewSession(state *state.State) (*Session, error) {
|
||||||
looper := handleloop.NewLoop(handler)
|
u, err := state.Me()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to get me")
|
||||||
|
}
|
||||||
|
|
||||||
return &Session{
|
return NewSessionCustom(state.Session, u.ID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionCustom creates a new voice session from the given session and user
|
||||||
|
// ID.
|
||||||
|
func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
|
||||||
|
handler := handler.New()
|
||||||
|
hlooper := handleloop.NewLoop(handler)
|
||||||
|
session := &Session{
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
looper: looper,
|
looper: hlooper,
|
||||||
session: ses,
|
session: ses,
|
||||||
state: voicegateway.State{
|
state: voicegateway.State{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
|
@ -65,9 +84,15 @@ func NewSession(ses *session.Session, userID discord.UserID) *Session {
|
||||||
ErrorLog: func(err error) {},
|
ErrorLog: func(err error) {},
|
||||||
incoming: make(chan struct{}, 2),
|
incoming: make(chan struct{}, 2),
|
||||||
}
|
}
|
||||||
|
session.cancels = []func(){
|
||||||
|
ses.AddHandler(session.updateServer),
|
||||||
|
ses.AddHandler(session.updateState),
|
||||||
|
}
|
||||||
|
|
||||||
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||||
// If this is true, then mutex is acquired already.
|
// If this is true, then mutex is acquired already.
|
||||||
if s.joining.Get() {
|
if s.joining.Get() {
|
||||||
if s.state.GuildID != ev.GuildID {
|
if s.state.GuildID != ev.GuildID {
|
||||||
|
@ -101,7 +126,7 @@ func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
|
||||||
if s.state.UserID != ev.UserID { // constant so no mutex
|
if s.state.UserID != ev.UserID { // constant so no mutex
|
||||||
// Not our state.
|
// Not our state.
|
||||||
return
|
return
|
||||||
|
@ -109,6 +134,10 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
||||||
|
|
||||||
// If this is true, then mutex is acquired already.
|
// If this is true, then mutex is acquired already.
|
||||||
if s.joining.Get() {
|
if s.joining.Get() {
|
||||||
|
if s.state.GuildID != ev.GuildID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
s.state.SessionID = ev.SessionID
|
s.state.SessionID = ev.SessionID
|
||||||
s.state.ChannelID = ev.ChannelID
|
s.state.ChannelID = ev.ChannelID
|
||||||
|
|
||||||
|
@ -117,20 +146,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) JoinChannel(
|
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
||||||
gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
|
return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
|
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
|
||||||
// directly, but rather Voice's. If this method is called concurrently, an error
|
// directly, but rather Voice's. This method shouldn't ever be called
|
||||||
// will be returned.
|
// concurrently.
|
||||||
func (s *Session) JoinChannelCtx(
|
func (s *Session) JoinChannelCtx(
|
||||||
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
||||||
|
|
||||||
if s.joining.Get() {
|
if s.joining.Get() {
|
||||||
return ErrAlreadyConnecting
|
return ErrAlreadyConnecting
|
||||||
|
@ -162,8 +189,8 @@ func (s *Session) JoinChannelCtx(
|
||||||
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
|
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
|
||||||
GuildID: gID,
|
GuildID: gID,
|
||||||
ChannelID: channelID,
|
ChannelID: channelID,
|
||||||
SelfMute: muted,
|
SelfMute: mute,
|
||||||
SelfDeaf: deafened,
|
SelfDeaf: deaf,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to send Voice State Update event")
|
return errors.Wrap(err, "failed to send Voice State Update event")
|
||||||
|
@ -247,7 +274,7 @@ func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
|
||||||
return gateway.Speaking(flag)
|
return gateway.Speaking(flag)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UseContext tells the UDP voice connection to write with the given mutex.
|
// UseContext tells the UDP voice connection to write with the given context.
|
||||||
func (s *Session) UseContext(ctx context.Context) error {
|
func (s *Session) UseContext(ctx context.Context) error {
|
||||||
s.mut.Lock()
|
s.mut.Lock()
|
||||||
defer s.mut.Unlock()
|
defer s.mut.Unlock()
|
||||||
|
@ -287,14 +314,17 @@ func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||||
return voiceUDP.WriteCtx(ctx, b)
|
return voiceUDP.WriteCtx(ctx, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Disconnect() error {
|
// Leave disconnects the current voice session from the currently connected
|
||||||
|
// channel.
|
||||||
|
func (s *Session) Leave() error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
return s.DisconnectCtx(ctx)
|
return s.LeaveCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) DisconnectCtx(ctx context.Context) error {
|
// LeaveCtx disconencts with a context. Refer to Leave for more information.
|
||||||
|
func (s *Session) LeaveCtx(ctx context.Context) error {
|
||||||
s.mut.Lock()
|
s.mut.Lock()
|
||||||
defer s.mut.Unlock()
|
defer s.mut.Unlock()
|
||||||
|
|
||||||
|
@ -340,9 +370,9 @@ func (s *Session) ensureClosed() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadPacket reads a single packet from the UDP connection.
|
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
|
||||||
// This is NOT at all thread safe, and must be used very carefully.
|
// thread safe, and must be used very carefully. The backing buffer is always
|
||||||
// The backing buffer is always reused.
|
// reused.
|
||||||
func (s *Session) ReadPacket() (*udp.Packet, error) {
|
func (s *Session) ReadPacket() (*udp.Packet, error) {
|
||||||
return s.voiceUDP.ReadPacket()
|
return s.voiceUDP.ReadPacket()
|
||||||
}
|
}
|
||||||
|
|
70
voice/session_example_test.go
Normal file
70
voice/session_example_test.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
// +build !unitonly
|
||||||
|
|
||||||
|
package voice_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/diamondburned/arikawa/v2/discord"
|
||||||
|
"github.com/diamondburned/arikawa/v2/internal/testenv"
|
||||||
|
"github.com/diamondburned/arikawa/v2/state"
|
||||||
|
"github.com/diamondburned/arikawa/v2/voice"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
token string
|
||||||
|
channelID discord.ChannelID
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
e, err := testenv.GetEnv()
|
||||||
|
if err == nil {
|
||||||
|
token = e.BotToken
|
||||||
|
channelID = e.VoiceChID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pseudo function for example
|
||||||
|
func writeOpusInto(w io.Writer) {}
|
||||||
|
|
||||||
|
// make godoc not show the full file
|
||||||
|
func TestNoop(t *testing.T) {
|
||||||
|
t.Skip("noop")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleSession() {
|
||||||
|
s, err := state.New("Bot " + token)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed to make state:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is required for bots.
|
||||||
|
voice.AddIntents(s.Gateway)
|
||||||
|
|
||||||
|
if err := s.Open(); err != nil {
|
||||||
|
log.Fatalln("failed to open gateway:", err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
c, err := s.Channel(channelID)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed to get channel:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := voice.NewSession(s)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed to create voice session:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
|
||||||
|
log.Fatalln("failed to join voice channel:", err)
|
||||||
|
}
|
||||||
|
defer v.Leave()
|
||||||
|
|
||||||
|
// Start writing Opus frames.
|
||||||
|
for {
|
||||||
|
writeOpusInto(v)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
// +build integration
|
// +build !unitonly
|
||||||
|
|
||||||
package voice
|
package voice
|
||||||
|
|
||||||
|
@ -15,14 +15,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/v2/discord"
|
"github.com/diamondburned/arikawa/v2/discord"
|
||||||
"github.com/diamondburned/arikawa/v2/gateway"
|
"github.com/diamondburned/arikawa/v2/internal/testenv"
|
||||||
|
"github.com/diamondburned/arikawa/v2/state"
|
||||||
"github.com/diamondburned/arikawa/v2/utils/wsutil"
|
"github.com/diamondburned/arikawa/v2/utils/wsutil"
|
||||||
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
|
"github.com/diamondburned/arikawa/v2/voice/voicegateway"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIntegration(t *testing.T) {
|
func TestIntegration(t *testing.T) {
|
||||||
config := mustConfig(t)
|
config := testenv.Must(t)
|
||||||
|
|
||||||
wsutil.WSDebug = func(v ...interface{}) {
|
wsutil.WSDebug = func(v ...interface{}) {
|
||||||
_, file, line, _ := runtime.Caller(1)
|
_, file, line, _ := runtime.Caller(1)
|
||||||
|
@ -30,23 +31,19 @@ func TestIntegration(t *testing.T) {
|
||||||
log.Println(append([]interface{}{caller}, v...)...)
|
log.Println(append([]interface{}{caller}, v...)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := NewFromToken("Bot " + config.BotToken)
|
s, err := state.New("Bot " + config.BotToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to create a new voice session:", err)
|
t.Fatal("Failed to create a new state:", err)
|
||||||
}
|
}
|
||||||
v.Gateway.AddIntents(gateway.IntentGuildVoiceStates)
|
AddIntents(s.Gateway)
|
||||||
|
|
||||||
v.ErrorLog = func(err error) {
|
if err := s.Open(); err != nil {
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := v.Open(); err != nil {
|
|
||||||
t.Fatal("Failed to connect:", err)
|
t.Fatal("Failed to connect:", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { v.Close() })
|
t.Cleanup(func() { s.Close() })
|
||||||
|
|
||||||
// Validate the given voice channel.
|
// Validate the given voice channel.
|
||||||
c, err := v.Channel(config.VoiceChID)
|
c, err := s.Channel(config.VoiceChID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to get channel:", err)
|
t.Fatal("Failed to get channel:", err)
|
||||||
}
|
}
|
||||||
|
@ -56,43 +53,48 @@ func TestIntegration(t *testing.T) {
|
||||||
|
|
||||||
log.Println("The voice channel's name is", c.Name)
|
log.Println("The voice channel's name is", c.Name)
|
||||||
|
|
||||||
|
v, err := NewSession(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to create a new voice session:", err)
|
||||||
|
}
|
||||||
|
v.ErrorLog = func(err error) { t.Error(err) }
|
||||||
|
|
||||||
// Grab a timer to benchmark things.
|
// Grab a timer to benchmark things.
|
||||||
finish := timer()
|
finish := timer()
|
||||||
|
|
||||||
// Join the voice channel concurrently.
|
// Add handler to receive speaking update beforehand.
|
||||||
raceValue := raceMe(t, "failed to join voice channel", func() (interface{}, error) {
|
v.AddHandler(func(e *voicegateway.SpeakingEvent) {
|
||||||
return v.JoinChannel(c.ID, false, false)
|
finish("receiving voice speaking event")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Join the voice channel concurrently.
|
||||||
|
raceMe(t, "failed to join voice channel", func() (interface{}, error) {
|
||||||
|
return nil, v.JoinChannel(c.GuildID, c.ID, false, false)
|
||||||
})
|
})
|
||||||
vs := raceValue.(*Session)
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
log.Println("Disconnecting from the voice channel concurrently.")
|
log.Println("Leaving the voice channel concurrently.")
|
||||||
|
|
||||||
raceMe(t, "failed to disconnect", func() (interface{}, error) {
|
raceMe(t, "failed to leave voice channel", func() (interface{}, error) {
|
||||||
return nil, vs.Disconnect()
|
return nil, v.Leave()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
finish("joining the voice channel")
|
finish("joining the voice channel")
|
||||||
|
|
||||||
// Add handler to receive speaking update
|
|
||||||
vs.AddHandler(func(e *voicegateway.SpeakingEvent) {
|
|
||||||
finish("received voice speaking event")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a context and only cancel it AFTER we're done sending silence
|
// Create a context and only cancel it AFTER we're done sending silence
|
||||||
// frames.
|
// frames.
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
t.Cleanup(cancel)
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
// Trigger speaking.
|
// Trigger speaking.
|
||||||
if err := vs.Speaking(voicegateway.Microphone); err != nil {
|
if err := v.Speaking(voicegateway.Microphone); err != nil {
|
||||||
t.Fatal("failed to start speaking:", err)
|
t.Fatal("failed to start speaking:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
finish("sending the speaking command")
|
finish("sending the speaking command")
|
||||||
|
|
||||||
if err := vs.UseContext(ctx); err != nil {
|
if err := v.UseContext(ctx); err != nil {
|
||||||
t.Fatal("failed to set ctx into vs:", err)
|
t.Fatal("failed to set ctx into vs:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,7 +119,7 @@ func TestIntegration(t *testing.T) {
|
||||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||||
|
|
||||||
// Copy the frame.
|
// Copy the frame.
|
||||||
if _, err := io.CopyN(vs, f, framelen); err != nil && err != io.EOF {
|
if _, err := io.CopyN(v, f, framelen); err != nil && err != io.EOF {
|
||||||
t.Fatal("failed to write:", err)
|
t.Fatal("failed to write:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,39 +167,6 @@ func raceMe(t *testing.T, wrapErr string, fn func() (interface{}, error)) interf
|
||||||
return val
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
type testConfig struct {
|
|
||||||
BotToken string
|
|
||||||
VoiceChID discord.ChannelID
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustConfig(t *testing.T) testConfig {
|
|
||||||
var token = os.Getenv("BOT_TOKEN")
|
|
||||||
if token == "" {
|
|
||||||
t.Fatal("Missing $BOT_TOKEN")
|
|
||||||
}
|
|
||||||
|
|
||||||
var sid = os.Getenv("VOICE_ID")
|
|
||||||
if sid == "" {
|
|
||||||
t.Fatal("Missing $VOICE_ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := discord.ParseSnowflake(sid)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Invalid $VOICE_ID:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return testConfig{
|
|
||||||
BotToken: token,
|
|
||||||
VoiceChID: discord.ChannelID(id),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// file is only a few bytes lolmao
|
|
||||||
func nicoReadTo(t *testing.T, dst io.Writer) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// simple shitty benchmark thing
|
// simple shitty benchmark thing
|
||||||
func timer() func(finished string) {
|
func timer() func(finished string) {
|
||||||
var then = time.Now()
|
var then = time.Now()
|
237
voice/voice.go
237
voice/voice.go
|
@ -1,235 +1,14 @@
|
||||||
// Package voice handles the Discord voice gateway and UDP connections, as well
|
// Package voice handles the Discord voice gateway and UDP connections. It does
|
||||||
// as managing and keeping track of multiple voice sessions.
|
// not handle book-keeping of those sessions.
|
||||||
//
|
//
|
||||||
// This package abstracts the subpackage voice/voicesession and voice/udp.
|
// This package abstracts the subpackage voice/voicesession and voice/udp.
|
||||||
package voice
|
package voice
|
||||||
|
|
||||||
import (
|
import "github.com/diamondburned/arikawa/v2/gateway"
|
||||||
"context"
|
|
||||||
"log"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/v2/discord"
|
// AddIntents adds the needed voice intents into gw. Bots should always call
|
||||||
"github.com/diamondburned/arikawa/v2/gateway"
|
// this before Open if voice is required.
|
||||||
"github.com/diamondburned/arikawa/v2/state"
|
func AddIntents(gw *gateway.Gateway) {
|
||||||
"github.com/pkg/errors"
|
gw.AddIntents(gateway.IntentGuilds)
|
||||||
)
|
gw.AddIntents(gateway.IntentGuildVoiceStates)
|
||||||
|
|
||||||
var (
|
|
||||||
// defaultErrorHandler is the default error handler
|
|
||||||
defaultErrorHandler = func(err error) { log.Println("voice gateway error:", err) }
|
|
||||||
|
|
||||||
// ErrCannotSend is an error when audio is sent to a closed channel.
|
|
||||||
ErrCannotSend = errors.New("cannot send audio to closed channel")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Voice represents a Voice Repository used for managing voice sessions.
|
|
||||||
type Voice struct {
|
|
||||||
*state.State
|
|
||||||
|
|
||||||
// Session holds all of the active voice sessions.
|
|
||||||
mapmutex sync.Mutex
|
|
||||||
sessions map[discord.GuildID]*Session
|
|
||||||
|
|
||||||
// Callbacks to remove the handlers.
|
|
||||||
closers []func()
|
|
||||||
|
|
||||||
// ErrorLog will be called when an error occurs (defaults to log.Println)
|
|
||||||
ErrorLog func(err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFromToken creates a new voice session from the given token.
|
|
||||||
func NewFromToken(token string) (*Voice, error) {
|
|
||||||
s, err := state.New(token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to create a new session")
|
|
||||||
}
|
|
||||||
|
|
||||||
return New(s), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new Voice repository wrapped around a state. The function will
|
|
||||||
// also automatically add the GuildVoiceStates intent, as that is required.
|
|
||||||
//
|
|
||||||
// This function will add the Guilds and GuildVoiceStates intents into the state
|
|
||||||
// in order to receive the needed events.
|
|
||||||
func New(s *state.State) *Voice {
|
|
||||||
// Register the voice intents.
|
|
||||||
s.Gateway.AddIntents(gateway.IntentGuilds)
|
|
||||||
s.Gateway.AddIntents(gateway.IntentGuildVoiceStates)
|
|
||||||
return NewWithoutIntents(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWithoutIntents creates a new Voice repository wrapped around a state
|
|
||||||
// without modifying the given Gateway to add intents.
|
|
||||||
func NewWithoutIntents(s *state.State) *Voice {
|
|
||||||
v := &Voice{
|
|
||||||
State: s,
|
|
||||||
sessions: make(map[discord.GuildID]*Session),
|
|
||||||
ErrorLog: defaultErrorHandler,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the required event handlers to the session.
|
|
||||||
v.closers = []func(){
|
|
||||||
s.AddHandler(v.onVoiceStateUpdate),
|
|
||||||
s.AddHandler(v.onVoiceServerUpdate),
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// onVoiceStateUpdate receives VoiceStateUpdateEvents from the gateway
|
|
||||||
// to keep track of the current user's voice state.
|
|
||||||
func (v *Voice) onVoiceStateUpdate(e *gateway.VoiceStateUpdateEvent) {
|
|
||||||
// Get the current user.
|
|
||||||
me, err := v.Me()
|
|
||||||
if err != nil {
|
|
||||||
v.ErrorLog(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore the event if it is an update from another user.
|
|
||||||
if me.ID != e.UserID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the stored voice session for the given guild.
|
|
||||||
vs, ok := v.GetSession(e.GuildID)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do what we must.
|
|
||||||
vs.UpdateState(e)
|
|
||||||
|
|
||||||
// Remove the connection if the current user has disconnected.
|
|
||||||
if e.ChannelID == 0 {
|
|
||||||
v.RemoveSession(e.GuildID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// onVoiceServerUpdate receives VoiceServerUpdateEvents from the gateway
|
|
||||||
// to manage the current user's voice connections.
|
|
||||||
func (v *Voice) onVoiceServerUpdate(e *gateway.VoiceServerUpdateEvent) {
|
|
||||||
// Get the stored voice session for the given guild.
|
|
||||||
vs, ok := v.GetSession(e.GuildID)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do what we must.
|
|
||||||
vs.UpdateServer(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSession gets a session for a guild with a read lock.
|
|
||||||
func (v *Voice) GetSession(guildID discord.GuildID) (*Session, bool) {
|
|
||||||
v.mapmutex.Lock()
|
|
||||||
defer v.mapmutex.Unlock()
|
|
||||||
|
|
||||||
// For some reason you cannot just put `return v.sessions[]` and return a bool D:
|
|
||||||
conn, ok := v.sessions[guildID]
|
|
||||||
return conn, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveSession removes a session.
|
|
||||||
func (v *Voice) RemoveSession(guildID discord.GuildID) {
|
|
||||||
v.mapmutex.Lock()
|
|
||||||
ses, ok := v.sessions[guildID]
|
|
||||||
if !ok {
|
|
||||||
v.mapmutex.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(v.sessions, guildID)
|
|
||||||
v.mapmutex.Unlock()
|
|
||||||
|
|
||||||
// Ensure that the session is disconnected.
|
|
||||||
ses.Disconnect()
|
|
||||||
}
|
|
||||||
|
|
||||||
// JoinChannel joins the specified channel in the specified guild.
|
|
||||||
func (v *Voice) JoinChannel(cID discord.ChannelID, muted, deafened bool) (*Session, error) {
|
|
||||||
c, err := v.Cabinet.Channel(cID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to get channel from state")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the stored voice session for the given guild.
|
|
||||||
conn, ok := v.GetSession(c.GuildID)
|
|
||||||
|
|
||||||
// Create a new voice session if one does not exist.
|
|
||||||
if !ok {
|
|
||||||
u, err := v.Me()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to get self")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn = NewSession(v.Session, u.ID)
|
|
||||||
conn.ErrorLog = v.ErrorLog
|
|
||||||
|
|
||||||
v.mapmutex.Lock()
|
|
||||||
v.sessions[c.GuildID] = conn
|
|
||||||
v.mapmutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect.
|
|
||||||
return conn, conn.JoinChannel(c.GuildID, cID, muted, deafened)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *Voice) Close() error {
|
|
||||||
err := &CloseError{
|
|
||||||
SessionErrors: make(map[discord.GuildID]error),
|
|
||||||
}
|
|
||||||
|
|
||||||
v.mapmutex.Lock()
|
|
||||||
defer v.mapmutex.Unlock()
|
|
||||||
|
|
||||||
// Remove all callback handlers.
|
|
||||||
for _, fn := range v.closers {
|
|
||||||
fn()
|
|
||||||
}
|
|
||||||
|
|
||||||
for gID, s := range v.sessions {
|
|
||||||
log.Println("closing", gID)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
||||||
if dErr := s.DisconnectCtx(ctx); dErr != nil {
|
|
||||||
err.SessionErrors[gID] = dErr
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
log.Println("closed", gID)
|
|
||||||
}
|
|
||||||
|
|
||||||
err.StateErr = v.State.Close()
|
|
||||||
if err.HasError() {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type CloseError struct {
|
|
||||||
SessionErrors map[discord.GuildID]error
|
|
||||||
StateErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CloseError) HasError() bool {
|
|
||||||
if e.StateErr != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(e.SessionErrors) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CloseError) Error() string {
|
|
||||||
if e.StateErr != nil {
|
|
||||||
return e.StateErr.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(e.SessionErrors) < 1 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect"
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,8 @@ func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
|
||||||
type SpeakingFlag uint64
|
type SpeakingFlag uint64
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Microphone SpeakingFlag = 1 << iota
|
NotSpeaking SpeakingFlag = 0
|
||||||
|
Microphone SpeakingFlag = 1 << iota
|
||||||
Soundshare
|
Soundshare
|
||||||
Priority
|
Priority
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue