1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-11-27 14:46:21 +00:00

wsutil: Websocket wrapper thread safety for simpler Conn impl

This commit removes the thread safety requirement that Conn
implementations must satisfy. It moves the mutex guards as well as the
multiple close wrapper over to the Websocket wrapper instead.
This commit is contained in:
diamondburned 2020-10-29 11:24:45 -07:00
parent 160a4e6606
commit 7668fe940c
2 changed files with 100 additions and 86 deletions

View file

@ -6,7 +6,6 @@ import (
"context"
"io"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
@ -29,32 +28,32 @@ var ErrWebsocketClosed = errors.New("websocket is closed")
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself, including
// modifying the connection URL.
// modifying the connection URL. The implementation doesn't have to be safe for
// concurrent use.
type Connection interface {
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) error
// Listen sends over events constantly. Error will be non-nil if Data is
// nil, so check for Error first.
// Listen returns an event channel that sends over events constantly. It can
// return nil if there isn't an ongoing connection.
Listen() <-chan Event
// Send allows the caller to send bytes. Thread safety is a requirement.
// Send allows the caller to send bytes. It does not need to clean itself
// up on errors, as the Websocket wrapper will do that.
Send(context.Context, []byte) error
// Close should close the websocket connection. The connection will not be
// reused.
// Close should close the websocket connection. The underlying connection
// may be reused, but this Connection instance will be reused with Dial. The
// Connection must still be reusable even if Close returns an error.
Close() error
}
// Conn is the default Websocket connection. It compresses all payloads using
// zlib.
// Conn is the default Websocket connection. It tries to compresses all payloads
// using zlib.
type Conn struct {
mutex sync.Mutex
Conn *websocket.Conn
dialer *websocket.Dialer
Dialer *websocket.Dialer
Conn *websocket.Conn
events chan Event
}
@ -73,24 +72,19 @@ func NewConn() *Conn {
// NewConn creates a new default websocket connection with a custom dialer.
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
return &Conn{dialer: dialer}
return &Conn{Dialer: dialer}
}
func (c *Conn) Dial(ctx context.Context, addr string) error {
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
// Enable compression:
headers := http.Header{
"Accept-Encoding": {"zlib"},
}
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
var err error
c.mutex.Lock()
defer c.mutex.Unlock()
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, headers)
if err != nil {
return errors.Wrap(err, "failed to dial WS")
}
@ -101,10 +95,9 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
return err
}
// Listen returns an event channel if there is a connection associated with it.
// It returns nil if there is none.
func (c *Conn) Listen() <-chan Event {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.events
}
@ -112,31 +105,23 @@ func (c *Conn) Listen() <-chan Event {
var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
d, ok := ctx.Deadline()
if ok {
c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline)
}
return c.Conn.WriteMessage(websocket.TextMessage, b)
// We need to clean up ourselves if things are erroring out.
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
return err
}
return nil
}
func (c *Conn) Close() error {
// Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls.
WSDebug("Conn: Acquiring write lock...")
// Acquire the write lock forever.
c.mutex.Lock()
defer c.mutex.Unlock()
WSDebug("Conn: Write lock acquired; closing.")
// Close the WS.
err := c.closeWS()
err := c.Conn.Close()
WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...")
@ -148,29 +133,9 @@ func (c *Conn) Close() error {
WSDebug("Flushed events.")
// Mark c.Conn as empty.
c.Conn = nil
return err
}
func (c *Conn) closeWS() error {
// We can't close with a write control here, since it will invalidate the
// old session, breaking resumes.
// // Quick deadline:
// deadline := time.Now().Add(CloseDeadline)
// // Make a closure message:
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
// // Send a close message before closing the connection. We're not error
// // checking this because it's not important.
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
return c.Conn.Close()
}
// loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself.
@ -224,7 +189,7 @@ func (state *loopState) handle() ([]byte, error) {
}
if t == websocket.BinaryMessage {
// Probably a zlib payload
// Probably a zlib payload.
if state.zlib == nil {
z, err := zlib.NewReader(r)

View file

@ -5,7 +5,7 @@ package wsutil
import (
"context"
"log"
"net/url"
"sync"
"time"
"github.com/pkg/errors"
@ -33,9 +33,16 @@ type Event struct {
Error error
}
// Websocket is a wrapper around a websocket Conn with thread safety and rate
// limiting for sending and throttling.
type Websocket struct {
Conn Connection
Addr string
mutex sync.Mutex
conn Connection
addr string
closed bool
// Constants. These must not be changed after the Websocket instance is used
// once, as they are not thread-safe.
// Timeout for connecting and writing to the Websocket, uses default
// WSTimeout (global).
@ -45,6 +52,7 @@ type Websocket struct {
DialLimiter *rate.Limiter
}
// New creates a default Websocket with the given address.
func New(addr string) *Websocket {
return NewCustom(NewConn(), addr)
}
@ -52,8 +60,9 @@ func New(addr string) *Websocket {
// NewCustom creates a new undialed Websocket.
func NewCustom(conn Connection, addr string) *Websocket {
return &Websocket{
Conn: conn,
Addr: addr,
conn: conn,
addr: addr,
closed: true,
Timeout: WSTimeout,
@ -62,6 +71,7 @@ func NewCustom(conn Connection, addr string) *Websocket {
}
}
// Dial waits until the rate limiter allows then dials the websocket.
func (ws *Websocket) Dial(ctx context.Context) error {
if ws.Timeout > 0 {
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
@ -75,46 +85,85 @@ func (ws *Websocket) Dial(ctx context.Context) error {
return errors.Wrap(err, "failed to wait")
}
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
ws.mutex.Lock()
defer ws.mutex.Unlock()
if !ws.closed {
WSDebug("Old connection not yet closed while dialog; closing it.")
ws.conn.Close()
}
if err := ws.conn.Dial(ctx, ws.addr); err != nil {
return errors.Wrap(err, "failed to dial")
}
ws.closed = false
return nil
}
// Listen returns the inner event channel or nil if the Websocket connection is
// not alive.
func (ws *Websocket) Listen() <-chan Event {
return ws.Conn.Listen()
ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return nil
}
return ws.conn.Listen()
}
// Send sends b over the Websocket without a timeout.
func (ws *Websocket) Send(b []byte) error {
return ws.SendCtx(context.Background(), b)
}
// SendCtx sends b over the Websocket with a deadline. It closes the internal
// Websocket if the Send method errors out.
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
if err := ws.SendLimiter.Wait(ctx); err != nil {
return errors.Wrap(err, "SendLimiter failed")
}
return ws.Conn.Send(ctx, b)
ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return ErrWebsocketClosed
}
if err := ws.conn.Send(ctx, b); err != nil {
ws.close()
return err
}
return nil
}
// Close closes the websocket connection. It assumes that the Websocket is
// closed even when it returns an error. If the Websocket was already closed
// before, nil will be returned.
func (ws *Websocket) Close() error {
return ws.Conn.Close()
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired; closing.")
return ws.close()
}
func InjectValues(rawurl string, values url.Values) string {
u, err := url.Parse(rawurl)
if err != nil {
// Unknown URL, return as-is.
return rawurl
// close closes the Websocket without acquiring the mutex. Refer to Close for
// more information.
func (ws *Websocket) close() error {
if ws.closed {
return nil
}
// Append additional parameters:
var q = u.Query()
for k, v := range values {
q[k] = append(q[k], v...)
}
u.RawQuery = q.Encode()
return u.String()
err := ws.conn.Close()
ws.closed = true
return err
}