wsutil: Refactored and decoupled structures for better thread safety

This commit is contained in:
diamondburned 2020-10-28 10:19:22 -07:00
parent 6c332ac145
commit 16a408bf30
1 changed files with 131 additions and 170 deletions

View File

@ -9,17 +9,17 @@ import (
"sync" "sync"
"time" "time"
"github.com/diamondburned/arikawa/utils/json"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// CopyBufferSize is used for the initial size of the internal WS' buffer. // CopyBufferSize is used for the initial size of the internal WS' buffer. Its
const CopyBufferSize = 2048 // size is 4KB.
var CopyBufferSize = 4096
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is // MaxCapUntilReset determines the maximum capacity before the bytes buffer is
// re-allocated. This constant is 4MB. // re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
const MaxCapUntilReset = 4 * (1 << 20) var MaxCapUntilReset = CopyBufferSize * 4
// CloseDeadline controls the deadline to wait for sending the Close frame. // CloseDeadline controls the deadline to wait for sending the Close frame.
var CloseDeadline = time.Second var CloseDeadline = time.Second
@ -50,78 +50,54 @@ type Connection interface {
// Conn is the default Websocket connection. It compresses all payloads using // Conn is the default Websocket connection. It compresses all payloads using
// zlib. // zlib.
type Conn struct { type Conn struct {
mutex sync.Mutex
Conn *websocket.Conn Conn *websocket.Conn
json.Driver
dialer *websocket.Dialer dialer *websocket.Dialer
events chan Event events chan Event
// write channels
writeMu *sync.Mutex
buf bytes.Buffer
zlib io.ReadCloser // (compress/zlib).reader
// nil until Dial().
closeOnce *sync.Once
// zlib *zlib.Inflator // zlib.NewReader
// buf []byte // io.Copy buffer
} }
var _ Connection = (*Conn)(nil) var _ Connection = (*Conn)(nil)
// NewConn creates a new default websocket connection with a default dialer.
func NewConn() *Conn { func NewConn() *Conn {
return NewConnWithDriver(json.Default) return NewConnWithDialer(&websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
ReadBufferSize: CopyBufferSize,
WriteBufferSize: CopyBufferSize,
EnableCompression: true,
})
} }
func NewConnWithDriver(driver json.Driver) *Conn { // NewConn creates a new default websocket connection with a custom dialer.
writeMu := sync.Mutex{} func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
writeMu.Lock() return &Conn{dialer: dialer}
writeBuf := bytes.Buffer{}
writeBuf.Grow(CopyBufferSize)
return &Conn{
Driver: driver,
dialer: &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
ReadBufferSize: CopyBufferSize,
WriteBufferSize: CopyBufferSize,
EnableCompression: true,
},
writeMu: &writeMu,
buf: writeBuf,
}
} }
func (c *Conn) Dial(ctx context.Context, addr string) error { func (c *Conn) Dial(ctx context.Context, addr string) error {
var err error
// Enable compression: // Enable compression:
headers := http.Header{} headers := http.Header{
headers.Set("Accept-Encoding", "zlib") "Accept-Encoding": {"zlib"},
}
// BUG: https://github.com/golang/go/issues/31514 // BUG which prevents stream compression.
// // Enable stream compression: // See https://github.com/golang/go/issues/31514.
// addr = InjectValues(addr, url.Values{
// "compress": {"zlib-stream"},
// })
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers) conn, _, err := c.dialer.DialContext(ctx, addr, headers)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to dial WS") return errors.Wrap(err, "failed to dial WS")
} }
// Set up the closer. events := make(chan Event, WSBuffer)
c.closeOnce = &sync.Once{} go startReadLoop(conn, events)
c.events = make(chan Event, WSBuffer) c.mutex.Lock()
go c.readLoop() defer c.mutex.Unlock()
// Unlock the mutex that would otherwise be acquired in NewConn and Close. c.Conn = conn
c.writeMu.Unlock() c.events = events
return err return err
} }
@ -130,96 +106,12 @@ func (c *Conn) Listen() <-chan Event {
return c.events return c.events
} }
func (c *Conn) readLoop() {
// Clean up the events channel in the end.
defer close(c.events)
for {
b, err := c.handle()
if err != nil {
// Is the error an EOF?
if errors.Is(err, io.EOF) {
// Yes it is, exit.
return
}
// Check if the error is a normal one:
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return
}
// Unusual error; log and exit:
c.events <- Event{nil, errors.Wrap(err, "WS error")}
return
}
// If the payload length is 0, skip it.
if len(b) == 0 {
continue
}
c.events <- Event{b, nil}
}
}
func (c *Conn) handle() ([]byte, error) {
// skip message type
t, r, err := c.Conn.NextReader()
if err != nil {
return nil, err
}
if t == websocket.BinaryMessage {
// Probably a zlib payload
if c.zlib == nil {
z, err := zlib.NewReader(r)
if err != nil {
return nil, errors.Wrap(err, "failed to create a zlib reader")
}
c.zlib = z
} else {
if err := c.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
return nil, errors.Wrap(err, "failed to reset zlib reader")
}
}
defer c.zlib.Close()
r = c.zlib
}
return readAll(&c.buf, r)
// if t is a text message, then handle it normally.
// if t == websocket.TextMessage {
// return readAll(&c.buf, r)
// }
// // Write to the zlib writer.
// c.zlib.Write(r)
// // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil {
// // return nil, errors.Wrap(err, "Failed to write to zlib")
// // }
// if !c.zlib.CanFlush() {
// return nil, nil
// }
// // Flush and get the uncompressed payload.
// b, err := c.zlib.Flush()
// if err != nil {
// return nil, errors.Wrap(err, "Failed to flush zlib")
// }
// return nil, errors.New("Unexpected binary message.")
}
// resetDeadline is used to reset the write deadline after using the context's. // resetDeadline is used to reset the write deadline after using the context's.
var resetDeadline = time.Time{} var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error { func (c *Conn) Send(ctx context.Context, b []byte) error {
c.writeMu.Lock() c.mutex.Lock()
defer c.writeMu.Unlock() defer c.mutex.Unlock()
d, ok := ctx.Deadline() d, ok := ctx.Deadline()
if ok { if ok {
@ -230,38 +122,37 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
return c.Conn.WriteMessage(websocket.TextMessage, b) return c.Conn.WriteMessage(websocket.TextMessage, b)
} }
func (c *Conn) Close() (err error) { func (c *Conn) Close() error {
// Use a sync.Once to guarantee that other Close() calls block until the // Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls. // main call is done. It also prevents future calls.
c.closeOnce.Do(func() { WSDebug("Conn: Acquiring write lock...")
WSDebug("Conn: Acquiring write lock...")
// Acquire the write lock forever. // Acquire the write lock forever.
c.writeMu.Lock() c.mutex.Lock()
defer c.mutex.Unlock()
WSDebug("Conn: Write lock acquired; closing.") WSDebug("Conn: Write lock acquired; closing.")
// Close the WS. // Close the WS.
err = c.closeWS() err := c.closeWS()
WSDebug("Conn: Websocket closed; error:", err) WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...") WSDebug("Conn: Flusing events...")
// Flush all events before closing the channel. This will return as soon as // Flush all events before closing the channel. This will return as soon as
// c.events is closed, or after closed. // c.events is closed, or after closed.
for range c.events { for range c.events {
} }
WSDebug("Flushed events.") WSDebug("Flushed events.")
// Mark c.Conn as empty. // Mark c.Conn as empty.
c.Conn = nil c.Conn = nil
})
return err return err
} }
func (c *Conn) closeWS() (err error) { func (c *Conn) closeWS() error {
// We can't close with a write control here, since it will invalidate the // We can't close with a write control here, since it will invalidate the
// old session, breaking resumes. // old session, breaking resumes.
@ -275,30 +166,100 @@ func (c *Conn) closeWS() (err error) {
// // checking this because it's not important. // // checking this because it's not important.
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline) // err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
if err := c.Conn.Close(); err != nil { return c.Conn.Close()
return err }
// loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself.
type loopState struct {
conn *websocket.Conn
zlib io.ReadCloser
buf bytes.Buffer
}
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
// Clean up the events channel in the end.
defer close(eventCh)
// Allocate the read loop its own private resources.
state := loopState{conn: conn}
state.buf.Grow(CopyBufferSize)
for {
b, err := state.handle()
if err != nil {
// Is the error an EOF?
if errors.Is(err, io.EOF) {
// Yes it is, exit.
return
}
// Check if the error is a normal one:
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return
}
// Unusual error; log and exit:
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
return
}
// If the payload length is 0, skip it.
if len(b) == 0 {
continue
}
eventCh <- Event{b, nil}
}
}
func (state *loopState) handle() ([]byte, error) {
// skip message type
t, r, err := state.conn.NextReader()
if err != nil {
return nil, err
} }
return if t == websocket.BinaryMessage {
// Probably a zlib payload
if state.zlib == nil {
z, err := zlib.NewReader(r)
if err != nil {
return nil, errors.Wrap(err, "failed to create a zlib reader")
}
state.zlib = z
} else {
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
return nil, errors.Wrap(err, "failed to reset zlib reader")
}
}
defer state.zlib.Close()
r = state.zlib
}
return state.readAll(r)
} }
// readAll reads bytes into an existing buffer, copy it over, then wipe the old // readAll reads bytes into an existing buffer, copy it over, then wipe the old
// buffer. // buffer.
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) { func (state *loopState) readAll(r io.Reader) ([]byte, error) {
defer buf.Reset() defer state.buf.Reset()
if _, err := buf.ReadFrom(r); err != nil { if _, err := state.buf.ReadFrom(r); err != nil {
return nil, err return nil, err
} }
// Copy the bytes so we could empty the buffer for reuse. // Copy the bytes so we could empty the buffer for reuse.
cpy := make([]byte, buf.Len()) cpy := make([]byte, state.buf.Len())
copy(cpy, buf.Bytes()) copy(cpy, state.buf.Bytes())
// If the buffer's capacity is over the limit, then re-allocate a new one. // If the buffer's capacity is over the limit, then re-allocate a new one.
if buf.Cap() > MaxCapUntilReset { if state.buf.Cap() > MaxCapUntilReset {
*buf = bytes.Buffer{} state.buf = bytes.Buffer{}
buf.Grow(CopyBufferSize) state.buf.Grow(CopyBufferSize)
} }
return cpy, nil return cpy, nil