Gateway: Switched to gorilla/websocket, fixes #11

This commit is contained in:
diamondburned (Forefront) 2020-04-06 12:03:42 -07:00
parent b5f7af70f3
commit 9f5c2ac958
10 changed files with 416 additions and 165 deletions

View File

@ -58,8 +58,8 @@ func (g *Gateway) Resume() error {
type HeartbeatData int
func (g *Gateway) Heartbeat() error {
g.available.RLock()
defer g.available.RUnlock()
// g.available.RLock()
// defer g.available.RUnlock()
return g.Send(HeartbeatOP, g.Sequence.Get())
}

View File

@ -99,7 +99,7 @@ type Gateway struct {
// Mutex to hold off calls when the WS is not available. Doesn't block if
// Start() is not called or Close() is called. Also doesn't block for
// Identify or Resume.
available sync.RWMutex
// available sync.RWMutex
// Filled by methods, internal use
paceDeath chan error
@ -131,19 +131,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
g.FatalError = g.fatalError
// Parameters for the gateway
param := url.Values{}
param.Set("v", Version)
param.Set("encoding", Encoding)
// param.Set("compress", Compress)
param := url.Values{
"v": {Version},
"encoding": {Encoding},
}
// Append the form to the URL
URL += "?" + param.Encode()
// Create a new undialed Websocket.
ws, err := wsutil.NewCustom(wsutil.NewConn(driver), URL)
if err != nil {
return nil, errors.Wrap(err, "Failed to connect to Gateway "+URL)
}
g.WS = ws
g.WS = wsutil.NewCustom(wsutil.NewConn(driver), URL)
// Try and dial it
return g, nil
@ -151,6 +148,12 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
// Check if the WS is already closed:
if g.waitGroup == nil && g.paceDeath == nil {
WSDebug("Gateway is already closed.")
return nil
}
// If the pacemaker is running:
if g.paceDeath != nil {
WSDebug("Stopping pacemaker...")
@ -167,22 +170,22 @@ func (g *Gateway) Close() error {
// would also exit our event loop. Both would be 2.
g.waitGroup.Wait()
WSDebug("WaitGroup is done.")
// Mark g.waitGroup as empty:
g.waitGroup = nil
// Stop the Websocket
return g.WS.Close(nil)
return g.WS.Close()
}
// Reconnects and resumes.
func (g *Gateway) Reconnect() error {
WSDebug("Reconnecting...")
// If the event loop is not dead:
if g.paceDeath != nil {
WSDebug("Gateway is not closed, closing before reconnecting...")
g.Close()
WSDebug("Gateway is closed asynchronously. Goroutine may not be exited.")
// Guarantee the gateway is already closed:
if err := g.Close(); err != nil {
return errors.Wrap(err, "Failed to close Gateway before reconnecting")
}
for i := 0; i < WSRetries; i++ {
@ -204,8 +207,11 @@ func (g *Gateway) Reconnect() error {
return ErrWSMaxTries
}
// Open connects to the Websocket and authenticate it. You should usually use
// this function over Start().
func (g *Gateway) Open() error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
// Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil {
@ -224,13 +230,17 @@ func (g *Gateway) Open() error {
}
// Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block.
// connection. This function doesn't block. You wouldn't usually use this
// function, but Open() instead.
func (g *Gateway) Start() error {
g.available.Lock()
defer g.available.Unlock()
// g.available.Lock()
// defer g.available.Unlock()
if err := g.start(); err != nil {
WSDebug("Start failed:", err)
// Close can be called with the mutex still acquired here, as the
// pacemaker hasn't started yet.
if err := g.Close(); err != nil {
WSDebug("Failed to close after start fail:", err)
}
@ -375,14 +385,11 @@ func (g *Gateway) send(lock bool, code OPCode, v interface{}) error {
return errors.Wrap(err, "Failed to encode payload")
}
// ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
// defer cancel()
ctx := context.Background()
// if lock {
// g.available.RLock()
// defer g.available.RUnlock()
// }
if lock {
g.available.RLock()
defer g.available.RUnlock()
}
return g.WS.Send(ctx, b)
// WS should already be thread-safe.
return g.WS.Send(b)
}

View File

@ -5,12 +5,17 @@ package gateway
import (
"log"
"os"
"strings"
"testing"
"time"
"nhooyr.io/websocket"
)
func init() {
WSDebug = func(v ...interface{}) {
log.Println(append([]interface{}{"Debug:"}, v...)...)
}
}
func TestInvalidToken(t *testing.T) {
g, err := NewGateway("bad token")
if err != nil {
@ -23,7 +28,7 @@ func TestInvalidToken(t *testing.T) {
}
// 4004 Authentication Failed.
if websocket.CloseStatus(err) == 4004 {
if strings.Contains(err.Error(), "4004") {
return
}
@ -65,6 +70,9 @@ func TestIntegration(t *testing.T) {
log.Println("Bot's username is", ready.User.Username)
// Sleep past the rate limiter before reconnecting:
time.Sleep(5 * time.Second)
// Try and reconnect
if err := gateway.Reconnect(); err != nil {
t.Fatal("Failed to reconnect:", err)
@ -77,8 +85,11 @@ Main:
select {
case ev := <-gateway.Events:
switch ev.(type) {
case *ResumedEvent, *ReadyEvent:
// Accept only a Resumed event.
case *ResumedEvent:
break Main
case *ReadyEvent:
t.Fatal("Ready event received instead of Resumed.")
}
case <-timeout:
t.Fatal("Timed out waiting for ResumedEvent")

View File

@ -45,6 +45,10 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
return nil, ev.Error
}
if len(ev.Data) == 0 {
return nil, errors.New("Empty payload")
}
var op *OP
if err := driver.Unmarshal(ev.Data, &op); err != nil {
return nil, errors.Wrap(err, "Failed to decode payload")
@ -170,8 +174,7 @@ func HandleOP(g *Gateway, op *OP) error {
return nil
default:
return fmt.Errorf(
"Unknown OP code %d (event %s)", op.Code, op.EventName)
return fmt.Errorf("Unknown OP code %d (event %s)", op.Code, op.EventName)
}
return nil

3
go.mod
View File

@ -4,10 +4,9 @@ go 1.13
require (
github.com/gorilla/schema v1.1.0
github.com/klauspost/compress v1.10.3 // indirect
github.com/gorilla/websocket v1.4.2
github.com/pkg/errors v0.9.1
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
nhooyr.io/websocket v1.7.4
)

43
go.sum
View File

@ -1,54 +1,15 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY=
github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8=
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa h1:xiD6U6h+QMkAwI195dFwdku2N+enlCy9XwFTnEXaCQo=
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa/go.mod h1:KKzWrLiWu6EpzxZBPmPisPgq6oL+do2yLa0C0BTx5fA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
nhooyr.io/websocket v1.7.4 h1:w/LGB2sZT0RV8lZYR7nfyaYz4PUbYZ5oF7NBon2M0NY=
nhooyr.io/websocket v1.7.4/go.mod h1:PxYxCwFdFYQ0yRvtQz3s/dC+VEm7CSuC/4b9t8MQQxw=

View File

@ -1,24 +1,29 @@
package wsutil
import (
"bytes"
"compress/zlib"
"context"
"io"
"io/ioutil"
"net/http"
"sync"
"time"
stderr "errors"
"github.com/diamondburned/arikawa/internal/json"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"nhooyr.io/websocket"
)
var WSReadLimit int64 = 8192000 // 8 MiB
const CopyBufferSize = 2048
// CloseDeadline controls the deadline to wait for sending the Close frame.
var CloseDeadline = time.Second
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself.
// This connection expects the driver to handle compression by itself, including
// modifying the connection URL.
type Connection interface {
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
@ -28,15 +33,12 @@ type Connection interface {
// nil, so check for Error first.
Listen() <-chan Event
// Send allows the caller to send bytes. Context needs to be passed in order
// to re-use the context that's already used for the limiter.
Send(context.Context, []byte) error
// Send allows the caller to send bytes. Thread safety is a requirement.
Send([]byte) error
// Close should close the websocket connection. The connection will not be
// reused.
// If error is nil, the connection should close with a StatusNormalClosure
// (1000). If not, it should close with a StatusProtocolError (1002).
Close(err error) error
// reused. Code should be sent as the status code for the close frame.
Close(code int) error
}
// Conn is the default Websocket connection. It compresses all payloads using
@ -45,8 +47,14 @@ type Conn struct {
Conn *websocket.Conn
json.Driver
mut sync.Mutex
dialer *websocket.Dialer
mut sync.RWMutex
events chan Event
buf bytes.Buffer
// zlib *zlib.Inflator // zlib.NewReader
// buf []byte // io.Copy buffer
}
var _ Connection = (*Conn)(nil)
@ -54,30 +62,40 @@ var _ Connection = (*Conn)(nil)
func NewConn(driver json.Driver) *Conn {
return &Conn{
Driver: driver,
dialer: &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: DefaultTimeout,
EnableCompression: true,
},
events: make(chan Event),
// zlib: zlib.NewInflator(),
// buf: make([]byte, CopyBufferSize),
}
}
func (c *Conn) Dial(ctx context.Context, addr string) error {
var err error
// Enable compression:
headers := http.Header{}
headers.Set("Accept-Encoding", "zlib") // enable
headers.Set("Accept-Encoding", "zlib")
// BUG: https://github.com/golang/go/issues/31514
// // Enable stream compression:
// addr = InjectValues(addr, url.Values{
// "compress": {"zlib-stream"},
// })
c.mut.Lock()
defer c.mut.Unlock()
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
HTTPHeader: headers,
})
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
if err != nil {
return errors.Wrap(err, "Failed to dial WS")
}
c.Conn.SetReadLimit(WSReadLimit)
c.events = make(chan Event)
c.readLoop()
go c.readLoop()
return err
}
@ -86,94 +104,149 @@ func (c *Conn) Listen() <-chan Event {
}
func (c *Conn) readLoop() {
conn := c.Conn
// Acquire the read lock throughout the span of the loop. This would still
// allow Send to acquire another RLock, but wouldn't allow Close to
// prematurely exit, as Close acquires a write lock.
c.mut.RLock()
defer c.mut.RUnlock()
go func() {
defer close(c.events)
// Clean up the events channel in the end.
defer close(c.events)
for {
b, err := readAll(conn, context.Background())
if err != nil {
// Is the error an EOF?
if stderr.Is(err, io.EOF) {
// Yes it is, exit.
return
}
// Check if the error is a fatal one
if code := websocket.CloseStatus(err); code > -1 {
// Is the exit normal?
if code == websocket.StatusNormalClosure {
return
}
}
// Unusual error; log:
c.events <- Event{nil, errors.Wrap(err, "WS error")}
for {
b, err := c.handle()
if err != nil {
// Is the error an EOF?
if stderr.Is(err, io.EOF) {
// Yes it is, exit.
return
}
c.events <- Event{b, nil}
// Check if the error is a normal one:
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return
}
// Unusual error; log and exit:
c.events <- Event{nil, errors.Wrap(err, "WS error")}
return
}
}()
// If nil bytes, then it's an incomplete payload.
if b == nil {
continue
}
c.events <- Event{b, nil}
}
}
func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) {
t, r, err := c.Reader(ctx)
func (c *Conn) handle() ([]byte, error) {
// skip message type
t, r, err := c.Conn.NextReader()
if err != nil {
return nil, err
}
if t == websocket.MessageBinary {
if t == websocket.BinaryMessage {
// Probably a zlib payload
z, err := zlib.NewReader(r)
if err != nil {
c.CloseRead(ctx)
return nil,
errors.Wrap(err, "Failed to create a zlib reader")
return nil, errors.Wrap(err, "Failed to create a zlib reader")
}
defer z.Close()
r = z
}
b, err := ioutil.ReadAll(r)
if err != nil {
c.CloseRead(ctx)
return nil, err
}
return readAll(&c.buf, r)
return b, nil
// if t is a text message, then handle it normally.
// if t == websocket.TextMessage {
// return readAll(&c.buf, r)
// }
// // Write to the zlib writer.
// c.zlib.Write(r)
// // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil {
// // return nil, errors.Wrap(err, "Failed to write to zlib")
// // }
// if !c.zlib.CanFlush() {
// return nil, nil
// }
// // Flush and get the uncompressed payload.
// b, err := c.zlib.Flush()
// if err != nil {
// return nil, errors.Wrap(err, "Failed to flush zlib")
// }
// return nil, errors.New("Unexpected binary message.")
}
func (c *Conn) Send(ctx context.Context, b []byte) error {
// TODO: zlib stream
return c.Conn.Write(ctx, websocket.MessageText, b)
func (c *Conn) Send(b []byte) error {
c.mut.RLock()
defer c.mut.RUnlock()
if c.Conn == nil {
return errors.New("Websocket is closed.")
}
return c.Conn.WriteMessage(websocket.TextMessage, b)
}
func (c *Conn) Close(err error) error {
// Wait for the read loop to exit after exiting.
defer c.close()
func (c *Conn) Close(code int) error {
// Wait for the read loop to exit at the end.
err := c.writeClose(code)
c.close()
return err
}
if err == nil {
return c.Conn.Close(websocket.StatusNormalClosure, "")
}
func (c *Conn) writeClose(code int) error {
c.mut.RLock()
defer c.mut.RUnlock()
var msg = err.Error()
if len(msg) > 125 {
msg = msg[:125] // truncate
}
// Quick deadline:
deadline := time.Now().Add(CloseDeadline)
return c.Conn.Close(websocket.StatusProtocolError, msg)
// Make a closure message:
msg := websocket.FormatCloseMessage(code, "")
// Send a close message before closing the connection. We're not error
// checking this because it's not important.
c.Conn.WriteControl(websocket.TextMessage, msg, deadline)
// Safe to close now.
return c.Conn.Close()
}
func (c *Conn) close() {
// Flush all events:
for range c.events {
}
// This blocks until the events channel is dead.
c.mut.Lock()
defer c.mut.Unlock()
<-c.events
// Clean up.
c.events = nil
// Set the connection to nil.
c.Conn = nil
}
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
// buffer.
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
defer buf.Reset()
if _, err := buf.ReadFrom(r); err != nil {
return nil, err
}
// Copy the bytes so we could empty the buffer for reuse.
p := buf.Bytes()
cpy := make([]byte, len(p))
copy(cpy, p)
return cpy, nil
}

View File

@ -4,9 +4,11 @@ package wsutil
import (
"context"
"net/url"
"time"
"github.com/diamondburned/arikawa/internal/json"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
@ -26,26 +28,21 @@ type Websocket struct {
SendLimiter *rate.Limiter
DialLimiter *rate.Limiter
listener <-chan Event
dialed bool
}
func New(addr string) (*Websocket, error) {
func New(addr string) *Websocket {
return NewCustom(NewConn(json.Default{}), addr)
}
// NewCustom creates a new undialed Websocket.
func NewCustom(conn Connection, addr string) (*Websocket, error) {
ws := &Websocket{
func NewCustom(conn Connection, addr string) *Websocket {
return &Websocket{
Conn: conn,
Addr: addr,
SendLimiter: NewSendLimiter(),
DialLimiter: NewDialLimiter(),
}
return ws, nil
}
func (ws *Websocket) Dial(ctx context.Context) error {
@ -68,14 +65,31 @@ func (ws *Websocket) Listen() <-chan Event {
return ws.Conn.Listen()
}
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
if err := ws.SendLimiter.Wait(ctx); err != nil {
func (ws *Websocket) Send(b []byte) error {
if err := ws.SendLimiter.Wait(context.Background()); err != nil {
return errors.Wrap(err, "SendLimiter failed")
}
return ws.Conn.Send(ctx, b)
return ws.Conn.Send(b)
}
func (ws *Websocket) Close(err error) error {
return ws.Conn.Close(err)
func (ws *Websocket) Close() error {
return ws.Conn.Close(websocket.CloseGoingAway)
}
func InjectValues(rawurl string, values url.Values) string {
u, err := url.Parse(rawurl)
if err != nil {
// Unknown URL, return as-is.
return rawurl
}
// Append additional parameters:
var q = u.Query()
for k, v := range values {
q[k] = append(q[k], v...)
}
u.RawQuery = q.Encode()
return u.String()
}

39
internal/zlib/flate.go Normal file
View File

@ -0,0 +1,39 @@
package zlib
import (
"compress/flate"
"compress/zlib"
"io"
)
type Reader interface {
io.ReadCloser
zlib.Resetter
}
func zlibStreamer(r flate.Reader) (Reader, error) {
// verify header
h := make([]byte, 2)
if _, err := io.ReadFull(r, h); err != nil {
return nil, err
}
// verify header
if err := verifyHeader(h); err != nil {
return nil, err
}
return flate.NewReader(r).(Reader), nil
}
// https://golang.org/src/compress/zlib/reader.go#L35
const zlibDeflate = 8
func verifyHeader(scratch []byte) error {
h := uint(scratch[0])<<8 | uint(scratch[1])
if (scratch[0]&0x0f != zlibDeflate) || (h%31 != 0) {
return zlib.ErrHeader
}
return nil
}

144
internal/zlib/zlib.go Normal file
View File

@ -0,0 +1,144 @@
// Package zlib provides abstractions on top of compress/zlib to work with
// Discord's method of compressing websocket packets.
package zlib
import (
"bytes"
"log"
"github.com/pkg/errors"
)
var Suffix = [4]byte{'\x00', '\x00', '\xff', '\xff'}
var ErrPartial = errors.New("only partial payload in buffer")
type Inflator struct {
zlib Reader
wbuf bytes.Buffer // write buffer for writing compressed bytes
rbuf bytes.Buffer // read buffer for writing uncompressed bytes
}
func NewInflator() *Inflator {
return &Inflator{
wbuf: bytes.Buffer{},
rbuf: bytes.Buffer{},
}
}
func (i *Inflator) Write(p []byte) (n int, err error) {
log.Println(p)
// Write to buffer normally.
return i.wbuf.Write(p)
}
// CanFlush returns if Flush() should be called.
func (i *Inflator) CanFlush() bool {
if i.wbuf.Len() < 4 {
return false
}
p := i.wbuf.Bytes()
return bytes.Equal(p[len(p)-4:], Suffix[:])
}
func (i *Inflator) Flush() ([]byte, error) {
// Check if close frames are there:
// if !i.CanFlush() {
// return nil, ErrPartial
// }
// log.Println(i.wbuf.Bytes())
// We should reset the write buffer after flushing.
// defer i.wbuf.Reset()
// We can reset the read buffer while returning its byte slice. This works
// as long as we copy the byte slice before resetting.
defer i.rbuf.Reset()
// Guarantee there's a zlib writer. Since Discord streams zlib, we have to
// reuse the same Reader. Only the first packet has the zlib header.
if i.zlib == nil {
r, err := zlibStreamer(&i.wbuf)
if err != nil {
return nil, errors.Wrap(err, "Failed to make a FLATE reader")
}
// safe assertion
i.zlib = r
// } else {
// // Reset the FLATE reader for future use:
// if err := i.zlib.Reset(&i.wbuf, nil); err != nil {
// return nil, errors.Wrap(err, "Failed to reset zlib reader")
// }
}
// We can ignore zlib.Read's error, as zlib.Close would return them.
_, err := i.rbuf.ReadFrom(i.zlib)
// ErrUnexpectedEOF happens because zlib tries to find the last 4 bytes
// to verify checksum. Discord doesn't send this.
if err != nil {
// Unexpected error, try and close.
return nil, errors.Wrap(err, "Failed to read from FLATE reader")
}
// if err := i.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF {
// // Try and close anyway.
// return nil, errors.Wrap(err, "Failed to read from zlib reader")
// }
// Copy the bytes.
return bytecopy(i.rbuf.Bytes()), nil
}
// func (d *Deflator) TryFlush() ([]byte, error) {
// // Check if the buffer ends with the zlib close suffix.
// if d.wbuf.Len() < 4 {
// return nil, nil
// }
// if p := d.wbuf.Bytes(); !bytes.Equal(p[len(p)-4:], Suffix[:]) {
// return nil, nil
// }
// // Guarantee there's a zlib writer. Since Discord streams zlib, we have to
// // reuse the same Reader. Only the first packet has the zlib header.
// if d.zlib == nil {
// r, err := zlib.NewReader(&d.wbuf)
// if err != nil {
// return nil, errors.Wrap(err, "Failed to make a zlib reader")
// }
// // safe assertion
// d.zlib = r
// }
// // We can reset the read buffer while returning its byte slice. This works
// // as long as we copy the byte slice before resetting.
// defer d.rbuf.Reset()
// defer d.wbuf.Reset()
// // We can ignore zlib.Read's error, as zlib.Close would return them.
// _, err := d.rbuf.ReadFrom(d.zlib)
// log.Println("Read:", err, d.rbuf.String())
// // ErrUnexpectedEOF happens because zlib tries to find the last 4 bytes
// // to verify checksum. Discord doesn't send this.
// // if err != nil && err != io.ErrUnexpectedEOF {
// // // Unexpected error, try and close.
// // return nil, errors.Wrap(err, "Failed to read from zlib reader")
// // }
// if err := d.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF {
// // Try and close anyway.
// return nil, errors.Wrap(err, "Failed to read from zlib reader")
// }
// // Copy the bytes.
// return bytecopy(d.rbuf.Bytes()), nil
// }
func bytecopy(p []byte) []byte {
cpy := make([]byte, len(p))
copy(cpy, p)
return cpy
}