mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-12-01 17:00:56 +00:00
More race condition fixes
This commit is contained in:
parent
81bc30968b
commit
b62ba3ecc0
|
|
@ -40,7 +40,7 @@ var (
|
||||||
// gateway.
|
// gateway.
|
||||||
WSRetries = uint(5)
|
WSRetries = uint(5)
|
||||||
// WSError is the default error handler
|
// WSError is the default error handler
|
||||||
WSError = func(err error) {}
|
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||||
// WSFatal is the default fatal handler, which is called when the Gateway
|
// WSFatal is the default fatal handler, which is called when the Gateway
|
||||||
// can't recover.
|
// can't recover.
|
||||||
WSFatal = func(err error) { log.Fatalln("Gateway failed:", err) }
|
WSFatal = func(err error) { log.Fatalln("Gateway failed:", err) }
|
||||||
|
|
@ -168,8 +168,11 @@ func (g *Gateway) Close() error {
|
||||||
|
|
||||||
// Reconnects and resumes.
|
// Reconnects and resumes.
|
||||||
func (g *Gateway) Reconnect() error {
|
func (g *Gateway) Reconnect() error {
|
||||||
// Close, but we don't care about the error (I think)
|
// If the event loop is not dead:
|
||||||
g.Close()
|
if g.done != nil {
|
||||||
|
g.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// Actually a reconnect at this point.
|
// Actually a reconnect at this point.
|
||||||
return g.Open()
|
return g.Open()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package gateway
|
package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
@ -8,16 +9,16 @@ import (
|
||||||
|
|
||||||
var ErrDead = errors.New("no heartbeat replied")
|
var ErrDead = errors.New("no heartbeat replied")
|
||||||
|
|
||||||
|
// Time is a UnixNano timestamp.
|
||||||
|
type Time = int64
|
||||||
|
|
||||||
type Pacemaker struct {
|
type Pacemaker struct {
|
||||||
// Heartrate is the received duration between heartbeats.
|
// Heartrate is the received duration between heartbeats.
|
||||||
Heartrate time.Duration
|
Heartrate time.Duration
|
||||||
|
|
||||||
// LastBeat logs the received heartbeats, with the newest one
|
// Time in nanoseconds, guarded by atomic read/writes.
|
||||||
// first.
|
SentBeat Time
|
||||||
// LastBeat [2]time.Time
|
EchoBeat Time
|
||||||
|
|
||||||
SentBeat time.Time
|
|
||||||
EchoBeat time.Time
|
|
||||||
|
|
||||||
// Any callback that returns an error will stop the pacer.
|
// Any callback that returns an error will stop the pacer.
|
||||||
Pace func() error
|
Pace func() error
|
||||||
|
|
@ -31,7 +32,7 @@ type Pacemaker struct {
|
||||||
func (p *Pacemaker) Echo() {
|
func (p *Pacemaker) Echo() {
|
||||||
// Swap our received heartbeats
|
// Swap our received heartbeats
|
||||||
// p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0]
|
// p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0]
|
||||||
p.EchoBeat = time.Now()
|
atomic.StoreInt64(&p.EchoBeat, time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dead, if true, will have Pace return an ErrDead.
|
// Dead, if true, will have Pace return an ErrDead.
|
||||||
|
|
@ -44,11 +45,16 @@ func (p *Pacemaker) Dead() bool {
|
||||||
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
|
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
|
||||||
*/
|
*/
|
||||||
|
|
||||||
if p.EchoBeat.IsZero() || p.SentBeat.IsZero() {
|
var (
|
||||||
|
echo = atomic.LoadInt64(&p.EchoBeat)
|
||||||
|
sent = atomic.LoadInt64(&p.SentBeat)
|
||||||
|
)
|
||||||
|
|
||||||
|
if echo == 0 || sent == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.SentBeat.Sub(p.EchoBeat) > p.Heartrate*2
|
return sent-echo > int64(p.Heartrate)*2
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pacemaker) Stop() {
|
func (p *Pacemaker) Stop() {
|
||||||
|
|
@ -58,14 +64,6 @@ func (p *Pacemaker) Stop() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start beats until it's dead.
|
|
||||||
func (p *Pacemaker) Start() error {
|
|
||||||
stop := make(chan struct{})
|
|
||||||
p.stop = stop
|
|
||||||
|
|
||||||
return p.start(stop)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pacemaker) start(stop chan struct{}) error {
|
func (p *Pacemaker) start(stop chan struct{}) error {
|
||||||
tick := time.NewTicker(p.Heartrate)
|
tick := time.NewTicker(p.Heartrate)
|
||||||
defer tick.Stop()
|
defer tick.Stop()
|
||||||
|
|
@ -83,8 +81,8 @@ func (p *Pacemaker) start(stop chan struct{}) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Paced, save
|
// Paced, save:
|
||||||
p.SentBeat = time.Now()
|
atomic.StoreInt64(&p.SentBeat, time.Now().UnixNano())
|
||||||
|
|
||||||
if p.Dead() {
|
if p.Dead() {
|
||||||
return ErrDead
|
return ErrDead
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,13 @@ package wsutil
|
||||||
import (
|
import (
|
||||||
"compress/zlib"
|
"compress/zlib"
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
stderr "errors"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/internal/json"
|
"github.com/diamondburned/arikawa/internal/json"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
|
|
@ -44,7 +47,6 @@ type Conn struct {
|
||||||
json.Driver
|
json.Driver
|
||||||
|
|
||||||
mut sync.Mutex
|
mut sync.Mutex
|
||||||
done chan struct{}
|
|
||||||
events chan Event
|
events chan Event
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -71,7 +73,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||||
})
|
})
|
||||||
c.Conn.SetReadLimit(WSReadLimit)
|
c.Conn.SetReadLimit(WSReadLimit)
|
||||||
|
|
||||||
c.readLoop(c.events)
|
c.events = make(chan Event, WSBuffer)
|
||||||
|
c.readLoop()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -79,31 +82,36 @@ func (c *Conn) Listen() <-chan Event {
|
||||||
return c.events
|
return c.events
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readLoop(ch chan Event) {
|
func (c *Conn) readLoop() {
|
||||||
c.done = make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(c.events)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
b, err := c.readAll(context.Background())
|
b, err := c.readAll(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Is the error an EOF?
|
||||||
|
if stderr.Is(err, io.EOF) {
|
||||||
|
// Yes it is, exit.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the error is a fatal one
|
// Check if the error is a fatal one
|
||||||
if code := websocket.CloseStatus(err); code > -1 {
|
if code := websocket.CloseStatus(err); code > -1 {
|
||||||
// Is the exit unusual?
|
// Is the exit unusual?
|
||||||
if code != websocket.StatusNormalClosure {
|
if code != websocket.StatusNormalClosure {
|
||||||
// Unusual error, log
|
// Unusual error, log
|
||||||
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
|
c.events <- Event{nil, errors.Wrap(err, "WS fatal")}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.done <- struct{}{}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// or it's not fatal, we just log and continue
|
// or it's not fatal, we just log and continue
|
||||||
ch <- Event{nil, errors.Wrap(err, "WS error")}
|
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- Event{b, nil}
|
c.events <- Event{b, nil}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
@ -147,8 +155,8 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
func (c *Conn) Close(err error) error {
|
func (c *Conn) Close(err error) error {
|
||||||
// Wait for the read loop to exit after exiting.
|
// Wait for the read loop to exit after exiting.
|
||||||
defer func() {
|
defer func() {
|
||||||
<-c.done
|
<-c.events
|
||||||
close(c.done)
|
c.events = nil
|
||||||
|
|
||||||
// Set the connection to nil.
|
// Set the connection to nil.
|
||||||
c.Conn = nil
|
c.Conn = nil
|
||||||
|
|
|
||||||
|
|
@ -64,10 +64,7 @@ func (ws *Websocket) Dial(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *Websocket) Listen() <-chan Event {
|
func (ws *Websocket) Listen() <-chan Event {
|
||||||
if ws.listener == nil {
|
return ws.Conn.Listen()
|
||||||
ws.listener = ws.Conn.Listen()
|
|
||||||
}
|
|
||||||
return ws.listener
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
|
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue