2020-01-09 05:24:45 +00:00
|
|
|
package wsutil
|
|
|
|
|
|
|
|
import (
|
2020-04-06 19:03:42 +00:00
|
|
|
"bytes"
|
2020-01-09 05:24:45 +00:00
|
|
|
"compress/zlib"
|
|
|
|
"context"
|
2020-01-29 03:54:22 +00:00
|
|
|
"io"
|
2020-01-09 05:24:45 +00:00
|
|
|
"net/http"
|
2020-01-20 11:06:20 +00:00
|
|
|
"sync"
|
2020-04-06 19:03:42 +00:00
|
|
|
"time"
|
2020-01-09 05:24:45 +00:00
|
|
|
|
2020-04-09 02:28:40 +00:00
|
|
|
"github.com/diamondburned/arikawa/utils/json"
|
2020-04-06 19:03:42 +00:00
|
|
|
"github.com/gorilla/websocket"
|
2020-01-09 05:24:45 +00:00
|
|
|
"github.com/pkg/errors"
|
|
|
|
)
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
const CopyBufferSize = 2048
|
|
|
|
|
|
|
|
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
|
|
|
var CloseDeadline = time.Second
|
2020-01-09 05:24:45 +00:00
|
|
|
|
|
|
|
// Connection is an interface that abstracts around a generic Websocket driver.
|
2020-04-06 19:03:42 +00:00
|
|
|
// This connection expects the driver to handle compression by itself, including
|
|
|
|
// modifying the connection URL.
|
2020-01-09 05:24:45 +00:00
|
|
|
type Connection interface {
|
2020-01-15 04:43:34 +00:00
|
|
|
// Dial dials the address (string). Context needs to be passed in for
|
|
|
|
// timeout. This method should also be re-usable after Close is called.
|
|
|
|
Dial(context.Context, string) error
|
|
|
|
|
2020-01-09 05:24:45 +00:00
|
|
|
// Listen sends over events constantly. Error will be non-nil if Data is
|
|
|
|
// nil, so check for Error first.
|
|
|
|
Listen() <-chan Event
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
// Send allows the caller to send bytes. Thread safety is a requirement.
|
|
|
|
Send([]byte) error
|
2020-01-09 05:24:45 +00:00
|
|
|
|
|
|
|
// Close should close the websocket connection. The connection will not be
|
2020-04-06 19:03:42 +00:00
|
|
|
// reused. Code should be sent as the status code for the close frame.
|
|
|
|
Close(code int) error
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Conn is the default Websocket connection. It compresses all payloads using
|
|
|
|
// zlib.
|
|
|
|
type Conn struct {
|
2020-01-20 11:06:20 +00:00
|
|
|
Conn *websocket.Conn
|
2020-01-09 05:24:45 +00:00
|
|
|
json.Driver
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
dialer *websocket.Dialer
|
|
|
|
mut sync.RWMutex
|
2020-01-09 05:24:45 +00:00
|
|
|
events chan Event
|
2020-04-06 19:03:42 +00:00
|
|
|
|
2020-04-06 21:03:08 +00:00
|
|
|
// write channels
|
|
|
|
writes chan []byte
|
|
|
|
errors chan error
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
buf bytes.Buffer
|
|
|
|
|
|
|
|
// zlib *zlib.Inflator // zlib.NewReader
|
|
|
|
// buf []byte // io.Copy buffer
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
var _ Connection = (*Conn)(nil)
|
|
|
|
|
|
|
|
func NewConn(driver json.Driver) *Conn {
|
|
|
|
return &Conn{
|
2020-01-17 22:29:13 +00:00
|
|
|
Driver: driver,
|
2020-04-06 19:03:42 +00:00
|
|
|
dialer: &websocket.Dialer{
|
|
|
|
Proxy: http.ProxyFromEnvironment,
|
|
|
|
HandshakeTimeout: DefaultTimeout,
|
|
|
|
EnableCompression: true,
|
|
|
|
},
|
|
|
|
// zlib: zlib.NewInflator(),
|
|
|
|
// buf: make([]byte, CopyBufferSize),
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) Dial(ctx context.Context, addr string) error {
|
|
|
|
var err error
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
// Enable compression:
|
2020-01-09 05:24:45 +00:00
|
|
|
headers := http.Header{}
|
2020-04-06 19:03:42 +00:00
|
|
|
headers.Set("Accept-Encoding", "zlib")
|
|
|
|
|
|
|
|
// BUG: https://github.com/golang/go/issues/31514
|
|
|
|
// // Enable stream compression:
|
|
|
|
// addr = InjectValues(addr, url.Values{
|
|
|
|
// "compress": {"zlib-stream"},
|
|
|
|
// })
|
2020-01-09 05:24:45 +00:00
|
|
|
|
2020-01-20 11:06:20 +00:00
|
|
|
c.mut.Lock()
|
|
|
|
defer c.mut.Unlock()
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
|
2020-02-11 17:23:42 +00:00
|
|
|
if err != nil {
|
|
|
|
return errors.Wrap(err, "Failed to dial WS")
|
|
|
|
}
|
|
|
|
|
2020-02-02 22:12:54 +00:00
|
|
|
c.events = make(chan Event)
|
2020-04-06 19:03:42 +00:00
|
|
|
go c.readLoop()
|
2020-04-06 21:03:08 +00:00
|
|
|
|
|
|
|
c.writes = make(chan []byte)
|
|
|
|
c.errors = make(chan error)
|
|
|
|
go c.writeLoop()
|
|
|
|
|
2020-01-09 05:24:45 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) Listen() <-chan Event {
|
|
|
|
return c.events
|
|
|
|
}
|
|
|
|
|
2020-01-29 03:54:22 +00:00
|
|
|
func (c *Conn) readLoop() {
|
2020-04-06 19:03:42 +00:00
|
|
|
// Acquire the read lock throughout the span of the loop. This would still
|
|
|
|
// allow Send to acquire another RLock, but wouldn't allow Close to
|
|
|
|
// prematurely exit, as Close acquires a write lock.
|
|
|
|
c.mut.RLock()
|
|
|
|
defer c.mut.RUnlock()
|
|
|
|
|
|
|
|
// Clean up the events channel in the end.
|
|
|
|
defer close(c.events)
|
|
|
|
|
|
|
|
for {
|
|
|
|
b, err := c.handle()
|
|
|
|
if err != nil {
|
|
|
|
// Is the error an EOF?
|
2020-04-09 23:19:52 +00:00
|
|
|
if errors.Is(err, io.EOF) {
|
2020-04-06 19:03:42 +00:00
|
|
|
// Yes it is, exit.
|
2020-02-02 22:12:54 +00:00
|
|
|
return
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
// Check if the error is a normal one:
|
|
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// Unusual error; log and exit:
|
|
|
|
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2020-04-09 20:49:12 +00:00
|
|
|
// If the payload length is 0, skip it.
|
|
|
|
if len(b) == 0 {
|
2020-04-06 19:03:42 +00:00
|
|
|
continue
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
2020-04-06 19:03:42 +00:00
|
|
|
|
|
|
|
c.events <- Event{b, nil}
|
|
|
|
}
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
2020-04-06 21:03:08 +00:00
|
|
|
func (c *Conn) writeLoop() {
|
|
|
|
c.mut.RLock()
|
|
|
|
defer c.mut.RUnlock()
|
|
|
|
|
|
|
|
for bytes := range c.writes {
|
|
|
|
c.errors <- c.Conn.WriteMessage(websocket.TextMessage, bytes)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
func (c *Conn) handle() ([]byte, error) {
|
|
|
|
// skip message type
|
|
|
|
t, r, err := c.Conn.NextReader()
|
2020-01-09 05:24:45 +00:00
|
|
|
if err != nil {
|
2020-01-16 03:28:21 +00:00
|
|
|
return nil, err
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
if t == websocket.BinaryMessage {
|
2020-01-09 05:24:45 +00:00
|
|
|
// Probably a zlib payload
|
|
|
|
z, err := zlib.NewReader(r)
|
|
|
|
if err != nil {
|
2020-04-06 19:03:42 +00:00
|
|
|
return nil, errors.Wrap(err, "Failed to create a zlib reader")
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
defer z.Close()
|
|
|
|
r = z
|
|
|
|
}
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
return readAll(&c.buf, r)
|
|
|
|
|
|
|
|
// if t is a text message, then handle it normally.
|
|
|
|
// if t == websocket.TextMessage {
|
|
|
|
// return readAll(&c.buf, r)
|
|
|
|
// }
|
|
|
|
|
|
|
|
// // Write to the zlib writer.
|
|
|
|
// c.zlib.Write(r)
|
|
|
|
// // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil {
|
|
|
|
// // return nil, errors.Wrap(err, "Failed to write to zlib")
|
|
|
|
// // }
|
|
|
|
|
|
|
|
// if !c.zlib.CanFlush() {
|
|
|
|
// return nil, nil
|
|
|
|
// }
|
|
|
|
|
|
|
|
// // Flush and get the uncompressed payload.
|
|
|
|
// b, err := c.zlib.Flush()
|
|
|
|
// if err != nil {
|
|
|
|
// return nil, errors.Wrap(err, "Failed to flush zlib")
|
|
|
|
// }
|
|
|
|
|
|
|
|
// return nil, errors.New("Unexpected binary message.")
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) Send(b []byte) error {
|
|
|
|
c.mut.RLock()
|
|
|
|
defer c.mut.RUnlock()
|
|
|
|
|
|
|
|
if c.Conn == nil {
|
|
|
|
return errors.New("Websocket is closed.")
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
2020-04-06 21:03:08 +00:00
|
|
|
c.writes <- b
|
|
|
|
return <-c.errors
|
2020-01-20 08:53:23 +00:00
|
|
|
}
|
2020-01-09 05:24:45 +00:00
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
func (c *Conn) Close(code int) error {
|
|
|
|
// Wait for the read loop to exit at the end.
|
|
|
|
err := c.writeClose(code)
|
|
|
|
c.close()
|
|
|
|
return err
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
func (c *Conn) writeClose(code int) error {
|
2020-04-06 21:03:08 +00:00
|
|
|
// Acquire a read lock instead, as the read and write loops are still alive.
|
2020-04-06 19:03:42 +00:00
|
|
|
c.mut.RLock()
|
|
|
|
defer c.mut.RUnlock()
|
2020-01-20 11:06:20 +00:00
|
|
|
|
2020-04-06 21:03:08 +00:00
|
|
|
// Keep the current write channel so we can close them when we're done.
|
|
|
|
wr := c.writes
|
|
|
|
defer close(wr)
|
|
|
|
|
|
|
|
// Stop future sends before closing. Nil channels block forever.
|
|
|
|
c.writes = nil
|
|
|
|
c.errors = nil
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
// Quick deadline:
|
|
|
|
deadline := time.Now().Add(CloseDeadline)
|
2020-01-09 05:24:45 +00:00
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
// Make a closure message:
|
|
|
|
msg := websocket.FormatCloseMessage(code, "")
|
|
|
|
|
|
|
|
// Send a close message before closing the connection. We're not error
|
|
|
|
// checking this because it's not important.
|
|
|
|
c.Conn.WriteControl(websocket.TextMessage, msg, deadline)
|
2020-01-09 05:24:45 +00:00
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
// Safe to close now.
|
|
|
|
return c.Conn.Close()
|
2020-01-09 05:24:45 +00:00
|
|
|
}
|
2020-02-08 06:17:27 +00:00
|
|
|
|
|
|
|
func (c *Conn) close() {
|
2020-04-06 19:03:42 +00:00
|
|
|
// Flush all events:
|
|
|
|
for range c.events {
|
|
|
|
}
|
|
|
|
|
|
|
|
// This blocks until the events channel is dead.
|
2020-02-08 06:17:27 +00:00
|
|
|
c.mut.Lock()
|
|
|
|
defer c.mut.Unlock()
|
|
|
|
|
2020-04-06 19:03:42 +00:00
|
|
|
// Clean up.
|
2020-02-08 06:17:27 +00:00
|
|
|
c.events = nil
|
|
|
|
c.Conn = nil
|
|
|
|
}
|
2020-04-06 19:03:42 +00:00
|
|
|
|
|
|
|
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
|
|
|
// buffer.
|
|
|
|
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
|
|
|
|
defer buf.Reset()
|
|
|
|
if _, err := buf.ReadFrom(r); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Copy the bytes so we could empty the buffer for reuse.
|
|
|
|
p := buf.Bytes()
|
|
|
|
cpy := make([]byte, len(p))
|
|
|
|
copy(cpy, p)
|
|
|
|
|
|
|
|
return cpy, nil
|
|
|
|
}
|