diff --git a/gateway/gateway.go b/gateway/gateway.go index ee4a501..e913b03 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -40,7 +40,7 @@ var ( // gateway. WSRetries = uint(5) // WSError is the default error handler - WSError = func(err error) {} + WSError = func(err error) { log.Println("Gateway error:", err) } // WSFatal is the default fatal handler, which is called when the Gateway // can't recover. WSFatal = func(err error) { log.Fatalln("Gateway failed:", err) } @@ -168,8 +168,11 @@ func (g *Gateway) Close() error { // Reconnects and resumes. func (g *Gateway) Reconnect() error { - // Close, but we don't care about the error (I think) - g.Close() + // If the event loop is not dead: + if g.done != nil { + g.Close() + } + // Actually a reconnect at this point. return g.Open() } diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go index 2d72093..86ef60c 100644 --- a/gateway/pacemaker.go +++ b/gateway/pacemaker.go @@ -1,6 +1,7 @@ package gateway import ( + "sync/atomic" "time" "github.com/pkg/errors" @@ -8,16 +9,16 @@ import ( var ErrDead = errors.New("no heartbeat replied") +// Time is a UnixNano timestamp. +type Time = int64 + type Pacemaker struct { // Heartrate is the received duration between heartbeats. Heartrate time.Duration - // LastBeat logs the received heartbeats, with the newest one - // first. - // LastBeat [2]time.Time - - SentBeat time.Time - EchoBeat time.Time + // Time in nanoseconds, guarded by atomic read/writes. + SentBeat Time + EchoBeat Time // Any callback that returns an error will stop the pacer. Pace func() error @@ -31,7 +32,7 @@ type Pacemaker struct { func (p *Pacemaker) Echo() { // Swap our received heartbeats // p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0] - p.EchoBeat = time.Now() + atomic.StoreInt64(&p.EchoBeat, time.Now().UnixNano()) } // Dead, if true, will have Pace return an ErrDead. @@ -44,11 +45,16 @@ func (p *Pacemaker) Dead() bool { return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2 */ - if p.EchoBeat.IsZero() || p.SentBeat.IsZero() { + var ( + echo = atomic.LoadInt64(&p.EchoBeat) + sent = atomic.LoadInt64(&p.SentBeat) + ) + + if echo == 0 || sent == 0 { return false } - return p.SentBeat.Sub(p.EchoBeat) > p.Heartrate*2 + return sent-echo > int64(p.Heartrate)*2 } func (p *Pacemaker) Stop() { @@ -58,14 +64,6 @@ func (p *Pacemaker) Stop() { } } -// Start beats until it's dead. -func (p *Pacemaker) Start() error { - stop := make(chan struct{}) - p.stop = stop - - return p.start(stop) -} - func (p *Pacemaker) start(stop chan struct{}) error { tick := time.NewTicker(p.Heartrate) defer tick.Stop() @@ -83,8 +81,8 @@ func (p *Pacemaker) start(stop chan struct{}) error { return err } - // Paced, save - p.SentBeat = time.Now() + // Paced, save: + atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano()) if p.Dead() { return ErrDead diff --git a/internal/wsutil/conn.go b/internal/wsutil/conn.go index 6bcbb7e..aa412f0 100644 --- a/internal/wsutil/conn.go +++ b/internal/wsutil/conn.go @@ -3,10 +3,13 @@ package wsutil import ( "compress/zlib" "context" + "io" "io/ioutil" "net/http" "sync" + stderr "errors" + "github.com/diamondburned/arikawa/internal/json" "github.com/pkg/errors" "nhooyr.io/websocket" @@ -44,7 +47,6 @@ type Conn struct { json.Driver mut sync.Mutex - done chan struct{} events chan Event } @@ -71,7 +73,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { }) c.Conn.SetReadLimit(WSReadLimit) - c.readLoop(c.events) + c.events = make(chan Event, WSBuffer) + c.readLoop() return err } @@ -79,31 +82,36 @@ func (c *Conn) Listen() <-chan Event { return c.events } -func (c *Conn) readLoop(ch chan Event) { - c.done = make(chan struct{}) - +func (c *Conn) readLoop() { go func() { + defer close(c.events) + for { b, err := c.readAll(context.Background()) if err != nil { + // Is the error an EOF? + if stderr.Is(err, io.EOF) { + // Yes it is, exit. + return + } + // Check if the error is a fatal one if code := websocket.CloseStatus(err); code > -1 { // Is the exit unusual? if code != websocket.StatusNormalClosure { // Unusual error, log - ch <- Event{nil, errors.Wrap(err, "WS fatal")} + c.events <- Event{nil, errors.Wrap(err, "WS fatal")} } - c.done <- struct{}{} return } // or it's not fatal, we just log and continue - ch <- Event{nil, errors.Wrap(err, "WS error")} + c.events <- Event{nil, errors.Wrap(err, "WS error")} continue } - ch <- Event{b, nil} + c.events <- Event{b, nil} } }() } @@ -147,8 +155,8 @@ func (c *Conn) Send(ctx context.Context, b []byte) error { func (c *Conn) Close(err error) error { // Wait for the read loop to exit after exiting. defer func() { - <-c.done - close(c.done) + <-c.events + c.events = nil // Set the connection to nil. c.Conn = nil diff --git a/internal/wsutil/ws.go b/internal/wsutil/ws.go index 0bedbfc..b758770 100644 --- a/internal/wsutil/ws.go +++ b/internal/wsutil/ws.go @@ -64,10 +64,7 @@ func (ws *Websocket) Dial(ctx context.Context) error { } func (ws *Websocket) Listen() <-chan Event { - if ws.listener == nil { - ws.listener = ws.Conn.Listen() - } - return ws.listener + return ws.Conn.Listen() } func (ws *Websocket) Send(ctx context.Context, b []byte) error {