mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-05-22 23:31:06 +00:00
Merge wsutil bug fixes and changes onto v2
This commit is contained in:
commit
8d21c5f43f
|
@ -6,7 +6,6 @@ import (
|
|||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -29,32 +28,32 @@ var ErrWebsocketClosed = errors.New("websocket is closed")
|
|||
|
||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||
// This connection expects the driver to handle compression by itself, including
|
||||
// modifying the connection URL.
|
||||
// modifying the connection URL. The implementation doesn't have to be safe for
|
||||
// concurrent use.
|
||||
type Connection interface {
|
||||
// Dial dials the address (string). Context needs to be passed in for
|
||||
// timeout. This method should also be re-usable after Close is called.
|
||||
Dial(context.Context, string) error
|
||||
|
||||
// Listen sends over events constantly. Error will be non-nil if Data is
|
||||
// nil, so check for Error first.
|
||||
// Listen returns an event channel that sends over events constantly. It can
|
||||
// return nil if there isn't an ongoing connection.
|
||||
Listen() <-chan Event
|
||||
|
||||
// Send allows the caller to send bytes. Thread safety is a requirement.
|
||||
// Send allows the caller to send bytes. It does not need to clean itself
|
||||
// up on errors, as the Websocket wrapper will do that.
|
||||
Send(context.Context, []byte) error
|
||||
|
||||
// Close should close the websocket connection. The connection will not be
|
||||
// reused.
|
||||
// Close should close the websocket connection. The underlying connection
|
||||
// may be reused, but this Connection instance will be reused with Dial. The
|
||||
// Connection must still be reusable even if Close returns an error.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Conn is the default Websocket connection. It compresses all payloads using
|
||||
// zlib.
|
||||
// Conn is the default Websocket connection. It tries to compresses all payloads
|
||||
// using zlib.
|
||||
type Conn struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
Conn *websocket.Conn
|
||||
|
||||
dialer *websocket.Dialer
|
||||
Dialer *websocket.Dialer
|
||||
Conn *websocket.Conn
|
||||
events chan Event
|
||||
}
|
||||
|
||||
|
@ -73,24 +72,19 @@ func NewConn() *Conn {
|
|||
|
||||
// NewConn creates a new default websocket connection with a custom dialer.
|
||||
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
|
||||
return &Conn{dialer: dialer}
|
||||
return &Conn{Dialer: dialer}
|
||||
}
|
||||
|
||||
func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
|
||||
// BUG which prevents stream compression.
|
||||
// See https://github.com/golang/go/issues/31514.
|
||||
|
||||
// Enable compression:
|
||||
headers := http.Header{
|
||||
"Accept-Encoding": {"zlib"},
|
||||
}
|
||||
|
||||
// BUG which prevents stream compression.
|
||||
// See https://github.com/golang/go/issues/31514.
|
||||
|
||||
var err error
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
|
||||
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, headers)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to dial WS")
|
||||
}
|
||||
|
@ -101,10 +95,9 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Listen returns an event channel if there is a connection associated with it.
|
||||
// It returns nil if there is none.
|
||||
func (c *Conn) Listen() <-chan Event {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return c.events
|
||||
}
|
||||
|
||||
|
@ -112,31 +105,23 @@ func (c *Conn) Listen() <-chan Event {
|
|||
var resetDeadline = time.Time{}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
d, ok := ctx.Deadline()
|
||||
if ok {
|
||||
c.Conn.SetWriteDeadline(d)
|
||||
defer c.Conn.SetWriteDeadline(resetDeadline)
|
||||
}
|
||||
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, b)
|
||||
// We need to clean up ourselves if things are erroring out.
|
||||
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
// Use a sync.Once to guarantee that other Close() calls block until the
|
||||
// main call is done. It also prevents future calls.
|
||||
WSDebug("Conn: Acquiring write lock...")
|
||||
|
||||
// Acquire the write lock forever.
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
WSDebug("Conn: Write lock acquired; closing.")
|
||||
|
||||
// Close the WS.
|
||||
err := c.closeWS()
|
||||
err := c.Conn.Close()
|
||||
|
||||
WSDebug("Conn: Websocket closed; error:", err)
|
||||
WSDebug("Conn: Flusing events...")
|
||||
|
@ -148,29 +133,9 @@ func (c *Conn) Close() error {
|
|||
|
||||
WSDebug("Flushed events.")
|
||||
|
||||
// Mark c.Conn as empty.
|
||||
c.Conn = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) closeWS() error {
|
||||
// We can't close with a write control here, since it will invalidate the
|
||||
// old session, breaking resumes.
|
||||
|
||||
// // Quick deadline:
|
||||
// deadline := time.Now().Add(CloseDeadline)
|
||||
|
||||
// // Make a closure message:
|
||||
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
|
||||
|
||||
// // Send a close message before closing the connection. We're not error
|
||||
// // checking this because it's not important.
|
||||
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
|
||||
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// loopState is a thread-unsafe disposable state container for the read loop.
|
||||
// It's made to completely separate the read loop of any synchronization that
|
||||
// doesn't involve the websocket connection itself.
|
||||
|
@ -224,7 +189,7 @@ func (state *loopState) handle() ([]byte, error) {
|
|||
}
|
||||
|
||||
if t == websocket.BinaryMessage {
|
||||
// Probably a zlib payload
|
||||
// Probably a zlib payload.
|
||||
|
||||
if state.zlib == nil {
|
||||
z, err := zlib.NewReader(r)
|
||||
|
|
|
@ -5,7 +5,7 @@ package wsutil
|
|||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -33,9 +33,16 @@ type Event struct {
|
|||
Error error
|
||||
}
|
||||
|
||||
// Websocket is a wrapper around a websocket Conn with thread safety and rate
|
||||
// limiting for sending and throttling.
|
||||
type Websocket struct {
|
||||
Conn Connection
|
||||
Addr string
|
||||
mutex sync.Mutex
|
||||
conn Connection
|
||||
addr string
|
||||
closed bool
|
||||
|
||||
// Constants. These must not be changed after the Websocket instance is used
|
||||
// once, as they are not thread-safe.
|
||||
|
||||
// Timeout for connecting and writing to the Websocket, uses default
|
||||
// WSTimeout (global).
|
||||
|
@ -45,6 +52,7 @@ type Websocket struct {
|
|||
DialLimiter *rate.Limiter
|
||||
}
|
||||
|
||||
// New creates a default Websocket with the given address.
|
||||
func New(addr string) *Websocket {
|
||||
return NewCustom(NewConn(), addr)
|
||||
}
|
||||
|
@ -52,8 +60,9 @@ func New(addr string) *Websocket {
|
|||
// NewCustom creates a new undialed Websocket.
|
||||
func NewCustom(conn Connection, addr string) *Websocket {
|
||||
return &Websocket{
|
||||
Conn: conn,
|
||||
Addr: addr,
|
||||
conn: conn,
|
||||
addr: addr,
|
||||
closed: true,
|
||||
|
||||
Timeout: WSTimeout,
|
||||
|
||||
|
@ -62,6 +71,7 @@ func NewCustom(conn Connection, addr string) *Websocket {
|
|||
}
|
||||
}
|
||||
|
||||
// Dial waits until the rate limiter allows then dials the websocket.
|
||||
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||
if ws.Timeout > 0 {
|
||||
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
|
||||
|
@ -75,46 +85,85 @@ func (ws *Websocket) Dial(ctx context.Context) error {
|
|||
return errors.Wrap(err, "failed to wait")
|
||||
}
|
||||
|
||||
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
if !ws.closed {
|
||||
WSDebug("Old connection not yet closed while dialog; closing it.")
|
||||
ws.conn.Close()
|
||||
}
|
||||
|
||||
if err := ws.conn.Dial(ctx, ws.addr); err != nil {
|
||||
return errors.Wrap(err, "failed to dial")
|
||||
}
|
||||
|
||||
ws.closed = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Listen returns the inner event channel or nil if the Websocket connection is
|
||||
// not alive.
|
||||
func (ws *Websocket) Listen() <-chan Event {
|
||||
return ws.Conn.Listen()
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
if ws.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ws.conn.Listen()
|
||||
}
|
||||
|
||||
// Send sends b over the Websocket without a timeout.
|
||||
func (ws *Websocket) Send(b []byte) error {
|
||||
return ws.SendCtx(context.Background(), b)
|
||||
}
|
||||
|
||||
// SendCtx sends b over the Websocket with a deadline. It closes the internal
|
||||
// Websocket if the Send method errors out.
|
||||
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
return ws.Conn.Send(ctx, b)
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
if ws.closed {
|
||||
return ErrWebsocketClosed
|
||||
}
|
||||
|
||||
if err := ws.conn.Send(ctx, b); err != nil {
|
||||
ws.close()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the websocket connection. It assumes that the Websocket is
|
||||
// closed even when it returns an error. If the Websocket was already closed
|
||||
// before, nil will be returned.
|
||||
func (ws *Websocket) Close() error {
|
||||
return ws.Conn.Close()
|
||||
WSDebug("Conn: Acquiring mutex lock to close...")
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
WSDebug("Conn: Write mutex acquired; closing.")
|
||||
|
||||
return ws.close()
|
||||
}
|
||||
|
||||
func InjectValues(rawurl string, values url.Values) string {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
// Unknown URL, return as-is.
|
||||
return rawurl
|
||||
// close closes the Websocket without acquiring the mutex. Refer to Close for
|
||||
// more information.
|
||||
func (ws *Websocket) close() error {
|
||||
if ws.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append additional parameters:
|
||||
var q = u.Query()
|
||||
for k, v := range values {
|
||||
q[k] = append(q[k], v...)
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
err := ws.conn.Close()
|
||||
ws.closed = true
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue