Fixed a ridiculous race condition in the gateway

This commit is contained in:
diamondburned (Forefront) 2020-01-20 03:06:20 -08:00
parent 2cd9def778
commit 3ca1d352c9
5 changed files with 102 additions and 69 deletions

View File

@ -13,7 +13,7 @@ stages:
unit_test: unit_test:
stage: test stage: test
script: script:
- go test ./... - go test -race -v ./...
integration_test: integration_test:
stage: test stage: test
@ -21,5 +21,5 @@ integration_test:
variables: variables:
- $BOT_TOKEN - $BOT_TOKEN
script: script:
- go test -tags integration ./... - go test -tags integration -race -v ./...

View File

@ -102,8 +102,8 @@ type Gateway struct {
OP chan Event OP chan Event
// Filled by methods, internal use // Filled by methods, internal use
paceDeath <-chan error done chan struct{}
handler chan struct{} paceDeath chan error
} }
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more // NewGateway starts a new Gateway with the default stdlib JSON driver. For more
@ -153,14 +153,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
// Close closes the underlying Websocket connection. // Close closes the underlying Websocket connection.
func (g *Gateway) Close() error { func (g *Gateway) Close() error {
// Stop the pacemaker // If the pacemaker is running:
// Stop the pacemaker and the event handler
g.Pacemaker.Stop() g.Pacemaker.Stop()
g.paceDeath = nil
// Stop the event handler if g.done != nil {
if g.handler != nil { // Wait for the event handler to fully exit
close(g.handler) <-g.done
g.handler = nil
// Final clean-up
g.done = nil
} }
// Stop the Websocket // Stop the Websocket
@ -193,10 +195,7 @@ func (g *Gateway) Open() error {
} }
// Reconnect to the Gateway // Reconnect to the Gateway
if err := g.WS.Redial(ctx); err != nil { if err := g.WS.Dial(ctx); err != nil {
// Close the connection
g.Close()
// Save the error, retry again // Save the error, retry again
Lerr = errors.Wrap(err, "Failed to reconnect") Lerr = errors.Wrap(err, "Failed to reconnect")
continue continue
@ -204,9 +203,6 @@ func (g *Gateway) Open() error {
// Try to resume the connection // Try to resume the connection
if err := g.Start(); err != nil { if err := g.Start(); err != nil {
// Close the connection
g.Close()
// If the connection is rate limited (documented behavior): // If the connection is rate limited (documented behavior):
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting // https://discordapp.com/developers/docs/topics/gateway#rate-limiting
if err == ErrInvalidSession { if err == ErrInvalidSession {
@ -234,6 +230,14 @@ func (g *Gateway) Open() error {
// Start authenticates with the websocket, or resume from a dead Websocket // Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block. // connection. This function doesn't block.
func (g *Gateway) Start() error { func (g *Gateway) Start() error {
if err := g.start(); err != nil {
g.Close()
return err
}
return nil
}
func (g *Gateway) start() error {
// This is where we'll get our events // This is where we'll get our events
ch := g.WS.Listen() ch := g.WS.Listen()
@ -279,22 +283,28 @@ func (g *Gateway) Start() error {
} }
// Start the event handler // Start the event handler
g.handler = make(chan struct{}) g.done = make(chan struct{})
go g.handleWS(g.handler) go g.handleWS(g.done)
return nil return nil
} }
// handleWS uses the Websocket and parses them into g.Events. // handleWS uses the Websocket and parses them into g.Events.
func (g *Gateway) handleWS(stop <-chan struct{}) { func (g *Gateway) handleWS(done chan struct{}) {
ch := g.WS.Listen() ch := g.WS.Listen()
defer func() {
done <- struct{}{}
}()
for { for {
select { select {
case <-stop:
return
case err := <-g.paceDeath: case err := <-g.paceDeath:
if err != nil { if err == nil {
// No error, just exit normally.
return
}
// Pacemaker died, pretty fatal. We'll reconnect though. // Pacemaker died, pretty fatal. We'll reconnect though.
if err := g.Reconnect(); err != nil { if err := g.Reconnect(); err != nil {
// Very fatal if this fails. We'll warn the user. // Very fatal if this fails. We'll warn the user.
@ -303,7 +313,6 @@ func (g *Gateway) handleWS(stop <-chan struct{}) {
// Then, we'll take the safe way and exit. // Then, we'll take the safe way and exit.
return return
} }
}
case ev := <-ch: case ev := <-ch:
// Check for error // Check for error

View File

@ -25,6 +25,7 @@ type Pacemaker struct {
OnDead func() error OnDead func() error
stop chan<- struct{} stop chan<- struct{}
death chan error
} }
func (p *Pacemaker) Echo() { func (p *Pacemaker) Echo() {
@ -92,16 +93,15 @@ func (p *Pacemaker) start(stop chan struct{}) error {
} }
} }
func (p *Pacemaker) StartAsync() (death <-chan error) { func (p *Pacemaker) StartAsync() (death chan error) {
var ch = make(chan error) p.death = make(chan error)
stop := make(chan struct{}) stop := make(chan struct{})
p.stop = stop p.stop = stop
go func() { go func() {
ch <- p.start(stop) p.death <- p.start(stop)
close(ch)
}() }()
return ch return p.death
} }

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"sync"
"github.com/diamondburned/arikawa/internal/json" "github.com/diamondburned/arikawa/internal/json"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -39,9 +40,11 @@ type Connection interface {
// Conn is the default Websocket connection. It compresses all payloads using // Conn is the default Websocket connection. It compresses all payloads using
// zlib. // zlib.
type Conn struct { type Conn struct {
*websocket.Conn Conn *websocket.Conn
json.Driver json.Driver
mut sync.Mutex
done chan struct{}
events chan Event events chan Event
} }
@ -60,16 +63,15 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
headers := http.Header{} headers := http.Header{}
headers.Set("Accept-Encoding", "zlib") // enable headers.Set("Accept-Encoding", "zlib") // enable
c.mut.Lock()
defer c.mut.Unlock()
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{ c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
HTTPHeader: headers, HTTPHeader: headers,
}) })
c.Conn.SetReadLimit(WSReadLimit) c.Conn.SetReadLimit(WSReadLimit)
go func() {
c.readLoop(c.events) c.readLoop(c.events)
}()
return err return err
} }
@ -78,6 +80,9 @@ func (c *Conn) Listen() <-chan Event {
} }
func (c *Conn) readLoop(ch chan Event) { func (c *Conn) readLoop(ch chan Event) {
c.done = make(chan struct{})
go func() {
for { for {
b, err := c.readAll(context.Background()) b, err := c.readAll(context.Background())
if err != nil { if err != nil {
@ -89,6 +94,7 @@ func (c *Conn) readLoop(ch chan Event) {
ch <- Event{nil, errors.Wrap(err, "WS fatal")} ch <- Event{nil, errors.Wrap(err, "WS fatal")}
} }
c.done <- struct{}{}
return return
} }
@ -99,10 +105,11 @@ func (c *Conn) readLoop(ch chan Event) {
ch <- Event{b, nil} ch <- Event{b, nil}
} }
}()
} }
func (c *Conn) readAll(ctx context.Context) ([]byte, error) { func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
t, r, err := c.Reader(ctx) t, r, err := c.Conn.Reader(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -111,7 +118,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
// Probably a zlib payload // Probably a zlib payload
z, err := zlib.NewReader(r) z, err := zlib.NewReader(r)
if err != nil { if err != nil {
c.CloseRead(ctx) c.Conn.CloseRead(ctx)
return nil, return nil,
errors.Wrap(err, "Failed to create a zlib reader") errors.Wrap(err, "Failed to create a zlib reader")
} }
@ -122,7 +129,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
b, err := ioutil.ReadAll(r) b, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
c.CloseRead(ctx) c.Conn.CloseRead(ctx)
return nil, err return nil, err
} }
@ -131,10 +138,22 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
func (c *Conn) Send(ctx context.Context, b []byte) error { func (c *Conn) Send(ctx context.Context, b []byte) error {
// TODO: zlib stream // TODO: zlib stream
return c.Write(ctx, websocket.MessageText, b) return c.Conn.Write(ctx, websocket.MessageText, b)
} }
func (c *Conn) Close(err error) error { func (c *Conn) Close(err error) error {
// Wait for the read loop to exit after exiting.
defer func() {
<-c.done
close(c.done)
// Set the connection to nil.
c.Conn = nil
// Flush all events.
c.flush()
}()
if err == nil { if err == nil {
return c.Conn.Close(websocket.StatusNormalClosure, "") return c.Conn.Close(websocket.StatusNormalClosure, "")
} }
@ -146,3 +165,14 @@ func (c *Conn) Close(err error) error {
return c.Conn.Close(websocket.StatusProtocolError, msg) return c.Conn.Close(websocket.StatusProtocolError, msg)
} }
func (c *Conn) flush() {
for {
select {
case <-c.events:
continue
default:
return
}
}
}

View File

@ -50,18 +50,12 @@ func NewCustom(
return ws, nil return ws, nil
} }
func (ws *Websocket) Redial(ctx context.Context) error { func (ws *Websocket) Dial(ctx context.Context) error {
if err := ws.DialLimiter.Wait(ctx); err != nil { if err := ws.DialLimiter.Wait(ctx); err != nil {
// Expired, fatal error // Expired, fatal error
return errors.Wrap(err, "Failed to wait") return errors.Wrap(err, "Failed to wait")
} }
// Close the connection
if ws.dialed {
ws.Conn.Close(nil)
}
ws.dialed = true
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil { if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
return errors.Wrap(err, "Failed to dial") return errors.Wrap(err, "Failed to dial")
} }