mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-30 18:53:30 +00:00
Gateway: Migrated functions and variables to other packages, added JSON default codecs
This commit is contained in:
parent
c0c17085ba
commit
2f076c041e
|
@ -9,7 +9,6 @@ package gateway
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
@ -32,23 +31,6 @@ var (
|
|||
// Compress = "zlib-stream"
|
||||
)
|
||||
|
||||
var (
|
||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||
// before Gateway cancels and fails.
|
||||
WSTimeout = wsutil.DefaultTimeout
|
||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
||||
// timeout for the websocket.
|
||||
WSExtraReadTimeout = time.Second
|
||||
// WSDebug is used for extra debug logging. This is expected to behave
|
||||
// similarly to log.Println().
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
|
||||
ErrWSMaxTries = errors.New("max tries reached")
|
||||
|
@ -95,11 +77,7 @@ func BotURL(token string) (*GatewayBotData, error) {
|
|||
}
|
||||
|
||||
type Gateway struct {
|
||||
WS *wsutil.Websocket
|
||||
json.Driver
|
||||
|
||||
// Timeout for connecting and writing to the Websocket, uses default
|
||||
// WSTimeout (global).
|
||||
WS *wsutil.Websocket
|
||||
WSTimeout time.Duration
|
||||
|
||||
// All events sent over are pointers to Event structs (structs suffixed with
|
||||
|
@ -139,11 +117,6 @@ type Gateway struct {
|
|||
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
||||
// information, refer to NewGatewayWithDriver.
|
||||
func NewGateway(token string) (*Gateway, error) {
|
||||
return NewGatewayWithDriver(token, json.Default{})
|
||||
}
|
||||
|
||||
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
|
||||
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
||||
URL, err := URL()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
|
||||
|
@ -158,18 +131,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|||
// Append the form to the URL
|
||||
URL += "?" + param.Encode()
|
||||
|
||||
return NewCustomGateway(URL, token, driver), nil
|
||||
return NewCustomGateway(URL, token), nil
|
||||
}
|
||||
|
||||
func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
|
||||
func NewCustomGateway(gatewayURL, token string) *Gateway {
|
||||
return &Gateway{
|
||||
WS: wsutil.NewCustom(wsutil.NewConn(driver), gatewayURL),
|
||||
Driver: driver,
|
||||
WSTimeout: WSTimeout,
|
||||
Events: make(chan Event, WSBuffer),
|
||||
WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
|
||||
Events: make(chan Event, wsutil.WSBuffer),
|
||||
Identifier: DefaultIdentifier(token),
|
||||
Sequence: NewSequence(),
|
||||
ErrorLog: WSError,
|
||||
ErrorLog: wsutil.WSError,
|
||||
AfterClose: func(error) {},
|
||||
}
|
||||
}
|
||||
|
@ -178,7 +149,7 @@ func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway {
|
|||
func (g *Gateway) Close() error {
|
||||
// Check if the WS is already closed:
|
||||
if g.waitGroup == nil && g.paceDeath == nil {
|
||||
WSDebug("Gateway is already closed.")
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
|
||||
g.AfterClose(nil)
|
||||
return nil
|
||||
|
@ -186,15 +157,15 @@ func (g *Gateway) Close() error {
|
|||
|
||||
// If the pacemaker is running:
|
||||
if g.paceDeath != nil {
|
||||
WSDebug("Stopping pacemaker...")
|
||||
wsutil.WSDebug("Stopping pacemaker...")
|
||||
|
||||
// Stop the pacemaker and the event handler
|
||||
g.Pacemaker.Stop()
|
||||
|
||||
WSDebug("Stopped pacemaker.")
|
||||
wsutil.WSDebug("Stopped pacemaker.")
|
||||
}
|
||||
|
||||
WSDebug("Waiting for WaitGroup to be done.")
|
||||
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||
|
||||
// This should work, since Pacemaker should signal its loop to stop, which
|
||||
// would also exit our event loop. Both would be 2.
|
||||
|
@ -203,7 +174,7 @@ func (g *Gateway) Close() error {
|
|||
// Mark g.waitGroup as empty:
|
||||
g.waitGroup = nil
|
||||
|
||||
WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
|
||||
err := g.WS.Close()
|
||||
g.AfterClose(err)
|
||||
|
@ -212,42 +183,47 @@ func (g *Gateway) Close() error {
|
|||
|
||||
// Reconnect tries to reconnect forever. It will resume the connection if
|
||||
// possible. If an Invalid Session is received, it will start a fresh one.
|
||||
func (g *Gateway) Reconnect() {
|
||||
WSDebug("Reconnecting...")
|
||||
func (g *Gateway) Reconnect() error {
|
||||
return g.ReconnectContext(context.Background())
|
||||
}
|
||||
|
||||
func (g *Gateway) ReconnectContext(ctx context.Context) error {
|
||||
wsutil.WSDebug("Reconnecting...")
|
||||
|
||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||
// redialing anyway.
|
||||
g.Close()
|
||||
|
||||
for i := 1; ; i++ {
|
||||
WSDebug("Trying to dial, attempt", i)
|
||||
wsutil.WSDebug("Trying to dial, attempt", i)
|
||||
|
||||
// Condition: err == ErrInvalidSession:
|
||||
// If the connection is rate limited (documented behavior):
|
||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||
|
||||
if err := g.Open(); err != nil {
|
||||
if err := g.OpenContext(ctx); err != nil {
|
||||
g.ErrorLog(errors.Wrap(err, "Failed to open gateway"))
|
||||
continue
|
||||
}
|
||||
|
||||
WSDebug("Started after attempt:", i)
|
||||
return
|
||||
wsutil.WSDebug("Started after attempt:", i)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Open connects to the Websocket and authenticate it. You should usually use
|
||||
// this function over Start().
|
||||
func (g *Gateway) Open() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
return g.OpenContext(context.Background())
|
||||
}
|
||||
|
||||
func (g *Gateway) OpenContext(ctx context.Context) error {
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "Failed to reconnect")
|
||||
}
|
||||
|
||||
WSDebug("Trying to start...")
|
||||
wsutil.WSDebug("Trying to start...")
|
||||
|
||||
// Try to resume the connection
|
||||
if err := g.Start(); err != nil {
|
||||
|
@ -266,12 +242,12 @@ func (g *Gateway) Start() error {
|
|||
// defer g.available.Unlock()
|
||||
|
||||
if err := g.start(); err != nil {
|
||||
WSDebug("Start failed:", err)
|
||||
wsutil.WSDebug("Start failed:", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
// pacemaker hasn't started yet.
|
||||
if err := g.Close(); err != nil {
|
||||
WSDebug("Failed to close after start fail:", err)
|
||||
wsutil.WSDebug("Failed to close after start fail:", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -293,7 +269,7 @@ func (g *Gateway) start() error {
|
|||
|
||||
// Wait for an OP 10 Hello
|
||||
var hello HelloEvent
|
||||
if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
|
||||
if _, err := AssertEvent(<-ch, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "Error at Hello")
|
||||
}
|
||||
|
||||
|
@ -311,16 +287,16 @@ func (g *Gateway) start() error {
|
|||
}
|
||||
|
||||
// Expect either READY or RESUMED before continuing.
|
||||
WSDebug("Waiting for either READY or RESUMED.")
|
||||
wsutil.WSDebug("Waiting for either READY or RESUMED.")
|
||||
|
||||
// WaitForEvent should
|
||||
err := WaitForEvent(g, ch, func(op *OP) bool {
|
||||
switch op.EventName {
|
||||
case "READY":
|
||||
WSDebug("Found READY event.")
|
||||
wsutil.WSDebug("Found READY event.")
|
||||
return true
|
||||
case "RESUMED":
|
||||
WSDebug("Found RESUMED event.")
|
||||
wsutil.WSDebug("Found RESUMED event.")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -344,7 +320,7 @@ func (g *Gateway) start() error {
|
|||
g.waitGroup.Add(1)
|
||||
go g.handleWS()
|
||||
|
||||
WSDebug("Started successfully.")
|
||||
wsutil.WSDebug("Started successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -353,7 +329,7 @@ func (g *Gateway) start() error {
|
|||
func (g *Gateway) handleWS() {
|
||||
err := g.eventLoop()
|
||||
g.waitGroup.Done() // mark so Close() can exit.
|
||||
WSDebug("Event loop stopped.")
|
||||
wsutil.WSDebug("Event loop stopped.")
|
||||
|
||||
if err != nil {
|
||||
g.ErrorLog(err)
|
||||
|
@ -372,7 +348,7 @@ func (g *Gateway) eventLoop() error {
|
|||
g.paceDeath = nil // mark
|
||||
|
||||
if err == nil {
|
||||
WSDebug("Pacemaker stopped without errors.")
|
||||
wsutil.WSDebug("Pacemaker stopped without errors.")
|
||||
// No error, just exit normally.
|
||||
return nil
|
||||
}
|
||||
|
@ -394,7 +370,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
|||
}
|
||||
|
||||
if v != nil {
|
||||
b, err := g.Driver.Marshal(v)
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode v")
|
||||
}
|
||||
|
@ -402,7 +378,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
|||
op.Data = b
|
||||
}
|
||||
|
||||
b, err := g.Driver.Marshal(op)
|
||||
b, err := json.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode payload")
|
||||
}
|
||||
|
|
|
@ -38,21 +38,21 @@ type OP struct {
|
|||
EventName string `json:"t,omitempty"`
|
||||
}
|
||||
|
||||
func DecodeEvent(driver json.Driver, ev wsutil.Event, v interface{}) (OPCode, error) {
|
||||
op, err := DecodeOP(driver, ev)
|
||||
func DecodeEvent(ev wsutil.Event, v interface{}) (OPCode, error) {
|
||||
op, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
||||
if err := json.Unmarshal(op.Data, v); err != nil {
|
||||
return 0, errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
return op.Code, nil
|
||||
}
|
||||
|
||||
func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
|
||||
op, err := DecodeOP(driver, ev)
|
||||
func AssertEvent(ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
|
||||
op, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}
|
|||
)
|
||||
}
|
||||
|
||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
||||
if err := json.Unmarshal(op.Data, v); err != nil {
|
||||
return op, errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}
|
|||
}
|
||||
|
||||
func HandleEvent(g *Gateway, ev wsutil.Event) error {
|
||||
o, err := DecodeOP(g.Driver, ev)
|
||||
o, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ func HandleEvent(g *Gateway, ev wsutil.Event) error {
|
|||
// regardless.
|
||||
func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
|
||||
for ev := range ch {
|
||||
o, err := DecodeOP(g.Driver, ev)
|
||||
o, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error {
|
|||
return errors.New("Event not found and event channel is closed.")
|
||||
}
|
||||
|
||||
func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
||||
func DecodeOP(ev wsutil.Event) (*OP, error) {
|
||||
if ev.Error != nil {
|
||||
return nil, ev.Error
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
|||
}
|
||||
|
||||
var op *OP
|
||||
if err := driver.Unmarshal(ev.Data, &op); err != nil {
|
||||
if err := json.Unmarshal(ev.Data, &op); err != nil {
|
||||
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
|
||||
}
|
||||
|
||||
|
@ -139,7 +139,7 @@ func HandleOP(g *Gateway, op *OP) error {
|
|||
|
||||
case ReconnectOP:
|
||||
// Server requests to reconnect, die and retry.
|
||||
WSDebug("ReconnectOP received.")
|
||||
wsutil.WSDebug("ReconnectOP received.")
|
||||
// We must reconnect in another goroutine, as running Reconnect
|
||||
// synchronously would prevent the main event loop from exiting.
|
||||
go g.Reconnect()
|
||||
|
@ -177,7 +177,7 @@ func HandleOP(g *Gateway, op *OP) error {
|
|||
var ev = fn()
|
||||
|
||||
// Try and parse the event
|
||||
if err := g.Driver.Unmarshal(op.Data, ev); err != nil {
|
||||
if err := json.Unmarshal(op.Data, ev); err != nil {
|
||||
return errors.Wrap(err, "Failed to parse event "+op.EventName)
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/wsutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -59,9 +60,9 @@ func (p *Pacemaker) Dead() bool {
|
|||
func (p *Pacemaker) Stop() {
|
||||
if p.stop != nil {
|
||||
p.stop <- struct{}{}
|
||||
WSDebug("(*Pacemaker).stop was sent a stop signal.")
|
||||
wsutil.WSDebug("(*Pacemaker).stop was sent a stop signal.")
|
||||
} else {
|
||||
WSDebug("(*Pacemaker).stop is nil, skipping.")
|
||||
wsutil.WSDebug("(*Pacemaker).stop is nil, skipping.")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,13 +74,13 @@ func (p *Pacemaker) start() error {
|
|||
p.Echo()
|
||||
|
||||
for {
|
||||
WSDebug("Pacemaker loop restarted.")
|
||||
wsutil.WSDebug("Pacemaker loop restarted.")
|
||||
|
||||
if err := p.Pace(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
WSDebug("Paced.")
|
||||
wsutil.WSDebug("Paced.")
|
||||
|
||||
// Paced, save:
|
||||
atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano())
|
||||
|
@ -90,11 +91,11 @@ func (p *Pacemaker) start() error {
|
|||
|
||||
select {
|
||||
case <-p.stop:
|
||||
WSDebug("Received stop signal.")
|
||||
wsutil.WSDebug("Received stop signal.")
|
||||
return nil
|
||||
|
||||
case <-tick.C:
|
||||
WSDebug("Ticked. Restarting.")
|
||||
wsutil.WSDebug("Ticked. Restarting.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -109,7 +110,7 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
|||
go func() {
|
||||
p.death <- p.start()
|
||||
// Debug.
|
||||
WSDebug("Pacemaker returned.")
|
||||
wsutil.WSDebug("Pacemaker returned.")
|
||||
// Mark the stop channel as nil, so later Close() calls won't block forever.
|
||||
p.stop = nil
|
||||
// Mark the pacemaker loop as done.
|
||||
|
|
|
@ -41,7 +41,7 @@ func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
|
|||
func NewClient() *Client {
|
||||
return &Client{
|
||||
Client: httpdriver.NewClient(),
|
||||
Driver: json.Default{},
|
||||
Driver: json.Default,
|
||||
SchemaEncoder: &DefaultSchema{},
|
||||
Retries: Retries,
|
||||
OnResponse: ResponseNoop,
|
||||
|
|
|
@ -7,65 +7,6 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
type (
|
||||
OptionBool = *bool
|
||||
OptionString = *string
|
||||
OptionUint = *uint
|
||||
OptionInt = *int
|
||||
)
|
||||
|
||||
var (
|
||||
True = getBool(true)
|
||||
False = getBool(false)
|
||||
|
||||
ZeroUint = Uint(0)
|
||||
ZeroInt = Int(0)
|
||||
|
||||
EmptyString = String("")
|
||||
)
|
||||
|
||||
func Uint(u uint) OptionUint {
|
||||
return &u
|
||||
}
|
||||
|
||||
func Int(i int) OptionInt {
|
||||
return &i
|
||||
}
|
||||
|
||||
func String(s string) OptionString {
|
||||
return &s
|
||||
}
|
||||
|
||||
func getBool(Bool bool) OptionBool {
|
||||
return &Bool
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
|
||||
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
|
||||
// from encoding/json.
|
||||
type Raw []byte
|
||||
|
||||
// Raw returns m as the JSON encoding of m.
|
||||
func (m Raw) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Raw) UnmarshalJSON(data []byte) error {
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Raw) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type Driver interface {
|
||||
Marshal(v interface{}) ([]byte, error)
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
|
@ -74,20 +15,43 @@ type Driver interface {
|
|||
EncodeStream(w io.Writer, v interface{}) error
|
||||
}
|
||||
|
||||
type Default struct{}
|
||||
type DefaultDriver struct{}
|
||||
|
||||
func (d Default) Marshal(v interface{}) ([]byte, error) {
|
||||
func (d DefaultDriver) Marshal(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
func (d Default) Unmarshal(data []byte, v interface{}) error {
|
||||
func (d DefaultDriver) Unmarshal(data []byte, v interface{}) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func (d Default) DecodeStream(r io.Reader, v interface{}) error {
|
||||
func (d DefaultDriver) DecodeStream(r io.Reader, v interface{}) error {
|
||||
return json.NewDecoder(r).Decode(v)
|
||||
}
|
||||
|
||||
func (d Default) EncodeStream(w io.Writer, v interface{}) error {
|
||||
func (d DefaultDriver) EncodeStream(w io.Writer, v interface{}) error {
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
// Default is the default JSON driver, which uses encoding/json.
|
||||
var Default Driver = DefaultDriver{}
|
||||
|
||||
// Marshal uses the default driver.
|
||||
func Marshal(v interface{}) ([]byte, error) {
|
||||
return Default.Marshal(v)
|
||||
}
|
||||
|
||||
// Unmarshal uses the default driver.
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
return Default.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// DecodeStream uses the default driver.
|
||||
func DecodeStream(r io.Reader, v interface{}) error {
|
||||
return Default.DecodeStream(r, v)
|
||||
}
|
||||
|
||||
// EncodeStream uses the default driver.
|
||||
func EncodeStream(w io.Writer, v interface{}) error {
|
||||
return Default.EncodeStream(w, v)
|
||||
}
|
||||
|
|
34
utils/json/option.go
Normal file
34
utils/json/option.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package json
|
||||
|
||||
type (
|
||||
OptionBool = *bool
|
||||
OptionString = *string
|
||||
OptionUint = *uint
|
||||
OptionInt = *int
|
||||
)
|
||||
|
||||
var (
|
||||
True = getBool(true)
|
||||
False = getBool(false)
|
||||
|
||||
ZeroUint = Uint(0)
|
||||
ZeroInt = Int(0)
|
||||
|
||||
EmptyString = String("")
|
||||
)
|
||||
|
||||
func Uint(u uint) OptionUint {
|
||||
return &u
|
||||
}
|
||||
|
||||
func Int(i int) OptionInt {
|
||||
return &i
|
||||
}
|
||||
|
||||
func String(s string) OptionString {
|
||||
return &s
|
||||
}
|
||||
|
||||
func getBool(Bool bool) OptionBool {
|
||||
return &Bool
|
||||
}
|
23
utils/json/raw.go
Normal file
23
utils/json/raw.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package json
|
||||
|
||||
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
|
||||
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
|
||||
// from encoding/json.
|
||||
type Raw []byte
|
||||
|
||||
// Raw returns m as the JSON encoding of m.
|
||||
func (m Raw) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Raw) UnmarshalJSON(data []byte) error {
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Raw) String() string {
|
||||
return string(m)
|
||||
}
|
|
@ -35,7 +35,7 @@ type Connection interface {
|
|||
Listen() <-chan Event
|
||||
|
||||
// Send allows the caller to send bytes. Thread safety is a requirement.
|
||||
Send([]byte) error
|
||||
Send(context.Context, []byte) error
|
||||
|
||||
// Close should close the websocket connection. The connection will not be
|
||||
// reused.
|
||||
|
@ -67,12 +67,16 @@ type Conn struct {
|
|||
|
||||
var _ Connection = (*Conn)(nil)
|
||||
|
||||
func NewConn(driver json.Driver) *Conn {
|
||||
func NewConn() *Conn {
|
||||
return NewConnWithDriver(json.Default)
|
||||
}
|
||||
|
||||
func NewConnWithDriver(driver json.Driver) *Conn {
|
||||
return &Conn{
|
||||
Driver: driver,
|
||||
dialer: &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: DefaultTimeout,
|
||||
HandshakeTimeout: WSTimeout,
|
||||
EnableCompression: true,
|
||||
},
|
||||
// zlib: zlib.NewInflator(),
|
||||
|
@ -226,14 +230,27 @@ func (c *Conn) handle() ([]byte, error) {
|
|||
// return nil, errors.New("Unexpected binary message.")
|
||||
}
|
||||
|
||||
func (c *Conn) Send(b []byte) error {
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
// If websocket is already closed.
|
||||
if c.writes == nil {
|
||||
return ErrWebsocketClosed
|
||||
}
|
||||
|
||||
c.writes <- b
|
||||
return <-c.errors
|
||||
// Send the bytes.
|
||||
select {
|
||||
case c.writes <- b:
|
||||
// continue
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Receive the error.
|
||||
select {
|
||||
case err := <-c.errors:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Close() (err error) {
|
||||
|
|
|
@ -4,15 +4,30 @@ package wsutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const DefaultTimeout = time.Minute
|
||||
var (
|
||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||
// before Gateway cancels and fails.
|
||||
WSTimeout = time.Minute
|
||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
||||
// timeout for the websocket.
|
||||
WSExtraReadTimeout = time.Second
|
||||
// WSDebug is used for extra debug logging. This is expected to behave
|
||||
// similarly to log.Println().
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Data []byte
|
||||
|
@ -25,12 +40,16 @@ type Websocket struct {
|
|||
Conn Connection
|
||||
Addr string
|
||||
|
||||
// Timeout for connecting and writing to the Websocket, uses default
|
||||
// WSTimeout (global).
|
||||
Timeout time.Duration
|
||||
|
||||
SendLimiter *rate.Limiter
|
||||
DialLimiter *rate.Limiter
|
||||
}
|
||||
|
||||
func New(addr string) *Websocket {
|
||||
return NewCustom(NewConn(json.Default{}), addr)
|
||||
return NewCustom(NewConn(), addr)
|
||||
}
|
||||
|
||||
// NewCustom creates a new undialed Websocket.
|
||||
|
@ -39,12 +58,21 @@ func NewCustom(conn Connection, addr string) *Websocket {
|
|||
Conn: conn,
|
||||
Addr: addr,
|
||||
|
||||
Timeout: WSTimeout,
|
||||
|
||||
SendLimiter: NewSendLimiter(),
|
||||
DialLimiter: NewDialLimiter(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||
if ws.Timeout > 0 {
|
||||
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
|
||||
defer cancel()
|
||||
|
||||
ctx = tctx
|
||||
}
|
||||
|
||||
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
||||
// Expired, fatal error
|
||||
return errors.Wrap(err, "Failed to wait")
|
||||
|
@ -65,11 +93,16 @@ func (ws *Websocket) Listen() <-chan Event {
|
|||
}
|
||||
|
||||
func (ws *Websocket) Send(b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(context.Background()); err != nil {
|
||||
return ws.SendContext(context.Background(), b)
|
||||
}
|
||||
|
||||
// SendContext is a beta API.
|
||||
func (ws *Websocket) SendContext(ctx context.Context, b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
return ws.Conn.Send(b)
|
||||
return ws.Conn.Send(ctx, b)
|
||||
}
|
||||
|
||||
func (ws *Websocket) Close() error {
|
||||
|
|
Loading…
Reference in a new issue