From 85b793a1a7cfbf531cba220c0d6548ee0e2058bb Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sun, 2 Feb 2020 14:12:54 -0800 Subject: [PATCH] Fixed some race conditions --- gateway/gateway.go | 5 +++- gateway/integration_test.go | 20 +++++++++------ gateway/op.go | 2 +- gateway/sequence.go | 10 +++----- go.mod | 1 + go.sum | 1 + internal/wsutil/conn.go | 50 +++++++++++++------------------------ 7 files changed, 41 insertions(+), 48 deletions(-) diff --git a/gateway/gateway.go b/gateway/gateway.go index e913b03..6443460 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -160,6 +160,7 @@ func (g *Gateway) Close() error { if g.done != nil { // Wait for the event handler to fully exit <-g.done + g.done = nil } // Stop the Websocket @@ -198,6 +199,7 @@ func (g *Gateway) Open() error { if err := g.WS.Dial(ctx); err != nil { // Save the error, retry again Lerr = errors.Wrap(err, "Failed to reconnect") + g.ErrorLog(err) continue } @@ -295,7 +297,6 @@ func (g *Gateway) handleWS() { defer func() { g.done <- struct{}{} - g.done = nil }() for { @@ -306,6 +307,8 @@ func (g *Gateway) handleWS() { return } + g.ErrorLog(errors.Wrap(err, "Pacemaker died")) + // Pacemaker died, pretty fatal. We'll reconnect though. if err := g.Reconnect(); err != nil { // Very fatal if this fails. We'll warn the user. diff --git a/gateway/integration_test.go b/gateway/integration_test.go index d4973f2..7b0ac11 100644 --- a/gateway/integration_test.go +++ b/gateway/integration_test.go @@ -16,7 +16,7 @@ func TestIntegration(t *testing.T) { } WSError = func(err error) { - log.Println(err) + t.Fatal(err) } var gateway *Gateway @@ -49,14 +49,20 @@ func TestIntegration(t *testing.T) { t.Fatal("Failed to reconnect:", err) } - /* TODO: We're not testing this, as Discord will replay events before it - * sends the Resumed event. + timeout := time.After(10 * time.Second) - resumed, ok := (<-gateway.Events).(*ResumedEvent) - if !ok { - t.Fatal("Event received is not of type Resumed:", resumed) +Main: + for { + select { + case ev := <-gateway.Events: + switch ev.(type) { + case *ResumedEvent, *ReadyEvent: + break Main + } + case <-timeout: + t.Fatal("Timed out waiting for ResumedEvent") + } } - */ if err := g.Close(); err != nil { t.Fatal("Failed to close Gateway:", err) diff --git a/gateway/op.go b/gateway/op.go index 89688c5..84d7169 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -99,7 +99,7 @@ func HandleEvent(g *Gateway, data []byte) error { // Parse the raw data into an OP struct var op *OP if err := g.Driver.Unmarshal(data, &op); err != nil { - return errors.Wrap(err, "OP error") + return errors.Wrap(err, "OP error: "+string(data)) } return HandleOP(g, op) diff --git a/gateway/sequence.go b/gateway/sequence.go index 59b5bc5..d7a22bb 100644 --- a/gateway/sequence.go +++ b/gateway/sequence.go @@ -2,13 +2,11 @@ package gateway import "sync/atomic" -type Sequence struct { - seq int64 -} +type Sequence int64 func NewSequence() *Sequence { - return &Sequence{0} + return (*Sequence)(new(int64)) } -func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) } -func (s *Sequence) Get() int64 { return atomic.LoadInt64(&s.seq) } +func (s *Sequence) Set(seq int64) { atomic.StoreInt64((*int64)(s), seq) } +func (s *Sequence) Get() int64 { return atomic.LoadInt64((*int64)(s)) } diff --git a/go.mod b/go.mod index 1ca46d2..b613640 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/diamondburned/arikawa go 1.13 require ( + github.com/davecgh/go-spew v1.1.1 github.com/gorilla/schema v1.1.0 github.com/pkg/errors v0.8.1 github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa diff --git a/go.sum b/go.sum index d5af972..6495d55 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/wsutil/conn.go b/internal/wsutil/conn.go index aa412f0..12eba3c 100644 --- a/internal/wsutil/conn.go +++ b/internal/wsutil/conn.go @@ -15,7 +15,6 @@ import ( "nhooyr.io/websocket" ) -var WSBuffer = 12 var WSReadLimit int64 = 8192000 // 8 MiB // Connection is an interface that abstracts around a generic Websocket driver. @@ -55,7 +54,7 @@ var _ Connection = (*Conn)(nil) func NewConn(driver json.Driver) *Conn { return &Conn{ Driver: driver, - events: make(chan Event, WSBuffer), + events: make(chan Event), } } @@ -73,7 +72,7 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { }) c.Conn.SetReadLimit(WSReadLimit) - c.events = make(chan Event, WSBuffer) + c.events = make(chan Event) c.readLoop() return err } @@ -83,11 +82,13 @@ func (c *Conn) Listen() <-chan Event { } func (c *Conn) readLoop() { + conn := c.Conn + go func() { defer close(c.events) for { - b, err := c.readAll(context.Background()) + b, err := readAll(conn, context.Background()) if err != nil { // Is the error an EOF? if stderr.Is(err, io.EOF) { @@ -97,18 +98,15 @@ func (c *Conn) readLoop() { // Check if the error is a fatal one if code := websocket.CloseStatus(err); code > -1 { - // Is the exit unusual? - if code != websocket.StatusNormalClosure { - // Unusual error, log - c.events <- Event{nil, errors.Wrap(err, "WS fatal")} + // Is the exit normal? + if code == websocket.StatusNormalClosure { + return } - - return } - // or it's not fatal, we just log and continue + // Unusual error; log: c.events <- Event{nil, errors.Wrap(err, "WS error")} - continue + return } c.events <- Event{b, nil} @@ -116,8 +114,8 @@ func (c *Conn) readLoop() { }() } -func (c *Conn) readAll(ctx context.Context) ([]byte, error) { - t, r, err := c.Conn.Reader(ctx) +func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) { + t, r, err := c.Reader(ctx) if err != nil { return nil, err } @@ -126,7 +124,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { // Probably a zlib payload z, err := zlib.NewReader(r) if err != nil { - c.Conn.CloseRead(ctx) + c.CloseRead(ctx) return nil, errors.Wrap(err, "Failed to create a zlib reader") } @@ -137,7 +135,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { b, err := ioutil.ReadAll(r) if err != nil { - c.Conn.CloseRead(ctx) + c.CloseRead(ctx) return nil, err } @@ -145,9 +143,6 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { } func (c *Conn) Send(ctx context.Context, b []byte) error { - c.mut.Lock() - defer c.mut.Unlock() - // TODO: zlib stream return c.Conn.Write(ctx, websocket.MessageText, b) } @@ -155,14 +150,14 @@ func (c *Conn) Send(ctx context.Context, b []byte) error { func (c *Conn) Close(err error) error { // Wait for the read loop to exit after exiting. defer func() { + c.mut.Lock() + defer c.mut.Unlock() + <-c.events c.events = nil // Set the connection to nil. c.Conn = nil - - // Flush all events. - c.flush() }() if err == nil { @@ -176,14 +171,3 @@ func (c *Conn) Close(err error) error { return c.Conn.Close(websocket.StatusProtocolError, msg) } - -func (c *Conn) flush() { - for { - select { - case <-c.events: - continue - default: - return - } - } -}