1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-05-22 23:31:06 +00:00

Merge wsutil bug fixes and changes onto v2

This commit is contained in:
diamondburned 2020-10-29 11:25:09 -07:00
commit 8d21c5f43f
2 changed files with 100 additions and 86 deletions

View file

@ -6,7 +6,6 @@ import (
"context"
"io"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
@ -29,32 +28,32 @@ var ErrWebsocketClosed = errors.New("websocket is closed")
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself, including
// modifying the connection URL.
// modifying the connection URL. The implementation doesn't have to be safe for
// concurrent use.
type Connection interface {
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) error
// Listen sends over events constantly. Error will be non-nil if Data is
// nil, so check for Error first.
// Listen returns an event channel that sends over events constantly. It can
// return nil if there isn't an ongoing connection.
Listen() <-chan Event
// Send allows the caller to send bytes. Thread safety is a requirement.
// Send allows the caller to send bytes. It does not need to clean itself
// up on errors, as the Websocket wrapper will do that.
Send(context.Context, []byte) error
// Close should close the websocket connection. The connection will not be
// reused.
// Close should close the websocket connection. The underlying connection
// may be reused, but this Connection instance will be reused with Dial. The
// Connection must still be reusable even if Close returns an error.
Close() error
}
// Conn is the default Websocket connection. It compresses all payloads using
// zlib.
// Conn is the default Websocket connection. It tries to compresses all payloads
// using zlib.
type Conn struct {
mutex sync.Mutex
Conn *websocket.Conn
dialer *websocket.Dialer
Dialer *websocket.Dialer
Conn *websocket.Conn
events chan Event
}
@ -73,24 +72,19 @@ func NewConn() *Conn {
// NewConn creates a new default websocket connection with a custom dialer.
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
return &Conn{dialer: dialer}
return &Conn{Dialer: dialer}
}
func (c *Conn) Dial(ctx context.Context, addr string) error {
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
// Enable compression:
headers := http.Header{
"Accept-Encoding": {"zlib"},
}
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
var err error
c.mutex.Lock()
defer c.mutex.Unlock()
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, headers)
if err != nil {
return errors.Wrap(err, "failed to dial WS")
}
@ -101,10 +95,9 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
return err
}
// Listen returns an event channel if there is a connection associated with it.
// It returns nil if there is none.
func (c *Conn) Listen() <-chan Event {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.events
}
@ -112,31 +105,23 @@ func (c *Conn) Listen() <-chan Event {
var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
d, ok := ctx.Deadline()
if ok {
c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline)
}
return c.Conn.WriteMessage(websocket.TextMessage, b)
// We need to clean up ourselves if things are erroring out.
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
return err
}
return nil
}
func (c *Conn) Close() error {
// Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls.
WSDebug("Conn: Acquiring write lock...")
// Acquire the write lock forever.
c.mutex.Lock()
defer c.mutex.Unlock()
WSDebug("Conn: Write lock acquired; closing.")
// Close the WS.
err := c.closeWS()
err := c.Conn.Close()
WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...")
@ -148,29 +133,9 @@ func (c *Conn) Close() error {
WSDebug("Flushed events.")
// Mark c.Conn as empty.
c.Conn = nil
return err
}
func (c *Conn) closeWS() error {
// We can't close with a write control here, since it will invalidate the
// old session, breaking resumes.
// // Quick deadline:
// deadline := time.Now().Add(CloseDeadline)
// // Make a closure message:
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
// // Send a close message before closing the connection. We're not error
// // checking this because it's not important.
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
return c.Conn.Close()
}
// loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself.
@ -224,7 +189,7 @@ func (state *loopState) handle() ([]byte, error) {
}
if t == websocket.BinaryMessage {
// Probably a zlib payload
// Probably a zlib payload.
if state.zlib == nil {
z, err := zlib.NewReader(r)

View file

@ -5,7 +5,7 @@ package wsutil
import (
"context"
"log"
"net/url"
"sync"
"time"
"github.com/pkg/errors"
@ -33,9 +33,16 @@ type Event struct {
Error error
}
// Websocket is a wrapper around a websocket Conn with thread safety and rate
// limiting for sending and throttling.
type Websocket struct {
Conn Connection
Addr string
mutex sync.Mutex
conn Connection
addr string
closed bool
// Constants. These must not be changed after the Websocket instance is used
// once, as they are not thread-safe.
// Timeout for connecting and writing to the Websocket, uses default
// WSTimeout (global).
@ -45,6 +52,7 @@ type Websocket struct {
DialLimiter *rate.Limiter
}
// New creates a default Websocket with the given address.
func New(addr string) *Websocket {
return NewCustom(NewConn(), addr)
}
@ -52,8 +60,9 @@ func New(addr string) *Websocket {
// NewCustom creates a new undialed Websocket.
func NewCustom(conn Connection, addr string) *Websocket {
return &Websocket{
Conn: conn,
Addr: addr,
conn: conn,
addr: addr,
closed: true,
Timeout: WSTimeout,
@ -62,6 +71,7 @@ func NewCustom(conn Connection, addr string) *Websocket {
}
}
// Dial waits until the rate limiter allows then dials the websocket.
func (ws *Websocket) Dial(ctx context.Context) error {
if ws.Timeout > 0 {
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
@ -75,46 +85,85 @@ func (ws *Websocket) Dial(ctx context.Context) error {
return errors.Wrap(err, "failed to wait")
}
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
ws.mutex.Lock()
defer ws.mutex.Unlock()
if !ws.closed {
WSDebug("Old connection not yet closed while dialog; closing it.")
ws.conn.Close()
}
if err := ws.conn.Dial(ctx, ws.addr); err != nil {
return errors.Wrap(err, "failed to dial")
}
ws.closed = false
return nil
}
// Listen returns the inner event channel or nil if the Websocket connection is
// not alive.
func (ws *Websocket) Listen() <-chan Event {
return ws.Conn.Listen()
ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return nil
}
return ws.conn.Listen()
}
// Send sends b over the Websocket without a timeout.
func (ws *Websocket) Send(b []byte) error {
return ws.SendCtx(context.Background(), b)
}
// SendCtx sends b over the Websocket with a deadline. It closes the internal
// Websocket if the Send method errors out.
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
if err := ws.SendLimiter.Wait(ctx); err != nil {
return errors.Wrap(err, "SendLimiter failed")
}
return ws.Conn.Send(ctx, b)
ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return ErrWebsocketClosed
}
if err := ws.conn.Send(ctx, b); err != nil {
ws.close()
return err
}
return nil
}
// Close closes the websocket connection. It assumes that the Websocket is
// closed even when it returns an error. If the Websocket was already closed
// before, nil will be returned.
func (ws *Websocket) Close() error {
return ws.Conn.Close()
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired; closing.")
return ws.close()
}
func InjectValues(rawurl string, values url.Values) string {
u, err := url.Parse(rawurl)
if err != nil {
// Unknown URL, return as-is.
return rawurl
// close closes the Websocket without acquiring the mutex. Refer to Close for
// more information.
func (ws *Websocket) close() error {
if ws.closed {
return nil
}
// Append additional parameters:
var q = u.Query()
for k, v := range values {
q[k] = append(q[k], v...)
}
u.RawQuery = q.Encode()
return u.String()
err := ws.conn.Close()
ws.closed = true
return err
}