wsutil: Improved internal code

This commit is contained in:
diamondburned 2020-08-20 14:15:52 -07:00
parent fd818e181e
commit 6b4e26e839
2 changed files with 73 additions and 70 deletions

View File

@ -14,8 +14,13 @@ import (
"github.com/pkg/errors"
)
// CopyBufferSize is used for the initial size of the internal WS' buffer.
const CopyBufferSize = 2048
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
// re-allocated. This constant is 4MB.
const MaxCapUntilReset = 4 * (1 << 20)
// CloseDeadline controls the deadline to wait for sending the Close frame.
var CloseDeadline = time.Second
@ -52,8 +57,7 @@ type Conn struct {
events chan Event
// write channels
writes chan []byte
errors chan error
writeMu *sync.Mutex
buf bytes.Buffer
zlib io.ReadCloser // (compress/zlib).reader
@ -72,15 +76,23 @@ func NewConn() *Conn {
}
func NewConnWithDriver(driver json.Driver) *Conn {
writeMu := sync.Mutex{}
writeMu.Lock()
writeBuf := bytes.Buffer{}
writeBuf.Grow(CopyBufferSize)
return &Conn{
Driver: driver,
dialer: &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
ReadBufferSize: CopyBufferSize,
WriteBufferSize: CopyBufferSize,
EnableCompression: true,
},
// zlib: zlib.NewInflator(),
// buf: make([]byte, CopyBufferSize),
writeMu: &writeMu,
buf: writeBuf,
}
}
@ -105,12 +117,11 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
// Set up the closer.
c.closeOnce = &sync.Once{}
c.events = make(chan Event)
c.events = make(chan Event, WSBuffer)
go c.readLoop()
c.writes = make(chan []byte)
c.errors = make(chan error)
go c.writeLoop()
// Unlock the mutex that would otherwise be acquired in NewConn and Close.
c.writeMu.Unlock()
return err
}
@ -120,12 +131,6 @@ func (c *Conn) Listen() <-chan Event {
}
func (c *Conn) readLoop() {
// Acquire the read lock throughout the span of the loop. This would still
// allow Send to acquire another RLock, but wouldn't allow Close to
// prematurely exit, as Close acquires a write lock.
// c.mut.RLock()
// defer c.mut.RUnlock()
// Clean up the events channel in the end.
defer close(c.events)
@ -157,27 +162,6 @@ func (c *Conn) readLoop() {
}
}
func (c *Conn) writeLoop() {
// Closing c.writes would break the loop immediately.
for b := range c.writes {
c.errors <- c.Conn.WriteMessage(websocket.TextMessage, b)
}
// Quick deadline:
deadline := time.Now().Add(CloseDeadline)
// Make a closure message:
msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
// Send a close message before closing the connection. We're not error
// checking this because it's not important.
c.Conn.WriteControl(websocket.TextMessage, msg, deadline)
// Safe to close now.
c.errors <- c.Conn.Close()
close(c.errors)
}
func (c *Conn) handle() ([]byte, error) {
// skip message type
t, r, err := c.Conn.NextReader()
@ -230,50 +214,45 @@ func (c *Conn) handle() ([]byte, error) {
// return nil, errors.New("Unexpected binary message.")
}
// resetDeadline is used to reset the write deadline after using the context's.
var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error {
// If websocket is already closed.
if c.writes == nil {
return ErrWebsocketClosed
c.writeMu.Lock()
defer c.writeMu.Unlock()
d, ok := ctx.Deadline()
if ok {
c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline)
}
// Send the bytes.
select {
case c.writes <- b:
// continue
case <-ctx.Done():
return ctx.Err()
}
// Receive the error.
select {
case err := <-c.errors:
return err
case <-ctx.Done():
return ctx.Err()
}
return c.Conn.WriteMessage(websocket.TextMessage, b)
}
func (c *Conn) Close() (err error) {
// Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls.
c.closeOnce.Do(func() {
// Close c.writes. This should trigger the websocket to close itself.
close(c.writes)
// Mark c.writes as empty.
c.writes = nil
WSDebug("Conn: Acquiring write lock...")
// Wait for the write loop to exit by flusing the errors channel.
err = <-c.errors // get close error
for range c.errors { // then flush
}
// Acquire the write lock forever.
c.writeMu.Lock()
WSDebug("Conn: Write lock acquired; closing.")
// Close the WS.
err = c.closeWS()
WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...")
// Flush all events before closing the channel. This will return as soon as
// c.events is closed, or after closed.
for range c.events {
}
// Mark c.events as empty.
c.events = nil
WSDebug("Flushed events.")
// Mark c.Conn as empty.
c.Conn = nil
@ -282,18 +261,45 @@ func (c *Conn) Close() (err error) {
return err
}
func (c *Conn) closeWS() (err error) {
// We can't close with a write control here, since it will invalidate the
// old session, breaking resumes.
// // Quick deadline:
// deadline := time.Now().Add(CloseDeadline)
// // Make a closure message:
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
// // Send a close message before closing the connection. We're not error
// // checking this because it's not important.
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
if err := c.Conn.Close(); err != nil {
return err
}
return
}
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
// buffer.
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
defer buf.Reset()
if _, err := buf.ReadFrom(r); err != nil {
return nil, err
}
// Copy the bytes so we could empty the buffer for reuse.
p := buf.Bytes()
cpy := make([]byte, len(p))
copy(cpy, p)
cpy := make([]byte, buf.Len())
copy(cpy, buf.Bytes())
// If the buffer's capacity is over the limit, then re-allocate a new one.
if buf.Cap() > MaxCapUntilReset {
*buf = bytes.Buffer{}
buf.Grow(CopyBufferSize)
}
return cpy, nil
}

View File

@ -15,15 +15,12 @@ import (
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
WSTimeout = 5 * time.Minute
WSTimeout = 30 * time.Second
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSExtraReadTimeout is the duration to be added to Hello, as a read
// timeout for the websocket.
WSExtraReadTimeout = time.Second
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}