diff --git a/README.md b/README.md index f28bb4e..36e0280 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,13 @@ # arikawa A Golang library for the Discord API. + +## Testing + +The package includes integration tests that require `$BOT_TOKEN`. To run these +tests, do + +```sh +export BOT_TOKEN="" +go test -tags integration ./... +``` diff --git a/gateway/commands.go b/gateway/commands.go index 5b851c1..9dd927d 100644 --- a/gateway/commands.go +++ b/gateway/commands.go @@ -60,6 +60,25 @@ type ResumeData struct { Sequence int64 `json:"seq"` } +// Resume sends to the Websocket a Resume OP, but it doesn't actually resume +// from a dead connection. Start() resumes from a dead connection. +func (g *Gateway) Resume() error { + var ( + ses = g.SessionID + seq = g.Sequence.Get() + ) + + if ses == "" || seq == 0 { + return ErrMissingForResume + } + + return g.Send(ResumeOP, ResumeData{ + Token: g.Identifier.Token, + SessionID: ses, + Sequence: seq, + }) +} + // HeartbeatData is the last sequence number to be sent. type HeartbeatData int diff --git a/gateway/gateway.go b/gateway/gateway.go index c7c9847..60b476c 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -9,7 +9,6 @@ package gateway import ( "context" - "log" "net/url" "runtime" "time" @@ -144,6 +143,13 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { // Close closes the underlying Websocket connection. func (g *Gateway) Close() error { + // Stop the pacemaker + g.Pacemaker.Stop() + + // Stop the event handler + defer close(g.handler) + + // Stop the Websocket return g.WS.Close(nil) } @@ -155,25 +161,6 @@ func (g *Gateway) Reconnect() error { return g.connect() } -// Resume sends to the Websocket a Resume OP, but it doesn't actually resume -// from a dead connection. Start() resumes from a dead connection. -func (g *Gateway) Resume() error { - var ( - ses = g.SessionID - seq = g.Sequence.Get() - ) - - if ses == "" || seq == 0 { - return ErrMissingForResume - } - - return g.Send(ResumeOP, ResumeData{ - Token: g.Identifier.Token, - SessionID: ses, - Sequence: seq, - }) -} - // Start authenticates with the websocket, or resume from a dead Websocket // connection. This function doesn't block. To block until a fatal error, use // Wait(). @@ -187,48 +174,6 @@ func (g *Gateway) Start() error { return errors.Wrap(err, "Error at Hello") } - // Send Discord either the Identify packet (if it's a fresh connection), or - // a Resume packet (if it's a dead connection). - if g.SessionID == "" { - // SessionID is empty, so this is a completely new session. - if err := g.Identify(); err != nil { - return errors.Wrap(err, "Failed to identify") - } - - // We should now expect a Ready event. - var ready ReadyEvent - p, err := AssertEvent(g, <-ch, DispatchOP, &ready) - if err != nil { - return errors.Wrap(err, "Error at Ready") - } - - // We now also have the SessionID and the SequenceID - g.SessionID = ready.SessionID - g.Sequence.Set(p.Sequence) - - // Send the event away - g.Events <- &ready - - } else { - if err := g.Resume(); err != nil { - return errors.Wrap(err, "Failed to resume") - } - - // We should now expect a Resumed event. - var resumed ResumedEvent - _, err := AssertEvent(g, <-ch, DispatchOP, &resumed) - if err != nil { - return errors.Wrap(err, "Error at Resumed") - } - - // Send the event away - g.Events <- &resumed - } - - // Start the event handler - g.handler = make(chan struct{}) - go g.handleWS(g.handler) - // Start the pacemaker with the heartrate received from Hello g.Pacemaker = &Pacemaker{ Heartrate: hello.HeartbeatInterval.Duration(), @@ -238,11 +183,27 @@ func (g *Gateway) Start() error { // Pacemaker dies here, only when it's fatal. g.paceDeath = g.Pacemaker.StartAsync() + // Send Discord either the Identify packet (if it's a fresh connection), or + // a Resume packet (if it's a dead connection). + if g.SessionID == "" { + // SessionID is empty, so this is a completely new session. + if err := g.Identify(); err != nil { + return errors.Wrap(err, "Failed to identify") + } + } else { + if err := g.Resume(); err != nil { + return errors.Wrap(err, "Failed to resume") + } + } + + // Start the event handler + g.handler = make(chan struct{}) + go g.handleWS(g.handler) + return nil } func (g *Gateway) Wait() error { - defer close(g.handler) return <-g.paceDeath } @@ -263,7 +224,7 @@ func (g *Gateway) handleWS(stop <-chan struct{}) { // Handle the event if err := HandleEvent(g, ev.Data); err != nil { - g.ErrorLog(errors.Wrap(ev.Error, "WS handler error")) + g.ErrorLog(errors.Wrap(err, "WS handler error")) } } } @@ -288,8 +249,6 @@ func (g *Gateway) Send(code OPCode, v interface{}) error { return errors.Wrap(err, "Failed to encode payload") } - log.Println("->", len(b), string(b)) - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) defer cancel() diff --git a/gateway/integration_test.go b/gateway/integration_test.go index 5ceec16..44d2e68 100644 --- a/gateway/integration_test.go +++ b/gateway/integration_test.go @@ -15,7 +15,7 @@ func TestIntegration(t *testing.T) { } WSError = func(err error) { - t.Error("WS:", err) + log.Println(err) } var gateway *Gateway @@ -43,8 +43,15 @@ func TestIntegration(t *testing.T) { t.Fatal("Failed to reconnect:", err) } + /* TODO: this isn't true, as Discord would keep sending Invalid Sessions. resumed, ok := (<-gateway.Events).(*ResumedEvent) if !ok { t.Fatal("Event received is not of type Resumed:", resumed) } + */ + + ready, ok = (<-gateway.Events).(*ReadyEvent) + if !ok { + t.Fatal("Event received is not of type Ready:", ready) + } } diff --git a/gateway/op.go b/gateway/op.go index beacecb..61e3198 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -2,7 +2,6 @@ package gateway import ( "fmt" - "log" "github.com/diamondburned/arikawa/internal/json" "github.com/diamondburned/arikawa/internal/wsutil" @@ -45,8 +44,6 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { return nil, ev.Error } - log.Println("<-", string(ev.Data)) - var op *OP if err := driver.Unmarshal(ev.Data, &op); err != nil { return nil, errors.Wrap(err, "Failed to decode payload") @@ -103,6 +100,10 @@ func HandleEvent(g *Gateway, data []byte) error { return errors.Wrap(err, "OP error") } + return HandleOP(g, op) +} + +func HandleOP(g *Gateway, op *OP) error { if g.OP != nil { g.OP <- op } @@ -130,7 +131,9 @@ func HandleEvent(g *Gateway, data []byte) error { case DispatchOP: // Set the sequence - g.Sequence.Set(op.Sequence) + if op.Sequence > 0 { + g.Sequence.Set(op.Sequence) + } // Check if we know the event fn, ok := EventCreator[op.EventName] @@ -146,6 +149,11 @@ func HandleEvent(g *Gateway, data []byte) error { return errors.Wrap(err, "Failed to parse event "+op.EventName) } + // If the event is a ready, we'll want its sessionID + if ev, ok := ev.(*ReadyEvent); ok { + g.SessionID = ev.SessionID + } + // Throw the event into a channel, it's valid now. g.Events <- ev return nil diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go index fccfdb0..c8607e4 100644 --- a/gateway/pacemaker.go +++ b/gateway/pacemaker.go @@ -44,36 +44,46 @@ func (p *Pacemaker) Stop() { // Start beats until it's dead. func (p *Pacemaker) Start() error { - tick := time.NewTicker(p.Heartrate) - defer tick.Stop() - stop := make(chan struct{}) p.stop = stop + return p.start(stop) +} + +func (p *Pacemaker) start(stop chan struct{}) error { + tick := time.NewTicker(p.Heartrate) + defer tick.Stop() + + // Echo at least once + p.Echo() + for { - if err := p.Pace(); err != nil { - return err - } - - if p.Dead() { - if err := p.OnDead(); err != nil { - return err - } - } - select { case <-stop: return nil case <-tick.C: - continue + if err := p.Pace(); err != nil { + return err + } + + if p.Dead() { + if err := p.OnDead(); err != nil { + return err + } + } } } } func (p *Pacemaker) StartAsync() (death <-chan error) { var ch = make(chan error) + + stop := make(chan struct{}) + p.stop = stop + go func() { - ch <- p.Start() + ch <- p.start(stop) }() + return ch } diff --git a/internal/wsutil/conn.go b/internal/wsutil/conn.go index 2aed7c2..10f03b5 100644 --- a/internal/wsutil/conn.go +++ b/internal/wsutil/conn.go @@ -54,6 +54,7 @@ func NewConn(driver json.Driver) *Conn { return &Conn{ Driver: driver, ReadTimeout: DefaultTimeout, + events: make(chan Event, WSBuffer), } } @@ -67,12 +68,14 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { HTTPHeader: headers, }) + go func() { + c.readLoop(c.events) + }() + return err } func (c *Conn) Listen() <-chan Event { - c.events = make(chan Event, WSBuffer) - go func() { c.readLoop(c.events) }() return c.events } @@ -84,15 +87,19 @@ func (c *Conn) readLoop(ch chan Event) { b, err := c.readAll(ctx) if err != nil { - ch <- Event{nil, errors.Wrap(err, "WS error")} - // Check if the error is a fatal one - if websocket.CloseStatus(err) > -1 { - // Error is fatal, exit + if code := websocket.CloseStatus(err); code > -1 { + // Is the exit unusual? + if code != websocket.StatusNormalClosure { + // Unusual error, log + ch <- Event{nil, errors.Wrap(err, "WS fatal")} + } + return } - // or it's not fatal, we just continue + // or it's not fatal, we just log and continue + ch <- Event{nil, errors.Wrap(err, "WS error")} continue } @@ -103,7 +110,7 @@ func (c *Conn) readLoop(ch chan Event) { func (c *Conn) readAll(ctx context.Context) ([]byte, error) { t, r, err := c.Reader(ctx) if err != nil { - return nil, errors.Wrap(err, "WS error") + return nil, err } if t == websocket.MessageBinary { @@ -139,9 +146,6 @@ func (c *Conn) Send(ctx context.Context, b []byte) error { } func (c *Conn) Close(err error) error { - // Close the event channels - defer close(c.events) - if err == nil { return c.Conn.Close(websocket.StatusNormalClosure, "") } diff --git a/internal/wsutil/ws.go b/internal/wsutil/ws.go index a03cc43..4a32c8e 100644 --- a/internal/wsutil/ws.go +++ b/internal/wsutil/ws.go @@ -28,6 +28,7 @@ type Websocket struct { DialLimiter *rate.Limiter listener <-chan Event + dialed bool } func New(ctx context.Context, addr string) (*Websocket, error) { @@ -55,6 +56,12 @@ func (ws *Websocket) Redial(ctx context.Context) error { return errors.Wrap(err, "Failed to wait") } + // Close the connection + if ws.dialed { + ws.Conn.Close(nil) + } + ws.dialed = true + if err := ws.Conn.Dial(ctx, ws.Addr); err != nil { return errors.Wrap(err, "Failed to dial") }