mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-03-26 20:09:37 +00:00
Merge pull request #14 from matthewpi/feature/voice
This commit is contained in:
commit
ccf4c69801
|
@ -109,7 +109,9 @@ func (s Seconds) Duration() time.Duration {
|
|||
|
||||
//
|
||||
|
||||
type Milliseconds int
|
||||
// Milliseconds is in float64 because some Discord events return time with a
|
||||
// trailing decimal.
|
||||
type Milliseconds float64
|
||||
|
||||
func DurationToMilliseconds(dura time.Duration) Milliseconds {
|
||||
return Milliseconds(dura.Milliseconds())
|
||||
|
@ -120,5 +122,6 @@ func (ms Milliseconds) String() string {
|
|||
}
|
||||
|
||||
func (ms Milliseconds) Duration() time.Duration {
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
const f64ms = Milliseconds(time.Millisecond)
|
||||
return time.Duration(ms * f64ms)
|
||||
}
|
||||
|
|
|
@ -68,10 +68,10 @@ func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error {
|
|||
}
|
||||
|
||||
type UpdateVoiceStateData struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
SelfMute bool `json:"self_mute"`
|
||||
SelfDeaf bool `json:"self_deaf"`
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
ChannelID *discord.Snowflake `json:"channel_id"` // nullable
|
||||
SelfMute bool `json:"self_mute"`
|
||||
SelfDeaf bool `json:"self_deaf"`
|
||||
}
|
||||
|
||||
func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error {
|
||||
|
|
|
@ -9,7 +9,6 @@ package gateway
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
@ -32,23 +31,6 @@ var (
|
|||
// Compress = "zlib-stream"
|
||||
)
|
||||
|
||||
var (
|
||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||
// before Gateway cancels and fails.
|
||||
WSTimeout = wsutil.DefaultTimeout
|
||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
||||
// timeout for the websocket.
|
||||
WSExtraReadTimeout = time.Second
|
||||
// WSDebug is used for extra debug logging. This is expected to behave
|
||||
// similarly to log.Println().
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
|
||||
ErrWSMaxTries = errors.New("max tries reached")
|
||||
|
@ -95,11 +77,7 @@ func BotURL(token string) (*GatewayBotData, error) {
|
|||
}
|
||||
|
||||
type Gateway struct {
|
||||
WS *wsutil.Websocket
|
||||
json.Driver
|
||||
|
||||
// Timeout for connecting and writing to the Websocket, uses default
|
||||
// WSTimeout (global).
|
||||
WS *wsutil.Websocket
|
||||
WSTimeout time.Duration
|
||||
|
||||
// All events sent over are pointers to Event structs (structs suffixed with
|
||||
|
@ -110,8 +88,8 @@ type Gateway struct {
|
|||
SessionID string
|
||||
|
||||
Identifier *Identifier
|
||||
Pacemaker *Pacemaker
|
||||
Sequence *Sequence
|
||||
PacerLoop *wsutil.PacemakerLoop
|
||||
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
|
||||
|
@ -120,30 +98,18 @@ type Gateway struct {
|
|||
// reconnections or any type of connection interruptions.
|
||||
AfterClose func(err error) // noop by default
|
||||
|
||||
// Only use for debugging
|
||||
|
||||
// If this channel is non-nil, all incoming OP packets will also be sent
|
||||
// here. This should be buffered, so to not block the main loop.
|
||||
OP chan *OP
|
||||
|
||||
// Mutex to hold off calls when the WS is not available. Doesn't block if
|
||||
// Start() is not called or Close() is called. Also doesn't block for
|
||||
// Identify or Resume.
|
||||
// available sync.RWMutex
|
||||
|
||||
// Filled by methods, internal use
|
||||
paceDeath chan error
|
||||
waitGroup *sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
||||
// information, refer to NewGatewayWithDriver.
|
||||
func NewGateway(token string) (*Gateway, error) {
|
||||
return NewGatewayWithDriver(token, json.Default{})
|
||||
}
|
||||
|
||||
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
|
||||
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
||||
URL, err := URL()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
|
||||
|
@ -158,18 +124,19 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|||
// Append the form to the URL
|
||||
URL += "?" + param.Encode()
|
||||
|
||||
return NewCustomGateway(URL, token, driver), nil
|
||||
return NewCustomGateway(URL, token), nil
|
||||
}
|
||||
|
||||
func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
|
||||
func NewCustomGateway(gatewayURL, token string) *Gateway {
|
||||
return &Gateway{
|
||||
WS: wsutil.NewCustom(wsutil.NewConn(driver), gatewayURL),
|
||||
Driver: driver,
|
||||
WSTimeout: WSTimeout,
|
||||
Events: make(chan Event, WSBuffer),
|
||||
WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
|
||||
WSTimeout: wsutil.WSTimeout,
|
||||
|
||||
Events: make(chan Event, wsutil.WSBuffer),
|
||||
Identifier: DefaultIdentifier(token),
|
||||
Sequence: NewSequence(),
|
||||
ErrorLog: WSError,
|
||||
|
||||
ErrorLog: wsutil.WSError,
|
||||
AfterClose: func(error) {},
|
||||
}
|
||||
}
|
||||
|
@ -177,24 +144,24 @@ func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
|
|||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
// Check if the WS is already closed:
|
||||
if g.waitGroup == nil && g.paceDeath == nil {
|
||||
WSDebug("Gateway is already closed.")
|
||||
if g.waitGroup == nil && g.PacerLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
|
||||
g.AfterClose(nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the pacemaker is running:
|
||||
if g.paceDeath != nil {
|
||||
WSDebug("Stopping pacemaker...")
|
||||
if !g.PacerLoop.Stopped() {
|
||||
wsutil.WSDebug("Stopping pacemaker...")
|
||||
|
||||
// Stop the pacemaker and the event handler
|
||||
g.Pacemaker.Stop()
|
||||
g.PacerLoop.Stop()
|
||||
|
||||
WSDebug("Stopped pacemaker.")
|
||||
wsutil.WSDebug("Stopped pacemaker.")
|
||||
}
|
||||
|
||||
WSDebug("Waiting for WaitGroup to be done.")
|
||||
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||
|
||||
// This should work, since Pacemaker should signal its loop to stop, which
|
||||
// would also exit our event loop. Both would be 2.
|
||||
|
@ -203,7 +170,7 @@ func (g *Gateway) Close() error {
|
|||
// Mark g.waitGroup as empty:
|
||||
g.waitGroup = nil
|
||||
|
||||
WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
|
||||
err := g.WS.Close()
|
||||
g.AfterClose(err)
|
||||
|
@ -212,42 +179,47 @@ func (g *Gateway) Close() error {
|
|||
|
||||
// Reconnect tries to reconnect forever. It will resume the connection if
|
||||
// possible. If an Invalid Session is received, it will start a fresh one.
|
||||
func (g *Gateway) Reconnect() {
|
||||
WSDebug("Reconnecting...")
|
||||
func (g *Gateway) Reconnect() error {
|
||||
return g.ReconnectContext(context.Background())
|
||||
}
|
||||
|
||||
func (g *Gateway) ReconnectContext(ctx context.Context) error {
|
||||
wsutil.WSDebug("Reconnecting...")
|
||||
|
||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||
// redialing anyway.
|
||||
g.Close()
|
||||
|
||||
for i := 1; ; i++ {
|
||||
WSDebug("Trying to dial, attempt", i)
|
||||
wsutil.WSDebug("Trying to dial, attempt", i)
|
||||
|
||||
// Condition: err == ErrInvalidSession:
|
||||
// If the connection is rate limited (documented behavior):
|
||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||
|
||||
if err := g.Open(); err != nil {
|
||||
if err := g.OpenContext(ctx); err != nil {
|
||||
g.ErrorLog(errors.Wrap(err, "Failed to open gateway"))
|
||||
continue
|
||||
}
|
||||
|
||||
WSDebug("Started after attempt:", i)
|
||||
return
|
||||
wsutil.WSDebug("Started after attempt:", i)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Open connects to the Websocket and authenticate it. You should usually use
|
||||
// this function over Start().
|
||||
func (g *Gateway) Open() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
return g.OpenContext(context.Background())
|
||||
}
|
||||
|
||||
func (g *Gateway) OpenContext(ctx context.Context) error {
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "Failed to reconnect")
|
||||
}
|
||||
|
||||
WSDebug("Trying to start...")
|
||||
wsutil.WSDebug("Trying to start...")
|
||||
|
||||
// Try to resume the connection
|
||||
if err := g.Start(); err != nil {
|
||||
|
@ -266,24 +238,18 @@ func (g *Gateway) Start() error {
|
|||
// defer g.available.Unlock()
|
||||
|
||||
if err := g.start(); err != nil {
|
||||
WSDebug("Start failed:", err)
|
||||
wsutil.WSDebug("Start failed:", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
// pacemaker hasn't started yet.
|
||||
if err := g.Close(); err != nil {
|
||||
WSDebug("Failed to close after start fail:", err)
|
||||
wsutil.WSDebug("Failed to close after start fail:", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait is deprecated. The gateway will reconnect forever. This function will
|
||||
// panic.
|
||||
func (g *Gateway) Wait() error {
|
||||
panic("Wait is deprecated. defer (*Gateway).Close() is required.")
|
||||
}
|
||||
|
||||
func (g *Gateway) start() error {
|
||||
// This is where we'll get our events
|
||||
ch := g.WS.Listen()
|
||||
|
@ -293,7 +259,7 @@ func (g *Gateway) start() error {
|
|||
|
||||
// Wait for an OP 10 Hello
|
||||
var hello HelloEvent
|
||||
if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
|
||||
if _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "Error at Hello")
|
||||
}
|
||||
|
||||
|
@ -311,16 +277,16 @@ func (g *Gateway) start() error {
|
|||
}
|
||||
|
||||
// Expect either READY or RESUMED before continuing.
|
||||
WSDebug("Waiting for either READY or RESUMED.")
|
||||
wsutil.WSDebug("Waiting for either READY or RESUMED.")
|
||||
|
||||
// WaitForEvent should
|
||||
err := WaitForEvent(g, ch, func(op *OP) bool {
|
||||
err := wsutil.WaitForEvent(g, ch, func(op *wsutil.OP) bool {
|
||||
switch op.EventName {
|
||||
case "READY":
|
||||
WSDebug("Found READY event.")
|
||||
wsutil.WSDebug("Found READY event.")
|
||||
return true
|
||||
case "RESUMED":
|
||||
WSDebug("Found RESUMED event.")
|
||||
wsutil.WSDebug("Found RESUMED event.")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -330,30 +296,23 @@ func (g *Gateway) start() error {
|
|||
return errors.Wrap(err, "First error")
|
||||
}
|
||||
|
||||
// Start the pacemaker with the heartrate received from Hello, after
|
||||
// initializing everything. This ensures we only heartbeat if the websocket
|
||||
// is authenticated.
|
||||
g.Pacemaker = &Pacemaker{
|
||||
Heartrate: hello.HeartbeatInterval.Duration(),
|
||||
Pace: g.Heartbeat,
|
||||
}
|
||||
// Pacemaker dies here, only when it's fatal.
|
||||
g.paceDeath = g.Pacemaker.StartAsync(g.waitGroup)
|
||||
// Use the pacemaker loop.
|
||||
g.PacerLoop = wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, g)
|
||||
|
||||
// Start the event handler, which also handles the pacemaker death signal.
|
||||
g.waitGroup.Add(1)
|
||||
go g.handleWS()
|
||||
|
||||
WSDebug("Started successfully.")
|
||||
wsutil.WSDebug("Started successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleWS uses the Websocket and parses them into g.Events.
|
||||
func (g *Gateway) handleWS() {
|
||||
err := g.eventLoop()
|
||||
err := g.PacerLoop.Run()
|
||||
g.waitGroup.Done() // mark so Close() can exit.
|
||||
WSDebug("Event loop stopped.")
|
||||
wsutil.WSDebug("Event loop stopped.")
|
||||
|
||||
if err != nil {
|
||||
g.ErrorLog(err)
|
||||
|
@ -362,39 +321,13 @@ func (g *Gateway) handleWS() {
|
|||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) eventLoop() error {
|
||||
ch := g.WS.Listen()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-g.paceDeath:
|
||||
// Got a paceDeath, we're exiting from here on out.
|
||||
g.paceDeath = nil // mark
|
||||
|
||||
if err == nil {
|
||||
WSDebug("Pacemaker stopped without errors.")
|
||||
// No error, just exit normally.
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Wrap(err, "Pacemaker died, reconnecting")
|
||||
|
||||
case ev := <-ch:
|
||||
// Handle the event
|
||||
if err := HandleEvent(g, ev); err != nil {
|
||||
g.ErrorLog(errors.Wrap(err, "WS handler error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) Send(code OPCode, v interface{}) error {
|
||||
var op = OP{
|
||||
var op = wsutil.OP{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
if v != nil {
|
||||
b, err := g.Driver.Marshal(v)
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode v")
|
||||
}
|
||||
|
@ -402,7 +335,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
|||
op.Data = b
|
||||
}
|
||||
|
||||
b, err := g.Driver.Marshal(op)
|
||||
b, err := json.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode payload")
|
||||
}
|
||||
|
|
|
@ -8,10 +8,12 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
WSDebug = func(v ...interface{}) {
|
||||
wsutil.WSDebug = func(v ...interface{}) {
|
||||
log.Println(append([]interface{}{"Debug:"}, v...)...)
|
||||
}
|
||||
}
|
||||
|
@ -41,7 +43,7 @@ func TestIntegration(t *testing.T) {
|
|||
t.Fatal("Missing $BOT_TOKEN")
|
||||
}
|
||||
|
||||
WSError = func(err error) {
|
||||
wsutil.WSError = func(err error) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -77,7 +79,11 @@ func TestIntegration(t *testing.T) {
|
|||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Try and reconnect forever:
|
||||
gotimeout(t, gateway.Reconnect)
|
||||
gotimeout(t, func() {
|
||||
if err := gateway.Reconnect(); err != nil {
|
||||
t.Fatal("Unexpected error while reconnecting:", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for the desired event:
|
||||
gotimeout(t, func() {
|
||||
|
|
110
gateway/op.go
110
gateway/op.go
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OPCode uint8
|
||||
type OPCode = wsutil.OPCode
|
||||
|
||||
const (
|
||||
DispatchOP OPCode = 0 // recv
|
||||
|
@ -29,117 +29,19 @@ const (
|
|||
GuildSubscriptionsOP OPCode = 14
|
||||
)
|
||||
|
||||
type OP struct {
|
||||
Code OPCode `json:"op"`
|
||||
Data json.Raw `json:"d,omitempty"`
|
||||
|
||||
// Only for Dispatch (op 0)
|
||||
Sequence int64 `json:"s,omitempty"`
|
||||
EventName string `json:"t,omitempty"`
|
||||
}
|
||||
|
||||
func DecodeEvent(driver json.Driver, ev wsutil.Event, v interface{}) (OPCode, error) {
|
||||
op, err := DecodeOP(driver, ev)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
||||
return 0, errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
return op.Code, nil
|
||||
}
|
||||
|
||||
func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
|
||||
op, err := DecodeOP(driver, ev)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if op.Code != code {
|
||||
return op, fmt.Errorf(
|
||||
"Unexpected OP Code: %d, expected %d (%s)",
|
||||
op.Code, code, op.Data,
|
||||
)
|
||||
}
|
||||
|
||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
||||
return op, errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func HandleEvent(g *Gateway, ev wsutil.Event) error {
|
||||
o, err := DecodeOP(g.Driver, ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return HandleOP(g, o)
|
||||
}
|
||||
|
||||
// WaitForEvent blocks until fn() returns true. All incoming events are handled
|
||||
// regardless.
|
||||
func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
|
||||
for ev := range ch {
|
||||
o, err := DecodeOP(g.Driver, ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle the *OP first, in case it's an Invalid Session. This should
|
||||
// also prevent a race condition with things that need Ready after
|
||||
// Open().
|
||||
if err := HandleOP(g, o); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Are these events what we're looking for? If we've found the event,
|
||||
// return.
|
||||
if fn(o) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("Event not found and event channel is closed.")
|
||||
}
|
||||
|
||||
func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
||||
if ev.Error != nil {
|
||||
return nil, ev.Error
|
||||
}
|
||||
|
||||
if len(ev.Data) == 0 {
|
||||
return nil, errors.New("Empty payload")
|
||||
}
|
||||
|
||||
var op *OP
|
||||
if err := driver.Unmarshal(ev.Data, &op); err != nil {
|
||||
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func HandleOP(g *Gateway, op *OP) error {
|
||||
if g.OP != nil {
|
||||
g.OP <- op
|
||||
}
|
||||
|
||||
func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
||||
switch op.Code {
|
||||
case HeartbeatAckOP:
|
||||
// Heartbeat from the server?
|
||||
g.Pacemaker.Echo()
|
||||
g.PacerLoop.Echo()
|
||||
|
||||
case HeartbeatOP:
|
||||
// Server requesting a heartbeat.
|
||||
return g.Pacemaker.Pace()
|
||||
return g.PacerLoop.Pace()
|
||||
|
||||
case ReconnectOP:
|
||||
// Server requests to reconnect, die and retry.
|
||||
WSDebug("ReconnectOP received.")
|
||||
wsutil.WSDebug("ReconnectOP received.")
|
||||
// We must reconnect in another goroutine, as running Reconnect
|
||||
// synchronously would prevent the main event loop from exiting.
|
||||
go g.Reconnect()
|
||||
|
@ -181,7 +83,7 @@ func HandleOP(g *Gateway, op *OP) error {
|
|||
var ev = fn()
|
||||
|
||||
// Try and parse the event
|
||||
if err := g.Driver.Unmarshal(op.Data, ev); err != nil {
|
||||
if err := json.Unmarshal(op.Data, ev); err != nil {
|
||||
return errors.Wrap(err, "Failed to parse event "+op.EventName)
|
||||
}
|
||||
|
||||
|
|
3
go.mod
3
go.mod
|
@ -5,8 +5,11 @@ go 1.13
|
|||
require (
|
||||
github.com/gorilla/schema v1.1.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
||||
github.com/sasha-s/go-deadlock v0.2.0
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 // indirect
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
|
||||
)
|
||||
|
|
5
go.sum
5
go.sum
|
@ -2,10 +2,15 @@ github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY=
|
|||
github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ=
|
||||
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa h1:xiD6U6h+QMkAwI195dFwdku2N+enlCy9XwFTnEXaCQo=
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa/go.mod h1:KKzWrLiWu6EpzxZBPmPisPgq6oL+do2yLa0C0BTx5fA=
|
||||
github.com/sasha-s/go-deadlock v0.2.0 h1:lMqc+fUb7RrFS3gQLtoQsJ7/6TV/pAIFvBsqX73DK8Y=
|
||||
github.com/sasha-s/go-deadlock v0.2.0/go.mod h1:StQn567HiB1fF2yJ44N9au7wOhrPS3iZqiDbRupzT10=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
package gateway
|
||||
// Package heart implements a general purpose pacemaker.
|
||||
package heart
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
@ -8,18 +9,35 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Debug is the default logger that Pacemaker uses.
|
||||
var Debug = func(v ...interface{}) {}
|
||||
|
||||
var ErrDead = errors.New("no heartbeat replied")
|
||||
|
||||
// Time is a UnixNano timestamp.
|
||||
type Time = int64
|
||||
// AtomicTime is a thread-safe UnixNano timestamp guarded by atomic.
|
||||
type AtomicTime struct {
|
||||
unixnano int64
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Get() int64 {
|
||||
return atomic.LoadInt64(&t.unixnano)
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Set(time time.Time) {
|
||||
atomic.StoreInt64(&t.unixnano, time.UnixNano())
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Time() time.Time {
|
||||
return time.Unix(0, t.Get())
|
||||
}
|
||||
|
||||
type Pacemaker struct {
|
||||
// Heartrate is the received duration between heartbeats.
|
||||
Heartrate time.Duration
|
||||
|
||||
// Time in nanoseconds, guarded by atomic read/writes.
|
||||
SentBeat Time
|
||||
EchoBeat Time
|
||||
SentBeat AtomicTime
|
||||
EchoBeat AtomicTime
|
||||
|
||||
// Any callback that returns an error will stop the pacer.
|
||||
Pace func() error
|
||||
|
@ -28,10 +46,17 @@ type Pacemaker struct {
|
|||
death chan error
|
||||
}
|
||||
|
||||
func NewPacemaker(heartrate time.Duration, pacer func() error) *Pacemaker {
|
||||
return &Pacemaker{
|
||||
Heartrate: heartrate,
|
||||
Pace: pacer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pacemaker) Echo() {
|
||||
// Swap our received heartbeats
|
||||
// p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0]
|
||||
atomic.StoreInt64(&p.EchoBeat, time.Now().UnixNano())
|
||||
p.EchoBeat.Set(time.Now())
|
||||
}
|
||||
|
||||
// Dead, if true, will have Pace return an ErrDead.
|
||||
|
@ -45,8 +70,8 @@ func (p *Pacemaker) Dead() bool {
|
|||
*/
|
||||
|
||||
var (
|
||||
echo = atomic.LoadInt64(&p.EchoBeat)
|
||||
sent = atomic.LoadInt64(&p.SentBeat)
|
||||
echo = p.EchoBeat.Get()
|
||||
sent = p.SentBeat.Get()
|
||||
)
|
||||
|
||||
if echo == 0 || sent == 0 {
|
||||
|
@ -59,13 +84,18 @@ func (p *Pacemaker) Dead() bool {
|
|||
func (p *Pacemaker) Stop() {
|
||||
if p.stop != nil {
|
||||
p.stop <- struct{}{}
|
||||
WSDebug("(*Pacemaker).stop was sent a stop signal.")
|
||||
Debug("(*Pacemaker).stop was sent a stop signal.")
|
||||
} else {
|
||||
WSDebug("(*Pacemaker).stop is nil, skipping.")
|
||||
Debug("(*Pacemaker).stop is nil, skipping.")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pacemaker) start() error {
|
||||
// Reset states to its old position.
|
||||
p.EchoBeat.Set(time.Time{})
|
||||
p.SentBeat.Set(time.Time{})
|
||||
|
||||
// Create a new ticker.
|
||||
tick := time.NewTicker(p.Heartrate)
|
||||
defer tick.Stop()
|
||||
|
||||
|
@ -73,16 +103,16 @@ func (p *Pacemaker) start() error {
|
|||
p.Echo()
|
||||
|
||||
for {
|
||||
WSDebug("Pacemaker loop restarted.")
|
||||
Debug("Pacemaker loop restarted.")
|
||||
|
||||
if err := p.Pace(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
WSDebug("Paced.")
|
||||
Debug("Paced.")
|
||||
|
||||
// Paced, save:
|
||||
atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano())
|
||||
p.SentBeat.Set(time.Now())
|
||||
|
||||
if p.Dead() {
|
||||
return ErrDead
|
||||
|
@ -90,11 +120,11 @@ func (p *Pacemaker) start() error {
|
|||
|
||||
select {
|
||||
case <-p.stop:
|
||||
WSDebug("Received stop signal.")
|
||||
Debug("Received stop signal.")
|
||||
return nil
|
||||
|
||||
case <-tick.C:
|
||||
WSDebug("Ticked. Restarting.")
|
||||
Debug("Ticked. Restarting.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -104,16 +134,21 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
|||
p.death = make(chan error)
|
||||
p.stop = make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
if wg != nil {
|
||||
wg.Add(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.death <- p.start()
|
||||
// Debug.
|
||||
WSDebug("Pacemaker returned.")
|
||||
Debug("Pacemaker returned.")
|
||||
// Mark the stop channel as nil, so later Close() calls won't block forever.
|
||||
p.stop = nil
|
||||
|
||||
// Mark the pacemaker loop as done.
|
||||
wg.Done()
|
||||
if wg != nil {
|
||||
wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
return p.death
|
|
@ -41,7 +41,7 @@ func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
|
|||
func NewClient() *Client {
|
||||
return &Client{
|
||||
Client: httpdriver.NewClient(),
|
||||
Driver: json.Default{},
|
||||
Driver: json.Default,
|
||||
SchemaEncoder: &DefaultSchema{},
|
||||
Retries: Retries,
|
||||
OnResponse: ResponseNoop,
|
||||
|
|
|
@ -7,65 +7,6 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
type (
|
||||
OptionBool = *bool
|
||||
OptionString = *string
|
||||
OptionUint = *uint
|
||||
OptionInt = *int
|
||||
)
|
||||
|
||||
var (
|
||||
True = getBool(true)
|
||||
False = getBool(false)
|
||||
|
||||
ZeroUint = Uint(0)
|
||||
ZeroInt = Int(0)
|
||||
|
||||
EmptyString = String("")
|
||||
)
|
||||
|
||||
func Uint(u uint) OptionUint {
|
||||
return &u
|
||||
}
|
||||
|
||||
func Int(i int) OptionInt {
|
||||
return &i
|
||||
}
|
||||
|
||||
func String(s string) OptionString {
|
||||
return &s
|
||||
}
|
||||
|
||||
func getBool(Bool bool) OptionBool {
|
||||
return &Bool
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
|
||||
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
|
||||
// from encoding/json.
|
||||
type Raw []byte
|
||||
|
||||
// Raw returns m as the JSON encoding of m.
|
||||
func (m Raw) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Raw) UnmarshalJSON(data []byte) error {
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Raw) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type Driver interface {
|
||||
Marshal(v interface{}) ([]byte, error)
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
|
@ -74,20 +15,43 @@ type Driver interface {
|
|||
EncodeStream(w io.Writer, v interface{}) error
|
||||
}
|
||||
|
||||
type Default struct{}
|
||||
type DefaultDriver struct{}
|
||||
|
||||
func (d Default) Marshal(v interface{}) ([]byte, error) {
|
||||
func (d DefaultDriver) Marshal(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
func (d Default) Unmarshal(data []byte, v interface{}) error {
|
||||
func (d DefaultDriver) Unmarshal(data []byte, v interface{}) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func (d Default) DecodeStream(r io.Reader, v interface{}) error {
|
||||
func (d DefaultDriver) DecodeStream(r io.Reader, v interface{}) error {
|
||||
return json.NewDecoder(r).Decode(v)
|
||||
}
|
||||
|
||||
func (d Default) EncodeStream(w io.Writer, v interface{}) error {
|
||||
func (d DefaultDriver) EncodeStream(w io.Writer, v interface{}) error {
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
// Default is the default JSON driver, which uses encoding/json.
|
||||
var Default Driver = DefaultDriver{}
|
||||
|
||||
// Marshal uses the default driver.
|
||||
func Marshal(v interface{}) ([]byte, error) {
|
||||
return Default.Marshal(v)
|
||||
}
|
||||
|
||||
// Unmarshal uses the default driver.
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
return Default.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// DecodeStream uses the default driver.
|
||||
func DecodeStream(r io.Reader, v interface{}) error {
|
||||
return Default.DecodeStream(r, v)
|
||||
}
|
||||
|
||||
// EncodeStream uses the default driver.
|
||||
func EncodeStream(w io.Writer, v interface{}) error {
|
||||
return Default.EncodeStream(w, v)
|
||||
}
|
||||
|
|
34
utils/json/option.go
Normal file
34
utils/json/option.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package json
|
||||
|
||||
type (
|
||||
OptionBool = *bool
|
||||
OptionString = *string
|
||||
OptionUint = *uint
|
||||
OptionInt = *int
|
||||
)
|
||||
|
||||
var (
|
||||
True = getBool(true)
|
||||
False = getBool(false)
|
||||
|
||||
ZeroUint = Uint(0)
|
||||
ZeroInt = Int(0)
|
||||
|
||||
EmptyString = String("")
|
||||
)
|
||||
|
||||
func Uint(u uint) OptionUint {
|
||||
return &u
|
||||
}
|
||||
|
||||
func Int(i int) OptionInt {
|
||||
return &i
|
||||
}
|
||||
|
||||
func String(s string) OptionString {
|
||||
return &s
|
||||
}
|
||||
|
||||
func getBool(Bool bool) OptionBool {
|
||||
return &Bool
|
||||
}
|
23
utils/json/raw.go
Normal file
23
utils/json/raw.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package json
|
||||
|
||||
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
|
||||
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
|
||||
// from encoding/json.
|
||||
type Raw []byte
|
||||
|
||||
// Raw returns m as the JSON encoding of m.
|
||||
func (m Raw) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Raw) UnmarshalJSON(data []byte) error {
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Raw) String() string {
|
||||
return string(m)
|
||||
}
|
19
utils/moreatomic/bool.go
Normal file
19
utils/moreatomic/bool.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package moreatomic
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type Bool struct {
|
||||
val uint32
|
||||
}
|
||||
|
||||
func (b *Bool) Get() bool {
|
||||
return atomic.LoadUint32(&b.val) == 1
|
||||
}
|
||||
|
||||
func (b *Bool) Set(val bool) {
|
||||
var x = uint32(0)
|
||||
if val {
|
||||
x = 1
|
||||
}
|
||||
atomic.StoreUint32(&b.val, x)
|
||||
}
|
33
utils/moreatomic/mutex.go
Normal file
33
utils/moreatomic/mutex.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package moreatomic
|
||||
|
||||
import "github.com/sasha-s/go-deadlock"
|
||||
|
||||
type BusyMutex struct {
|
||||
busy Bool
|
||||
mut deadlock.Mutex
|
||||
}
|
||||
|
||||
func (m *BusyMutex) TryLock() bool {
|
||||
if m.busy.Get() {
|
||||
return false
|
||||
}
|
||||
|
||||
m.mut.Lock()
|
||||
m.busy.Set(true)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *BusyMutex) IsBusy() bool {
|
||||
return m.busy.Get()
|
||||
}
|
||||
|
||||
func (m *BusyMutex) Lock() {
|
||||
m.mut.Lock()
|
||||
m.busy.Set(true)
|
||||
}
|
||||
|
||||
func (m *BusyMutex) Unlock() {
|
||||
m.busy.Set(false)
|
||||
m.mut.Unlock()
|
||||
}
|
16
utils/moreatomic/serial.go
Normal file
16
utils/moreatomic/serial.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package moreatomic
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type Serial struct {
|
||||
serial uint32
|
||||
}
|
||||
|
||||
func (s *Serial) Get() int {
|
||||
return int(atomic.LoadUint32(&s.serial))
|
||||
}
|
||||
|
||||
func (s *Serial) Incr() int {
|
||||
atomic.AddUint32(&s.serial, 1)
|
||||
return s.Get()
|
||||
}
|
17
utils/moreatomic/snowflake.go
Normal file
17
utils/moreatomic/snowflake.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package moreatomic
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
)
|
||||
|
||||
type Snowflake int64
|
||||
|
||||
func (s *Snowflake) Get() discord.Snowflake {
|
||||
return discord.Snowflake(atomic.LoadInt64((*int64)(s)))
|
||||
}
|
||||
|
||||
func (s *Snowflake) Set(id discord.Snowflake) {
|
||||
atomic.StoreInt64((*int64)(s), int64(id))
|
||||
}
|
18
utils/moreatomic/string.go
Normal file
18
utils/moreatomic/string.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package moreatomic
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type String struct {
|
||||
v atomic.Value
|
||||
}
|
||||
|
||||
func (s *String) Get() string {
|
||||
if v, ok := s.v.Load().(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *String) Set(str string) {
|
||||
s.v.Store(str)
|
||||
}
|
46
utils/moreatomic/time.go
Normal file
46
utils/moreatomic/time.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package moreatomic
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Time struct {
|
||||
unixnano int64
|
||||
}
|
||||
|
||||
func Now() *Time {
|
||||
return &Time{
|
||||
unixnano: time.Now().UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Time) Get() time.Time {
|
||||
nano := atomic.LoadInt64(&t.unixnano)
|
||||
return time.Unix(0, nano)
|
||||
}
|
||||
|
||||
func (t *Time) Set(time time.Time) {
|
||||
atomic.StoreInt64(&t.unixnano, time.UnixNano())
|
||||
}
|
||||
|
||||
// HasBeen checks if it has been this long since the last time. If yes, it will
|
||||
// set the time.
|
||||
func (t *Time) HasBeen(dura time.Duration) bool {
|
||||
now := time.Now()
|
||||
nano := atomic.LoadInt64(&t.unixnano)
|
||||
|
||||
// We have to be careful of zero values.
|
||||
if nano != 0 {
|
||||
// Subtract the duration to now. If subtracted now is before the stored
|
||||
// time, that means it hasn't been that long yet. We also have to be careful
|
||||
// of an unitialized time.
|
||||
if now.Add(-dura).Before(time.Unix(0, nano)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// It has been that long, so store the variable.
|
||||
t.Set(now)
|
||||
return true
|
||||
}
|
|
@ -35,7 +35,7 @@ type Connection interface {
|
|||
Listen() <-chan Event
|
||||
|
||||
// Send allows the caller to send bytes. Thread safety is a requirement.
|
||||
Send([]byte) error
|
||||
Send(context.Context, []byte) error
|
||||
|
||||
// Close should close the websocket connection. The connection will not be
|
||||
// reused.
|
||||
|
@ -67,12 +67,16 @@ type Conn struct {
|
|||
|
||||
var _ Connection = (*Conn)(nil)
|
||||
|
||||
func NewConn(driver json.Driver) *Conn {
|
||||
func NewConn() *Conn {
|
||||
return NewConnWithDriver(json.Default)
|
||||
}
|
||||
|
||||
func NewConnWithDriver(driver json.Driver) *Conn {
|
||||
return &Conn{
|
||||
Driver: driver,
|
||||
dialer: &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: DefaultTimeout,
|
||||
HandshakeTimeout: WSTimeout,
|
||||
EnableCompression: true,
|
||||
},
|
||||
// zlib: zlib.NewInflator(),
|
||||
|
@ -226,14 +230,27 @@ func (c *Conn) handle() ([]byte, error) {
|
|||
// return nil, errors.New("Unexpected binary message.")
|
||||
}
|
||||
|
||||
func (c *Conn) Send(b []byte) error {
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
// If websocket is already closed.
|
||||
if c.writes == nil {
|
||||
return ErrWebsocketClosed
|
||||
}
|
||||
|
||||
c.writes <- b
|
||||
return <-c.errors
|
||||
// Send the bytes.
|
||||
select {
|
||||
case c.writes <- b:
|
||||
// continue
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Receive the error.
|
||||
select {
|
||||
case err := <-c.errors:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Close() (err error) {
|
||||
|
|
108
utils/wsutil/heart.go
Normal file
108
utils/wsutil/heart.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/heart"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// TODO API
|
||||
type EventLoop interface {
|
||||
Heartbeat() error
|
||||
HandleOP(*OP) error
|
||||
// HandleEvent(ev Event) error
|
||||
}
|
||||
|
||||
// PacemakerLoop provides an event loop with a pacemaker.
|
||||
type PacemakerLoop struct {
|
||||
pacemaker *heart.Pacemaker // let's not copy this
|
||||
pacedeath chan error
|
||||
|
||||
events <-chan Event
|
||||
handler func(*OP) error
|
||||
|
||||
Extras ExtraHandlers
|
||||
|
||||
ErrorLog func(error)
|
||||
}
|
||||
|
||||
func NewLoop(heartrate time.Duration, evs <-chan Event, evl EventLoop) *PacemakerLoop {
|
||||
pacemaker := heart.NewPacemaker(heartrate, evl.Heartbeat)
|
||||
|
||||
return &PacemakerLoop{
|
||||
pacemaker: pacemaker,
|
||||
events: evs,
|
||||
handler: evl.HandleOP,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) errorLog(err error) {
|
||||
if p.ErrorLog == nil {
|
||||
WSDebug("Uncaught error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.ErrorLog(err)
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) Pace() error {
|
||||
return p.pacemaker.Pace()
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) Echo() {
|
||||
p.pacemaker.Echo()
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) Stop() {
|
||||
p.pacemaker.Stop()
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) Stopped() bool {
|
||||
return p == nil || p.pacedeath == nil
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) Run() error {
|
||||
// If the event loop is already running.
|
||||
if p.pacedeath != nil {
|
||||
return nil
|
||||
}
|
||||
// callers should explicitly handle waitgroups.
|
||||
p.pacedeath = p.pacemaker.StartAsync(nil)
|
||||
|
||||
defer func() {
|
||||
// mark pacedeath once done
|
||||
p.pacedeath = nil
|
||||
|
||||
WSDebug("Pacemaker loop has exited.")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-p.pacedeath:
|
||||
return errors.Wrap(err, "Pacemaker died, reconnecting")
|
||||
|
||||
case ev, ok := <-p.events:
|
||||
if !ok {
|
||||
// Events channel is closed. Kill the pacemaker manually and
|
||||
// die.
|
||||
p.pacemaker.Stop()
|
||||
return <-p.pacedeath
|
||||
}
|
||||
|
||||
o, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
p.errorLog(errors.Wrap(err, "Failed to decode OP"))
|
||||
continue // ignore
|
||||
}
|
||||
|
||||
// Check the events before handling.
|
||||
p.Extras.Check(o)
|
||||
|
||||
// Handle the event
|
||||
if err := p.handler(o); err != nil {
|
||||
p.errorLog(errors.Wrap(err, "Handler failed"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
168
utils/wsutil/op.go
Normal file
168
utils/wsutil/op.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
"github.com/diamondburned/arikawa/utils/moreatomic"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var ErrEmptyPayload = errors.New("Empty payload")
|
||||
|
||||
// OPCode is a generic type for websocket OP codes.
|
||||
type OPCode uint8
|
||||
|
||||
type OP struct {
|
||||
Code OPCode `json:"op"`
|
||||
Data json.Raw `json:"d,omitempty"`
|
||||
|
||||
// Only for Gateway Dispatch (op 0)
|
||||
Sequence int64 `json:"s,omitempty"`
|
||||
EventName string `json:"t,omitempty"`
|
||||
}
|
||||
|
||||
func (op *OP) UnmarshalData(v interface{}) error {
|
||||
return json.Unmarshal(op.Data, v)
|
||||
}
|
||||
|
||||
func DecodeOP(ev Event) (*OP, error) {
|
||||
if ev.Error != nil {
|
||||
return nil, ev.Error
|
||||
}
|
||||
|
||||
if len(ev.Data) == 0 {
|
||||
return nil, ErrEmptyPayload
|
||||
}
|
||||
|
||||
var op *OP
|
||||
if err := json.Unmarshal(ev.Data, &op); err != nil {
|
||||
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func AssertEvent(ev Event, code OPCode, v interface{}) (*OP, error) {
|
||||
op, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if op.Code != code {
|
||||
return op, fmt.Errorf(
|
||||
"Unexpected OP Code: %d, expected %d (%s)",
|
||||
op.Code, code, op.Data,
|
||||
)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(op.Data, v); err != nil {
|
||||
return op, errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
type EventHandler interface {
|
||||
HandleOP(op *OP) error
|
||||
}
|
||||
|
||||
func HandleEvent(h EventHandler, ev Event) error {
|
||||
o, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.HandleOP(o)
|
||||
}
|
||||
|
||||
// WaitForEvent blocks until fn() returns true. All incoming events are handled
|
||||
// regardless.
|
||||
func WaitForEvent(h EventHandler, ch <-chan Event, fn func(*OP) bool) error {
|
||||
for ev := range ch {
|
||||
o, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle the *OP first, in case it's an Invalid Session. This should
|
||||
// also prevent a race condition with things that need Ready after
|
||||
// Open().
|
||||
if err := h.HandleOP(o); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Are these events what we're looking for? If we've found the event,
|
||||
// return.
|
||||
if fn(o) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("Event not found and event channel is closed.")
|
||||
}
|
||||
|
||||
type ExtraHandlers struct {
|
||||
mutex sync.Mutex
|
||||
handlers map[uint32]*ExtraHandler
|
||||
serial uint32
|
||||
}
|
||||
|
||||
type ExtraHandler struct {
|
||||
Check func(*OP) bool
|
||||
send chan *OP
|
||||
|
||||
closed moreatomic.Bool
|
||||
}
|
||||
|
||||
func (ex *ExtraHandlers) Add(check func(*OP) bool) (<-chan *OP, func()) {
|
||||
handler := &ExtraHandler{
|
||||
Check: check,
|
||||
send: make(chan *OP),
|
||||
}
|
||||
|
||||
ex.mutex.Lock()
|
||||
defer ex.mutex.Unlock()
|
||||
|
||||
if ex.handlers == nil {
|
||||
ex.handlers = make(map[uint32]*ExtraHandler, 1)
|
||||
}
|
||||
|
||||
i := ex.serial
|
||||
ex.serial++
|
||||
|
||||
ex.handlers[i] = handler
|
||||
|
||||
return handler.send, func() {
|
||||
// Check the atomic bool before acquiring the mutex. Might help a bit in
|
||||
// performance.
|
||||
if handler.closed.Get() {
|
||||
return
|
||||
}
|
||||
|
||||
ex.mutex.Lock()
|
||||
defer ex.mutex.Unlock()
|
||||
|
||||
delete(ex.handlers, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Check runs and sends OP data. It is not thread-safe.
|
||||
func (ex *ExtraHandlers) Check(op *OP) {
|
||||
ex.mutex.Lock()
|
||||
defer ex.mutex.Unlock()
|
||||
|
||||
for i, handler := range ex.handlers {
|
||||
if handler.Check(op) {
|
||||
// Attempt to send.
|
||||
handler.send <- op
|
||||
|
||||
// Mark the handler as closed.
|
||||
handler.closed.Set(true)
|
||||
|
||||
// Delete the handler.
|
||||
delete(ex.handlers, i)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,15 +4,30 @@ package wsutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const DefaultTimeout = time.Minute
|
||||
var (
|
||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||
// before Gateway cancels and fails.
|
||||
WSTimeout = time.Minute
|
||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
||||
// timeout for the websocket.
|
||||
WSExtraReadTimeout = time.Second
|
||||
// WSDebug is used for extra debug logging. This is expected to behave
|
||||
// similarly to log.Println().
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Data []byte
|
||||
|
@ -25,12 +40,16 @@ type Websocket struct {
|
|||
Conn Connection
|
||||
Addr string
|
||||
|
||||
// Timeout for connecting and writing to the Websocket, uses default
|
||||
// WSTimeout (global).
|
||||
Timeout time.Duration
|
||||
|
||||
SendLimiter *rate.Limiter
|
||||
DialLimiter *rate.Limiter
|
||||
}
|
||||
|
||||
func New(addr string) *Websocket {
|
||||
return NewCustom(NewConn(json.Default{}), addr)
|
||||
return NewCustom(NewConn(), addr)
|
||||
}
|
||||
|
||||
// NewCustom creates a new undialed Websocket.
|
||||
|
@ -39,12 +58,21 @@ func NewCustom(conn Connection, addr string) *Websocket {
|
|||
Conn: conn,
|
||||
Addr: addr,
|
||||
|
||||
Timeout: WSTimeout,
|
||||
|
||||
SendLimiter: NewSendLimiter(),
|
||||
DialLimiter: NewDialLimiter(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||
if ws.Timeout > 0 {
|
||||
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
|
||||
defer cancel()
|
||||
|
||||
ctx = tctx
|
||||
}
|
||||
|
||||
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
||||
// Expired, fatal error
|
||||
return errors.Wrap(err, "Failed to wait")
|
||||
|
@ -65,11 +93,16 @@ func (ws *Websocket) Listen() <-chan Event {
|
|||
}
|
||||
|
||||
func (ws *Websocket) Send(b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(context.Background()); err != nil {
|
||||
return ws.SendContext(context.Background(), b)
|
||||
}
|
||||
|
||||
// SendContext is a beta API.
|
||||
func (ws *Websocket) SendContext(ctx context.Context, b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
return ws.Conn.Send(b)
|
||||
return ws.Conn.Send(ctx, b)
|
||||
}
|
||||
|
||||
func (ws *Websocket) Close() error {
|
||||
|
|
62
voice/README.md
Normal file
62
voice/README.md
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Voice
|
||||
|
||||
## Terminology
|
||||
* **Discord Gateway** - The standard Discord Gateway users connect to and receive update events from
|
||||
* **Discord Voice Gateway** - The Discord Voice gateway that allows voice connections to be configured
|
||||
* **Voice Server** - What the Discord Voice Gateway allows connection to for sending of Opus voice packets over UDP
|
||||
* **Voice Packet** - Opus encoded UDP packet that contains audio
|
||||
* **Application** - Could be a custom Discord Client or Bot (nothing that is within this package)
|
||||
* **Library** - Code within this package
|
||||
|
||||
## Connection Flow
|
||||
* The **application** would get a new `*Voice` instance by calling `NewVoice()`
|
||||
* When the **application** wants to connect to a voice channel they would call `JoinChannel()` on
|
||||
the stored `*Voice` instance
|
||||
|
||||
---
|
||||
|
||||
* The **library** sends a [Voice State Update](https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information-gateway-voice-state-update-example)
|
||||
to the **Discord Gateway**
|
||||
* The **library** waits until it receives a [Voice Server Update](https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information-example-voice-server-update-payload)
|
||||
from the **Discord Gateway**
|
||||
* Once a *Voice Server Update* event is received, a new connection is opened to the **Discord Voice Gateway**
|
||||
|
||||
---
|
||||
|
||||
* When the connection is opened an [Identify Event](https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload)
|
||||
or [Resume Event](https://discordapp.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload)
|
||||
is sent to the **Discord Voice Gateway** depending on if the **library** is reconnecting
|
||||
* The **Discord Voice Gateway** should respond with a [Hello Event](https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-hello-payload-since-v3)
|
||||
which will be used to create a new `*gateway.Pacemaker` and start sending heartbeats to the **Discord Voice Gateway**
|
||||
|
||||
---
|
||||
|
||||
* The **Discord Voice Gateway** should also respond with a [Ready Event](https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-ready-payload)
|
||||
once the connection is opened, providing the required information to connect to a **Voice Server**
|
||||
* Using the information provided in the *Ready Event*, a new UDP connection is opened to the **Voice Server**
|
||||
and [IP Discovery](https://discordapp.com/developers/docs/topics/voice-connections#ip-discovery) occurs
|
||||
* After *IP Discovery* returns the **Application**'s external ip and port it connected to the **Voice Server**
|
||||
with, the **library** sends a [Select Protocol Event](https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload)
|
||||
to the **Discord Voice Gateway**
|
||||
* The **library** waits until it receives a [Session Description Event](https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-session-description-payload)
|
||||
from the **Discord Voice Gateway**
|
||||
* Once the *Session Description Event* is received, [Speaking Events](https://discordapp.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload)
|
||||
and **Voice Packets** can begin to be sent to the **Discord Voice Gateway** and **Voice Server** respectively
|
||||
|
||||
## Usage
|
||||
* The **application** would get a new `*Voice` instance by calling `NewVoice()` and keep it
|
||||
stored for when it needs to open voice connections
|
||||
* When the **application** wants to connect to a voice channel they would call `JoinChannel()` on
|
||||
the stored `*Voice` instance
|
||||
* `JoinChannel()` will block as it follows the [Connection Flow](#connection-flow), returning an
|
||||
`error` if one occurs and a `*Connection` if it was successful
|
||||
* The **application** should now call `*Connection#Speaking()` with the wanted [voice flag](https://discordapp.com/developers/docs/topics/voice-connections#speaking)
|
||||
(`Microphone`, `Soundshare`, `Priority`)
|
||||
* The **application** can now send **Voice Packets** to the `*Connection#OpusSend` channel
|
||||
which will be sent to the **Voice Server**
|
||||
* When the **application** wants to stop sending **Voice Packets** they should call
|
||||
`*Connection#StopSpeaking()`, any required voice cleanup (closing streams, etc), then
|
||||
`*Connection#Disconnect()`
|
||||
|
||||
## Examples
|
||||
###### Coming SoonTM
|
167
voice/integration_test.go
Normal file
167
voice/integration_test.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
// +build integration
|
||||
|
||||
package voice
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/state"
|
||||
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/voice/voicegateway"
|
||||
)
|
||||
|
||||
type testConfig struct {
|
||||
BotToken string
|
||||
VoiceChID discord.Snowflake
|
||||
}
|
||||
|
||||
func mustConfig(t *testing.T) testConfig {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
t.Fatal("Missing $BOT_TOKEN")
|
||||
}
|
||||
|
||||
var sid = os.Getenv("VOICE_ID")
|
||||
if sid == "" {
|
||||
t.Fatal("Missing $VOICE_ID")
|
||||
}
|
||||
|
||||
id, err := discord.ParseSnowflake(sid)
|
||||
if err != nil {
|
||||
t.Fatal("Invalid $VOICE_ID:", err)
|
||||
}
|
||||
|
||||
return testConfig{
|
||||
BotToken: token,
|
||||
VoiceChID: id,
|
||||
}
|
||||
}
|
||||
|
||||
// file is only a few bytes lolmao
|
||||
func nicoReadTo(t *testing.T, dst io.Writer) {
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open nico.dca:", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
f.Close()
|
||||
})
|
||||
|
||||
var lenbuf [4]byte
|
||||
|
||||
for {
|
||||
if _, err := io.ReadFull(f, lenbuf[:]); !catchRead(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Read the integer
|
||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||
|
||||
// Copy the frame.
|
||||
if _, err := io.CopyN(dst, f, framelen); !catchRead(t, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func catchRead(t *testing.T, err error) bool {
|
||||
t.Helper()
|
||||
|
||||
if err == io.EOF {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("Failed to read:", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
config := mustConfig(t)
|
||||
|
||||
wsutil.WSDebug = func(v ...interface{}) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
caller := file + ":" + strconv.Itoa(line)
|
||||
log.Println(append([]interface{}{caller}, v...)...)
|
||||
}
|
||||
|
||||
// heart.Debug = func(v ...interface{}) {
|
||||
// log.Println(append([]interface{}{"Pacemaker:"}, v...)...)
|
||||
// }
|
||||
|
||||
s, err := state.New("Bot " + config.BotToken)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create a new session:", err)
|
||||
}
|
||||
|
||||
v := NewVoice(s)
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
t.Fatal("Failed to connect:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// Validate the given voice channel.
|
||||
c, err := s.Channel(config.VoiceChID)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to get channel:", err)
|
||||
}
|
||||
if c.Type != discord.GuildVoice {
|
||||
t.Fatal("Channel isn't a guild voice channel.")
|
||||
}
|
||||
|
||||
// Grab a timer to benchmark things.
|
||||
finish := timer()
|
||||
|
||||
// Join the voice channel.
|
||||
vs, err := v.JoinChannel(c.GuildID, c.ID, false, false)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to join channel:", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Println("Disconnecting from the voice channel.")
|
||||
if err := vs.Disconnect(); err != nil {
|
||||
t.Fatal("Failed to disconnect:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
finish("joining the voice channel")
|
||||
|
||||
// Trigger speaking.
|
||||
if err := vs.Speaking(voicegateway.Microphone); err != nil {
|
||||
t.Fatal("Failed to start speaking:", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Println("Stopping speaking.") // sounds grammatically wrong
|
||||
if err := vs.StopSpeaking(); err != nil {
|
||||
t.Fatal("Failed to stop speaking:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
finish("sending the speaking command")
|
||||
|
||||
// Copy the audio?
|
||||
nicoReadTo(t, vs)
|
||||
|
||||
finish("copying the audio")
|
||||
}
|
||||
|
||||
// simple shitty benchmark thing
|
||||
func timer() func(finished string) {
|
||||
var then = time.Now()
|
||||
|
||||
return func(finished string) {
|
||||
now := time.Now()
|
||||
log.Println("Finished", finished+", took", now.Sub(then))
|
||||
then = now
|
||||
}
|
||||
}
|
11
voice/packet.go
Normal file
11
voice/packet.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package voice
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#encrypting-and-sending-voice
|
||||
type Packet struct {
|
||||
Version byte // Single byte value of 0x80 - 1 byte
|
||||
Type byte // Single byte value of 0x78 - 1 byte
|
||||
Sequence uint16 // Unsigned short (big endian) - 4 bytes
|
||||
Timestamp uint32 // Unsigned integer (big endian) - 4 bytes
|
||||
SSRC uint32 // Unsigned integer (big endian) - 4 bytes
|
||||
Opus []byte // Binary data
|
||||
}
|
258
voice/session.go
Normal file
258
voice/session.go
Normal file
|
@ -0,0 +1,258 @@
|
|||
package voice
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/session"
|
||||
"github.com/diamondburned/arikawa/utils/moreatomic"
|
||||
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/voice/udp"
|
||||
"github.com/diamondburned/arikawa/voice/voicegateway"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const Protocol = "xsalsa20_poly1305"
|
||||
|
||||
var OpusSilence = [...]byte{0xF8, 0xFF, 0xFE}
|
||||
|
||||
type Session struct {
|
||||
session *session.Session
|
||||
state voicegateway.State
|
||||
|
||||
ErrorLog func(err error)
|
||||
|
||||
// Filled by events.
|
||||
// sessionID string
|
||||
// token string
|
||||
// endpoint string
|
||||
|
||||
// joining determines the behavior of incoming event callbacks (Update).
|
||||
// If this is true, incoming events will just send into Updated channels. If
|
||||
// false, events will trigger a reconnection.
|
||||
joining moreatomic.Bool
|
||||
incoming chan struct{} // used only when joining == true
|
||||
|
||||
mut sync.RWMutex
|
||||
|
||||
// TODO: expose getters mutex-guarded.
|
||||
gateway *voicegateway.Gateway
|
||||
voiceUDP *udp.Connection
|
||||
|
||||
muted bool
|
||||
deafened bool
|
||||
speaking bool
|
||||
}
|
||||
|
||||
func NewSession(ses *session.Session, userID discord.Snowflake) *Session {
|
||||
return &Session{
|
||||
session: ses,
|
||||
state: voicegateway.State{
|
||||
UserID: userID,
|
||||
},
|
||||
ErrorLog: func(err error) {},
|
||||
incoming: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) UpdateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||
// If this is true, then mutex is acquired already.
|
||||
if s.joining.Get() {
|
||||
s.state.Endpoint = ev.Endpoint
|
||||
s.state.Token = ev.Token
|
||||
|
||||
s.incoming <- struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
// Reconnect.
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
s.state.Endpoint = ev.Endpoint
|
||||
s.state.Token = ev.Token
|
||||
|
||||
if err := s.reconnect(); err != nil {
|
||||
s.ErrorLog(errors.Wrap(err, "Failed to reconnect after voice server update"))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
||||
if s.state.UserID != ev.UserID {
|
||||
// Not our state.
|
||||
return
|
||||
}
|
||||
|
||||
// If this is true, then mutex is acquired already.
|
||||
if s.joining.Get() {
|
||||
s.state.SessionID = ev.SessionID
|
||||
s.state.ChannelID = ev.ChannelID
|
||||
|
||||
s.incoming <- struct{}{}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) error {
|
||||
// Acquire the mutex during join, locking during IO as well.
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
// Set that we're joining.
|
||||
s.joining.Set(true)
|
||||
defer s.joining.Set(false) // reset when done
|
||||
|
||||
// ensure gateeway and voiceUDP is already closed.
|
||||
s.ensureClosed()
|
||||
|
||||
// Set the state.
|
||||
s.state.ChannelID = cID
|
||||
s.state.GuildID = gID
|
||||
|
||||
s.muted = muted
|
||||
s.deafened = deafened
|
||||
s.speaking = false
|
||||
|
||||
// Ensure that if `cID` is zero that it passes null to the update event.
|
||||
var channelID *discord.Snowflake
|
||||
if cID.Valid() {
|
||||
channelID = &cID
|
||||
}
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
|
||||
// Send a Voice State Update event to the gateway.
|
||||
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{
|
||||
GuildID: gID,
|
||||
ChannelID: channelID,
|
||||
SelfMute: muted,
|
||||
SelfDeaf: deafened,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to send Voice State Update event")
|
||||
}
|
||||
|
||||
// Wait for replies. The above command should reply with these 2 events.
|
||||
<-s.incoming
|
||||
<-s.incoming
|
||||
|
||||
// These 2 methods should've updated s.state before sending into these
|
||||
// channels. Since s.state is already filled, we can go ahead and connect.
|
||||
|
||||
return s.reconnect()
|
||||
}
|
||||
|
||||
// reconnect uses the current state to reconnect to a new gateway and UDP
|
||||
// connection.
|
||||
func (s *Session) reconnect() (err error) {
|
||||
s.gateway = voicegateway.New(s.state)
|
||||
|
||||
// Open the voice gateway. The function will block until Ready is received.
|
||||
if err := s.gateway.Open(); err != nil {
|
||||
return errors.Wrap(err, "Failed to open voice gateway")
|
||||
}
|
||||
|
||||
// Get the Ready event.
|
||||
voiceReady := s.gateway.Ready()
|
||||
|
||||
// Prepare the UDP voice connection.
|
||||
s.voiceUDP, err = udp.DialConnection(voiceReady.Addr(), voiceReady.SSRC)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to open voice UDP connection")
|
||||
}
|
||||
|
||||
// Get the session description from the voice gateway.
|
||||
d, err := s.gateway.SessionDescription(voicegateway.SelectProtocol{
|
||||
Protocol: "udp",
|
||||
Data: voicegateway.SelectProtocolData{
|
||||
Address: s.voiceUDP.GatewayIP,
|
||||
Port: s.voiceUDP.GatewayPort,
|
||||
Mode: Protocol,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to select protocol")
|
||||
}
|
||||
|
||||
// Start the UDP loop.
|
||||
go s.voiceUDP.Start(&d.SecretKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Speaking tells Discord we're speaking. This calls
|
||||
// (*voicegateway.Gateway).Speaking().
|
||||
func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
|
||||
// TODO: maybe we don't need to mutex protect IO.
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
return s.gateway.Speaking(flag)
|
||||
}
|
||||
|
||||
func (s *Session) StopSpeaking() error {
|
||||
// Send 5 frames of silence.
|
||||
for i := 0; i < 5; i++ {
|
||||
if _, err := s.Write(OpusSilence[:]); err != nil {
|
||||
return errors.Wrapf(err, "Failed to send frame %d", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) Write(b []byte) (int, error) {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
if s.voiceUDP == nil {
|
||||
return 0, ErrCannotSend
|
||||
}
|
||||
return s.voiceUDP.Write(b)
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect() error {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
// If we're already closed.
|
||||
if s.gateway == nil && s.voiceUDP == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Notify Discord that we're leaving. This will send a
|
||||
// VoiceStateUpdateEvent, in which our handler will promptly remove the
|
||||
// session from the map.
|
||||
|
||||
err := s.session.Gateway.UpdateVoiceState(gateway.UpdateVoiceStateData{
|
||||
GuildID: s.state.GuildID,
|
||||
ChannelID: nil,
|
||||
SelfMute: true,
|
||||
SelfDeaf: true,
|
||||
})
|
||||
|
||||
s.ensureClosed()
|
||||
// wrap returns nil if err is nil
|
||||
return errors.Wrap(err, "Failed to update voice state")
|
||||
}
|
||||
|
||||
// close ensures everything is closed. It does not acquire the mutex.
|
||||
func (s *Session) ensureClosed() {
|
||||
// If we're already closed.
|
||||
if s.gateway == nil && s.voiceUDP == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Disconnect the UDP connection.
|
||||
if s.voiceUDP != nil {
|
||||
s.voiceUDP.Close()
|
||||
s.voiceUDP = nil
|
||||
}
|
||||
|
||||
// Disconnect the voice gateway, ignoring the error.
|
||||
if s.gateway != nil {
|
||||
if err := s.gateway.Close(); err != nil {
|
||||
wsutil.WSDebug("Uncaught voice gateway close error:", err)
|
||||
}
|
||||
s.gateway = nil
|
||||
}
|
||||
}
|
BIN
voice/testdata/nico.dca
vendored
Normal file
BIN
voice/testdata/nico.dca
vendored
Normal file
Binary file not shown.
158
voice/udp/udp.go
Normal file
158
voice/udp/udp.go
Normal file
|
@ -0,0 +1,158 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/nacl/secretbox"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
GatewayIP string
|
||||
GatewayPort uint16
|
||||
|
||||
ssrc uint32
|
||||
|
||||
sequence uint16
|
||||
timestamp uint32
|
||||
nonce [24]byte
|
||||
|
||||
conn *net.UDPConn
|
||||
close chan struct{}
|
||||
closed chan struct{}
|
||||
|
||||
send chan []byte
|
||||
reply chan error
|
||||
}
|
||||
|
||||
func DialConnection(addr string, ssrc uint32) (*Connection, error) {
|
||||
// Resolve the host.
|
||||
a, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to resolve host")
|
||||
}
|
||||
|
||||
// Create a new UDP connection.
|
||||
conn, err := net.DialUDP("udp", nil, a)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to dial host")
|
||||
}
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#ip-discovery
|
||||
ssrcBuffer := [70]byte{
|
||||
0x1, 0x2,
|
||||
}
|
||||
binary.BigEndian.PutUint16(ssrcBuffer[2:4], 70)
|
||||
binary.BigEndian.PutUint32(ssrcBuffer[4:8], ssrc)
|
||||
|
||||
_, err = conn.Write(ssrcBuffer[:])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to write SSRC buffer")
|
||||
}
|
||||
|
||||
var ipBuffer [70]byte
|
||||
|
||||
// ReadFull makes sure to read all 70 bytes.
|
||||
_, err = io.ReadFull(conn, ipBuffer[:])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to read IP buffer")
|
||||
}
|
||||
|
||||
ipbody := ipBuffer[4:68]
|
||||
|
||||
nullPos := bytes.Index(ipbody, []byte{'\x00'})
|
||||
if nullPos < 0 {
|
||||
return nil, errors.New("UDP IP discovery did not contain a null terminator")
|
||||
}
|
||||
|
||||
ip := ipbody[:nullPos]
|
||||
port := binary.LittleEndian.Uint16(ipBuffer[68:70])
|
||||
|
||||
return &Connection{
|
||||
GatewayIP: string(ip),
|
||||
GatewayPort: port,
|
||||
|
||||
ssrc: ssrc,
|
||||
conn: conn,
|
||||
send: make(chan []byte),
|
||||
reply: make(chan error),
|
||||
close: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Start(secret *[32]byte) {
|
||||
header := [12]byte{
|
||||
0: 0x80, // Version + Flags
|
||||
1: 0x78, // Payload Type
|
||||
// [2:4] // Sequence
|
||||
// [4:8] // Timestamp
|
||||
}
|
||||
|
||||
// Write SSRC to the header.
|
||||
binary.BigEndian.PutUint32(header[8:12], c.ssrc) // SSRC
|
||||
|
||||
// 50 sends per second, 960 samples each at 48kHz
|
||||
frequency := time.NewTicker(time.Millisecond * 20)
|
||||
defer frequency.Stop()
|
||||
|
||||
var b []byte
|
||||
var ok bool
|
||||
|
||||
// Close these channels at the end so Write() doesn't block.
|
||||
defer func() {
|
||||
close(c.send)
|
||||
close(c.closed)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case b, ok = <-c.send:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-c.close:
|
||||
return
|
||||
}
|
||||
|
||||
// Write a new sequence.
|
||||
binary.BigEndian.PutUint16(header[2:4], c.sequence)
|
||||
c.sequence++
|
||||
|
||||
binary.BigEndian.PutUint32(header[4:8], c.timestamp)
|
||||
c.timestamp += 960 // Samples
|
||||
|
||||
copy(c.nonce[:], header[:])
|
||||
|
||||
toSend := secretbox.Seal(header[:], b, &c.nonce, secret)
|
||||
|
||||
select {
|
||||
case <-frequency.C:
|
||||
case <-c.close:
|
||||
return
|
||||
}
|
||||
|
||||
_, err := c.conn.Write(toSend)
|
||||
c.reply <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
close(c.close)
|
||||
<-c.closed
|
||||
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Write sends bytes into the voice UDP connection.
|
||||
func (c *Connection) Write(b []byte) (int, error) {
|
||||
c.send <- b
|
||||
if err := <-c.reply; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
136
voice/voice.go
Normal file
136
voice/voice.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
// Package voice is coming soon to an arikawa near you!
|
||||
package voice
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/state"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// defaultErrorHandler is the default error handler
|
||||
defaultErrorHandler = func(err error) { log.Println("Voice gateway error:", err) }
|
||||
|
||||
// ErrCannotSend is an error when audio is sent to a closed channel.
|
||||
ErrCannotSend = errors.New("cannot send audio to closed channel")
|
||||
)
|
||||
|
||||
// Voice represents a Voice Repository used for managing voice sessions.
|
||||
type Voice struct {
|
||||
*state.State
|
||||
|
||||
// Session holds all of the active voice sessions.
|
||||
mapmutex sync.Mutex
|
||||
sessions map[discord.Snowflake]*Session // guildID:Session
|
||||
|
||||
// ErrorLog will be called when an error occurs (defaults to log.Println)
|
||||
ErrorLog func(err error)
|
||||
}
|
||||
|
||||
// NewVoice creates a new Voice repository wrapped around a state.
|
||||
func NewVoice(s *state.State) *Voice {
|
||||
v := &Voice{
|
||||
State: s,
|
||||
sessions: make(map[discord.Snowflake]*Session),
|
||||
ErrorLog: defaultErrorHandler,
|
||||
}
|
||||
|
||||
// Add the required event handlers to the session.
|
||||
s.AddHandler(v.onVoiceStateUpdate)
|
||||
s.AddHandler(v.onVoiceServerUpdate)
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// onVoiceStateUpdate receives VoiceStateUpdateEvents from the gateway
|
||||
// to keep track of the current user's voice state.
|
||||
func (v *Voice) onVoiceStateUpdate(e *gateway.VoiceStateUpdateEvent) {
|
||||
// Get the current user.
|
||||
me, err := v.Me()
|
||||
if err != nil {
|
||||
v.ErrorLog(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Ignore the event if it is an update from another user.
|
||||
if me.ID != e.UserID {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the stored voice session for the given guild.
|
||||
vs, ok := v.GetSession(e.GuildID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Do what we must.
|
||||
vs.UpdateState(e)
|
||||
|
||||
// Remove the connection if the current user has disconnected.
|
||||
if e.ChannelID == 0 {
|
||||
v.RemoveSession(e.GuildID)
|
||||
}
|
||||
}
|
||||
|
||||
// onVoiceServerUpdate receives VoiceServerUpdateEvents from the gateway
|
||||
// to manage the current user's voice connections.
|
||||
func (v *Voice) onVoiceServerUpdate(e *gateway.VoiceServerUpdateEvent) {
|
||||
// Get the stored voice session for the given guild.
|
||||
vs, ok := v.GetSession(e.GuildID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Do what we must.
|
||||
vs.UpdateServer(e)
|
||||
}
|
||||
|
||||
// GetSession gets a session for a guild with a read lock.
|
||||
func (v *Voice) GetSession(guildID discord.Snowflake) (*Session, bool) {
|
||||
v.mapmutex.Lock()
|
||||
defer v.mapmutex.Unlock()
|
||||
|
||||
// For some reason you cannot just put `return v.sessions[]` and return a bool D:
|
||||
conn, ok := v.sessions[guildID]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// RemoveSession removes a session.
|
||||
func (v *Voice) RemoveSession(guildID discord.Snowflake) {
|
||||
v.mapmutex.Lock()
|
||||
defer v.mapmutex.Unlock()
|
||||
|
||||
// Ensure that the session is disconnected.
|
||||
if ses, ok := v.sessions[guildID]; ok {
|
||||
ses.Disconnect()
|
||||
}
|
||||
|
||||
delete(v.sessions, guildID)
|
||||
}
|
||||
|
||||
// JoinChannel joins the specified channel in the specified guild.
|
||||
func (v *Voice) JoinChannel(gID, cID discord.Snowflake, muted, deafened bool) (*Session, error) {
|
||||
// Get the stored voice session for the given guild.
|
||||
conn, ok := v.GetSession(gID)
|
||||
|
||||
// Create a new voice session if one does not exist.
|
||||
if !ok {
|
||||
u, err := v.Me()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get self")
|
||||
}
|
||||
|
||||
conn = NewSession(v.Session, u.ID)
|
||||
|
||||
v.mapmutex.Lock()
|
||||
v.sessions[gID] = conn
|
||||
v.mapmutex.Unlock()
|
||||
}
|
||||
|
||||
// Connect.
|
||||
return conn, conn.JoinChannel(gID, cID, muted, deafened)
|
||||
}
|
125
voice/voicegateway/commands.go
Normal file
125
voice/voicegateway/commands.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package voicegateway
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMissingForIdentify is an error when we are missing information to identify.
|
||||
ErrMissingForIdentify = errors.New("missing GuildID, UserID, SessionID, or Token for identify")
|
||||
|
||||
// ErrMissingForResume is an error when we are missing information to resume.
|
||||
ErrMissingForResume = errors.New("missing GuildID, SessionID, or Token for resuming")
|
||||
)
|
||||
|
||||
// OPCode 0
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload
|
||||
type IdentifyData struct {
|
||||
GuildID discord.Snowflake `json:"server_id"` // yes, this should be "server_id"
|
||||
UserID discord.Snowflake `json:"user_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// Identify sends an Identify operation (opcode 0) to the Gateway Gateway.
|
||||
func (c *Gateway) Identify() error {
|
||||
guildID := c.state.GuildID
|
||||
userID := c.state.UserID
|
||||
sessionID := c.state.SessionID
|
||||
token := c.state.Token
|
||||
|
||||
if guildID == 0 || userID == 0 || sessionID == "" || token == "" {
|
||||
return ErrMissingForIdentify
|
||||
}
|
||||
|
||||
return c.Send(IdentifyOP, IdentifyData{
|
||||
GuildID: guildID,
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
Token: token,
|
||||
})
|
||||
}
|
||||
|
||||
// OPCode 1
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload
|
||||
type SelectProtocol struct {
|
||||
Protocol string `json:"protocol"`
|
||||
Data SelectProtocolData `json:"data"`
|
||||
}
|
||||
|
||||
type SelectProtocolData struct {
|
||||
Address string `json:"address"`
|
||||
Port uint16 `json:"port"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
// SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
|
||||
func (c *Gateway) SelectProtocol(data SelectProtocol) error {
|
||||
return c.Send(SelectProtocolOP, data)
|
||||
}
|
||||
|
||||
// OPCode 3
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
|
||||
type Heartbeat uint64
|
||||
|
||||
// Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
|
||||
func (c *Gateway) Heartbeat() error {
|
||||
return c.Send(HeartbeatOP, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#speaking
|
||||
type SpeakingFlag uint64
|
||||
|
||||
const (
|
||||
Microphone SpeakingFlag = 1 << iota
|
||||
Soundshare
|
||||
Priority
|
||||
)
|
||||
|
||||
// OPCode 5
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload
|
||||
type SpeakingData struct {
|
||||
Speaking SpeakingFlag `json:"speaking"`
|
||||
Delay int `json:"delay"`
|
||||
SSRC uint32 `json:"ssrc"`
|
||||
}
|
||||
|
||||
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
|
||||
func (c *Gateway) Speaking(flag SpeakingFlag) error {
|
||||
// How do we allow a user to stop speaking?
|
||||
// Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation
|
||||
|
||||
return c.Send(SpeakingOP, SpeakingData{
|
||||
Speaking: flag,
|
||||
Delay: 0,
|
||||
SSRC: c.ready.SSRC,
|
||||
})
|
||||
}
|
||||
|
||||
// OPCode 7
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload
|
||||
type ResumeData struct {
|
||||
GuildID discord.Snowflake `json:"server_id"` // yes, this should be "server_id"
|
||||
SessionID string `json:"session_id"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// Resume sends a Resume operation (opcode 7) to the Gateway Gateway.
|
||||
func (c *Gateway) Resume() error {
|
||||
guildID := c.state.GuildID
|
||||
sessionID := c.state.SessionID
|
||||
token := c.state.Token
|
||||
|
||||
if !guildID.Valid() || sessionID == "" || token == "" {
|
||||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
return c.Send(ResumeOP, ResumeData{
|
||||
GuildID: guildID,
|
||||
SessionID: sessionID,
|
||||
Token: token,
|
||||
})
|
||||
}
|
52
voice/voicegateway/events.go
Normal file
52
voice/voicegateway/events.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package voicegateway
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
)
|
||||
|
||||
// OPCode 2
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-ready-payload
|
||||
type ReadyEvent struct {
|
||||
SSRC uint32 `json:"ssrc"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Modes []string `json:"modes"`
|
||||
Experiments []string `json:"experiments"`
|
||||
|
||||
// From Discord's API Docs:
|
||||
//
|
||||
// `heartbeat_interval` here is an erroneous field and should be ignored.
|
||||
// The correct `heartbeat_interval` value comes from the Hello payload.
|
||||
|
||||
// HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
|
||||
}
|
||||
|
||||
func (r ReadyEvent) Addr() string {
|
||||
return r.IP + ":" + strconv.Itoa(r.Port)
|
||||
}
|
||||
|
||||
// OPCode 4
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-session-description-payload
|
||||
type SessionDescriptionEvent struct {
|
||||
Mode string `json:"mode"`
|
||||
SecretKey [32]byte `json:"secret_key"`
|
||||
}
|
||||
|
||||
// OPCode 5
|
||||
type SpeakingEvent SpeakingData
|
||||
|
||||
// OPCode 6
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-ack-payload
|
||||
type HeartbeatACKEvent uint64
|
||||
|
||||
// OPCode 8
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#heartbeating-example-hello-payload-since-v3
|
||||
type HelloEvent struct {
|
||||
HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
|
||||
}
|
||||
|
||||
// OPCode 9
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resumed-payload
|
||||
type ResumedEvent struct{}
|
311
voice/voicegateway/gateway.go
Normal file
311
voice/voicegateway/gateway.go
Normal file
|
@ -0,0 +1,311 @@
|
|||
//
|
||||
// For the brave souls who get this far: You are the chosen ones,
|
||||
// the valiant knights of programming who toil away, without rest,
|
||||
// fixing our most awful code. To you, true saviors, kings of men,
|
||||
// I say this: never gonna give you up, never gonna let you down,
|
||||
// never gonna run around and desert you. Never gonna make you cry,
|
||||
// never gonna say goodbye. Never gonna tell a lie and hurt you.
|
||||
//
|
||||
|
||||
package voicegateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
"github.com/diamondburned/arikawa/utils/moreatomic"
|
||||
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// Version represents the current version of the Discord Gateway Gateway this package uses.
|
||||
Version = "4"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoSessionID = errors.New("no sessionID was received")
|
||||
ErrNoEndpoint = errors.New("no endpoint was received")
|
||||
)
|
||||
|
||||
// State contains state information of a voice gateway.
|
||||
type State struct {
|
||||
GuildID discord.Snowflake
|
||||
ChannelID discord.Snowflake
|
||||
UserID discord.Snowflake
|
||||
|
||||
SessionID string
|
||||
Token string
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
// Gateway represents a Discord Gateway Gateway connection.
|
||||
type Gateway struct {
|
||||
state State // constant
|
||||
|
||||
mutex sync.RWMutex
|
||||
ready ReadyEvent
|
||||
|
||||
ws *wsutil.Websocket
|
||||
|
||||
Timeout time.Duration
|
||||
reconnect moreatomic.Bool
|
||||
|
||||
EventLoop *wsutil.PacemakerLoop
|
||||
|
||||
// ErrorLog will be called when an error occurs (defaults to log.Println)
|
||||
ErrorLog func(err error)
|
||||
// AfterClose is called after each close. Error can be non-nil, as this is
|
||||
// called even when the Gateway is gracefully closed. It's used mainly for
|
||||
// reconnections or any type of connection interruptions. (defaults to noop)
|
||||
AfterClose func(err error)
|
||||
|
||||
// Filled by methods, internal use
|
||||
waitGroup *sync.WaitGroup
|
||||
}
|
||||
|
||||
func New(state State) *Gateway {
|
||||
return &Gateway{
|
||||
state: state,
|
||||
Timeout: wsutil.WSTimeout,
|
||||
ErrorLog: wsutil.WSError,
|
||||
AfterClose: func(error) {},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: get rid of
|
||||
func (c *Gateway) Ready() ReadyEvent {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
return c.ready
|
||||
}
|
||||
|
||||
// Open shouldn't be used, but JoinServer instead.
|
||||
func (c *Gateway) Open() error {
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
|
||||
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
|
||||
|
||||
wsutil.WSDebug("Connecting to voice endpoint (endpoint=" + endpoint + ")")
|
||||
c.ws = wsutil.New(endpoint)
|
||||
|
||||
// Create a new context with a timeout for the connection.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Connect to the Gateway Gateway.
|
||||
if err := c.ws.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "Failed to connect to voice gateway")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Trying to start...")
|
||||
|
||||
// Try to start or resume the connection.
|
||||
if err := c.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start .
|
||||
func (c *Gateway) start() error {
|
||||
if err := c.__start(); err != nil {
|
||||
wsutil.WSDebug("Start failed: ", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
// pacemaker hasn't started yet.
|
||||
if err := c.Close(); err != nil {
|
||||
wsutil.WSDebug("Failed to close after start fail: ", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// this function blocks until READY.
|
||||
func (c *Gateway) __start() error {
|
||||
// Make a new WaitGroup for use in background loops:
|
||||
c.waitGroup = new(sync.WaitGroup)
|
||||
|
||||
ch := c.ws.Listen()
|
||||
|
||||
// Wait for hello.
|
||||
wsutil.WSDebug("Waiting for Hello..")
|
||||
|
||||
var hello *HelloEvent
|
||||
_, err := wsutil.AssertEvent(<-ch, HelloOP, &hello)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error at Hello")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Received Hello")
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
|
||||
// Turns out Hello is sent right away on connection start.
|
||||
if !c.reconnect.Get() {
|
||||
if err := c.Identify(); err != nil {
|
||||
return errors.Wrap(err, "Failed to identify")
|
||||
}
|
||||
} else {
|
||||
if err := c.Resume(); err != nil {
|
||||
return errors.Wrap(err, "Failed to resume")
|
||||
}
|
||||
}
|
||||
// This bool is because we should only try and Resume once.
|
||||
c.reconnect.Set(false)
|
||||
|
||||
// Wait for either Ready or Resumed.
|
||||
err = wsutil.WaitForEvent(c, ch, func(op *wsutil.OP) bool {
|
||||
return op.Code == ReadyOP || op.Code == ResumedOP
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to wait for Ready or Resumed")
|
||||
}
|
||||
|
||||
// Start the event handler, which also handles the pacemaker death signal.
|
||||
c.waitGroup.Add(1)
|
||||
|
||||
// Start the websocket handler.
|
||||
go c.handleWS(wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, c))
|
||||
|
||||
wsutil.WSDebug("Started successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close .
|
||||
func (c *Gateway) Close() error {
|
||||
// Check if the WS is already closed:
|
||||
if c.waitGroup == nil && c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
|
||||
c.AfterClose(nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the pacemaker is running:
|
||||
if !c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Stopping pacemaker...")
|
||||
|
||||
// Stop the pacemaker and the event handler
|
||||
c.EventLoop.Stop()
|
||||
|
||||
wsutil.WSDebug("Stopped pacemaker.")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||
|
||||
// This should work, since Pacemaker should signal its loop to stop, which
|
||||
// would also exit our event loop. Both would be 2.
|
||||
c.waitGroup.Wait()
|
||||
|
||||
// Mark g.waitGroup as empty:
|
||||
c.waitGroup = nil
|
||||
|
||||
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
|
||||
err := c.ws.Close()
|
||||
c.AfterClose(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Gateway) Reconnect() error {
|
||||
wsutil.WSDebug("Reconnecting...")
|
||||
|
||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||
// redialing anyway.
|
||||
c.Close()
|
||||
|
||||
c.reconnect.Set(true)
|
||||
|
||||
// Condition: err == ErrInvalidSession:
|
||||
// If the connection is rate limited (documented behavior):
|
||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||
|
||||
if err := c.Open(); err != nil {
|
||||
return errors.Wrap(err, "Failed to reopen gateway")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Reconnected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Gateway) SessionDescription(sp SelectProtocol) (*SessionDescriptionEvent, error) {
|
||||
// Add the handler first.
|
||||
ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool {
|
||||
return op.Code == SessionDescriptionOP
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
if err := c.SelectProtocol(sp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sesdesc *SessionDescriptionEvent
|
||||
|
||||
// Wait for SessionDescriptionOP packet.
|
||||
if err := (<-ch).UnmarshalData(&sesdesc); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to unmarshal session description")
|
||||
}
|
||||
|
||||
return sesdesc, nil
|
||||
}
|
||||
|
||||
// handleWS .
|
||||
func (c *Gateway) handleWS(evl *wsutil.PacemakerLoop) {
|
||||
c.EventLoop = evl
|
||||
err := c.EventLoop.Run()
|
||||
|
||||
c.waitGroup.Done() // mark so Close() can exit.
|
||||
wsutil.WSDebug("Event loop stopped.")
|
||||
|
||||
if err != nil {
|
||||
c.ErrorLog(err)
|
||||
c.Reconnect()
|
||||
// Reconnect should spawn another eventLoop in its Start function.
|
||||
}
|
||||
}
|
||||
|
||||
// Send .
|
||||
func (c *Gateway) Send(code OPCode, v interface{}) error {
|
||||
return c.send(code, v)
|
||||
}
|
||||
|
||||
// send .
|
||||
func (c *Gateway) send(code OPCode, v interface{}) error {
|
||||
if c.ws == nil {
|
||||
return errors.New("tried to send data to a connection without a Websocket")
|
||||
}
|
||||
|
||||
if c.ws.Conn == nil {
|
||||
return errors.New("tried to send data to a connection with a closed Websocket")
|
||||
}
|
||||
|
||||
var op = wsutil.OP{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
if v != nil {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode v")
|
||||
}
|
||||
|
||||
op.Data = b
|
||||
}
|
||||
|
||||
b, err := json.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode payload")
|
||||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return c.ws.Send(b)
|
||||
}
|
72
voice/voicegateway/op.go
Normal file
72
voice/voicegateway/op.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package voicegateway
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// OPCode represents a Discord Gateway Gateway operation code.
|
||||
type OPCode = wsutil.OPCode
|
||||
|
||||
const (
|
||||
IdentifyOP OPCode = 0 // send
|
||||
SelectProtocolOP OPCode = 1 // send
|
||||
ReadyOP OPCode = 2 // receive
|
||||
HeartbeatOP OPCode = 3 // send
|
||||
SessionDescriptionOP OPCode = 4 // receive
|
||||
SpeakingOP OPCode = 5 // send/receive
|
||||
HeartbeatAckOP OPCode = 6 // receive
|
||||
ResumeOP OPCode = 7 // send
|
||||
HelloOP OPCode = 8 // receive
|
||||
ResumedOP OPCode = 9 // receive
|
||||
// ClientDisconnectOP OPCode = 13 // receive
|
||||
)
|
||||
|
||||
func (c *Gateway) HandleOP(op *wsutil.OP) error {
|
||||
switch op.Code {
|
||||
// Gives information required to make a UDP connection
|
||||
case ReadyOP:
|
||||
if err := unmarshalMutex(op.Data, &c.ready, &c.mutex); err != nil {
|
||||
return errors.Wrap(err, "Failed to parse READY event")
|
||||
}
|
||||
|
||||
// Gives information about the encryption mode and secret key for sending voice packets
|
||||
case SessionDescriptionOP:
|
||||
// ?
|
||||
// Already handled by Session.
|
||||
|
||||
// Someone started or stopped speaking.
|
||||
case SpeakingOP:
|
||||
// ?
|
||||
// TODO: handler in Session
|
||||
|
||||
// Heartbeat response from the server
|
||||
case HeartbeatAckOP:
|
||||
c.EventLoop.Echo()
|
||||
|
||||
// Hello server, we hear you! :)
|
||||
case HelloOP:
|
||||
// ?
|
||||
// Already handled on initial connection.
|
||||
|
||||
// Server is saying the connection was resumed, no data here.
|
||||
case ResumedOP:
|
||||
wsutil.WSDebug("Gateway connection has been resumed.")
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown OP code %d", op.Code)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalMutex(d []byte, v interface{}, m *sync.RWMutex) error {
|
||||
m.Lock()
|
||||
err := json.Unmarshal(d, v)
|
||||
m.Unlock()
|
||||
return err
|
||||
}
|
Loading…
Reference in a new issue