diff --git a/utils/wsutil/conn.go b/utils/wsutil/conn.go index 523015e..750a742 100644 --- a/utils/wsutil/conn.go +++ b/utils/wsutil/conn.go @@ -6,7 +6,6 @@ import ( "context" "io" "net/http" - "sync" "time" "github.com/gorilla/websocket" @@ -29,32 +28,32 @@ var ErrWebsocketClosed = errors.New("websocket is closed") // Connection is an interface that abstracts around a generic Websocket driver. // This connection expects the driver to handle compression by itself, including -// modifying the connection URL. +// modifying the connection URL. The implementation doesn't have to be safe for +// concurrent use. type Connection interface { // Dial dials the address (string). Context needs to be passed in for // timeout. This method should also be re-usable after Close is called. Dial(context.Context, string) error - // Listen sends over events constantly. Error will be non-nil if Data is - // nil, so check for Error first. + // Listen returns an event channel that sends over events constantly. It can + // return nil if there isn't an ongoing connection. Listen() <-chan Event - // Send allows the caller to send bytes. Thread safety is a requirement. + // Send allows the caller to send bytes. It does not need to clean itself + // up on errors, as the Websocket wrapper will do that. Send(context.Context, []byte) error - // Close should close the websocket connection. The connection will not be - // reused. + // Close should close the websocket connection. The underlying connection + // may be reused, but this Connection instance will be reused with Dial. The + // Connection must still be reusable even if Close returns an error. Close() error } -// Conn is the default Websocket connection. It compresses all payloads using -// zlib. +// Conn is the default Websocket connection. It tries to compresses all payloads +// using zlib. type Conn struct { - mutex sync.Mutex - - Conn *websocket.Conn - - dialer *websocket.Dialer + Dialer *websocket.Dialer + Conn *websocket.Conn events chan Event } @@ -73,24 +72,19 @@ func NewConn() *Conn { // NewConn creates a new default websocket connection with a custom dialer. func NewConnWithDialer(dialer *websocket.Dialer) *Conn { - return &Conn{dialer: dialer} + return &Conn{Dialer: dialer} } -func (c *Conn) Dial(ctx context.Context, addr string) error { +func (c *Conn) Dial(ctx context.Context, addr string) (err error) { + // BUG which prevents stream compression. + // See https://github.com/golang/go/issues/31514. + // Enable compression: headers := http.Header{ "Accept-Encoding": {"zlib"}, } - // BUG which prevents stream compression. - // See https://github.com/golang/go/issues/31514. - - var err error - - c.mutex.Lock() - defer c.mutex.Unlock() - - c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers) + c.Conn, _, err = c.Dialer.DialContext(ctx, addr, headers) if err != nil { return errors.Wrap(err, "failed to dial WS") } @@ -101,10 +95,9 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { return err } +// Listen returns an event channel if there is a connection associated with it. +// It returns nil if there is none. func (c *Conn) Listen() <-chan Event { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.events } @@ -112,31 +105,23 @@ func (c *Conn) Listen() <-chan Event { var resetDeadline = time.Time{} func (c *Conn) Send(ctx context.Context, b []byte) error { - c.mutex.Lock() - defer c.mutex.Unlock() - d, ok := ctx.Deadline() if ok { c.Conn.SetWriteDeadline(d) defer c.Conn.SetWriteDeadline(resetDeadline) } - return c.Conn.WriteMessage(websocket.TextMessage, b) + // We need to clean up ourselves if things are erroring out. + if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil { + return err + } + + return nil } func (c *Conn) Close() error { - // Use a sync.Once to guarantee that other Close() calls block until the - // main call is done. It also prevents future calls. - WSDebug("Conn: Acquiring write lock...") - - // Acquire the write lock forever. - c.mutex.Lock() - defer c.mutex.Unlock() - - WSDebug("Conn: Write lock acquired; closing.") - // Close the WS. - err := c.closeWS() + err := c.Conn.Close() WSDebug("Conn: Websocket closed; error:", err) WSDebug("Conn: Flusing events...") @@ -148,29 +133,9 @@ func (c *Conn) Close() error { WSDebug("Flushed events.") - // Mark c.Conn as empty. - c.Conn = nil - return err } -func (c *Conn) closeWS() error { - // We can't close with a write control here, since it will invalidate the - // old session, breaking resumes. - - // // Quick deadline: - // deadline := time.Now().Add(CloseDeadline) - - // // Make a closure message: - // msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "") - - // // Send a close message before closing the connection. We're not error - // // checking this because it's not important. - // err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline) - - return c.Conn.Close() -} - // loopState is a thread-unsafe disposable state container for the read loop. // It's made to completely separate the read loop of any synchronization that // doesn't involve the websocket connection itself. @@ -224,7 +189,7 @@ func (state *loopState) handle() ([]byte, error) { } if t == websocket.BinaryMessage { - // Probably a zlib payload + // Probably a zlib payload. if state.zlib == nil { z, err := zlib.NewReader(r) diff --git a/utils/wsutil/ws.go b/utils/wsutil/ws.go index dcbb54d..dfaa315 100644 --- a/utils/wsutil/ws.go +++ b/utils/wsutil/ws.go @@ -5,7 +5,7 @@ package wsutil import ( "context" "log" - "net/url" + "sync" "time" "github.com/pkg/errors" @@ -33,9 +33,16 @@ type Event struct { Error error } +// Websocket is a wrapper around a websocket Conn with thread safety and rate +// limiting for sending and throttling. type Websocket struct { - Conn Connection - Addr string + mutex sync.Mutex + conn Connection + addr string + closed bool + + // Constants. These must not be changed after the Websocket instance is used + // once, as they are not thread-safe. // Timeout for connecting and writing to the Websocket, uses default // WSTimeout (global). @@ -45,6 +52,7 @@ type Websocket struct { DialLimiter *rate.Limiter } +// New creates a default Websocket with the given address. func New(addr string) *Websocket { return NewCustom(NewConn(), addr) } @@ -52,8 +60,9 @@ func New(addr string) *Websocket { // NewCustom creates a new undialed Websocket. func NewCustom(conn Connection, addr string) *Websocket { return &Websocket{ - Conn: conn, - Addr: addr, + conn: conn, + addr: addr, + closed: true, Timeout: WSTimeout, @@ -62,6 +71,7 @@ func NewCustom(conn Connection, addr string) *Websocket { } } +// Dial waits until the rate limiter allows then dials the websocket. func (ws *Websocket) Dial(ctx context.Context) error { if ws.Timeout > 0 { tctx, cancel := context.WithTimeout(ctx, ws.Timeout) @@ -75,46 +85,85 @@ func (ws *Websocket) Dial(ctx context.Context) error { return errors.Wrap(err, "failed to wait") } - if err := ws.Conn.Dial(ctx, ws.Addr); err != nil { + ws.mutex.Lock() + defer ws.mutex.Unlock() + + if !ws.closed { + WSDebug("Old connection not yet closed while dialog; closing it.") + ws.conn.Close() + } + + if err := ws.conn.Dial(ctx, ws.addr); err != nil { return errors.Wrap(err, "failed to dial") } + ws.closed = false + return nil } +// Listen returns the inner event channel or nil if the Websocket connection is +// not alive. func (ws *Websocket) Listen() <-chan Event { - return ws.Conn.Listen() + ws.mutex.Lock() + defer ws.mutex.Unlock() + + if ws.closed { + return nil + } + + return ws.conn.Listen() } +// Send sends b over the Websocket without a timeout. func (ws *Websocket) Send(b []byte) error { return ws.SendCtx(context.Background(), b) } +// SendCtx sends b over the Websocket with a deadline. It closes the internal +// Websocket if the Send method errors out. func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error { if err := ws.SendLimiter.Wait(ctx); err != nil { return errors.Wrap(err, "SendLimiter failed") } - return ws.Conn.Send(ctx, b) + ws.mutex.Lock() + defer ws.mutex.Unlock() + + if ws.closed { + return ErrWebsocketClosed + } + + if err := ws.conn.Send(ctx, b); err != nil { + ws.close() + return err + } + + return nil } +// Close closes the websocket connection. It assumes that the Websocket is +// closed even when it returns an error. If the Websocket was already closed +// before, nil will be returned. func (ws *Websocket) Close() error { - return ws.Conn.Close() + WSDebug("Conn: Acquiring mutex lock to close...") + + ws.mutex.Lock() + defer ws.mutex.Unlock() + + WSDebug("Conn: Write mutex acquired; closing.") + + return ws.close() } -func InjectValues(rawurl string, values url.Values) string { - u, err := url.Parse(rawurl) - if err != nil { - // Unknown URL, return as-is. - return rawurl +// close closes the Websocket without acquiring the mutex. Refer to Close for +// more information. +func (ws *Websocket) close() error { + if ws.closed { + return nil } - // Append additional parameters: - var q = u.Query() - for k, v := range values { - q[k] = append(q[k], v...) - } - - u.RawQuery = q.Encode() - return u.String() + err := ws.conn.Close() + ws.closed = true + return err }