Gateway: Migrated functions and variables to other packages, added JSON default codecs
This commit is contained in:
parent
c0c17085ba
commit
2f076c041e
|
@ -9,7 +9,6 @@ package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -32,23 +31,6 @@ var (
|
||||||
// Compress = "zlib-stream"
|
// Compress = "zlib-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
|
||||||
// before Gateway cancels and fails.
|
|
||||||
WSTimeout = wsutil.DefaultTimeout
|
|
||||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
|
||||||
// make space for the first Event: Ready or Resumed.
|
|
||||||
WSBuffer = 10
|
|
||||||
// WSError is the default error handler
|
|
||||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
|
||||||
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
|
||||||
// timeout for the websocket.
|
|
||||||
WSExtraReadTimeout = time.Second
|
|
||||||
// WSDebug is used for extra debug logging. This is expected to behave
|
|
||||||
// similarly to log.Println().
|
|
||||||
WSDebug = func(v ...interface{}) {}
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
|
ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
|
||||||
ErrWSMaxTries = errors.New("max tries reached")
|
ErrWSMaxTries = errors.New("max tries reached")
|
||||||
|
@ -96,10 +78,6 @@ func BotURL(token string) (*GatewayBotData, error) {
|
||||||
|
|
||||||
type Gateway struct {
|
type Gateway struct {
|
||||||
WS *wsutil.Websocket
|
WS *wsutil.Websocket
|
||||||
json.Driver
|
|
||||||
|
|
||||||
// Timeout for connecting and writing to the Websocket, uses default
|
|
||||||
// WSTimeout (global).
|
|
||||||
WSTimeout time.Duration
|
WSTimeout time.Duration
|
||||||
|
|
||||||
// All events sent over are pointers to Event structs (structs suffixed with
|
// All events sent over are pointers to Event structs (structs suffixed with
|
||||||
|
@ -139,11 +117,6 @@ type Gateway struct {
|
||||||
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
||||||
// information, refer to NewGatewayWithDriver.
|
// information, refer to NewGatewayWithDriver.
|
||||||
func NewGateway(token string) (*Gateway, error) {
|
func NewGateway(token string) (*Gateway, error) {
|
||||||
return NewGatewayWithDriver(token, json.Default{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
|
|
||||||
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|
||||||
URL, err := URL()
|
URL, err := URL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
|
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
|
||||||
|
@ -158,18 +131,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
||||||
// Append the form to the URL
|
// Append the form to the URL
|
||||||
URL += "?" + param.Encode()
|
URL += "?" + param.Encode()
|
||||||
|
|
||||||
return NewCustomGateway(URL, token, driver), nil
|
return NewCustomGateway(URL, token), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
|
func NewCustomGateway(gatewayURL, token string) *Gateway {
|
||||||
return &Gateway{
|
return &Gateway{
|
||||||
WS: wsutil.NewCustom(wsutil.NewConn(driver), gatewayURL),
|
WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
|
||||||
Driver: driver,
|
Events: make(chan Event, wsutil.WSBuffer),
|
||||||
WSTimeout: WSTimeout,
|
|
||||||
Events: make(chan Event, WSBuffer),
|
|
||||||
Identifier: DefaultIdentifier(token),
|
Identifier: DefaultIdentifier(token),
|
||||||
Sequence: NewSequence(),
|
Sequence: NewSequence(),
|
||||||
ErrorLog: WSError,
|
ErrorLog: wsutil.WSError,
|
||||||
AfterClose: func(error) {},
|
AfterClose: func(error) {},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -178,7 +149,7 @@ func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
|
||||||
func (g *Gateway) Close() error {
|
func (g *Gateway) Close() error {
|
||||||
// Check if the WS is already closed:
|
// Check if the WS is already closed:
|
||||||
if g.waitGroup == nil && g.paceDeath == nil {
|
if g.waitGroup == nil && g.paceDeath == nil {
|
||||||
WSDebug("Gateway is already closed.")
|
wsutil.WSDebug("Gateway is already closed.")
|
||||||
|
|
||||||
g.AfterClose(nil)
|
g.AfterClose(nil)
|
||||||
return nil
|
return nil
|
||||||
|
@ -186,15 +157,15 @@ func (g *Gateway) Close() error {
|
||||||
|
|
||||||
// If the pacemaker is running:
|
// If the pacemaker is running:
|
||||||
if g.paceDeath != nil {
|
if g.paceDeath != nil {
|
||||||
WSDebug("Stopping pacemaker...")
|
wsutil.WSDebug("Stopping pacemaker...")
|
||||||
|
|
||||||
// Stop the pacemaker and the event handler
|
// Stop the pacemaker and the event handler
|
||||||
g.Pacemaker.Stop()
|
g.Pacemaker.Stop()
|
||||||
|
|
||||||
WSDebug("Stopped pacemaker.")
|
wsutil.WSDebug("Stopped pacemaker.")
|
||||||
}
|
}
|
||||||
|
|
||||||
WSDebug("Waiting for WaitGroup to be done.")
|
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||||
|
|
||||||
// This should work, since Pacemaker should signal its loop to stop, which
|
// This should work, since Pacemaker should signal its loop to stop, which
|
||||||
// would also exit our event loop. Both would be 2.
|
// would also exit our event loop. Both would be 2.
|
||||||
|
@ -203,7 +174,7 @@ func (g *Gateway) Close() error {
|
||||||
// Mark g.waitGroup as empty:
|
// Mark g.waitGroup as empty:
|
||||||
g.waitGroup = nil
|
g.waitGroup = nil
|
||||||
|
|
||||||
WSDebug("WaitGroup is done. Closing the websocket.")
|
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||||
|
|
||||||
err := g.WS.Close()
|
err := g.WS.Close()
|
||||||
g.AfterClose(err)
|
g.AfterClose(err)
|
||||||
|
@ -212,42 +183,47 @@ func (g *Gateway) Close() error {
|
||||||
|
|
||||||
// Reconnect tries to reconnect forever. It will resume the connection if
|
// Reconnect tries to reconnect forever. It will resume the connection if
|
||||||
// possible. If an Invalid Session is received, it will start a fresh one.
|
// possible. If an Invalid Session is received, it will start a fresh one.
|
||||||
func (g *Gateway) Reconnect() {
|
func (g *Gateway) Reconnect() error {
|
||||||
WSDebug("Reconnecting...")
|
return g.ReconnectContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) ReconnectContext(ctx context.Context) error {
|
||||||
|
wsutil.WSDebug("Reconnecting...")
|
||||||
|
|
||||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||||
// redialing anyway.
|
// redialing anyway.
|
||||||
g.Close()
|
g.Close()
|
||||||
|
|
||||||
for i := 1; ; i++ {
|
for i := 1; ; i++ {
|
||||||
WSDebug("Trying to dial, attempt", i)
|
wsutil.WSDebug("Trying to dial, attempt", i)
|
||||||
|
|
||||||
// Condition: err == ErrInvalidSession:
|
// Condition: err == ErrInvalidSession:
|
||||||
// If the connection is rate limited (documented behavior):
|
// If the connection is rate limited (documented behavior):
|
||||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||||
|
|
||||||
if err := g.Open(); err != nil {
|
if err := g.OpenContext(ctx); err != nil {
|
||||||
g.ErrorLog(errors.Wrap(err, "Failed to open gateway"))
|
g.ErrorLog(errors.Wrap(err, "Failed to open gateway"))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
WSDebug("Started after attempt:", i)
|
wsutil.WSDebug("Started after attempt:", i)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open connects to the Websocket and authenticate it. You should usually use
|
// Open connects to the Websocket and authenticate it. You should usually use
|
||||||
// this function over Start().
|
// this function over Start().
|
||||||
func (g *Gateway) Open() error {
|
func (g *Gateway) Open() error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
return g.OpenContext(context.Background())
|
||||||
defer cancel()
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) OpenContext(ctx context.Context) error {
|
||||||
// Reconnect to the Gateway
|
// Reconnect to the Gateway
|
||||||
if err := g.WS.Dial(ctx); err != nil {
|
if err := g.WS.Dial(ctx); err != nil {
|
||||||
return errors.Wrap(err, "Failed to reconnect")
|
return errors.Wrap(err, "Failed to reconnect")
|
||||||
}
|
}
|
||||||
|
|
||||||
WSDebug("Trying to start...")
|
wsutil.WSDebug("Trying to start...")
|
||||||
|
|
||||||
// Try to resume the connection
|
// Try to resume the connection
|
||||||
if err := g.Start(); err != nil {
|
if err := g.Start(); err != nil {
|
||||||
|
@ -266,12 +242,12 @@ func (g *Gateway) Start() error {
|
||||||
// defer g.available.Unlock()
|
// defer g.available.Unlock()
|
||||||
|
|
||||||
if err := g.start(); err != nil {
|
if err := g.start(); err != nil {
|
||||||
WSDebug("Start failed:", err)
|
wsutil.WSDebug("Start failed:", err)
|
||||||
|
|
||||||
// Close can be called with the mutex still acquired here, as the
|
// Close can be called with the mutex still acquired here, as the
|
||||||
// pacemaker hasn't started yet.
|
// pacemaker hasn't started yet.
|
||||||
if err := g.Close(); err != nil {
|
if err := g.Close(); err != nil {
|
||||||
WSDebug("Failed to close after start fail:", err)
|
wsutil.WSDebug("Failed to close after start fail:", err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -293,7 +269,7 @@ func (g *Gateway) start() error {
|
||||||
|
|
||||||
// Wait for an OP 10 Hello
|
// Wait for an OP 10 Hello
|
||||||
var hello HelloEvent
|
var hello HelloEvent
|
||||||
if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
|
if _, err := AssertEvent(<-ch, HelloOP, &hello); err != nil {
|
||||||
return errors.Wrap(err, "Error at Hello")
|
return errors.Wrap(err, "Error at Hello")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,16 +287,16 @@ func (g *Gateway) start() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expect either READY or RESUMED before continuing.
|
// Expect either READY or RESUMED before continuing.
|
||||||
WSDebug("Waiting for either READY or RESUMED.")
|
wsutil.WSDebug("Waiting for either READY or RESUMED.")
|
||||||
|
|
||||||
// WaitForEvent should
|
// WaitForEvent should
|
||||||
err := WaitForEvent(g, ch, func(op *OP) bool {
|
err := WaitForEvent(g, ch, func(op *OP) bool {
|
||||||
switch op.EventName {
|
switch op.EventName {
|
||||||
case "READY":
|
case "READY":
|
||||||
WSDebug("Found READY event.")
|
wsutil.WSDebug("Found READY event.")
|
||||||
return true
|
return true
|
||||||
case "RESUMED":
|
case "RESUMED":
|
||||||
WSDebug("Found RESUMED event.")
|
wsutil.WSDebug("Found RESUMED event.")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -344,7 +320,7 @@ func (g *Gateway) start() error {
|
||||||
g.waitGroup.Add(1)
|
g.waitGroup.Add(1)
|
||||||
go g.handleWS()
|
go g.handleWS()
|
||||||
|
|
||||||
WSDebug("Started successfully.")
|
wsutil.WSDebug("Started successfully.")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -353,7 +329,7 @@ func (g *Gateway) start() error {
|
||||||
func (g *Gateway) handleWS() {
|
func (g *Gateway) handleWS() {
|
||||||
err := g.eventLoop()
|
err := g.eventLoop()
|
||||||
g.waitGroup.Done() // mark so Close() can exit.
|
g.waitGroup.Done() // mark so Close() can exit.
|
||||||
WSDebug("Event loop stopped.")
|
wsutil.WSDebug("Event loop stopped.")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.ErrorLog(err)
|
g.ErrorLog(err)
|
||||||
|
@ -372,7 +348,7 @@ func (g *Gateway) eventLoop() error {
|
||||||
g.paceDeath = nil // mark
|
g.paceDeath = nil // mark
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
WSDebug("Pacemaker stopped without errors.")
|
wsutil.WSDebug("Pacemaker stopped without errors.")
|
||||||
// No error, just exit normally.
|
// No error, just exit normally.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -394,7 +370,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if v != nil {
|
if v != nil {
|
||||||
b, err := g.Driver.Marshal(v)
|
b, err := json.Marshal(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Failed to encode v")
|
return errors.Wrap(err, "Failed to encode v")
|
||||||
}
|
}
|
||||||
|
@ -402,7 +378,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
||||||
op.Data = b
|
op.Data = b
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := g.Driver.Marshal(op)
|
b, err := json.Marshal(op)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Failed to encode payload")
|
return errors.Wrap(err, "Failed to encode payload")
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,21 +38,21 @@ type OP struct {
|
||||||
EventName string `json:"t,omitempty"`
|
EventName string `json:"t,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeEvent(driver json.Driver, ev wsutil.Event, v interface{}) (OPCode, error) {
|
func DecodeEvent(ev wsutil.Event, v interface{}) (OPCode, error) {
|
||||||
op, err := DecodeOP(driver, ev)
|
op, err := DecodeOP(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
if err := json.Unmarshal(op.Data, v); err != nil {
|
||||||
return 0, errors.Wrap(err, "Failed to decode data")
|
return 0, errors.Wrap(err, "Failed to decode data")
|
||||||
}
|
}
|
||||||
|
|
||||||
return op.Code, nil
|
return op.Code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
|
func AssertEvent(ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
|
||||||
op, err := DecodeOP(driver, ev)
|
op, err := DecodeOP(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
if err := json.Unmarshal(op.Data, v); err != nil {
|
||||||
return op, errors.Wrap(err, "Failed to decode data")
|
return op, errors.Wrap(err, "Failed to decode data")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleEvent(g *Gateway, ev wsutil.Event) error {
|
func HandleEvent(g *Gateway, ev wsutil.Event) error {
|
||||||
o, err := DecodeOP(g.Driver, ev)
|
o, err := DecodeOP(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ func HandleEvent(g *Gateway, ev wsutil.Event) error {
|
||||||
// regardless.
|
// regardless.
|
||||||
func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
|
func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
|
||||||
for ev := range ch {
|
for ev := range ch {
|
||||||
o, err := DecodeOP(g.Driver, ev)
|
o, err := DecodeOP(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -106,7 +106,7 @@ func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
|
||||||
return errors.New("Event not found and event channel is closed.")
|
return errors.New("Event not found and event channel is closed.")
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
func DecodeOP(ev wsutil.Event) (*OP, error) {
|
||||||
if ev.Error != nil {
|
if ev.Error != nil {
|
||||||
return nil, ev.Error
|
return nil, ev.Error
|
||||||
}
|
}
|
||||||
|
@ -116,7 +116,7 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var op *OP
|
var op *OP
|
||||||
if err := driver.Unmarshal(ev.Data, &op); err != nil {
|
if err := json.Unmarshal(ev.Data, &op); err != nil {
|
||||||
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
|
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ func HandleOP(g *Gateway, op *OP) error {
|
||||||
|
|
||||||
case ReconnectOP:
|
case ReconnectOP:
|
||||||
// Server requests to reconnect, die and retry.
|
// Server requests to reconnect, die and retry.
|
||||||
WSDebug("ReconnectOP received.")
|
wsutil.WSDebug("ReconnectOP received.")
|
||||||
// We must reconnect in another goroutine, as running Reconnect
|
// We must reconnect in another goroutine, as running Reconnect
|
||||||
// synchronously would prevent the main event loop from exiting.
|
// synchronously would prevent the main event loop from exiting.
|
||||||
go g.Reconnect()
|
go g.Reconnect()
|
||||||
|
@ -177,7 +177,7 @@ func HandleOP(g *Gateway, op *OP) error {
|
||||||
var ev = fn()
|
var ev = fn()
|
||||||
|
|
||||||
// Try and parse the event
|
// Try and parse the event
|
||||||
if err := g.Driver.Unmarshal(op.Data, ev); err != nil {
|
if err := json.Unmarshal(op.Data, ev); err != nil {
|
||||||
return errors.Wrap(err, "Failed to parse event "+op.EventName)
|
return errors.Wrap(err, "Failed to parse event "+op.EventName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,9 +60,9 @@ func (p *Pacemaker) Dead() bool {
|
||||||
func (p *Pacemaker) Stop() {
|
func (p *Pacemaker) Stop() {
|
||||||
if p.stop != nil {
|
if p.stop != nil {
|
||||||
p.stop <- struct{}{}
|
p.stop <- struct{}{}
|
||||||
WSDebug("(*Pacemaker).stop was sent a stop signal.")
|
wsutil.WSDebug("(*Pacemaker).stop was sent a stop signal.")
|
||||||
} else {
|
} else {
|
||||||
WSDebug("(*Pacemaker).stop is nil, skipping.")
|
wsutil.WSDebug("(*Pacemaker).stop is nil, skipping.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,13 +74,13 @@ func (p *Pacemaker) start() error {
|
||||||
p.Echo()
|
p.Echo()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
WSDebug("Pacemaker loop restarted.")
|
wsutil.WSDebug("Pacemaker loop restarted.")
|
||||||
|
|
||||||
if err := p.Pace(); err != nil {
|
if err := p.Pace(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
WSDebug("Paced.")
|
wsutil.WSDebug("Paced.")
|
||||||
|
|
||||||
// Paced, save:
|
// Paced, save:
|
||||||
atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano())
|
atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano())
|
||||||
|
@ -90,11 +91,11 @@ func (p *Pacemaker) start() error {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-p.stop:
|
case <-p.stop:
|
||||||
WSDebug("Received stop signal.")
|
wsutil.WSDebug("Received stop signal.")
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case <-tick.C:
|
case <-tick.C:
|
||||||
WSDebug("Ticked. Restarting.")
|
wsutil.WSDebug("Ticked. Restarting.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,7 +110,7 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
||||||
go func() {
|
go func() {
|
||||||
p.death <- p.start()
|
p.death <- p.start()
|
||||||
// Debug.
|
// Debug.
|
||||||
WSDebug("Pacemaker returned.")
|
wsutil.WSDebug("Pacemaker returned.")
|
||||||
// Mark the stop channel as nil, so later Close() calls won't block forever.
|
// Mark the stop channel as nil, so later Close() calls won't block forever.
|
||||||
p.stop = nil
|
p.stop = nil
|
||||||
// Mark the pacemaker loop as done.
|
// Mark the pacemaker loop as done.
|
||||||
|
|
|
@ -41,7 +41,7 @@ func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
|
||||||
func NewClient() *Client {
|
func NewClient() *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
Client: httpdriver.NewClient(),
|
Client: httpdriver.NewClient(),
|
||||||
Driver: json.Default{},
|
Driver: json.Default,
|
||||||
SchemaEncoder: &DefaultSchema{},
|
SchemaEncoder: &DefaultSchema{},
|
||||||
Retries: Retries,
|
Retries: Retries,
|
||||||
OnResponse: ResponseNoop,
|
OnResponse: ResponseNoop,
|
||||||
|
|
|
@ -7,65 +7,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
|
||||||
OptionBool = *bool
|
|
||||||
OptionString = *string
|
|
||||||
OptionUint = *uint
|
|
||||||
OptionInt = *int
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
True = getBool(true)
|
|
||||||
False = getBool(false)
|
|
||||||
|
|
||||||
ZeroUint = Uint(0)
|
|
||||||
ZeroInt = Int(0)
|
|
||||||
|
|
||||||
EmptyString = String("")
|
|
||||||
)
|
|
||||||
|
|
||||||
func Uint(u uint) OptionUint {
|
|
||||||
return &u
|
|
||||||
}
|
|
||||||
|
|
||||||
func Int(i int) OptionInt {
|
|
||||||
return &i
|
|
||||||
}
|
|
||||||
|
|
||||||
func String(s string) OptionString {
|
|
||||||
return &s
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBool(Bool bool) OptionBool {
|
|
||||||
return &Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
|
|
||||||
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
|
|
||||||
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
|
|
||||||
// from encoding/json.
|
|
||||||
type Raw []byte
|
|
||||||
|
|
||||||
// Raw returns m as the JSON encoding of m.
|
|
||||||
func (m Raw) MarshalJSON() ([]byte, error) {
|
|
||||||
if m == nil {
|
|
||||||
return []byte("null"), nil
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Raw) UnmarshalJSON(data []byte) error {
|
|
||||||
*m = append((*m)[0:0], data...)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Raw) String() string {
|
|
||||||
return string(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
|
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
Marshal(v interface{}) ([]byte, error)
|
Marshal(v interface{}) ([]byte, error)
|
||||||
Unmarshal(data []byte, v interface{}) error
|
Unmarshal(data []byte, v interface{}) error
|
||||||
|
@ -74,20 +15,43 @@ type Driver interface {
|
||||||
EncodeStream(w io.Writer, v interface{}) error
|
EncodeStream(w io.Writer, v interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Default struct{}
|
type DefaultDriver struct{}
|
||||||
|
|
||||||
func (d Default) Marshal(v interface{}) ([]byte, error) {
|
func (d DefaultDriver) Marshal(v interface{}) ([]byte, error) {
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Default) Unmarshal(data []byte, v interface{}) error {
|
func (d DefaultDriver) Unmarshal(data []byte, v interface{}) error {
|
||||||
return json.Unmarshal(data, v)
|
return json.Unmarshal(data, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Default) DecodeStream(r io.Reader, v interface{}) error {
|
func (d DefaultDriver) DecodeStream(r io.Reader, v interface{}) error {
|
||||||
return json.NewDecoder(r).Decode(v)
|
return json.NewDecoder(r).Decode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Default) EncodeStream(w io.Writer, v interface{}) error {
|
func (d DefaultDriver) EncodeStream(w io.Writer, v interface{}) error {
|
||||||
return json.NewEncoder(w).Encode(v)
|
return json.NewEncoder(w).Encode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default is the default JSON driver, which uses encoding/json.
|
||||||
|
var Default Driver = DefaultDriver{}
|
||||||
|
|
||||||
|
// Marshal uses the default driver.
|
||||||
|
func Marshal(v interface{}) ([]byte, error) {
|
||||||
|
return Default.Marshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal uses the default driver.
|
||||||
|
func Unmarshal(data []byte, v interface{}) error {
|
||||||
|
return Default.Unmarshal(data, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeStream uses the default driver.
|
||||||
|
func DecodeStream(r io.Reader, v interface{}) error {
|
||||||
|
return Default.DecodeStream(r, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeStream uses the default driver.
|
||||||
|
func EncodeStream(w io.Writer, v interface{}) error {
|
||||||
|
return Default.EncodeStream(w, v)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
package json
|
||||||
|
|
||||||
|
type (
|
||||||
|
OptionBool = *bool
|
||||||
|
OptionString = *string
|
||||||
|
OptionUint = *uint
|
||||||
|
OptionInt = *int
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
True = getBool(true)
|
||||||
|
False = getBool(false)
|
||||||
|
|
||||||
|
ZeroUint = Uint(0)
|
||||||
|
ZeroInt = Int(0)
|
||||||
|
|
||||||
|
EmptyString = String("")
|
||||||
|
)
|
||||||
|
|
||||||
|
func Uint(u uint) OptionUint {
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
func Int(i int) OptionInt {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|
||||||
|
func String(s string) OptionString {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBool(Bool bool) OptionBool {
|
||||||
|
return &Bool
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package json
|
||||||
|
|
||||||
|
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
|
||||||
|
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
|
||||||
|
// from encoding/json.
|
||||||
|
type Raw []byte
|
||||||
|
|
||||||
|
// Raw returns m as the JSON encoding of m.
|
||||||
|
func (m Raw) MarshalJSON() ([]byte, error) {
|
||||||
|
if m == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Raw) UnmarshalJSON(data []byte) error {
|
||||||
|
*m = append((*m)[0:0], data...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Raw) String() string {
|
||||||
|
return string(m)
|
||||||
|
}
|
|
@ -35,7 +35,7 @@ type Connection interface {
|
||||||
Listen() <-chan Event
|
Listen() <-chan Event
|
||||||
|
|
||||||
// Send allows the caller to send bytes. Thread safety is a requirement.
|
// Send allows the caller to send bytes. Thread safety is a requirement.
|
||||||
Send([]byte) error
|
Send(context.Context, []byte) error
|
||||||
|
|
||||||
// Close should close the websocket connection. The connection will not be
|
// Close should close the websocket connection. The connection will not be
|
||||||
// reused.
|
// reused.
|
||||||
|
@ -67,12 +67,16 @@ type Conn struct {
|
||||||
|
|
||||||
var _ Connection = (*Conn)(nil)
|
var _ Connection = (*Conn)(nil)
|
||||||
|
|
||||||
func NewConn(driver json.Driver) *Conn {
|
func NewConn() *Conn {
|
||||||
|
return NewConnWithDriver(json.Default)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnWithDriver(driver json.Driver) *Conn {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
Driver: driver,
|
Driver: driver,
|
||||||
dialer: &websocket.Dialer{
|
dialer: &websocket.Dialer{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
HandshakeTimeout: DefaultTimeout,
|
HandshakeTimeout: WSTimeout,
|
||||||
EnableCompression: true,
|
EnableCompression: true,
|
||||||
},
|
},
|
||||||
// zlib: zlib.NewInflator(),
|
// zlib: zlib.NewInflator(),
|
||||||
|
@ -226,14 +230,27 @@ func (c *Conn) handle() ([]byte, error) {
|
||||||
// return nil, errors.New("Unexpected binary message.")
|
// return nil, errors.New("Unexpected binary message.")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Send(b []byte) error {
|
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
// If websocket is already closed.
|
// If websocket is already closed.
|
||||||
if c.writes == nil {
|
if c.writes == nil {
|
||||||
return ErrWebsocketClosed
|
return ErrWebsocketClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writes <- b
|
// Send the bytes.
|
||||||
return <-c.errors
|
select {
|
||||||
|
case c.writes <- b:
|
||||||
|
// continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive the error.
|
||||||
|
select {
|
||||||
|
case err := <-c.errors:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() (err error) {
|
func (c *Conn) Close() (err error) {
|
||||||
|
|
|
@ -4,15 +4,30 @@ package wsutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/utils/json"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultTimeout = time.Minute
|
var (
|
||||||
|
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||||
|
// before Gateway cancels and fails.
|
||||||
|
WSTimeout = time.Minute
|
||||||
|
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||||
|
// make space for the first Event: Ready or Resumed.
|
||||||
|
WSBuffer = 10
|
||||||
|
// WSError is the default error handler
|
||||||
|
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||||
|
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
||||||
|
// timeout for the websocket.
|
||||||
|
WSExtraReadTimeout = time.Second
|
||||||
|
// WSDebug is used for extra debug logging. This is expected to behave
|
||||||
|
// similarly to log.Println().
|
||||||
|
WSDebug = func(v ...interface{}) {}
|
||||||
|
)
|
||||||
|
|
||||||
type Event struct {
|
type Event struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
|
@ -25,12 +40,16 @@ type Websocket struct {
|
||||||
Conn Connection
|
Conn Connection
|
||||||
Addr string
|
Addr string
|
||||||
|
|
||||||
|
// Timeout for connecting and writing to the Websocket, uses default
|
||||||
|
// WSTimeout (global).
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
SendLimiter *rate.Limiter
|
SendLimiter *rate.Limiter
|
||||||
DialLimiter *rate.Limiter
|
DialLimiter *rate.Limiter
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(addr string) *Websocket {
|
func New(addr string) *Websocket {
|
||||||
return NewCustom(NewConn(json.Default{}), addr)
|
return NewCustom(NewConn(), addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCustom creates a new undialed Websocket.
|
// NewCustom creates a new undialed Websocket.
|
||||||
|
@ -39,12 +58,21 @@ func NewCustom(conn Connection, addr string) *Websocket {
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
|
|
||||||
|
Timeout: WSTimeout,
|
||||||
|
|
||||||
SendLimiter: NewSendLimiter(),
|
SendLimiter: NewSendLimiter(),
|
||||||
DialLimiter: NewDialLimiter(),
|
DialLimiter: NewDialLimiter(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *Websocket) Dial(ctx context.Context) error {
|
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||||
|
if ws.Timeout > 0 {
|
||||||
|
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx = tctx
|
||||||
|
}
|
||||||
|
|
||||||
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
||||||
// Expired, fatal error
|
// Expired, fatal error
|
||||||
return errors.Wrap(err, "Failed to wait")
|
return errors.Wrap(err, "Failed to wait")
|
||||||
|
@ -65,11 +93,16 @@ func (ws *Websocket) Listen() <-chan Event {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *Websocket) Send(b []byte) error {
|
func (ws *Websocket) Send(b []byte) error {
|
||||||
if err := ws.SendLimiter.Wait(context.Background()); err != nil {
|
return ws.SendContext(context.Background(), b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendContext is a beta API.
|
||||||
|
func (ws *Websocket) SendContext(ctx context.Context, b []byte) error {
|
||||||
|
if err := ws.SendLimiter.Wait(ctx); err != nil {
|
||||||
return errors.Wrap(err, "SendLimiter failed")
|
return errors.Wrap(err, "SendLimiter failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
return ws.Conn.Send(b)
|
return ws.Conn.Send(ctx, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *Websocket) Close() error {
|
func (ws *Websocket) Close() error {
|
||||||
|
|
Loading…
Reference in New Issue