mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-07 12:38:05 +00:00
diamondburned
7668fe940c
This commit removes the thread safety requirement that Conn implementations must satisfy. It moves the mutex guards as well as the multiple close wrapper over to the Websocket wrapper instead.
234 lines
6 KiB
Go
234 lines
6 KiB
Go
package wsutil
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/zlib"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
|
|
// size is 4KB.
|
|
var CopyBufferSize = 4096
|
|
|
|
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
|
|
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
|
|
var MaxCapUntilReset = CopyBufferSize * 4
|
|
|
|
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
|
var CloseDeadline = time.Second
|
|
|
|
// ErrWebsocketClosed is returned if the websocket is already closed.
|
|
var ErrWebsocketClosed = errors.New("websocket is closed")
|
|
|
|
// Connection is an interface that abstracts around a generic Websocket driver.
|
|
// This connection expects the driver to handle compression by itself, including
|
|
// modifying the connection URL. The implementation doesn't have to be safe for
|
|
// concurrent use.
|
|
type Connection interface {
|
|
// Dial dials the address (string). Context needs to be passed in for
|
|
// timeout. This method should also be re-usable after Close is called.
|
|
Dial(context.Context, string) error
|
|
|
|
// Listen returns an event channel that sends over events constantly. It can
|
|
// return nil if there isn't an ongoing connection.
|
|
Listen() <-chan Event
|
|
|
|
// Send allows the caller to send bytes. It does not need to clean itself
|
|
// up on errors, as the Websocket wrapper will do that.
|
|
Send(context.Context, []byte) error
|
|
|
|
// Close should close the websocket connection. The underlying connection
|
|
// may be reused, but this Connection instance will be reused with Dial. The
|
|
// Connection must still be reusable even if Close returns an error.
|
|
Close() error
|
|
}
|
|
|
|
// Conn is the default Websocket connection. It tries to compresses all payloads
|
|
// using zlib.
|
|
type Conn struct {
|
|
Dialer *websocket.Dialer
|
|
Conn *websocket.Conn
|
|
events chan Event
|
|
}
|
|
|
|
var _ Connection = (*Conn)(nil)
|
|
|
|
// NewConn creates a new default websocket connection with a default dialer.
|
|
func NewConn() *Conn {
|
|
return NewConnWithDialer(&websocket.Dialer{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
HandshakeTimeout: WSTimeout,
|
|
ReadBufferSize: CopyBufferSize,
|
|
WriteBufferSize: CopyBufferSize,
|
|
EnableCompression: true,
|
|
})
|
|
}
|
|
|
|
// NewConn creates a new default websocket connection with a custom dialer.
|
|
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
|
|
return &Conn{Dialer: dialer}
|
|
}
|
|
|
|
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
|
|
// BUG which prevents stream compression.
|
|
// See https://github.com/golang/go/issues/31514.
|
|
|
|
// Enable compression:
|
|
headers := http.Header{
|
|
"Accept-Encoding": {"zlib"},
|
|
}
|
|
|
|
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, headers)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to dial WS")
|
|
}
|
|
|
|
c.events = make(chan Event, WSBuffer)
|
|
go startReadLoop(c.Conn, c.events)
|
|
|
|
return err
|
|
}
|
|
|
|
// Listen returns an event channel if there is a connection associated with it.
|
|
// It returns nil if there is none.
|
|
func (c *Conn) Listen() <-chan Event {
|
|
return c.events
|
|
}
|
|
|
|
// resetDeadline is used to reset the write deadline after using the context's.
|
|
var resetDeadline = time.Time{}
|
|
|
|
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
|
d, ok := ctx.Deadline()
|
|
if ok {
|
|
c.Conn.SetWriteDeadline(d)
|
|
defer c.Conn.SetWriteDeadline(resetDeadline)
|
|
}
|
|
|
|
// We need to clean up ourselves if things are erroring out.
|
|
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
// Close the WS.
|
|
err := c.Conn.Close()
|
|
|
|
WSDebug("Conn: Websocket closed; error:", err)
|
|
WSDebug("Conn: Flusing events...")
|
|
|
|
// Flush all events before closing the channel. This will return as soon as
|
|
// c.events is closed, or after closed.
|
|
for range c.events {
|
|
}
|
|
|
|
WSDebug("Flushed events.")
|
|
|
|
return err
|
|
}
|
|
|
|
// loopState is a thread-unsafe disposable state container for the read loop.
|
|
// It's made to completely separate the read loop of any synchronization that
|
|
// doesn't involve the websocket connection itself.
|
|
type loopState struct {
|
|
conn *websocket.Conn
|
|
zlib io.ReadCloser
|
|
buf bytes.Buffer
|
|
}
|
|
|
|
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
|
|
// Clean up the events channel in the end.
|
|
defer close(eventCh)
|
|
|
|
// Allocate the read loop its own private resources.
|
|
state := loopState{conn: conn}
|
|
state.buf.Grow(CopyBufferSize)
|
|
|
|
for {
|
|
b, err := state.handle()
|
|
if err != nil {
|
|
// Is the error an EOF?
|
|
if errors.Is(err, io.EOF) {
|
|
// Yes it is, exit.
|
|
return
|
|
}
|
|
|
|
// Check if the error is a normal one:
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
return
|
|
}
|
|
|
|
// Unusual error; log and exit:
|
|
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
|
|
return
|
|
}
|
|
|
|
// If the payload length is 0, skip it.
|
|
if len(b) == 0 {
|
|
continue
|
|
}
|
|
|
|
eventCh <- Event{b, nil}
|
|
}
|
|
}
|
|
|
|
func (state *loopState) handle() ([]byte, error) {
|
|
// skip message type
|
|
t, r, err := state.conn.NextReader()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if t == websocket.BinaryMessage {
|
|
// Probably a zlib payload.
|
|
|
|
if state.zlib == nil {
|
|
z, err := zlib.NewReader(r)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create a zlib reader")
|
|
}
|
|
state.zlib = z
|
|
} else {
|
|
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
|
return nil, errors.Wrap(err, "failed to reset zlib reader")
|
|
}
|
|
}
|
|
|
|
defer state.zlib.Close()
|
|
r = state.zlib
|
|
}
|
|
|
|
return state.readAll(r)
|
|
}
|
|
|
|
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
|
// buffer.
|
|
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
|
|
defer state.buf.Reset()
|
|
|
|
if _, err := state.buf.ReadFrom(r); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Copy the bytes so we could empty the buffer for reuse.
|
|
cpy := make([]byte, state.buf.Len())
|
|
copy(cpy, state.buf.Bytes())
|
|
|
|
// If the buffer's capacity is over the limit, then re-allocate a new one.
|
|
if state.buf.Cap() > MaxCapUntilReset {
|
|
state.buf = bytes.Buffer{}
|
|
state.buf.Grow(CopyBufferSize)
|
|
}
|
|
|
|
return cpy, nil
|
|
}
|