From 3ca1d352c9c5bd39f9f387970289327ed64d0707 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Mon, 20 Jan 2020 03:06:20 -0800 Subject: [PATCH] Fixed a ridiculous race condition in the gateway --- .gitlab-ci.yml | 4 +- gateway/gateway.go | 65 ++++++++++++++++++-------------- gateway/pacemaker.go | 12 +++--- internal/wsutil/conn.go | 82 ++++++++++++++++++++++++++++------------- internal/wsutil/ws.go | 8 +--- 5 files changed, 102 insertions(+), 69 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 541655c..415c259 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -13,7 +13,7 @@ stages: unit_test: stage: test script: - - go test ./... + - go test -race -v ./... integration_test: stage: test @@ -21,5 +21,5 @@ integration_test: variables: - $BOT_TOKEN script: - - go test -tags integration ./... + - go test -tags integration -race -v ./... diff --git a/gateway/gateway.go b/gateway/gateway.go index daec068..881b0be 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -102,8 +102,8 @@ type Gateway struct { OP chan Event // Filled by methods, internal use - paceDeath <-chan error - handler chan struct{} + done chan struct{} + paceDeath chan error } // NewGateway starts a new Gateway with the default stdlib JSON driver. For more @@ -153,14 +153,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { // Close closes the underlying Websocket connection. func (g *Gateway) Close() error { - // Stop the pacemaker + // If the pacemaker is running: + // Stop the pacemaker and the event handler g.Pacemaker.Stop() - g.paceDeath = nil - // Stop the event handler - if g.handler != nil { - close(g.handler) - g.handler = nil + if g.done != nil { + // Wait for the event handler to fully exit + <-g.done + + // Final clean-up + g.done = nil } // Stop the Websocket @@ -193,10 +195,7 @@ func (g *Gateway) Open() error { } // Reconnect to the Gateway - if err := g.WS.Redial(ctx); err != nil { - // Close the connection - g.Close() - + if err := g.WS.Dial(ctx); err != nil { // Save the error, retry again Lerr = errors.Wrap(err, "Failed to reconnect") continue @@ -204,9 +203,6 @@ func (g *Gateway) Open() error { // Try to resume the connection if err := g.Start(); err != nil { - // Close the connection - g.Close() - // If the connection is rate limited (documented behavior): // https://discordapp.com/developers/docs/topics/gateway#rate-limiting if err == ErrInvalidSession { @@ -234,6 +230,14 @@ func (g *Gateway) Open() error { // Start authenticates with the websocket, or resume from a dead Websocket // connection. This function doesn't block. func (g *Gateway) Start() error { + if err := g.start(); err != nil { + g.Close() + return err + } + return nil +} + +func (g *Gateway) start() error { // This is where we'll get our events ch := g.WS.Listen() @@ -279,30 +283,35 @@ func (g *Gateway) Start() error { } // Start the event handler - g.handler = make(chan struct{}) - go g.handleWS(g.handler) + g.done = make(chan struct{}) + go g.handleWS(g.done) return nil } // handleWS uses the Websocket and parses them into g.Events. -func (g *Gateway) handleWS(stop <-chan struct{}) { +func (g *Gateway) handleWS(done chan struct{}) { ch := g.WS.Listen() + defer func() { + done <- struct{}{} + }() + for { select { - case <-stop: - return case err := <-g.paceDeath: - if err != nil { - // Pacemaker died, pretty fatal. We'll reconnect though. - if err := g.Reconnect(); err != nil { - // Very fatal if this fails. We'll warn the user. - g.FatalLog(errors.Wrap(err, "Failed to reconnect")) + if err == nil { + // No error, just exit normally. + return + } - // Then, we'll take the safe way and exit. - return - } + // Pacemaker died, pretty fatal. We'll reconnect though. + if err := g.Reconnect(); err != nil { + // Very fatal if this fails. We'll warn the user. + g.FatalLog(errors.Wrap(err, "Failed to reconnect")) + + // Then, we'll take the safe way and exit. + return } case ev := <-ch: diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go index 66511ce..2d72093 100644 --- a/gateway/pacemaker.go +++ b/gateway/pacemaker.go @@ -24,7 +24,8 @@ type Pacemaker struct { // Event OnDead func() error - stop chan<- struct{} + stop chan<- struct{} + death chan error } func (p *Pacemaker) Echo() { @@ -92,16 +93,15 @@ func (p *Pacemaker) start(stop chan struct{}) error { } } -func (p *Pacemaker) StartAsync() (death <-chan error) { - var ch = make(chan error) +func (p *Pacemaker) StartAsync() (death chan error) { + p.death = make(chan error) stop := make(chan struct{}) p.stop = stop go func() { - ch <- p.start(stop) - close(ch) + p.death <- p.start(stop) }() - return ch + return p.death } diff --git a/internal/wsutil/conn.go b/internal/wsutil/conn.go index 4e5562b..f57fbf9 100644 --- a/internal/wsutil/conn.go +++ b/internal/wsutil/conn.go @@ -5,6 +5,7 @@ import ( "context" "io/ioutil" "net/http" + "sync" "github.com/diamondburned/arikawa/internal/json" "github.com/pkg/errors" @@ -39,9 +40,11 @@ type Connection interface { // Conn is the default Websocket connection. It compresses all payloads using // zlib. type Conn struct { - *websocket.Conn + Conn *websocket.Conn json.Driver + mut sync.Mutex + done chan struct{} events chan Event } @@ -60,16 +63,15 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { headers := http.Header{} headers.Set("Accept-Encoding", "zlib") // enable + c.mut.Lock() + defer c.mut.Unlock() + c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{ HTTPHeader: headers, }) - c.Conn.SetReadLimit(WSReadLimit) - go func() { - c.readLoop(c.events) - }() - + c.readLoop(c.events) return err } @@ -78,31 +80,36 @@ func (c *Conn) Listen() <-chan Event { } func (c *Conn) readLoop(ch chan Event) { - for { - b, err := c.readAll(context.Background()) - if err != nil { - // Check if the error is a fatal one - if code := websocket.CloseStatus(err); code > -1 { - // Is the exit unusual? - if code != websocket.StatusNormalClosure { - // Unusual error, log - ch <- Event{nil, errors.Wrap(err, "WS fatal")} + c.done = make(chan struct{}) + + go func() { + for { + b, err := c.readAll(context.Background()) + if err != nil { + // Check if the error is a fatal one + if code := websocket.CloseStatus(err); code > -1 { + // Is the exit unusual? + if code != websocket.StatusNormalClosure { + // Unusual error, log + ch <- Event{nil, errors.Wrap(err, "WS fatal")} + } + + c.done <- struct{}{} + return } - return + // or it's not fatal, we just log and continue + ch <- Event{nil, errors.Wrap(err, "WS error")} + continue } - // or it's not fatal, we just log and continue - ch <- Event{nil, errors.Wrap(err, "WS error")} - continue + ch <- Event{b, nil} } - - ch <- Event{b, nil} - } + }() } func (c *Conn) readAll(ctx context.Context) ([]byte, error) { - t, r, err := c.Reader(ctx) + t, r, err := c.Conn.Reader(ctx) if err != nil { return nil, err } @@ -111,7 +118,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { // Probably a zlib payload z, err := zlib.NewReader(r) if err != nil { - c.CloseRead(ctx) + c.Conn.CloseRead(ctx) return nil, errors.Wrap(err, "Failed to create a zlib reader") } @@ -122,7 +129,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { b, err := ioutil.ReadAll(r) if err != nil { - c.CloseRead(ctx) + c.Conn.CloseRead(ctx) return nil, err } @@ -131,10 +138,22 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { func (c *Conn) Send(ctx context.Context, b []byte) error { // TODO: zlib stream - return c.Write(ctx, websocket.MessageText, b) + return c.Conn.Write(ctx, websocket.MessageText, b) } func (c *Conn) Close(err error) error { + // Wait for the read loop to exit after exiting. + defer func() { + <-c.done + close(c.done) + + // Set the connection to nil. + c.Conn = nil + + // Flush all events. + c.flush() + }() + if err == nil { return c.Conn.Close(websocket.StatusNormalClosure, "") } @@ -146,3 +165,14 @@ func (c *Conn) Close(err error) error { return c.Conn.Close(websocket.StatusProtocolError, msg) } + +func (c *Conn) flush() { + for { + select { + case <-c.events: + continue + default: + return + } + } +} diff --git a/internal/wsutil/ws.go b/internal/wsutil/ws.go index 61cdd15..0bedbfc 100644 --- a/internal/wsutil/ws.go +++ b/internal/wsutil/ws.go @@ -50,18 +50,12 @@ func NewCustom( return ws, nil } -func (ws *Websocket) Redial(ctx context.Context) error { +func (ws *Websocket) Dial(ctx context.Context) error { if err := ws.DialLimiter.Wait(ctx); err != nil { // Expired, fatal error return errors.Wrap(err, "Failed to wait") } - // Close the connection - if ws.dialed { - ws.Conn.Close(nil) - } - ws.dialed = true - if err := ws.Conn.Dial(ctx, ws.Addr); err != nil { return errors.Wrap(err, "Failed to dial") }