1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-12-01 17:00:56 +00:00

More race condition fixes

This commit is contained in:
diamondburned (Forefront) 2020-01-28 19:54:22 -08:00
parent 81bc30968b
commit b62ba3ecc0
4 changed files with 43 additions and 37 deletions

View file

@ -40,7 +40,7 @@ var (
// gateway. // gateway.
WSRetries = uint(5) WSRetries = uint(5)
// WSError is the default error handler // WSError is the default error handler
WSError = func(err error) {} WSError = func(err error) { log.Println("Gateway error:", err) }
// WSFatal is the default fatal handler, which is called when the Gateway // WSFatal is the default fatal handler, which is called when the Gateway
// can't recover. // can't recover.
WSFatal = func(err error) { log.Fatalln("Gateway failed:", err) } WSFatal = func(err error) { log.Fatalln("Gateway failed:", err) }
@ -168,8 +168,11 @@ func (g *Gateway) Close() error {
// Reconnects and resumes. // Reconnects and resumes.
func (g *Gateway) Reconnect() error { func (g *Gateway) Reconnect() error {
// Close, but we don't care about the error (I think) // If the event loop is not dead:
g.Close() if g.done != nil {
g.Close()
}
// Actually a reconnect at this point. // Actually a reconnect at this point.
return g.Open() return g.Open()
} }

View file

@ -1,6 +1,7 @@
package gateway package gateway
import ( import (
"sync/atomic"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -8,16 +9,16 @@ import (
var ErrDead = errors.New("no heartbeat replied") var ErrDead = errors.New("no heartbeat replied")
// Time is a UnixNano timestamp.
type Time = int64
type Pacemaker struct { type Pacemaker struct {
// Heartrate is the received duration between heartbeats. // Heartrate is the received duration between heartbeats.
Heartrate time.Duration Heartrate time.Duration
// LastBeat logs the received heartbeats, with the newest one // Time in nanoseconds, guarded by atomic read/writes.
// first. SentBeat Time
// LastBeat [2]time.Time EchoBeat Time
SentBeat time.Time
EchoBeat time.Time
// Any callback that returns an error will stop the pacer. // Any callback that returns an error will stop the pacer.
Pace func() error Pace func() error
@ -31,7 +32,7 @@ type Pacemaker struct {
func (p *Pacemaker) Echo() { func (p *Pacemaker) Echo() {
// Swap our received heartbeats // Swap our received heartbeats
// p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0] // p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0]
p.EchoBeat = time.Now() atomic.StoreInt64(&p.EchoBeat, time.Now().UnixNano())
} }
// Dead, if true, will have Pace return an ErrDead. // Dead, if true, will have Pace return an ErrDead.
@ -44,11 +45,16 @@ func (p *Pacemaker) Dead() bool {
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2 return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
*/ */
if p.EchoBeat.IsZero() || p.SentBeat.IsZero() { var (
echo = atomic.LoadInt64(&p.EchoBeat)
sent = atomic.LoadInt64(&p.SentBeat)
)
if echo == 0 || sent == 0 {
return false return false
} }
return p.SentBeat.Sub(p.EchoBeat) > p.Heartrate*2 return sent-echo > int64(p.Heartrate)*2
} }
func (p *Pacemaker) Stop() { func (p *Pacemaker) Stop() {
@ -58,14 +64,6 @@ func (p *Pacemaker) Stop() {
} }
} }
// Start beats until it's dead.
func (p *Pacemaker) Start() error {
stop := make(chan struct{})
p.stop = stop
return p.start(stop)
}
func (p *Pacemaker) start(stop chan struct{}) error { func (p *Pacemaker) start(stop chan struct{}) error {
tick := time.NewTicker(p.Heartrate) tick := time.NewTicker(p.Heartrate)
defer tick.Stop() defer tick.Stop()
@ -83,8 +81,8 @@ func (p *Pacemaker) start(stop chan struct{}) error {
return err return err
} }
// Paced, save // Paced, save:
p.SentBeat = time.Now() atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano())
if p.Dead() { if p.Dead() {
return ErrDead return ErrDead

View file

@ -3,10 +3,13 @@ package wsutil
import ( import (
"compress/zlib" "compress/zlib"
"context" "context"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"sync" "sync"
stderr "errors"
"github.com/diamondburned/arikawa/internal/json" "github.com/diamondburned/arikawa/internal/json"
"github.com/pkg/errors" "github.com/pkg/errors"
"nhooyr.io/websocket" "nhooyr.io/websocket"
@ -44,7 +47,6 @@ type Conn struct {
json.Driver json.Driver
mut sync.Mutex mut sync.Mutex
done chan struct{}
events chan Event events chan Event
} }
@ -71,7 +73,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
}) })
c.Conn.SetReadLimit(WSReadLimit) c.Conn.SetReadLimit(WSReadLimit)
c.readLoop(c.events) c.events = make(chan Event, WSBuffer)
c.readLoop()
return err return err
} }
@ -79,31 +82,36 @@ func (c *Conn) Listen() <-chan Event {
return c.events return c.events
} }
func (c *Conn) readLoop(ch chan Event) { func (c *Conn) readLoop() {
c.done = make(chan struct{})
go func() { go func() {
defer close(c.events)
for { for {
b, err := c.readAll(context.Background()) b, err := c.readAll(context.Background())
if err != nil { if err != nil {
// Is the error an EOF?
if stderr.Is(err, io.EOF) {
// Yes it is, exit.
return
}
// Check if the error is a fatal one // Check if the error is a fatal one
if code := websocket.CloseStatus(err); code > -1 { if code := websocket.CloseStatus(err); code > -1 {
// Is the exit unusual? // Is the exit unusual?
if code != websocket.StatusNormalClosure { if code != websocket.StatusNormalClosure {
// Unusual error, log // Unusual error, log
ch <- Event{nil, errors.Wrap(err, "WS fatal")} c.events <- Event{nil, errors.Wrap(err, "WS fatal")}
} }
c.done <- struct{}{}
return return
} }
// or it's not fatal, we just log and continue // or it's not fatal, we just log and continue
ch <- Event{nil, errors.Wrap(err, "WS error")} c.events <- Event{nil, errors.Wrap(err, "WS error")}
continue continue
} }
ch <- Event{b, nil} c.events <- Event{b, nil}
} }
}() }()
} }
@ -147,8 +155,8 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
func (c *Conn) Close(err error) error { func (c *Conn) Close(err error) error {
// Wait for the read loop to exit after exiting. // Wait for the read loop to exit after exiting.
defer func() { defer func() {
<-c.done <-c.events
close(c.done) c.events = nil
// Set the connection to nil. // Set the connection to nil.
c.Conn = nil c.Conn = nil

View file

@ -64,10 +64,7 @@ func (ws *Websocket) Dial(ctx context.Context) error {
} }
func (ws *Websocket) Listen() <-chan Event { func (ws *Websocket) Listen() <-chan Event {
if ws.listener == nil { return ws.Conn.Listen()
ws.listener = ws.Conn.Listen()
}
return ws.listener
} }
func (ws *Websocket) Send(ctx context.Context, b []byte) error { func (ws *Websocket) Send(ctx context.Context, b []byte) error {