1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-11-27 06:35:48 +00:00

Merge wsutil bug fixes and changes onto v2

This commit is contained in:
diamondburned 2020-10-29 11:25:09 -07:00
commit 8d21c5f43f
2 changed files with 100 additions and 86 deletions

View file

@ -6,7 +6,6 @@ import (
"context" "context"
"io" "io"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -29,32 +28,32 @@ var ErrWebsocketClosed = errors.New("websocket is closed")
// Connection is an interface that abstracts around a generic Websocket driver. // Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself, including // This connection expects the driver to handle compression by itself, including
// modifying the connection URL. // modifying the connection URL. The implementation doesn't have to be safe for
// concurrent use.
type Connection interface { type Connection interface {
// Dial dials the address (string). Context needs to be passed in for // Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called. // timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) error Dial(context.Context, string) error
// Listen sends over events constantly. Error will be non-nil if Data is // Listen returns an event channel that sends over events constantly. It can
// nil, so check for Error first. // return nil if there isn't an ongoing connection.
Listen() <-chan Event Listen() <-chan Event
// Send allows the caller to send bytes. Thread safety is a requirement. // Send allows the caller to send bytes. It does not need to clean itself
// up on errors, as the Websocket wrapper will do that.
Send(context.Context, []byte) error Send(context.Context, []byte) error
// Close should close the websocket connection. The connection will not be // Close should close the websocket connection. The underlying connection
// reused. // may be reused, but this Connection instance will be reused with Dial. The
// Connection must still be reusable even if Close returns an error.
Close() error Close() error
} }
// Conn is the default Websocket connection. It compresses all payloads using // Conn is the default Websocket connection. It tries to compresses all payloads
// zlib. // using zlib.
type Conn struct { type Conn struct {
mutex sync.Mutex Dialer *websocket.Dialer
Conn *websocket.Conn
Conn *websocket.Conn
dialer *websocket.Dialer
events chan Event events chan Event
} }
@ -73,24 +72,19 @@ func NewConn() *Conn {
// NewConn creates a new default websocket connection with a custom dialer. // NewConn creates a new default websocket connection with a custom dialer.
func NewConnWithDialer(dialer *websocket.Dialer) *Conn { func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
return &Conn{dialer: dialer} return &Conn{Dialer: dialer}
} }
func (c *Conn) Dial(ctx context.Context, addr string) error { func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
// Enable compression: // Enable compression:
headers := http.Header{ headers := http.Header{
"Accept-Encoding": {"zlib"}, "Accept-Encoding": {"zlib"},
} }
// BUG which prevents stream compression. c.Conn, _, err = c.Dialer.DialContext(ctx, addr, headers)
// See https://github.com/golang/go/issues/31514.
var err error
c.mutex.Lock()
defer c.mutex.Unlock()
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to dial WS") return errors.Wrap(err, "failed to dial WS")
} }
@ -101,10 +95,9 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
return err return err
} }
// Listen returns an event channel if there is a connection associated with it.
// It returns nil if there is none.
func (c *Conn) Listen() <-chan Event { func (c *Conn) Listen() <-chan Event {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.events return c.events
} }
@ -112,31 +105,23 @@ func (c *Conn) Listen() <-chan Event {
var resetDeadline = time.Time{} var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error { func (c *Conn) Send(ctx context.Context, b []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
d, ok := ctx.Deadline() d, ok := ctx.Deadline()
if ok { if ok {
c.Conn.SetWriteDeadline(d) c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline) defer c.Conn.SetWriteDeadline(resetDeadline)
} }
return c.Conn.WriteMessage(websocket.TextMessage, b) // We need to clean up ourselves if things are erroring out.
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
return err
}
return nil
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
// Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls.
WSDebug("Conn: Acquiring write lock...")
// Acquire the write lock forever.
c.mutex.Lock()
defer c.mutex.Unlock()
WSDebug("Conn: Write lock acquired; closing.")
// Close the WS. // Close the WS.
err := c.closeWS() err := c.Conn.Close()
WSDebug("Conn: Websocket closed; error:", err) WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...") WSDebug("Conn: Flusing events...")
@ -148,29 +133,9 @@ func (c *Conn) Close() error {
WSDebug("Flushed events.") WSDebug("Flushed events.")
// Mark c.Conn as empty.
c.Conn = nil
return err return err
} }
func (c *Conn) closeWS() error {
// We can't close with a write control here, since it will invalidate the
// old session, breaking resumes.
// // Quick deadline:
// deadline := time.Now().Add(CloseDeadline)
// // Make a closure message:
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
// // Send a close message before closing the connection. We're not error
// // checking this because it's not important.
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
return c.Conn.Close()
}
// loopState is a thread-unsafe disposable state container for the read loop. // loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that // It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself. // doesn't involve the websocket connection itself.
@ -224,7 +189,7 @@ func (state *loopState) handle() ([]byte, error) {
} }
if t == websocket.BinaryMessage { if t == websocket.BinaryMessage {
// Probably a zlib payload // Probably a zlib payload.
if state.zlib == nil { if state.zlib == nil {
z, err := zlib.NewReader(r) z, err := zlib.NewReader(r)

View file

@ -5,7 +5,7 @@ package wsutil
import ( import (
"context" "context"
"log" "log"
"net/url" "sync"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -33,9 +33,16 @@ type Event struct {
Error error Error error
} }
// Websocket is a wrapper around a websocket Conn with thread safety and rate
// limiting for sending and throttling.
type Websocket struct { type Websocket struct {
Conn Connection mutex sync.Mutex
Addr string conn Connection
addr string
closed bool
// Constants. These must not be changed after the Websocket instance is used
// once, as they are not thread-safe.
// Timeout for connecting and writing to the Websocket, uses default // Timeout for connecting and writing to the Websocket, uses default
// WSTimeout (global). // WSTimeout (global).
@ -45,6 +52,7 @@ type Websocket struct {
DialLimiter *rate.Limiter DialLimiter *rate.Limiter
} }
// New creates a default Websocket with the given address.
func New(addr string) *Websocket { func New(addr string) *Websocket {
return NewCustom(NewConn(), addr) return NewCustom(NewConn(), addr)
} }
@ -52,8 +60,9 @@ func New(addr string) *Websocket {
// NewCustom creates a new undialed Websocket. // NewCustom creates a new undialed Websocket.
func NewCustom(conn Connection, addr string) *Websocket { func NewCustom(conn Connection, addr string) *Websocket {
return &Websocket{ return &Websocket{
Conn: conn, conn: conn,
Addr: addr, addr: addr,
closed: true,
Timeout: WSTimeout, Timeout: WSTimeout,
@ -62,6 +71,7 @@ func NewCustom(conn Connection, addr string) *Websocket {
} }
} }
// Dial waits until the rate limiter allows then dials the websocket.
func (ws *Websocket) Dial(ctx context.Context) error { func (ws *Websocket) Dial(ctx context.Context) error {
if ws.Timeout > 0 { if ws.Timeout > 0 {
tctx, cancel := context.WithTimeout(ctx, ws.Timeout) tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
@ -75,46 +85,85 @@ func (ws *Websocket) Dial(ctx context.Context) error {
return errors.Wrap(err, "failed to wait") return errors.Wrap(err, "failed to wait")
} }
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil { ws.mutex.Lock()
defer ws.mutex.Unlock()
if !ws.closed {
WSDebug("Old connection not yet closed while dialog; closing it.")
ws.conn.Close()
}
if err := ws.conn.Dial(ctx, ws.addr); err != nil {
return errors.Wrap(err, "failed to dial") return errors.Wrap(err, "failed to dial")
} }
ws.closed = false
return nil return nil
} }
// Listen returns the inner event channel or nil if the Websocket connection is
// not alive.
func (ws *Websocket) Listen() <-chan Event { func (ws *Websocket) Listen() <-chan Event {
return ws.Conn.Listen() ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return nil
}
return ws.conn.Listen()
} }
// Send sends b over the Websocket without a timeout.
func (ws *Websocket) Send(b []byte) error { func (ws *Websocket) Send(b []byte) error {
return ws.SendCtx(context.Background(), b) return ws.SendCtx(context.Background(), b)
} }
// SendCtx sends b over the Websocket with a deadline. It closes the internal
// Websocket if the Send method errors out.
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error { func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
if err := ws.SendLimiter.Wait(ctx); err != nil { if err := ws.SendLimiter.Wait(ctx); err != nil {
return errors.Wrap(err, "SendLimiter failed") return errors.Wrap(err, "SendLimiter failed")
} }
return ws.Conn.Send(ctx, b) ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return ErrWebsocketClosed
}
if err := ws.conn.Send(ctx, b); err != nil {
ws.close()
return err
}
return nil
} }
// Close closes the websocket connection. It assumes that the Websocket is
// closed even when it returns an error. If the Websocket was already closed
// before, nil will be returned.
func (ws *Websocket) Close() error { func (ws *Websocket) Close() error {
return ws.Conn.Close() WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired; closing.")
return ws.close()
} }
func InjectValues(rawurl string, values url.Values) string { // close closes the Websocket without acquiring the mutex. Refer to Close for
u, err := url.Parse(rawurl) // more information.
if err != nil { func (ws *Websocket) close() error {
// Unknown URL, return as-is. if ws.closed {
return rawurl return nil
} }
// Append additional parameters: err := ws.conn.Close()
var q = u.Query() ws.closed = true
for k, v := range values { return err
q[k] = append(q[k], v...)
}
u.RawQuery = q.Encode()
return u.String()
} }