1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-30 18:53:30 +00:00

Gateway: Migrated functions and variables to other packages, added JSON default codecs

This commit is contained in:
diamondburned (Forefront) 2020-04-23 23:34:08 -07:00
parent c0c17085ba
commit 2f076c041e
9 changed files with 203 additions and 155 deletions

View file

@ -9,7 +9,6 @@ package gateway
import (
"context"
"log"
"net/http"
"net/url"
"sync"
@ -32,23 +31,6 @@ var (
// Compress = "zlib-stream"
)
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
WSTimeout = wsutil.DefaultTimeout
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSExtraReadTimeout is the duration to be added to Hello, as a read
// timeout for the websocket.
WSExtraReadTimeout = time.Second
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}
)
var (
ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
ErrWSMaxTries = errors.New("max tries reached")
@ -95,11 +77,7 @@ func BotURL(token string) (*GatewayBotData, error) {
}
type Gateway struct {
WS *wsutil.Websocket
json.Driver
// Timeout for connecting and writing to the Websocket, uses default
// WSTimeout (global).
WS *wsutil.Websocket
WSTimeout time.Duration
// All events sent over are pointers to Event structs (structs suffixed with
@ -139,11 +117,6 @@ type Gateway struct {
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
// information, refer to NewGatewayWithDriver.
func NewGateway(token string) (*Gateway, error) {
return NewGatewayWithDriver(token, json.Default{})
}
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
URL, err := URL()
if err != nil {
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
@ -158,18 +131,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
// Append the form to the URL
URL += "?" + param.Encode()
return NewCustomGateway(URL, token, driver), nil
return NewCustomGateway(URL, token), nil
}
func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
func NewCustomGateway(gatewayURL, token string) *Gateway {
return &Gateway{
WS: wsutil.NewCustom(wsutil.NewConn(driver), gatewayURL),
Driver: driver,
WSTimeout: WSTimeout,
Events: make(chan Event, WSBuffer),
WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
Events: make(chan Event, wsutil.WSBuffer),
Identifier: DefaultIdentifier(token),
Sequence: NewSequence(),
ErrorLog: WSError,
ErrorLog: wsutil.WSError,
AfterClose: func(error) {},
}
}
@ -178,7 +149,7 @@ func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
func (g *Gateway) Close() error {
// Check if the WS is already closed:
if g.waitGroup == nil && g.paceDeath == nil {
WSDebug("Gateway is already closed.")
wsutil.WSDebug("Gateway is already closed.")
g.AfterClose(nil)
return nil
@ -186,15 +157,15 @@ func (g *Gateway) Close() error {
// If the pacemaker is running:
if g.paceDeath != nil {
WSDebug("Stopping pacemaker...")
wsutil.WSDebug("Stopping pacemaker...")
// Stop the pacemaker and the event handler
g.Pacemaker.Stop()
WSDebug("Stopped pacemaker.")
wsutil.WSDebug("Stopped pacemaker.")
}
WSDebug("Waiting for WaitGroup to be done.")
wsutil.WSDebug("Waiting for WaitGroup to be done.")
// This should work, since Pacemaker should signal its loop to stop, which
// would also exit our event loop. Both would be 2.
@ -203,7 +174,7 @@ func (g *Gateway) Close() error {
// Mark g.waitGroup as empty:
g.waitGroup = nil
WSDebug("WaitGroup is done. Closing the websocket.")
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
err := g.WS.Close()
g.AfterClose(err)
@ -212,42 +183,47 @@ func (g *Gateway) Close() error {
// Reconnect tries to reconnect forever. It will resume the connection if
// possible. If an Invalid Session is received, it will start a fresh one.
func (g *Gateway) Reconnect() {
WSDebug("Reconnecting...")
func (g *Gateway) Reconnect() error {
return g.ReconnectContext(context.Background())
}
func (g *Gateway) ReconnectContext(ctx context.Context) error {
wsutil.WSDebug("Reconnecting...")
// Guarantee the gateway is already closed. Ignore its error, as we're
// redialing anyway.
g.Close()
for i := 1; ; i++ {
WSDebug("Trying to dial, attempt", i)
wsutil.WSDebug("Trying to dial, attempt", i)
// Condition: err == ErrInvalidSession:
// If the connection is rate limited (documented behavior):
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
if err := g.Open(); err != nil {
if err := g.OpenContext(ctx); err != nil {
g.ErrorLog(errors.Wrap(err, "Failed to open gateway"))
continue
}
WSDebug("Started after attempt:", i)
return
wsutil.WSDebug("Started after attempt:", i)
return nil
}
}
// Open connects to the Websocket and authenticate it. You should usually use
// this function over Start().
func (g *Gateway) Open() error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return g.OpenContext(context.Background())
}
func (g *Gateway) OpenContext(ctx context.Context) error {
// Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "Failed to reconnect")
}
WSDebug("Trying to start...")
wsutil.WSDebug("Trying to start...")
// Try to resume the connection
if err := g.Start(); err != nil {
@ -266,12 +242,12 @@ func (g *Gateway) Start() error {
// defer g.available.Unlock()
if err := g.start(); err != nil {
WSDebug("Start failed:", err)
wsutil.WSDebug("Start failed:", err)
// Close can be called with the mutex still acquired here, as the
// pacemaker hasn't started yet.
if err := g.Close(); err != nil {
WSDebug("Failed to close after start fail:", err)
wsutil.WSDebug("Failed to close after start fail:", err)
}
return err
}
@ -293,7 +269,7 @@ func (g *Gateway) start() error {
// Wait for an OP 10 Hello
var hello HelloEvent
if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
if _, err := AssertEvent(<-ch, HelloOP, &hello); err != nil {
return errors.Wrap(err, "Error at Hello")
}
@ -311,16 +287,16 @@ func (g *Gateway) start() error {
}
// Expect either READY or RESUMED before continuing.
WSDebug("Waiting for either READY or RESUMED.")
wsutil.WSDebug("Waiting for either READY or RESUMED.")
// WaitForEvent should
err := WaitForEvent(g, ch, func(op *OP) bool {
switch op.EventName {
case "READY":
WSDebug("Found READY event.")
wsutil.WSDebug("Found READY event.")
return true
case "RESUMED":
WSDebug("Found RESUMED event.")
wsutil.WSDebug("Found RESUMED event.")
return true
}
return false
@ -344,7 +320,7 @@ func (g *Gateway) start() error {
g.waitGroup.Add(1)
go g.handleWS()
WSDebug("Started successfully.")
wsutil.WSDebug("Started successfully.")
return nil
}
@ -353,7 +329,7 @@ func (g *Gateway) start() error {
func (g *Gateway) handleWS() {
err := g.eventLoop()
g.waitGroup.Done() // mark so Close() can exit.
WSDebug("Event loop stopped.")
wsutil.WSDebug("Event loop stopped.")
if err != nil {
g.ErrorLog(err)
@ -372,7 +348,7 @@ func (g *Gateway) eventLoop() error {
g.paceDeath = nil // mark
if err == nil {
WSDebug("Pacemaker stopped without errors.")
wsutil.WSDebug("Pacemaker stopped without errors.")
// No error, just exit normally.
return nil
}
@ -394,7 +370,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
}
if v != nil {
b, err := g.Driver.Marshal(v)
b, err := json.Marshal(v)
if err != nil {
return errors.Wrap(err, "Failed to encode v")
}
@ -402,7 +378,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
op.Data = b
}
b, err := g.Driver.Marshal(op)
b, err := json.Marshal(op)
if err != nil {
return errors.Wrap(err, "Failed to encode payload")
}

View file

@ -38,21 +38,21 @@ type OP struct {
EventName string `json:"t,omitempty"`
}
func DecodeEvent(driver json.Driver, ev wsutil.Event, v interface{}) (OPCode, error) {
op, err := DecodeOP(driver, ev)
func DecodeEvent(ev wsutil.Event, v interface{}) (OPCode, error) {
op, err := DecodeOP(ev)
if err != nil {
return 0, err
}
if err := driver.Unmarshal(op.Data, v); err != nil {
if err := json.Unmarshal(op.Data, v); err != nil {
return 0, errors.Wrap(err, "Failed to decode data")
}
return op.Code, nil
}
func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
op, err := DecodeOP(driver, ev)
func AssertEvent(ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
op, err := DecodeOP(ev)
if err != nil {
return nil, err
}
@ -64,7 +64,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}
)
}
if err := driver.Unmarshal(op.Data, v); err != nil {
if err := json.Unmarshal(op.Data, v); err != nil {
return op, errors.Wrap(err, "Failed to decode data")
}
@ -72,7 +72,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}
}
func HandleEvent(g *Gateway, ev wsutil.Event) error {
o, err := DecodeOP(g.Driver, ev)
o, err := DecodeOP(ev)
if err != nil {
return err
}
@ -84,7 +84,7 @@ func HandleEvent(g *Gateway, ev wsutil.Event) error {
// regardless.
func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
for ev := range ch {
o, err := DecodeOP(g.Driver, ev)
o, err := DecodeOP(ev)
if err != nil {
return err
}
@ -106,7 +106,7 @@ func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
return errors.New("Event not found and event channel is closed.")
}
func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
func DecodeOP(ev wsutil.Event) (*OP, error) {
if ev.Error != nil {
return nil, ev.Error
}
@ -116,7 +116,7 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
}
var op *OP
if err := driver.Unmarshal(ev.Data, &op); err != nil {
if err := json.Unmarshal(ev.Data, &op); err != nil {
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
}
@ -139,7 +139,7 @@ func HandleOP(g *Gateway, op *OP) error {
case ReconnectOP:
// Server requests to reconnect, die and retry.
WSDebug("ReconnectOP received.")
wsutil.WSDebug("ReconnectOP received.")
// We must reconnect in another goroutine, as running Reconnect
// synchronously would prevent the main event loop from exiting.
go g.Reconnect()
@ -177,7 +177,7 @@ func HandleOP(g *Gateway, op *OP) error {
var ev = fn()
// Try and parse the event
if err := g.Driver.Unmarshal(op.Data, ev); err != nil {
if err := json.Unmarshal(op.Data, ev); err != nil {
return errors.Wrap(err, "Failed to parse event "+op.EventName)
}

View file

@ -5,6 +5,7 @@ import (
"sync/atomic"
"time"
"github.com/diamondburned/arikawa/utils/wsutil"
"github.com/pkg/errors"
)
@ -59,9 +60,9 @@ func (p *Pacemaker) Dead() bool {
func (p *Pacemaker) Stop() {
if p.stop != nil {
p.stop <- struct{}{}
WSDebug("(*Pacemaker).stop was sent a stop signal.")
wsutil.WSDebug("(*Pacemaker).stop was sent a stop signal.")
} else {
WSDebug("(*Pacemaker).stop is nil, skipping.")
wsutil.WSDebug("(*Pacemaker).stop is nil, skipping.")
}
}
@ -73,13 +74,13 @@ func (p *Pacemaker) start() error {
p.Echo()
for {
WSDebug("Pacemaker loop restarted.")
wsutil.WSDebug("Pacemaker loop restarted.")
if err := p.Pace(); err != nil {
return err
}
WSDebug("Paced.")
wsutil.WSDebug("Paced.")
// Paced, save:
atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano())
@ -90,11 +91,11 @@ func (p *Pacemaker) start() error {
select {
case <-p.stop:
WSDebug("Received stop signal.")
wsutil.WSDebug("Received stop signal.")
return nil
case <-tick.C:
WSDebug("Ticked. Restarting.")
wsutil.WSDebug("Ticked. Restarting.")
}
}
}
@ -109,7 +110,7 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
go func() {
p.death <- p.start()
// Debug.
WSDebug("Pacemaker returned.")
wsutil.WSDebug("Pacemaker returned.")
// Mark the stop channel as nil, so later Close() calls won't block forever.
p.stop = nil
// Mark the pacemaker loop as done.

View file

@ -41,7 +41,7 @@ func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
func NewClient() *Client {
return &Client{
Client: httpdriver.NewClient(),
Driver: json.Default{},
Driver: json.Default,
SchemaEncoder: &DefaultSchema{},
Retries: Retries,
OnResponse: ResponseNoop,

View file

@ -7,65 +7,6 @@ import (
"io"
)
type (
OptionBool = *bool
OptionString = *string
OptionUint = *uint
OptionInt = *int
)
var (
True = getBool(true)
False = getBool(false)
ZeroUint = Uint(0)
ZeroInt = Int(0)
EmptyString = String("")
)
func Uint(u uint) OptionUint {
return &u
}
func Int(i int) OptionInt {
return &i
}
func String(s string) OptionString {
return &s
}
func getBool(Bool bool) OptionBool {
return &Bool
}
//
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
// from encoding/json.
type Raw []byte
// Raw returns m as the JSON encoding of m.
func (m Raw) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
func (m *Raw) UnmarshalJSON(data []byte) error {
*m = append((*m)[0:0], data...)
return nil
}
func (m Raw) String() string {
return string(m)
}
//
type Driver interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
@ -74,20 +15,43 @@ type Driver interface {
EncodeStream(w io.Writer, v interface{}) error
}
type Default struct{}
type DefaultDriver struct{}
func (d Default) Marshal(v interface{}) ([]byte, error) {
func (d DefaultDriver) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
func (d Default) Unmarshal(data []byte, v interface{}) error {
func (d DefaultDriver) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
func (d Default) DecodeStream(r io.Reader, v interface{}) error {
func (d DefaultDriver) DecodeStream(r io.Reader, v interface{}) error {
return json.NewDecoder(r).Decode(v)
}
func (d Default) EncodeStream(w io.Writer, v interface{}) error {
func (d DefaultDriver) EncodeStream(w io.Writer, v interface{}) error {
return json.NewEncoder(w).Encode(v)
}
// Default is the default JSON driver, which uses encoding/json.
var Default Driver = DefaultDriver{}
// Marshal uses the default driver.
func Marshal(v interface{}) ([]byte, error) {
return Default.Marshal(v)
}
// Unmarshal uses the default driver.
func Unmarshal(data []byte, v interface{}) error {
return Default.Unmarshal(data, v)
}
// DecodeStream uses the default driver.
func DecodeStream(r io.Reader, v interface{}) error {
return Default.DecodeStream(r, v)
}
// EncodeStream uses the default driver.
func EncodeStream(w io.Writer, v interface{}) error {
return Default.EncodeStream(w, v)
}

34
utils/json/option.go Normal file
View file

@ -0,0 +1,34 @@
package json
type (
OptionBool = *bool
OptionString = *string
OptionUint = *uint
OptionInt = *int
)
var (
True = getBool(true)
False = getBool(false)
ZeroUint = Uint(0)
ZeroInt = Int(0)
EmptyString = String("")
)
func Uint(u uint) OptionUint {
return &u
}
func Int(i int) OptionInt {
return &i
}
func String(s string) OptionString {
return &s
}
func getBool(Bool bool) OptionBool {
return &Bool
}

23
utils/json/raw.go Normal file
View file

@ -0,0 +1,23 @@
package json
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
// from encoding/json.
type Raw []byte
// Raw returns m as the JSON encoding of m.
func (m Raw) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
func (m *Raw) UnmarshalJSON(data []byte) error {
*m = append((*m)[0:0], data...)
return nil
}
func (m Raw) String() string {
return string(m)
}

View file

@ -35,7 +35,7 @@ type Connection interface {
Listen() <-chan Event
// Send allows the caller to send bytes. Thread safety is a requirement.
Send([]byte) error
Send(context.Context, []byte) error
// Close should close the websocket connection. The connection will not be
// reused.
@ -67,12 +67,16 @@ type Conn struct {
var _ Connection = (*Conn)(nil)
func NewConn(driver json.Driver) *Conn {
func NewConn() *Conn {
return NewConnWithDriver(json.Default)
}
func NewConnWithDriver(driver json.Driver) *Conn {
return &Conn{
Driver: driver,
dialer: &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: DefaultTimeout,
HandshakeTimeout: WSTimeout,
EnableCompression: true,
},
// zlib: zlib.NewInflator(),
@ -226,14 +230,27 @@ func (c *Conn) handle() ([]byte, error) {
// return nil, errors.New("Unexpected binary message.")
}
func (c *Conn) Send(b []byte) error {
func (c *Conn) Send(ctx context.Context, b []byte) error {
// If websocket is already closed.
if c.writes == nil {
return ErrWebsocketClosed
}
c.writes <- b
return <-c.errors
// Send the bytes.
select {
case c.writes <- b:
// continue
case <-ctx.Done():
return ctx.Err()
}
// Receive the error.
select {
case err := <-c.errors:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) Close() (err error) {

View file

@ -4,15 +4,30 @@ package wsutil
import (
"context"
"log"
"net/url"
"time"
"github.com/diamondburned/arikawa/utils/json"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
const DefaultTimeout = time.Minute
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
WSTimeout = time.Minute
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSExtraReadTimeout is the duration to be added to Hello, as a read
// timeout for the websocket.
WSExtraReadTimeout = time.Second
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}
)
type Event struct {
Data []byte
@ -25,12 +40,16 @@ type Websocket struct {
Conn Connection
Addr string
// Timeout for connecting and writing to the Websocket, uses default
// WSTimeout (global).
Timeout time.Duration
SendLimiter *rate.Limiter
DialLimiter *rate.Limiter
}
func New(addr string) *Websocket {
return NewCustom(NewConn(json.Default{}), addr)
return NewCustom(NewConn(), addr)
}
// NewCustom creates a new undialed Websocket.
@ -39,12 +58,21 @@ func NewCustom(conn Connection, addr string) *Websocket {
Conn: conn,
Addr: addr,
Timeout: WSTimeout,
SendLimiter: NewSendLimiter(),
DialLimiter: NewDialLimiter(),
}
}
func (ws *Websocket) Dial(ctx context.Context) error {
if ws.Timeout > 0 {
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
defer cancel()
ctx = tctx
}
if err := ws.DialLimiter.Wait(ctx); err != nil {
// Expired, fatal error
return errors.Wrap(err, "Failed to wait")
@ -65,11 +93,16 @@ func (ws *Websocket) Listen() <-chan Event {
}
func (ws *Websocket) Send(b []byte) error {
if err := ws.SendLimiter.Wait(context.Background()); err != nil {
return ws.SendContext(context.Background(), b)
}
// SendContext is a beta API.
func (ws *Websocket) SendContext(ctx context.Context, b []byte) error {
if err := ws.SendLimiter.Wait(ctx); err != nil {
return errors.Wrap(err, "SendLimiter failed")
}
return ws.Conn.Send(b)
return ws.Conn.Send(ctx, b)
}
func (ws *Websocket) Close() error {