mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-30 18:53:30 +00:00
wsutil: Refactored and decoupled structures for better thread safety
This commit is contained in:
parent
6c332ac145
commit
16a408bf30
|
@ -9,17 +9,17 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/utils/json"
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CopyBufferSize is used for the initial size of the internal WS' buffer.
|
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
|
||||||
const CopyBufferSize = 2048
|
// size is 4KB.
|
||||||
|
var CopyBufferSize = 4096
|
||||||
|
|
||||||
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
|
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
|
||||||
// re-allocated. This constant is 4MB.
|
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
|
||||||
const MaxCapUntilReset = 4 * (1 << 20)
|
var MaxCapUntilReset = CopyBufferSize * 4
|
||||||
|
|
||||||
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
||||||
var CloseDeadline = time.Second
|
var CloseDeadline = time.Second
|
||||||
|
@ -50,78 +50,54 @@ type Connection interface {
|
||||||
// Conn is the default Websocket connection. It compresses all payloads using
|
// Conn is the default Websocket connection. It compresses all payloads using
|
||||||
// zlib.
|
// zlib.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
|
mutex sync.Mutex
|
||||||
|
|
||||||
Conn *websocket.Conn
|
Conn *websocket.Conn
|
||||||
json.Driver
|
|
||||||
|
|
||||||
dialer *websocket.Dialer
|
dialer *websocket.Dialer
|
||||||
events chan Event
|
events chan Event
|
||||||
|
|
||||||
// write channels
|
|
||||||
writeMu *sync.Mutex
|
|
||||||
|
|
||||||
buf bytes.Buffer
|
|
||||||
zlib io.ReadCloser // (compress/zlib).reader
|
|
||||||
|
|
||||||
// nil until Dial().
|
|
||||||
closeOnce *sync.Once
|
|
||||||
|
|
||||||
// zlib *zlib.Inflator // zlib.NewReader
|
|
||||||
// buf []byte // io.Copy buffer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Connection = (*Conn)(nil)
|
var _ Connection = (*Conn)(nil)
|
||||||
|
|
||||||
|
// NewConn creates a new default websocket connection with a default dialer.
|
||||||
func NewConn() *Conn {
|
func NewConn() *Conn {
|
||||||
return NewConnWithDriver(json.Default)
|
return NewConnWithDialer(&websocket.Dialer{
|
||||||
}
|
|
||||||
|
|
||||||
func NewConnWithDriver(driver json.Driver) *Conn {
|
|
||||||
writeMu := sync.Mutex{}
|
|
||||||
writeMu.Lock()
|
|
||||||
|
|
||||||
writeBuf := bytes.Buffer{}
|
|
||||||
writeBuf.Grow(CopyBufferSize)
|
|
||||||
|
|
||||||
return &Conn{
|
|
||||||
Driver: driver,
|
|
||||||
dialer: &websocket.Dialer{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
HandshakeTimeout: WSTimeout,
|
HandshakeTimeout: WSTimeout,
|
||||||
ReadBufferSize: CopyBufferSize,
|
ReadBufferSize: CopyBufferSize,
|
||||||
WriteBufferSize: CopyBufferSize,
|
WriteBufferSize: CopyBufferSize,
|
||||||
EnableCompression: true,
|
EnableCompression: true,
|
||||||
},
|
})
|
||||||
writeMu: &writeMu,
|
}
|
||||||
buf: writeBuf,
|
|
||||||
}
|
// NewConn creates a new default websocket connection with a custom dialer.
|
||||||
|
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
|
||||||
|
return &Conn{dialer: dialer}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Dial(ctx context.Context, addr string) error {
|
func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||||
var err error
|
|
||||||
|
|
||||||
// Enable compression:
|
// Enable compression:
|
||||||
headers := http.Header{}
|
headers := http.Header{
|
||||||
headers.Set("Accept-Encoding", "zlib")
|
"Accept-Encoding": {"zlib"},
|
||||||
|
}
|
||||||
|
|
||||||
// BUG: https://github.com/golang/go/issues/31514
|
// BUG which prevents stream compression.
|
||||||
// // Enable stream compression:
|
// See https://github.com/golang/go/issues/31514.
|
||||||
// addr = InjectValues(addr, url.Values{
|
|
||||||
// "compress": {"zlib-stream"},
|
|
||||||
// })
|
|
||||||
|
|
||||||
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
|
conn, _, err := c.dialer.DialContext(ctx, addr, headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to dial WS")
|
return errors.Wrap(err, "failed to dial WS")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the closer.
|
events := make(chan Event, WSBuffer)
|
||||||
c.closeOnce = &sync.Once{}
|
go startReadLoop(conn, events)
|
||||||
|
|
||||||
c.events = make(chan Event, WSBuffer)
|
c.mutex.Lock()
|
||||||
go c.readLoop()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
// Unlock the mutex that would otherwise be acquired in NewConn and Close.
|
c.Conn = conn
|
||||||
c.writeMu.Unlock()
|
c.events = events
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -130,96 +106,12 @@ func (c *Conn) Listen() <-chan Event {
|
||||||
return c.events
|
return c.events
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readLoop() {
|
|
||||||
// Clean up the events channel in the end.
|
|
||||||
defer close(c.events)
|
|
||||||
|
|
||||||
for {
|
|
||||||
b, err := c.handle()
|
|
||||||
if err != nil {
|
|
||||||
// Is the error an EOF?
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
// Yes it is, exit.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the error is a normal one:
|
|
||||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unusual error; log and exit:
|
|
||||||
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the payload length is 0, skip it.
|
|
||||||
if len(b) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c.events <- Event{b, nil}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) handle() ([]byte, error) {
|
|
||||||
// skip message type
|
|
||||||
t, r, err := c.Conn.NextReader()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if t == websocket.BinaryMessage {
|
|
||||||
// Probably a zlib payload
|
|
||||||
|
|
||||||
if c.zlib == nil {
|
|
||||||
z, err := zlib.NewReader(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to create a zlib reader")
|
|
||||||
}
|
|
||||||
c.zlib = z
|
|
||||||
} else {
|
|
||||||
if err := c.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to reset zlib reader")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
defer c.zlib.Close()
|
|
||||||
r = c.zlib
|
|
||||||
}
|
|
||||||
|
|
||||||
return readAll(&c.buf, r)
|
|
||||||
|
|
||||||
// if t is a text message, then handle it normally.
|
|
||||||
// if t == websocket.TextMessage {
|
|
||||||
// return readAll(&c.buf, r)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Write to the zlib writer.
|
|
||||||
// c.zlib.Write(r)
|
|
||||||
// // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil {
|
|
||||||
// // return nil, errors.Wrap(err, "Failed to write to zlib")
|
|
||||||
// // }
|
|
||||||
|
|
||||||
// if !c.zlib.CanFlush() {
|
|
||||||
// return nil, nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Flush and get the uncompressed payload.
|
|
||||||
// b, err := c.zlib.Flush()
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, errors.Wrap(err, "Failed to flush zlib")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return nil, errors.New("Unexpected binary message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// resetDeadline is used to reset the write deadline after using the context's.
|
// resetDeadline is used to reset the write deadline after using the context's.
|
||||||
var resetDeadline = time.Time{}
|
var resetDeadline = time.Time{}
|
||||||
|
|
||||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
c.writeMu.Lock()
|
c.mutex.Lock()
|
||||||
defer c.writeMu.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
d, ok := ctx.Deadline()
|
d, ok := ctx.Deadline()
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -230,19 +122,19 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
return c.Conn.WriteMessage(websocket.TextMessage, b)
|
return c.Conn.WriteMessage(websocket.TextMessage, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() (err error) {
|
func (c *Conn) Close() error {
|
||||||
// Use a sync.Once to guarantee that other Close() calls block until the
|
// Use a sync.Once to guarantee that other Close() calls block until the
|
||||||
// main call is done. It also prevents future calls.
|
// main call is done. It also prevents future calls.
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
WSDebug("Conn: Acquiring write lock...")
|
WSDebug("Conn: Acquiring write lock...")
|
||||||
|
|
||||||
// Acquire the write lock forever.
|
// Acquire the write lock forever.
|
||||||
c.writeMu.Lock()
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
WSDebug("Conn: Write lock acquired; closing.")
|
WSDebug("Conn: Write lock acquired; closing.")
|
||||||
|
|
||||||
// Close the WS.
|
// Close the WS.
|
||||||
err = c.closeWS()
|
err := c.closeWS()
|
||||||
|
|
||||||
WSDebug("Conn: Websocket closed; error:", err)
|
WSDebug("Conn: Websocket closed; error:", err)
|
||||||
WSDebug("Conn: Flusing events...")
|
WSDebug("Conn: Flusing events...")
|
||||||
|
@ -256,12 +148,11 @@ func (c *Conn) Close() (err error) {
|
||||||
|
|
||||||
// Mark c.Conn as empty.
|
// Mark c.Conn as empty.
|
||||||
c.Conn = nil
|
c.Conn = nil
|
||||||
})
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) closeWS() (err error) {
|
func (c *Conn) closeWS() error {
|
||||||
// We can't close with a write control here, since it will invalidate the
|
// We can't close with a write control here, since it will invalidate the
|
||||||
// old session, breaking resumes.
|
// old session, breaking resumes.
|
||||||
|
|
||||||
|
@ -275,30 +166,100 @@ func (c *Conn) closeWS() (err error) {
|
||||||
// // checking this because it's not important.
|
// // checking this because it's not important.
|
||||||
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
|
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
|
||||||
|
|
||||||
if err := c.Conn.Close(); err != nil {
|
return c.Conn.Close()
|
||||||
return err
|
}
|
||||||
|
|
||||||
|
// loopState is a thread-unsafe disposable state container for the read loop.
|
||||||
|
// It's made to completely separate the read loop of any synchronization that
|
||||||
|
// doesn't involve the websocket connection itself.
|
||||||
|
type loopState struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
zlib io.ReadCloser
|
||||||
|
buf bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
|
||||||
|
// Clean up the events channel in the end.
|
||||||
|
defer close(eventCh)
|
||||||
|
|
||||||
|
// Allocate the read loop its own private resources.
|
||||||
|
state := loopState{conn: conn}
|
||||||
|
state.buf.Grow(CopyBufferSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
b, err := state.handle()
|
||||||
|
if err != nil {
|
||||||
|
// Is the error an EOF?
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
// Yes it is, exit.
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the error is a normal one:
|
||||||
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unusual error; log and exit:
|
||||||
|
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the payload length is 0, skip it.
|
||||||
|
if len(b) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
eventCh <- Event{b, nil}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *loopState) handle() ([]byte, error) {
|
||||||
|
// skip message type
|
||||||
|
t, r, err := state.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if t == websocket.BinaryMessage {
|
||||||
|
// Probably a zlib payload
|
||||||
|
|
||||||
|
if state.zlib == nil {
|
||||||
|
z, err := zlib.NewReader(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create a zlib reader")
|
||||||
|
}
|
||||||
|
state.zlib = z
|
||||||
|
} else {
|
||||||
|
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to reset zlib reader")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer state.zlib.Close()
|
||||||
|
r = state.zlib
|
||||||
|
}
|
||||||
|
|
||||||
|
return state.readAll(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
||||||
// buffer.
|
// buffer.
|
||||||
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
|
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
|
||||||
defer buf.Reset()
|
defer state.buf.Reset()
|
||||||
|
|
||||||
if _, err := buf.ReadFrom(r); err != nil {
|
if _, err := state.buf.ReadFrom(r); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the bytes so we could empty the buffer for reuse.
|
// Copy the bytes so we could empty the buffer for reuse.
|
||||||
cpy := make([]byte, buf.Len())
|
cpy := make([]byte, state.buf.Len())
|
||||||
copy(cpy, buf.Bytes())
|
copy(cpy, state.buf.Bytes())
|
||||||
|
|
||||||
// If the buffer's capacity is over the limit, then re-allocate a new one.
|
// If the buffer's capacity is over the limit, then re-allocate a new one.
|
||||||
if buf.Cap() > MaxCapUntilReset {
|
if state.buf.Cap() > MaxCapUntilReset {
|
||||||
*buf = bytes.Buffer{}
|
state.buf = bytes.Buffer{}
|
||||||
buf.Grow(CopyBufferSize)
|
state.buf.Grow(CopyBufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cpy, nil
|
return cpy, nil
|
||||||
|
|
Loading…
Reference in a new issue