From 3d47bada0745c95466bd4584ebc34181cc96daab Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sat, 11 Apr 2020 12:34:40 -0700 Subject: [PATCH] Gateway: Fixed the double Close and Gateway ReconnectOP bugs --- utils/wsutil/conn.go | 75 +++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/utils/wsutil/conn.go b/utils/wsutil/conn.go index a526b71..606529e 100644 --- a/utils/wsutil/conn.go +++ b/utils/wsutil/conn.go @@ -6,6 +6,7 @@ import ( "context" "io" "net/http" + "sync" "time" "github.com/diamondburned/arikawa/utils/json" @@ -18,6 +19,9 @@ const CopyBufferSize = 2048 // CloseDeadline controls the deadline to wait for sending the Close frame. var CloseDeadline = time.Second +// ErrWebsocketClosed is returned if the websocket is already closed. +var ErrWebsocketClosed = errors.New("Websocket is closed.") + // Connection is an interface that abstracts around a generic Websocket driver. // This connection expects the driver to handle compression by itself, including // modifying the connection URL. @@ -45,14 +49,17 @@ type Conn struct { json.Driver dialer *websocket.Dialer - // mut sync.RWMutex events chan Event // write channels writes chan []byte errors chan error - buf bytes.Buffer + buf bytes.Buffer + zlib io.ReadCloser // (compress/zlib).reader + + // nil until Dial(). + closeOnce *sync.Once // zlib *zlib.Inflator // zlib.NewReader // buf []byte // io.Copy buffer @@ -91,6 +98,9 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { return errors.Wrap(err, "Failed to dial WS") } + // Set up the closer. + c.closeOnce = &sync.Once{} + c.events = make(chan Event) go c.readLoop() @@ -173,13 +183,21 @@ func (c *Conn) handle() ([]byte, error) { if t == websocket.BinaryMessage { // Probably a zlib payload - z, err := zlib.NewReader(r) - if err != nil { - return nil, errors.Wrap(err, "Failed to create a zlib reader") + + if c.zlib == nil { + z, err := zlib.NewReader(r) + if err != nil { + return nil, errors.Wrap(err, "Failed to create a zlib reader") + } + c.zlib = z + } else { + if err := c.zlib.(zlib.Resetter).Reset(r, nil); err != nil { + return nil, errors.Wrap(err, "Failed to reset zlib reader") + } } - defer z.Close() - r = z + defer c.zlib.Close() + r = c.zlib } return readAll(&c.buf, r) @@ -209,39 +227,40 @@ func (c *Conn) handle() ([]byte, error) { } func (c *Conn) Send(b []byte) error { - // Don't send a nil byte slice. That would confuse the write loop. - if b == nil { - return nil - } - // If websocket is already closed. if c.writes == nil { - return errors.New("Websocket is closed.") + return ErrWebsocketClosed } c.writes <- b return <-c.errors } -func (c *Conn) Close() error { - // Close c.writes. This should trigger the websocket to close itself. - close(c.writes) +func (c *Conn) Close() (err error) { + // Use a sync.Once to guarantee that other Close() calls block until the + // main call is done. It also prevents future calls. + c.closeOnce.Do(func() { + // Close c.writes. This should trigger the websocket to close itself. + close(c.writes) + // Mark c.writes as empty. + c.writes = nil - // Wait for the write loop to exit by flusing the errors channel. - var err = <-c.errors - for range c.errors { - } + // Wait for the write loop to exit by flusing the errors channel. + err = <-c.errors // get close error + for range c.errors { // then flush + } - // Flush all events before closing the channel. This will return as soon as - // c.events is closed, or after closed. - for range c.events { - } + // Flush all events before closing the channel. This will return as soon as + // c.events is closed, or after closed. + for range c.events { + } - // Mark c.events as empty. - c.events = nil + // Mark c.events as empty. + c.events = nil - // Mark c.Conn as empty. - c.Conn = nil + // Mark c.Conn as empty. + c.Conn = nil + }) return err }