Fixed a ridiculous race condition in the gateway

This commit is contained in:
diamondburned (Forefront) 2020-01-20 03:06:20 -08:00
parent 2cd9def778
commit 3ca1d352c9
5 changed files with 102 additions and 69 deletions

View File

@ -13,7 +13,7 @@ stages:
unit_test:
stage: test
script:
- go test ./...
- go test -race -v ./...
integration_test:
stage: test
@ -21,5 +21,5 @@ integration_test:
variables:
- $BOT_TOKEN
script:
- go test -tags integration ./...
- go test -tags integration -race -v ./...

View File

@ -102,8 +102,8 @@ type Gateway struct {
OP chan Event
// Filled by methods, internal use
paceDeath <-chan error
handler chan struct{}
done chan struct{}
paceDeath chan error
}
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
@ -153,14 +153,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
// Stop the pacemaker
// If the pacemaker is running:
// Stop the pacemaker and the event handler
g.Pacemaker.Stop()
g.paceDeath = nil
// Stop the event handler
if g.handler != nil {
close(g.handler)
g.handler = nil
if g.done != nil {
// Wait for the event handler to fully exit
<-g.done
// Final clean-up
g.done = nil
}
// Stop the Websocket
@ -193,10 +195,7 @@ func (g *Gateway) Open() error {
}
// Reconnect to the Gateway
if err := g.WS.Redial(ctx); err != nil {
// Close the connection
g.Close()
if err := g.WS.Dial(ctx); err != nil {
// Save the error, retry again
Lerr = errors.Wrap(err, "Failed to reconnect")
continue
@ -204,9 +203,6 @@ func (g *Gateway) Open() error {
// Try to resume the connection
if err := g.Start(); err != nil {
// Close the connection
g.Close()
// If the connection is rate limited (documented behavior):
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
if err == ErrInvalidSession {
@ -234,6 +230,14 @@ func (g *Gateway) Open() error {
// Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block.
func (g *Gateway) Start() error {
if err := g.start(); err != nil {
g.Close()
return err
}
return nil
}
func (g *Gateway) start() error {
// This is where we'll get our events
ch := g.WS.Listen()
@ -279,30 +283,35 @@ func (g *Gateway) Start() error {
}
// Start the event handler
g.handler = make(chan struct{})
go g.handleWS(g.handler)
g.done = make(chan struct{})
go g.handleWS(g.done)
return nil
}
// handleWS uses the Websocket and parses them into g.Events.
func (g *Gateway) handleWS(stop <-chan struct{}) {
func (g *Gateway) handleWS(done chan struct{}) {
ch := g.WS.Listen()
defer func() {
done <- struct{}{}
}()
for {
select {
case <-stop:
return
case err := <-g.paceDeath:
if err != nil {
// Pacemaker died, pretty fatal. We'll reconnect though.
if err := g.Reconnect(); err != nil {
// Very fatal if this fails. We'll warn the user.
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
if err == nil {
// No error, just exit normally.
return
}
// Then, we'll take the safe way and exit.
return
}
// Pacemaker died, pretty fatal. We'll reconnect though.
if err := g.Reconnect(); err != nil {
// Very fatal if this fails. We'll warn the user.
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
// Then, we'll take the safe way and exit.
return
}
case ev := <-ch:

View File

@ -24,7 +24,8 @@ type Pacemaker struct {
// Event
OnDead func() error
stop chan<- struct{}
stop chan<- struct{}
death chan error
}
func (p *Pacemaker) Echo() {
@ -92,16 +93,15 @@ func (p *Pacemaker) start(stop chan struct{}) error {
}
}
func (p *Pacemaker) StartAsync() (death <-chan error) {
var ch = make(chan error)
func (p *Pacemaker) StartAsync() (death chan error) {
p.death = make(chan error)
stop := make(chan struct{})
p.stop = stop
go func() {
ch <- p.start(stop)
close(ch)
p.death <- p.start(stop)
}()
return ch
return p.death
}

View File

@ -5,6 +5,7 @@ import (
"context"
"io/ioutil"
"net/http"
"sync"
"github.com/diamondburned/arikawa/internal/json"
"github.com/pkg/errors"
@ -39,9 +40,11 @@ type Connection interface {
// Conn is the default Websocket connection. It compresses all payloads using
// zlib.
type Conn struct {
*websocket.Conn
Conn *websocket.Conn
json.Driver
mut sync.Mutex
done chan struct{}
events chan Event
}
@ -60,16 +63,15 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
headers := http.Header{}
headers.Set("Accept-Encoding", "zlib") // enable
c.mut.Lock()
defer c.mut.Unlock()
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
HTTPHeader: headers,
})
c.Conn.SetReadLimit(WSReadLimit)
go func() {
c.readLoop(c.events)
}()
c.readLoop(c.events)
return err
}
@ -78,31 +80,36 @@ func (c *Conn) Listen() <-chan Event {
}
func (c *Conn) readLoop(ch chan Event) {
for {
b, err := c.readAll(context.Background())
if err != nil {
// Check if the error is a fatal one
if code := websocket.CloseStatus(err); code > -1 {
// Is the exit unusual?
if code != websocket.StatusNormalClosure {
// Unusual error, log
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
c.done = make(chan struct{})
go func() {
for {
b, err := c.readAll(context.Background())
if err != nil {
// Check if the error is a fatal one
if code := websocket.CloseStatus(err); code > -1 {
// Is the exit unusual?
if code != websocket.StatusNormalClosure {
// Unusual error, log
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
}
c.done <- struct{}{}
return
}
return
// or it's not fatal, we just log and continue
ch <- Event{nil, errors.Wrap(err, "WS error")}
continue
}
// or it's not fatal, we just log and continue
ch <- Event{nil, errors.Wrap(err, "WS error")}
continue
ch <- Event{b, nil}
}
ch <- Event{b, nil}
}
}()
}
func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
t, r, err := c.Reader(ctx)
t, r, err := c.Conn.Reader(ctx)
if err != nil {
return nil, err
}
@ -111,7 +118,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
// Probably a zlib payload
z, err := zlib.NewReader(r)
if err != nil {
c.CloseRead(ctx)
c.Conn.CloseRead(ctx)
return nil,
errors.Wrap(err, "Failed to create a zlib reader")
}
@ -122,7 +129,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
c.CloseRead(ctx)
c.Conn.CloseRead(ctx)
return nil, err
}
@ -131,10 +138,22 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
func (c *Conn) Send(ctx context.Context, b []byte) error {
// TODO: zlib stream
return c.Write(ctx, websocket.MessageText, b)
return c.Conn.Write(ctx, websocket.MessageText, b)
}
func (c *Conn) Close(err error) error {
// Wait for the read loop to exit after exiting.
defer func() {
<-c.done
close(c.done)
// Set the connection to nil.
c.Conn = nil
// Flush all events.
c.flush()
}()
if err == nil {
return c.Conn.Close(websocket.StatusNormalClosure, "")
}
@ -146,3 +165,14 @@ func (c *Conn) Close(err error) error {
return c.Conn.Close(websocket.StatusProtocolError, msg)
}
func (c *Conn) flush() {
for {
select {
case <-c.events:
continue
default:
return
}
}
}

View File

@ -50,18 +50,12 @@ func NewCustom(
return ws, nil
}
func (ws *Websocket) Redial(ctx context.Context) error {
func (ws *Websocket) Dial(ctx context.Context) error {
if err := ws.DialLimiter.Wait(ctx); err != nil {
// Expired, fatal error
return errors.Wrap(err, "Failed to wait")
}
// Close the connection
if ws.dialed {
ws.Conn.Close(nil)
}
ws.dialed = true
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
return errors.Wrap(err, "Failed to dial")
}