diff --git a/gateway/gateway.go b/gateway/gateway.go index ce8cc3c..b48c685 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -9,7 +9,6 @@ package gateway import ( "context" - "log" "net/http" "net/url" "sync" @@ -32,23 +31,6 @@ var ( // Compress = "zlib-stream" ) -var ( - // WSTimeout is the timeout for connecting and writing to the Websocket, - // before Gateway cancels and fails. - WSTimeout = wsutil.DefaultTimeout - // WSBuffer is the size of the Event channel. This has to be at least 1 to - // make space for the first Event: Ready or Resumed. - WSBuffer = 10 - // WSError is the default error handler - WSError = func(err error) { log.Println("Gateway error:", err) } - // WSExtraReadTimeout is the duration to be added to Hello, as a read - // timeout for the websocket. - WSExtraReadTimeout = time.Second - // WSDebug is used for extra debug logging. This is expected to behave - // similarly to log.Println(). - WSDebug = func(v ...interface{}) {} -) - var ( ErrMissingForResume = errors.New("missing session ID or sequence for resuming") ErrWSMaxTries = errors.New("max tries reached") @@ -95,11 +77,7 @@ func BotURL(token string) (*GatewayBotData, error) { } type Gateway struct { - WS *wsutil.Websocket - json.Driver - - // Timeout for connecting and writing to the Websocket, uses default - // WSTimeout (global). + WS *wsutil.Websocket WSTimeout time.Duration // All events sent over are pointers to Event structs (structs suffixed with @@ -139,11 +117,6 @@ type Gateway struct { // NewGateway starts a new Gateway with the default stdlib JSON driver. For more // information, refer to NewGatewayWithDriver. func NewGateway(token string) (*Gateway, error) { - return NewGatewayWithDriver(token, json.Default{}) -} - -// NewGatewayWithDriver connects to the Gateway and authenticates automatically. -func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { URL, err := URL() if err != nil { return nil, errors.Wrap(err, "Failed to get gateway endpoint") @@ -158,18 +131,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { // Append the form to the URL URL += "?" + param.Encode() - return NewCustomGateway(URL, token, driver), nil + return NewCustomGateway(URL, token), nil } -func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway { +func NewCustomGateway(gatewayURL, token string) *Gateway { return &Gateway{ - WS: wsutil.NewCustom(wsutil.NewConn(driver), gatewayURL), - Driver: driver, - WSTimeout: WSTimeout, - Events: make(chan Event, WSBuffer), + WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL), + Events: make(chan Event, wsutil.WSBuffer), Identifier: DefaultIdentifier(token), Sequence: NewSequence(), - ErrorLog: WSError, + ErrorLog: wsutil.WSError, AfterClose: func(error) {}, } } @@ -178,7 +149,7 @@ func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway { func (g *Gateway) Close() error { // Check if the WS is already closed: if g.waitGroup == nil && g.paceDeath == nil { - WSDebug("Gateway is already closed.") + wsutil.WSDebug("Gateway is already closed.") g.AfterClose(nil) return nil @@ -186,15 +157,15 @@ func (g *Gateway) Close() error { // If the pacemaker is running: if g.paceDeath != nil { - WSDebug("Stopping pacemaker...") + wsutil.WSDebug("Stopping pacemaker...") // Stop the pacemaker and the event handler g.Pacemaker.Stop() - WSDebug("Stopped pacemaker.") + wsutil.WSDebug("Stopped pacemaker.") } - WSDebug("Waiting for WaitGroup to be done.") + wsutil.WSDebug("Waiting for WaitGroup to be done.") // This should work, since Pacemaker should signal its loop to stop, which // would also exit our event loop. Both would be 2. @@ -203,7 +174,7 @@ func (g *Gateway) Close() error { // Mark g.waitGroup as empty: g.waitGroup = nil - WSDebug("WaitGroup is done. Closing the websocket.") + wsutil.WSDebug("WaitGroup is done. Closing the websocket.") err := g.WS.Close() g.AfterClose(err) @@ -212,42 +183,47 @@ func (g *Gateway) Close() error { // Reconnect tries to reconnect forever. It will resume the connection if // possible. If an Invalid Session is received, it will start a fresh one. -func (g *Gateway) Reconnect() { - WSDebug("Reconnecting...") +func (g *Gateway) Reconnect() error { + return g.ReconnectContext(context.Background()) +} + +func (g *Gateway) ReconnectContext(ctx context.Context) error { + wsutil.WSDebug("Reconnecting...") // Guarantee the gateway is already closed. Ignore its error, as we're // redialing anyway. g.Close() for i := 1; ; i++ { - WSDebug("Trying to dial, attempt", i) + wsutil.WSDebug("Trying to dial, attempt", i) // Condition: err == ErrInvalidSession: // If the connection is rate limited (documented behavior): // https://discordapp.com/developers/docs/topics/gateway#rate-limiting - if err := g.Open(); err != nil { + if err := g.OpenContext(ctx); err != nil { g.ErrorLog(errors.Wrap(err, "Failed to open gateway")) continue } - WSDebug("Started after attempt:", i) - return + wsutil.WSDebug("Started after attempt:", i) + return nil } } // Open connects to the Websocket and authenticate it. You should usually use // this function over Start(). func (g *Gateway) Open() error { - ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) - defer cancel() + return g.OpenContext(context.Background()) +} +func (g *Gateway) OpenContext(ctx context.Context) error { // Reconnect to the Gateway if err := g.WS.Dial(ctx); err != nil { return errors.Wrap(err, "Failed to reconnect") } - WSDebug("Trying to start...") + wsutil.WSDebug("Trying to start...") // Try to resume the connection if err := g.Start(); err != nil { @@ -266,12 +242,12 @@ func (g *Gateway) Start() error { // defer g.available.Unlock() if err := g.start(); err != nil { - WSDebug("Start failed:", err) + wsutil.WSDebug("Start failed:", err) // Close can be called with the mutex still acquired here, as the // pacemaker hasn't started yet. if err := g.Close(); err != nil { - WSDebug("Failed to close after start fail:", err) + wsutil.WSDebug("Failed to close after start fail:", err) } return err } @@ -293,7 +269,7 @@ func (g *Gateway) start() error { // Wait for an OP 10 Hello var hello HelloEvent - if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil { + if _, err := AssertEvent(<-ch, HelloOP, &hello); err != nil { return errors.Wrap(err, "Error at Hello") } @@ -311,16 +287,16 @@ func (g *Gateway) start() error { } // Expect either READY or RESUMED before continuing. - WSDebug("Waiting for either READY or RESUMED.") + wsutil.WSDebug("Waiting for either READY or RESUMED.") // WaitForEvent should err := WaitForEvent(g, ch, func(op *OP) bool { switch op.EventName { case "READY": - WSDebug("Found READY event.") + wsutil.WSDebug("Found READY event.") return true case "RESUMED": - WSDebug("Found RESUMED event.") + wsutil.WSDebug("Found RESUMED event.") return true } return false @@ -344,7 +320,7 @@ func (g *Gateway) start() error { g.waitGroup.Add(1) go g.handleWS() - WSDebug("Started successfully.") + wsutil.WSDebug("Started successfully.") return nil } @@ -353,7 +329,7 @@ func (g *Gateway) start() error { func (g *Gateway) handleWS() { err := g.eventLoop() g.waitGroup.Done() // mark so Close() can exit. - WSDebug("Event loop stopped.") + wsutil.WSDebug("Event loop stopped.") if err != nil { g.ErrorLog(err) @@ -372,7 +348,7 @@ func (g *Gateway) eventLoop() error { g.paceDeath = nil // mark if err == nil { - WSDebug("Pacemaker stopped without errors.") + wsutil.WSDebug("Pacemaker stopped without errors.") // No error, just exit normally. return nil } @@ -394,7 +370,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error { } if v != nil { - b, err := g.Driver.Marshal(v) + b, err := json.Marshal(v) if err != nil { return errors.Wrap(err, "Failed to encode v") } @@ -402,7 +378,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error { op.Data = b } - b, err := g.Driver.Marshal(op) + b, err := json.Marshal(op) if err != nil { return errors.Wrap(err, "Failed to encode payload") } diff --git a/gateway/op.go b/gateway/op.go index 0410253..219e973 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -38,21 +38,21 @@ type OP struct { EventName string `json:"t,omitempty"` } -func DecodeEvent(driver json.Driver, ev wsutil.Event, v interface{}) (OPCode, error) { - op, err := DecodeOP(driver, ev) +func DecodeEvent(ev wsutil.Event, v interface{}) (OPCode, error) { + op, err := DecodeOP(ev) if err != nil { return 0, err } - if err := driver.Unmarshal(op.Data, v); err != nil { + if err := json.Unmarshal(op.Data, v); err != nil { return 0, errors.Wrap(err, "Failed to decode data") } return op.Code, nil } -func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}) (*OP, error) { - op, err := DecodeOP(driver, ev) +func AssertEvent(ev wsutil.Event, code OPCode, v interface{}) (*OP, error) { + op, err := DecodeOP(ev) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{} ) } - if err := driver.Unmarshal(op.Data, v); err != nil { + if err := json.Unmarshal(op.Data, v); err != nil { return op, errors.Wrap(err, "Failed to decode data") } @@ -72,7 +72,7 @@ func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{} } func HandleEvent(g *Gateway, ev wsutil.Event) error { - o, err := DecodeOP(g.Driver, ev) + o, err := DecodeOP(ev) if err != nil { return err } @@ -84,7 +84,7 @@ func HandleEvent(g *Gateway, ev wsutil.Event) error { // regardless. func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error { for ev := range ch { - o, err := DecodeOP(g.Driver, ev) + o, err := DecodeOP(ev) if err != nil { return err } @@ -106,7 +106,7 @@ func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error { return errors.New("Event not found and event channel is closed.") } -func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { +func DecodeOP(ev wsutil.Event) (*OP, error) { if ev.Error != nil { return nil, ev.Error } @@ -116,7 +116,7 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { } var op *OP - if err := driver.Unmarshal(ev.Data, &op); err != nil { + if err := json.Unmarshal(ev.Data, &op); err != nil { return nil, errors.Wrap(err, "OP error: "+string(ev.Data)) } @@ -139,7 +139,7 @@ func HandleOP(g *Gateway, op *OP) error { case ReconnectOP: // Server requests to reconnect, die and retry. - WSDebug("ReconnectOP received.") + wsutil.WSDebug("ReconnectOP received.") // We must reconnect in another goroutine, as running Reconnect // synchronously would prevent the main event loop from exiting. go g.Reconnect() @@ -177,7 +177,7 @@ func HandleOP(g *Gateway, op *OP) error { var ev = fn() // Try and parse the event - if err := g.Driver.Unmarshal(op.Data, ev); err != nil { + if err := json.Unmarshal(op.Data, ev); err != nil { return errors.Wrap(err, "Failed to parse event "+op.EventName) } diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go index db3ed2e..f51b0d6 100644 --- a/gateway/pacemaker.go +++ b/gateway/pacemaker.go @@ -5,6 +5,7 @@ import ( "sync/atomic" "time" + "github.com/diamondburned/arikawa/utils/wsutil" "github.com/pkg/errors" ) @@ -59,9 +60,9 @@ func (p *Pacemaker) Dead() bool { func (p *Pacemaker) Stop() { if p.stop != nil { p.stop <- struct{}{} - WSDebug("(*Pacemaker).stop was sent a stop signal.") + wsutil.WSDebug("(*Pacemaker).stop was sent a stop signal.") } else { - WSDebug("(*Pacemaker).stop is nil, skipping.") + wsutil.WSDebug("(*Pacemaker).stop is nil, skipping.") } } @@ -73,13 +74,13 @@ func (p *Pacemaker) start() error { p.Echo() for { - WSDebug("Pacemaker loop restarted.") + wsutil.WSDebug("Pacemaker loop restarted.") if err := p.Pace(); err != nil { return err } - WSDebug("Paced.") + wsutil.WSDebug("Paced.") // Paced, save: atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano()) @@ -90,11 +91,11 @@ func (p *Pacemaker) start() error { select { case <-p.stop: - WSDebug("Received stop signal.") + wsutil.WSDebug("Received stop signal.") return nil case <-tick.C: - WSDebug("Ticked. Restarting.") + wsutil.WSDebug("Ticked. Restarting.") } } } @@ -109,7 +110,7 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) { go func() { p.death <- p.start() // Debug. - WSDebug("Pacemaker returned.") + wsutil.WSDebug("Pacemaker returned.") // Mark the stop channel as nil, so later Close() calls won't block forever. p.stop = nil // Mark the pacemaker loop as done. diff --git a/utils/httputil/client.go b/utils/httputil/client.go index 002b248..b5c461b 100644 --- a/utils/httputil/client.go +++ b/utils/httputil/client.go @@ -41,7 +41,7 @@ func ResponseNoop(httpdriver.Request, httpdriver.Response) error { func NewClient() *Client { return &Client{ Client: httpdriver.NewClient(), - Driver: json.Default{}, + Driver: json.Default, SchemaEncoder: &DefaultSchema{}, Retries: Retries, OnResponse: ResponseNoop, diff --git a/utils/json/json.go b/utils/json/json.go index 081399e..478bb42 100644 --- a/utils/json/json.go +++ b/utils/json/json.go @@ -7,65 +7,6 @@ import ( "io" ) -type ( - OptionBool = *bool - OptionString = *string - OptionUint = *uint - OptionInt = *int -) - -var ( - True = getBool(true) - False = getBool(false) - - ZeroUint = Uint(0) - ZeroInt = Int(0) - - EmptyString = String("") -) - -func Uint(u uint) OptionUint { - return &u -} - -func Int(i int) OptionInt { - return &i -} - -func String(s string) OptionString { - return &s -} - -func getBool(Bool bool) OptionBool { - return &Bool -} - -// - -// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and -// can be used to delay JSON decoding or precompute a JSON encoding. It's taken -// from encoding/json. -type Raw []byte - -// Raw returns m as the JSON encoding of m. -func (m Raw) MarshalJSON() ([]byte, error) { - if m == nil { - return []byte("null"), nil - } - return m, nil -} - -func (m *Raw) UnmarshalJSON(data []byte) error { - *m = append((*m)[0:0], data...) - return nil -} - -func (m Raw) String() string { - return string(m) -} - -// - type Driver interface { Marshal(v interface{}) ([]byte, error) Unmarshal(data []byte, v interface{}) error @@ -74,20 +15,43 @@ type Driver interface { EncodeStream(w io.Writer, v interface{}) error } -type Default struct{} +type DefaultDriver struct{} -func (d Default) Marshal(v interface{}) ([]byte, error) { +func (d DefaultDriver) Marshal(v interface{}) ([]byte, error) { return json.Marshal(v) } -func (d Default) Unmarshal(data []byte, v interface{}) error { +func (d DefaultDriver) Unmarshal(data []byte, v interface{}) error { return json.Unmarshal(data, v) } -func (d Default) DecodeStream(r io.Reader, v interface{}) error { +func (d DefaultDriver) DecodeStream(r io.Reader, v interface{}) error { return json.NewDecoder(r).Decode(v) } -func (d Default) EncodeStream(w io.Writer, v interface{}) error { +func (d DefaultDriver) EncodeStream(w io.Writer, v interface{}) error { return json.NewEncoder(w).Encode(v) } + +// Default is the default JSON driver, which uses encoding/json. +var Default Driver = DefaultDriver{} + +// Marshal uses the default driver. +func Marshal(v interface{}) ([]byte, error) { + return Default.Marshal(v) +} + +// Unmarshal uses the default driver. +func Unmarshal(data []byte, v interface{}) error { + return Default.Unmarshal(data, v) +} + +// DecodeStream uses the default driver. +func DecodeStream(r io.Reader, v interface{}) error { + return Default.DecodeStream(r, v) +} + +// EncodeStream uses the default driver. +func EncodeStream(w io.Writer, v interface{}) error { + return Default.EncodeStream(w, v) +} diff --git a/utils/json/option.go b/utils/json/option.go new file mode 100644 index 0000000..0edbc1a --- /dev/null +++ b/utils/json/option.go @@ -0,0 +1,34 @@ +package json + +type ( + OptionBool = *bool + OptionString = *string + OptionUint = *uint + OptionInt = *int +) + +var ( + True = getBool(true) + False = getBool(false) + + ZeroUint = Uint(0) + ZeroInt = Int(0) + + EmptyString = String("") +) + +func Uint(u uint) OptionUint { + return &u +} + +func Int(i int) OptionInt { + return &i +} + +func String(s string) OptionString { + return &s +} + +func getBool(Bool bool) OptionBool { + return &Bool +} diff --git a/utils/json/raw.go b/utils/json/raw.go new file mode 100644 index 0000000..59385c3 --- /dev/null +++ b/utils/json/raw.go @@ -0,0 +1,23 @@ +package json + +// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and +// can be used to delay JSON decoding or precompute a JSON encoding. It's taken +// from encoding/json. +type Raw []byte + +// Raw returns m as the JSON encoding of m. +func (m Raw) MarshalJSON() ([]byte, error) { + if m == nil { + return []byte("null"), nil + } + return m, nil +} + +func (m *Raw) UnmarshalJSON(data []byte) error { + *m = append((*m)[0:0], data...) + return nil +} + +func (m Raw) String() string { + return string(m) +} diff --git a/utils/wsutil/conn.go b/utils/wsutil/conn.go index 606529e..effaa8e 100644 --- a/utils/wsutil/conn.go +++ b/utils/wsutil/conn.go @@ -35,7 +35,7 @@ type Connection interface { Listen() <-chan Event // Send allows the caller to send bytes. Thread safety is a requirement. - Send([]byte) error + Send(context.Context, []byte) error // Close should close the websocket connection. The connection will not be // reused. @@ -67,12 +67,16 @@ type Conn struct { var _ Connection = (*Conn)(nil) -func NewConn(driver json.Driver) *Conn { +func NewConn() *Conn { + return NewConnWithDriver(json.Default) +} + +func NewConnWithDriver(driver json.Driver) *Conn { return &Conn{ Driver: driver, dialer: &websocket.Dialer{ Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: DefaultTimeout, + HandshakeTimeout: WSTimeout, EnableCompression: true, }, // zlib: zlib.NewInflator(), @@ -226,14 +230,27 @@ func (c *Conn) handle() ([]byte, error) { // return nil, errors.New("Unexpected binary message.") } -func (c *Conn) Send(b []byte) error { +func (c *Conn) Send(ctx context.Context, b []byte) error { // If websocket is already closed. if c.writes == nil { return ErrWebsocketClosed } - c.writes <- b - return <-c.errors + // Send the bytes. + select { + case c.writes <- b: + // continue + case <-ctx.Done(): + return ctx.Err() + } + + // Receive the error. + select { + case err := <-c.errors: + return err + case <-ctx.Done(): + return ctx.Err() + } } func (c *Conn) Close() (err error) { diff --git a/utils/wsutil/ws.go b/utils/wsutil/ws.go index f7f5493..0b87ce2 100644 --- a/utils/wsutil/ws.go +++ b/utils/wsutil/ws.go @@ -4,15 +4,30 @@ package wsutil import ( "context" + "log" "net/url" "time" - "github.com/diamondburned/arikawa/utils/json" "github.com/pkg/errors" "golang.org/x/time/rate" ) -const DefaultTimeout = time.Minute +var ( + // WSTimeout is the timeout for connecting and writing to the Websocket, + // before Gateway cancels and fails. + WSTimeout = time.Minute + // WSBuffer is the size of the Event channel. This has to be at least 1 to + // make space for the first Event: Ready or Resumed. + WSBuffer = 10 + // WSError is the default error handler + WSError = func(err error) { log.Println("Gateway error:", err) } + // WSExtraReadTimeout is the duration to be added to Hello, as a read + // timeout for the websocket. + WSExtraReadTimeout = time.Second + // WSDebug is used for extra debug logging. This is expected to behave + // similarly to log.Println(). + WSDebug = func(v ...interface{}) {} +) type Event struct { Data []byte @@ -25,12 +40,16 @@ type Websocket struct { Conn Connection Addr string + // Timeout for connecting and writing to the Websocket, uses default + // WSTimeout (global). + Timeout time.Duration + SendLimiter *rate.Limiter DialLimiter *rate.Limiter } func New(addr string) *Websocket { - return NewCustom(NewConn(json.Default{}), addr) + return NewCustom(NewConn(), addr) } // NewCustom creates a new undialed Websocket. @@ -39,12 +58,21 @@ func NewCustom(conn Connection, addr string) *Websocket { Conn: conn, Addr: addr, + Timeout: WSTimeout, + SendLimiter: NewSendLimiter(), DialLimiter: NewDialLimiter(), } } func (ws *Websocket) Dial(ctx context.Context) error { + if ws.Timeout > 0 { + tctx, cancel := context.WithTimeout(ctx, ws.Timeout) + defer cancel() + + ctx = tctx + } + if err := ws.DialLimiter.Wait(ctx); err != nil { // Expired, fatal error return errors.Wrap(err, "Failed to wait") @@ -65,11 +93,16 @@ func (ws *Websocket) Listen() <-chan Event { } func (ws *Websocket) Send(b []byte) error { - if err := ws.SendLimiter.Wait(context.Background()); err != nil { + return ws.SendContext(context.Background(), b) +} + +// SendContext is a beta API. +func (ws *Websocket) SendContext(ctx context.Context, b []byte) error { + if err := ws.SendLimiter.Wait(ctx); err != nil { return errors.Wrap(err, "SendLimiter failed") } - return ws.Conn.Send(b) + return ws.Conn.Send(ctx, b) } func (ws *Websocket) Close() error {