1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-08 13:07:43 +00:00
arikawa/utils/wsutil/conn.go

234 lines
6 KiB
Go
Raw Normal View History

2020-01-09 05:24:45 +00:00
package wsutil
import (
"bytes"
2020-01-09 05:24:45 +00:00
"compress/zlib"
"context"
2020-01-29 03:54:22 +00:00
"io"
2020-01-09 05:24:45 +00:00
"net/http"
"time"
2020-01-09 05:24:45 +00:00
"github.com/gorilla/websocket"
2020-01-09 05:24:45 +00:00
"github.com/pkg/errors"
)
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
// size is 4KB.
var CopyBufferSize = 4096
2020-08-20 21:15:52 +00:00
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
var MaxCapUntilReset = CopyBufferSize * 4
2020-08-20 21:15:52 +00:00
// CloseDeadline controls the deadline to wait for sending the Close frame.
var CloseDeadline = time.Second
2020-01-09 05:24:45 +00:00
// ErrWebsocketClosed is returned if the websocket is already closed.
2020-05-16 21:14:49 +00:00
var ErrWebsocketClosed = errors.New("websocket is closed")
2020-01-09 05:24:45 +00:00
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself, including
// modifying the connection URL. The implementation doesn't have to be safe for
// concurrent use.
2020-01-09 05:24:45 +00:00
type Connection interface {
2020-01-15 04:43:34 +00:00
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) error
// Listen returns an event channel that sends over events constantly. It can
// return nil if there isn't an ongoing connection.
2020-01-09 05:24:45 +00:00
Listen() <-chan Event
// Send allows the caller to send bytes. It does not need to clean itself
// up on errors, as the Websocket wrapper will do that.
Send(context.Context, []byte) error
2020-01-09 05:24:45 +00:00
// Close should close the websocket connection. The underlying connection
// may be reused, but this Connection instance will be reused with Dial. The
// Connection must still be reusable even if Close returns an error.
Close() error
2020-01-09 05:24:45 +00:00
}
// Conn is the default Websocket connection. It tries to compresses all payloads
// using zlib.
2020-01-09 05:24:45 +00:00
type Conn struct {
Dialer *websocket.Dialer
Conn *websocket.Conn
2020-01-09 05:24:45 +00:00
events chan Event
}
var _ Connection = (*Conn)(nil)
// NewConn creates a new default websocket connection with a default dialer.
func NewConn() *Conn {
return NewConnWithDialer(&websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
ReadBufferSize: CopyBufferSize,
WriteBufferSize: CopyBufferSize,
EnableCompression: true,
})
}
// NewConn creates a new default websocket connection with a custom dialer.
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
return &Conn{Dialer: dialer}
2020-01-09 05:24:45 +00:00
}
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
// Enable compression:
headers := http.Header{
"Accept-Encoding": {"zlib"},
}
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, headers)
if err != nil {
return errors.Wrap(err, "failed to dial WS")
}
c.events = make(chan Event, WSBuffer)
go startReadLoop(c.Conn, c.events)
2020-01-09 05:24:45 +00:00
return err
}
// Listen returns an event channel if there is a connection associated with it.
// It returns nil if there is none.
2020-01-09 05:24:45 +00:00
func (c *Conn) Listen() <-chan Event {
return c.events
}
// resetDeadline is used to reset the write deadline after using the context's.
var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error {
d, ok := ctx.Deadline()
if ok {
c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline)
}
// We need to clean up ourselves if things are erroring out.
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
return err
}
return nil
}
func (c *Conn) Close() error {
// Close the WS.
err := c.Conn.Close()
WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...")
// Flush all events before closing the channel. This will return as soon as
// c.events is closed, or after closed.
for range c.events {
}
WSDebug("Flushed events.")
return err
}
// loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself.
type loopState struct {
conn *websocket.Conn
zlib io.ReadCloser
buf bytes.Buffer
}
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
// Clean up the events channel in the end.
defer close(eventCh)
// Allocate the read loop its own private resources.
state := loopState{conn: conn}
state.buf.Grow(CopyBufferSize)
for {
b, err := state.handle()
if err != nil {
// Is the error an EOF?
if errors.Is(err, io.EOF) {
// Yes it is, exit.
2020-02-02 22:12:54 +00:00
return
2020-01-09 05:24:45 +00:00
}
// Check if the error is a normal one:
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return
}
// Unusual error; log and exit:
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
return
}
// If the payload length is 0, skip it.
if len(b) == 0 {
continue
2020-01-09 05:24:45 +00:00
}
eventCh <- Event{b, nil}
}
2020-01-09 05:24:45 +00:00
}
func (state *loopState) handle() ([]byte, error) {
// skip message type
t, r, err := state.conn.NextReader()
2020-01-09 05:24:45 +00:00
if err != nil {
2020-01-16 03:28:21 +00:00
return nil, err
2020-01-09 05:24:45 +00:00
}
if t == websocket.BinaryMessage {
// Probably a zlib payload.
if state.zlib == nil {
z, err := zlib.NewReader(r)
if err != nil {
2020-05-16 21:14:49 +00:00
return nil, errors.Wrap(err, "failed to create a zlib reader")
}
state.zlib = z
} else {
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
2020-05-16 21:14:49 +00:00
return nil, errors.Wrap(err, "failed to reset zlib reader")
}
2020-01-09 05:24:45 +00:00
}
defer state.zlib.Close()
r = state.zlib
2020-08-20 21:15:52 +00:00
}
return state.readAll(r)
2020-08-20 21:15:52 +00:00
}
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
// buffer.
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
defer state.buf.Reset()
2020-08-20 21:15:52 +00:00
if _, err := state.buf.ReadFrom(r); err != nil {
return nil, err
}
// Copy the bytes so we could empty the buffer for reuse.
cpy := make([]byte, state.buf.Len())
copy(cpy, state.buf.Bytes())
2020-08-20 21:15:52 +00:00
// If the buffer's capacity is over the limit, then re-allocate a new one.
if state.buf.Cap() > MaxCapUntilReset {
state.buf = bytes.Buffer{}
state.buf.Grow(CopyBufferSize)
2020-08-20 21:15:52 +00:00
}
return cpy, nil
}