Fixed a ridiculous race condition in the gateway
This commit is contained in:
parent
2cd9def778
commit
3ca1d352c9
|
@ -13,7 +13,7 @@ stages:
|
|||
unit_test:
|
||||
stage: test
|
||||
script:
|
||||
- go test ./...
|
||||
- go test -race -v ./...
|
||||
|
||||
integration_test:
|
||||
stage: test
|
||||
|
@ -21,5 +21,5 @@ integration_test:
|
|||
variables:
|
||||
- $BOT_TOKEN
|
||||
script:
|
||||
- go test -tags integration ./...
|
||||
- go test -tags integration -race -v ./...
|
||||
|
||||
|
|
|
@ -102,8 +102,8 @@ type Gateway struct {
|
|||
OP chan Event
|
||||
|
||||
// Filled by methods, internal use
|
||||
paceDeath <-chan error
|
||||
handler chan struct{}
|
||||
done chan struct{}
|
||||
paceDeath chan error
|
||||
}
|
||||
|
||||
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
||||
|
@ -153,14 +153,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|||
|
||||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
// Stop the pacemaker
|
||||
// If the pacemaker is running:
|
||||
// Stop the pacemaker and the event handler
|
||||
g.Pacemaker.Stop()
|
||||
g.paceDeath = nil
|
||||
|
||||
// Stop the event handler
|
||||
if g.handler != nil {
|
||||
close(g.handler)
|
||||
g.handler = nil
|
||||
if g.done != nil {
|
||||
// Wait for the event handler to fully exit
|
||||
<-g.done
|
||||
|
||||
// Final clean-up
|
||||
g.done = nil
|
||||
}
|
||||
|
||||
// Stop the Websocket
|
||||
|
@ -193,10 +195,7 @@ func (g *Gateway) Open() error {
|
|||
}
|
||||
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Redial(ctx); err != nil {
|
||||
// Close the connection
|
||||
g.Close()
|
||||
|
||||
if err := g.WS.Dial(ctx); err != nil {
|
||||
// Save the error, retry again
|
||||
Lerr = errors.Wrap(err, "Failed to reconnect")
|
||||
continue
|
||||
|
@ -204,9 +203,6 @@ func (g *Gateway) Open() error {
|
|||
|
||||
// Try to resume the connection
|
||||
if err := g.Start(); err != nil {
|
||||
// Close the connection
|
||||
g.Close()
|
||||
|
||||
// If the connection is rate limited (documented behavior):
|
||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||
if err == ErrInvalidSession {
|
||||
|
@ -234,6 +230,14 @@ func (g *Gateway) Open() error {
|
|||
// Start authenticates with the websocket, or resume from a dead Websocket
|
||||
// connection. This function doesn't block.
|
||||
func (g *Gateway) Start() error {
|
||||
if err := g.start(); err != nil {
|
||||
g.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) start() error {
|
||||
// This is where we'll get our events
|
||||
ch := g.WS.Listen()
|
||||
|
||||
|
@ -279,30 +283,35 @@ func (g *Gateway) Start() error {
|
|||
}
|
||||
|
||||
// Start the event handler
|
||||
g.handler = make(chan struct{})
|
||||
go g.handleWS(g.handler)
|
||||
g.done = make(chan struct{})
|
||||
go g.handleWS(g.done)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleWS uses the Websocket and parses them into g.Events.
|
||||
func (g *Gateway) handleWS(stop <-chan struct{}) {
|
||||
func (g *Gateway) handleWS(done chan struct{}) {
|
||||
ch := g.WS.Listen()
|
||||
|
||||
defer func() {
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case err := <-g.paceDeath:
|
||||
if err != nil {
|
||||
// Pacemaker died, pretty fatal. We'll reconnect though.
|
||||
if err := g.Reconnect(); err != nil {
|
||||
// Very fatal if this fails. We'll warn the user.
|
||||
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
|
||||
if err == nil {
|
||||
// No error, just exit normally.
|
||||
return
|
||||
}
|
||||
|
||||
// Then, we'll take the safe way and exit.
|
||||
return
|
||||
}
|
||||
// Pacemaker died, pretty fatal. We'll reconnect though.
|
||||
if err := g.Reconnect(); err != nil {
|
||||
// Very fatal if this fails. We'll warn the user.
|
||||
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
|
||||
|
||||
// Then, we'll take the safe way and exit.
|
||||
return
|
||||
}
|
||||
|
||||
case ev := <-ch:
|
||||
|
|
|
@ -24,7 +24,8 @@ type Pacemaker struct {
|
|||
// Event
|
||||
OnDead func() error
|
||||
|
||||
stop chan<- struct{}
|
||||
stop chan<- struct{}
|
||||
death chan error
|
||||
}
|
||||
|
||||
func (p *Pacemaker) Echo() {
|
||||
|
@ -92,16 +93,15 @@ func (p *Pacemaker) start(stop chan struct{}) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Pacemaker) StartAsync() (death <-chan error) {
|
||||
var ch = make(chan error)
|
||||
func (p *Pacemaker) StartAsync() (death chan error) {
|
||||
p.death = make(chan error)
|
||||
|
||||
stop := make(chan struct{})
|
||||
p.stop = stop
|
||||
|
||||
go func() {
|
||||
ch <- p.start(stop)
|
||||
close(ch)
|
||||
p.death <- p.start(stop)
|
||||
}()
|
||||
|
||||
return ch
|
||||
return p.death
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/internal/json"
|
||||
"github.com/pkg/errors"
|
||||
|
@ -39,9 +40,11 @@ type Connection interface {
|
|||
// Conn is the default Websocket connection. It compresses all payloads using
|
||||
// zlib.
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
Conn *websocket.Conn
|
||||
json.Driver
|
||||
|
||||
mut sync.Mutex
|
||||
done chan struct{}
|
||||
events chan Event
|
||||
}
|
||||
|
||||
|
@ -60,16 +63,15 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
|||
headers := http.Header{}
|
||||
headers.Set("Accept-Encoding", "zlib") // enable
|
||||
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
|
||||
HTTPHeader: headers,
|
||||
})
|
||||
|
||||
c.Conn.SetReadLimit(WSReadLimit)
|
||||
|
||||
go func() {
|
||||
c.readLoop(c.events)
|
||||
}()
|
||||
|
||||
c.readLoop(c.events)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -78,31 +80,36 @@ func (c *Conn) Listen() <-chan Event {
|
|||
}
|
||||
|
||||
func (c *Conn) readLoop(ch chan Event) {
|
||||
for {
|
||||
b, err := c.readAll(context.Background())
|
||||
if err != nil {
|
||||
// Check if the error is a fatal one
|
||||
if code := websocket.CloseStatus(err); code > -1 {
|
||||
// Is the exit unusual?
|
||||
if code != websocket.StatusNormalClosure {
|
||||
// Unusual error, log
|
||||
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
|
||||
c.done = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
b, err := c.readAll(context.Background())
|
||||
if err != nil {
|
||||
// Check if the error is a fatal one
|
||||
if code := websocket.CloseStatus(err); code > -1 {
|
||||
// Is the exit unusual?
|
||||
if code != websocket.StatusNormalClosure {
|
||||
// Unusual error, log
|
||||
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
|
||||
}
|
||||
|
||||
c.done <- struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
// or it's not fatal, we just log and continue
|
||||
ch <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
continue
|
||||
}
|
||||
|
||||
// or it's not fatal, we just log and continue
|
||||
ch <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
continue
|
||||
ch <- Event{b, nil}
|
||||
}
|
||||
|
||||
ch <- Event{b, nil}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||
t, r, err := c.Reader(ctx)
|
||||
t, r, err := c.Conn.Reader(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -111,7 +118,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
// Probably a zlib payload
|
||||
z, err := zlib.NewReader(r)
|
||||
if err != nil {
|
||||
c.CloseRead(ctx)
|
||||
c.Conn.CloseRead(ctx)
|
||||
return nil,
|
||||
errors.Wrap(err, "Failed to create a zlib reader")
|
||||
}
|
||||
|
@ -122,7 +129,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
c.CloseRead(ctx)
|
||||
c.Conn.CloseRead(ctx)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -131,10 +138,22 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
// TODO: zlib stream
|
||||
return c.Write(ctx, websocket.MessageText, b)
|
||||
return c.Conn.Write(ctx, websocket.MessageText, b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close(err error) error {
|
||||
// Wait for the read loop to exit after exiting.
|
||||
defer func() {
|
||||
<-c.done
|
||||
close(c.done)
|
||||
|
||||
// Set the connection to nil.
|
||||
c.Conn = nil
|
||||
|
||||
// Flush all events.
|
||||
c.flush()
|
||||
}()
|
||||
|
||||
if err == nil {
|
||||
return c.Conn.Close(websocket.StatusNormalClosure, "")
|
||||
}
|
||||
|
@ -146,3 +165,14 @@ func (c *Conn) Close(err error) error {
|
|||
|
||||
return c.Conn.Close(websocket.StatusProtocolError, msg)
|
||||
}
|
||||
|
||||
func (c *Conn) flush() {
|
||||
for {
|
||||
select {
|
||||
case <-c.events:
|
||||
continue
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,18 +50,12 @@ func NewCustom(
|
|||
return ws, nil
|
||||
}
|
||||
|
||||
func (ws *Websocket) Redial(ctx context.Context) error {
|
||||
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
||||
// Expired, fatal error
|
||||
return errors.Wrap(err, "Failed to wait")
|
||||
}
|
||||
|
||||
// Close the connection
|
||||
if ws.dialed {
|
||||
ws.Conn.Close(nil)
|
||||
}
|
||||
ws.dialed = true
|
||||
|
||||
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
|
||||
return errors.Wrap(err, "Failed to dial")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue