mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-09 13:37:02 +00:00
254 lines
6.5 KiB
Go
254 lines
6.5 KiB
Go
package wsutil
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/zlib"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
|
|
// size is 4KB.
|
|
var CopyBufferSize = 4096
|
|
|
|
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
|
|
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
|
|
var MaxCapUntilReset = CopyBufferSize * 4
|
|
|
|
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
|
var CloseDeadline = time.Second
|
|
|
|
// ErrWebsocketClosed is returned if the websocket is already closed.
|
|
var ErrWebsocketClosed = errors.New("websocket is closed")
|
|
|
|
// Connection is an interface that abstracts around a generic Websocket driver.
|
|
// This connection expects the driver to handle compression by itself, including
|
|
// modifying the connection URL. The implementation doesn't have to be safe for
|
|
// concurrent use.
|
|
type Connection interface {
|
|
// Dial dials the address (string). Context needs to be passed in for
|
|
// timeout. This method should also be re-usable after Close is called.
|
|
Dial(context.Context, string) error
|
|
|
|
// Listen returns an event channel that sends over events constantly. It can
|
|
// return nil if there isn't an ongoing connection.
|
|
Listen() <-chan Event
|
|
|
|
// Send allows the caller to send bytes. It does not need to clean itself
|
|
// up on errors, as the Websocket wrapper will do that.
|
|
Send(context.Context, []byte) error
|
|
|
|
// Close should close the websocket connection. The underlying connection
|
|
// may be reused, but this Connection instance will be reused with Dial. The
|
|
// Connection must still be reusable even if Close returns an error.
|
|
Close() error
|
|
}
|
|
|
|
// Conn is the default Websocket connection. It tries to compresses all payloads
|
|
// using zlib.
|
|
type Conn struct {
|
|
Dialer websocket.Dialer
|
|
Header http.Header
|
|
Conn *websocket.Conn
|
|
events chan Event
|
|
}
|
|
|
|
var _ Connection = (*Conn)(nil)
|
|
|
|
// NewConn creates a new default websocket connection with a default dialer.
|
|
func NewConn() *Conn {
|
|
return NewConnWithDialer(websocket.Dialer{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
HandshakeTimeout: WSTimeout,
|
|
ReadBufferSize: CopyBufferSize,
|
|
WriteBufferSize: CopyBufferSize,
|
|
EnableCompression: true,
|
|
})
|
|
}
|
|
|
|
// NewConn creates a new default websocket connection with a custom dialer.
|
|
func NewConnWithDialer(dialer websocket.Dialer) *Conn {
|
|
return &Conn{
|
|
Dialer: dialer,
|
|
Header: http.Header{
|
|
"Accept-Encoding": {"zlib"},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
|
|
// BUG which prevents stream compression.
|
|
// See https://github.com/golang/go/issues/31514.
|
|
|
|
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, c.Header)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to dial WS")
|
|
}
|
|
|
|
// Reset the deadline.
|
|
c.Conn.SetWriteDeadline(resetDeadline)
|
|
|
|
c.events = make(chan Event, WSBuffer)
|
|
go startReadLoop(c.Conn, c.events)
|
|
|
|
return err
|
|
}
|
|
|
|
// Listen returns an event channel if there is a connection associated with it.
|
|
// It returns nil if there is none.
|
|
func (c *Conn) Listen() <-chan Event {
|
|
return c.events
|
|
}
|
|
|
|
// resetDeadline is used to reset the write deadline after using the context's.
|
|
var resetDeadline = time.Time{}
|
|
|
|
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
|
d, ok := ctx.Deadline()
|
|
if ok {
|
|
c.Conn.SetWriteDeadline(d)
|
|
defer c.Conn.SetWriteDeadline(resetDeadline)
|
|
}
|
|
|
|
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
WSDebug("Conn: Close is called; shutting down the Websocket connection.")
|
|
|
|
// Have a deadline before closing.
|
|
var deadline = time.Now().Add(5 * time.Second)
|
|
c.Conn.SetWriteDeadline(deadline)
|
|
|
|
// Close the WS.
|
|
err := c.Conn.Close()
|
|
|
|
c.Conn.SetWriteDeadline(resetDeadline)
|
|
|
|
WSDebug("Conn: Websocket closed; error:", err)
|
|
WSDebug("Conn: Flusing events...")
|
|
|
|
// Flush all events before closing the channel. This will return as soon as
|
|
// c.events is closed, or after closed.
|
|
for range c.events {
|
|
}
|
|
|
|
WSDebug("Flushed events.")
|
|
|
|
return err
|
|
}
|
|
|
|
// loopState is a thread-unsafe disposable state container for the read loop.
|
|
// It's made to completely separate the read loop of any synchronization that
|
|
// doesn't involve the websocket connection itself.
|
|
type loopState struct {
|
|
conn *websocket.Conn
|
|
zlib io.ReadCloser
|
|
buf bytes.Buffer
|
|
}
|
|
|
|
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
|
|
// Clean up the events channel in the end.
|
|
defer close(eventCh)
|
|
|
|
// Allocate the read loop its own private resources.
|
|
state := loopState{conn: conn}
|
|
state.buf.Grow(CopyBufferSize)
|
|
|
|
for {
|
|
b, err := state.handle()
|
|
if err != nil {
|
|
WSDebug("Conn: Read error:", err)
|
|
|
|
// Is the error an EOF?
|
|
if errors.Is(err, io.EOF) {
|
|
// Yes it is, exit.
|
|
return
|
|
}
|
|
|
|
// Is the error an intentional close call? Go 1.16 exposes
|
|
// ErrClosing, but we have to do this for now.
|
|
if strings.HasSuffix(err.Error(), "use of closed network connection") {
|
|
return
|
|
}
|
|
|
|
// Check if the error is a normal one:
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
return
|
|
}
|
|
|
|
// Unusual error; log and exit:
|
|
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
|
|
return
|
|
}
|
|
|
|
// If the payload length is 0, skip it.
|
|
if len(b) == 0 {
|
|
continue
|
|
}
|
|
|
|
eventCh <- Event{b, nil}
|
|
}
|
|
}
|
|
|
|
func (state *loopState) handle() ([]byte, error) {
|
|
// skip message type
|
|
t, r, err := state.conn.NextReader()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if t == websocket.BinaryMessage {
|
|
// Probably a zlib payload.
|
|
|
|
if state.zlib == nil {
|
|
z, err := zlib.NewReader(r)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create a zlib reader")
|
|
}
|
|
state.zlib = z
|
|
} else {
|
|
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
|
return nil, errors.Wrap(err, "failed to reset zlib reader")
|
|
}
|
|
}
|
|
|
|
defer state.zlib.Close()
|
|
r = state.zlib
|
|
}
|
|
|
|
return state.readAll(r)
|
|
}
|
|
|
|
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
|
// buffer.
|
|
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
|
|
defer state.buf.Reset()
|
|
|
|
if _, err := state.buf.ReadFrom(r); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Copy the bytes so we could empty the buffer for reuse.
|
|
cpy := make([]byte, state.buf.Len())
|
|
copy(cpy, state.buf.Bytes())
|
|
|
|
// If the buffer's capacity is over the limit, then re-allocate a new one.
|
|
if state.buf.Cap() > MaxCapUntilReset {
|
|
state.buf = bytes.Buffer{}
|
|
state.buf.Grow(CopyBufferSize)
|
|
}
|
|
|
|
return cpy, nil
|
|
}
|