Gateway: Switched to gorilla/websocket, fixes #11
This commit is contained in:
parent
b5f7af70f3
commit
9f5c2ac958
|
@ -58,8 +58,8 @@ func (g *Gateway) Resume() error {
|
|||
type HeartbeatData int
|
||||
|
||||
func (g *Gateway) Heartbeat() error {
|
||||
g.available.RLock()
|
||||
defer g.available.RUnlock()
|
||||
// g.available.RLock()
|
||||
// defer g.available.RUnlock()
|
||||
return g.Send(HeartbeatOP, g.Sequence.Get())
|
||||
}
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ type Gateway struct {
|
|||
// Mutex to hold off calls when the WS is not available. Doesn't block if
|
||||
// Start() is not called or Close() is called. Also doesn't block for
|
||||
// Identify or Resume.
|
||||
available sync.RWMutex
|
||||
// available sync.RWMutex
|
||||
|
||||
// Filled by methods, internal use
|
||||
paceDeath chan error
|
||||
|
@ -131,19 +131,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|||
g.FatalError = g.fatalError
|
||||
|
||||
// Parameters for the gateway
|
||||
param := url.Values{}
|
||||
param.Set("v", Version)
|
||||
param.Set("encoding", Encoding)
|
||||
// param.Set("compress", Compress)
|
||||
param := url.Values{
|
||||
"v": {Version},
|
||||
"encoding": {Encoding},
|
||||
}
|
||||
|
||||
// Append the form to the URL
|
||||
URL += "?" + param.Encode()
|
||||
|
||||
// Create a new undialed Websocket.
|
||||
ws, err := wsutil.NewCustom(wsutil.NewConn(driver), URL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to connect to Gateway "+URL)
|
||||
}
|
||||
g.WS = ws
|
||||
g.WS = wsutil.NewCustom(wsutil.NewConn(driver), URL)
|
||||
|
||||
// Try and dial it
|
||||
return g, nil
|
||||
|
@ -151,6 +148,12 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|||
|
||||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
// Check if the WS is already closed:
|
||||
if g.waitGroup == nil && g.paceDeath == nil {
|
||||
WSDebug("Gateway is already closed.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the pacemaker is running:
|
||||
if g.paceDeath != nil {
|
||||
WSDebug("Stopping pacemaker...")
|
||||
|
@ -167,22 +170,22 @@ func (g *Gateway) Close() error {
|
|||
// would also exit our event loop. Both would be 2.
|
||||
g.waitGroup.Wait()
|
||||
|
||||
WSDebug("WaitGroup is done.")
|
||||
|
||||
// Mark g.waitGroup as empty:
|
||||
g.waitGroup = nil
|
||||
|
||||
// Stop the Websocket
|
||||
return g.WS.Close(nil)
|
||||
return g.WS.Close()
|
||||
}
|
||||
|
||||
// Reconnects and resumes.
|
||||
func (g *Gateway) Reconnect() error {
|
||||
WSDebug("Reconnecting...")
|
||||
|
||||
// If the event loop is not dead:
|
||||
if g.paceDeath != nil {
|
||||
WSDebug("Gateway is not closed, closing before reconnecting...")
|
||||
g.Close()
|
||||
WSDebug("Gateway is closed asynchronously. Goroutine may not be exited.")
|
||||
// Guarantee the gateway is already closed:
|
||||
if err := g.Close(); err != nil {
|
||||
return errors.Wrap(err, "Failed to close Gateway before reconnecting")
|
||||
}
|
||||
|
||||
for i := 0; i < WSRetries; i++ {
|
||||
|
@ -204,8 +207,11 @@ func (g *Gateway) Reconnect() error {
|
|||
return ErrWSMaxTries
|
||||
}
|
||||
|
||||
// Open connects to the Websocket and authenticate it. You should usually use
|
||||
// this function over Start().
|
||||
func (g *Gateway) Open() error {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Dial(ctx); err != nil {
|
||||
|
@ -224,13 +230,17 @@ func (g *Gateway) Open() error {
|
|||
}
|
||||
|
||||
// Start authenticates with the websocket, or resume from a dead Websocket
|
||||
// connection. This function doesn't block.
|
||||
// connection. This function doesn't block. You wouldn't usually use this
|
||||
// function, but Open() instead.
|
||||
func (g *Gateway) Start() error {
|
||||
g.available.Lock()
|
||||
defer g.available.Unlock()
|
||||
// g.available.Lock()
|
||||
// defer g.available.Unlock()
|
||||
|
||||
if err := g.start(); err != nil {
|
||||
WSDebug("Start failed:", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
// pacemaker hasn't started yet.
|
||||
if err := g.Close(); err != nil {
|
||||
WSDebug("Failed to close after start fail:", err)
|
||||
}
|
||||
|
@ -375,14 +385,11 @@ func (g *Gateway) send(lock bool, code OPCode, v interface{}) error {
|
|||
return errors.Wrap(err, "Failed to encode payload")
|
||||
}
|
||||
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
// defer cancel()
|
||||
ctx := context.Background()
|
||||
// if lock {
|
||||
// g.available.RLock()
|
||||
// defer g.available.RUnlock()
|
||||
// }
|
||||
|
||||
if lock {
|
||||
g.available.RLock()
|
||||
defer g.available.RUnlock()
|
||||
}
|
||||
|
||||
return g.WS.Send(ctx, b)
|
||||
// WS should already be thread-safe.
|
||||
return g.WS.Send(b)
|
||||
}
|
||||
|
|
|
@ -5,12 +5,17 @@ package gateway
|
|||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
func init() {
|
||||
WSDebug = func(v ...interface{}) {
|
||||
log.Println(append([]interface{}{"Debug:"}, v...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidToken(t *testing.T) {
|
||||
g, err := NewGateway("bad token")
|
||||
if err != nil {
|
||||
|
@ -23,7 +28,7 @@ func TestInvalidToken(t *testing.T) {
|
|||
}
|
||||
|
||||
// 4004 Authentication Failed.
|
||||
if websocket.CloseStatus(err) == 4004 {
|
||||
if strings.Contains(err.Error(), "4004") {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -65,6 +70,9 @@ func TestIntegration(t *testing.T) {
|
|||
|
||||
log.Println("Bot's username is", ready.User.Username)
|
||||
|
||||
// Sleep past the rate limiter before reconnecting:
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Try and reconnect
|
||||
if err := gateway.Reconnect(); err != nil {
|
||||
t.Fatal("Failed to reconnect:", err)
|
||||
|
@ -77,8 +85,11 @@ Main:
|
|||
select {
|
||||
case ev := <-gateway.Events:
|
||||
switch ev.(type) {
|
||||
case *ResumedEvent, *ReadyEvent:
|
||||
// Accept only a Resumed event.
|
||||
case *ResumedEvent:
|
||||
break Main
|
||||
case *ReadyEvent:
|
||||
t.Fatal("Ready event received instead of Resumed.")
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatal("Timed out waiting for ResumedEvent")
|
||||
|
|
|
@ -45,6 +45,10 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
|||
return nil, ev.Error
|
||||
}
|
||||
|
||||
if len(ev.Data) == 0 {
|
||||
return nil, errors.New("Empty payload")
|
||||
}
|
||||
|
||||
var op *OP
|
||||
if err := driver.Unmarshal(ev.Data, &op); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to decode payload")
|
||||
|
@ -170,8 +174,7 @@ func HandleOP(g *Gateway, op *OP) error {
|
|||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"Unknown OP code %d (event %s)", op.Code, op.EventName)
|
||||
return fmt.Errorf("Unknown OP code %d (event %s)", op.Code, op.EventName)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
3
go.mod
3
go.mod
|
@ -4,10 +4,9 @@ go 1.13
|
|||
|
||||
require (
|
||||
github.com/gorilla/schema v1.1.0
|
||||
github.com/klauspost/compress v1.10.3 // indirect
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 // indirect
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
|
||||
nhooyr.io/websocket v1.7.4
|
||||
)
|
||||
|
|
43
go.sum
43
go.sum
|
@ -1,54 +1,15 @@
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0=
|
||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8=
|
||||
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
|
||||
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY=
|
||||
github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
|
||||
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8=
|
||||
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa h1:xiD6U6h+QMkAwI195dFwdku2N+enlCy9XwFTnEXaCQo=
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa/go.mod h1:KKzWrLiWu6EpzxZBPmPisPgq6oL+do2yLa0C0BTx5fA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU=
|
||||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
|
||||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
nhooyr.io/websocket v1.7.4 h1:w/LGB2sZT0RV8lZYR7nfyaYz4PUbYZ5oF7NBon2M0NY=
|
||||
nhooyr.io/websocket v1.7.4/go.mod h1:PxYxCwFdFYQ0yRvtQz3s/dC+VEm7CSuC/4b9t8MQQxw=
|
||||
|
|
|
@ -1,24 +1,29 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
stderr "errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/internal/json"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
var WSReadLimit int64 = 8192000 // 8 MiB
|
||||
const CopyBufferSize = 2048
|
||||
|
||||
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
||||
var CloseDeadline = time.Second
|
||||
|
||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||
// This connection expects the driver to handle compression by itself.
|
||||
// This connection expects the driver to handle compression by itself, including
|
||||
// modifying the connection URL.
|
||||
type Connection interface {
|
||||
// Dial dials the address (string). Context needs to be passed in for
|
||||
// timeout. This method should also be re-usable after Close is called.
|
||||
|
@ -28,15 +33,12 @@ type Connection interface {
|
|||
// nil, so check for Error first.
|
||||
Listen() <-chan Event
|
||||
|
||||
// Send allows the caller to send bytes. Context needs to be passed in order
|
||||
// to re-use the context that's already used for the limiter.
|
||||
Send(context.Context, []byte) error
|
||||
// Send allows the caller to send bytes. Thread safety is a requirement.
|
||||
Send([]byte) error
|
||||
|
||||
// Close should close the websocket connection. The connection will not be
|
||||
// reused.
|
||||
// If error is nil, the connection should close with a StatusNormalClosure
|
||||
// (1000). If not, it should close with a StatusProtocolError (1002).
|
||||
Close(err error) error
|
||||
// reused. Code should be sent as the status code for the close frame.
|
||||
Close(code int) error
|
||||
}
|
||||
|
||||
// Conn is the default Websocket connection. It compresses all payloads using
|
||||
|
@ -45,8 +47,14 @@ type Conn struct {
|
|||
Conn *websocket.Conn
|
||||
json.Driver
|
||||
|
||||
mut sync.Mutex
|
||||
dialer *websocket.Dialer
|
||||
mut sync.RWMutex
|
||||
events chan Event
|
||||
|
||||
buf bytes.Buffer
|
||||
|
||||
// zlib *zlib.Inflator // zlib.NewReader
|
||||
// buf []byte // io.Copy buffer
|
||||
}
|
||||
|
||||
var _ Connection = (*Conn)(nil)
|
||||
|
@ -54,30 +62,40 @@ var _ Connection = (*Conn)(nil)
|
|||
func NewConn(driver json.Driver) *Conn {
|
||||
return &Conn{
|
||||
Driver: driver,
|
||||
dialer: &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: DefaultTimeout,
|
||||
EnableCompression: true,
|
||||
},
|
||||
events: make(chan Event),
|
||||
// zlib: zlib.NewInflator(),
|
||||
// buf: make([]byte, CopyBufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||
var err error
|
||||
|
||||
// Enable compression:
|
||||
headers := http.Header{}
|
||||
headers.Set("Accept-Encoding", "zlib") // enable
|
||||
headers.Set("Accept-Encoding", "zlib")
|
||||
|
||||
// BUG: https://github.com/golang/go/issues/31514
|
||||
// // Enable stream compression:
|
||||
// addr = InjectValues(addr, url.Values{
|
||||
// "compress": {"zlib-stream"},
|
||||
// })
|
||||
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
|
||||
HTTPHeader: headers,
|
||||
})
|
||||
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to dial WS")
|
||||
}
|
||||
|
||||
c.Conn.SetReadLimit(WSReadLimit)
|
||||
|
||||
c.events = make(chan Event)
|
||||
c.readLoop()
|
||||
go c.readLoop()
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -86,94 +104,149 @@ func (c *Conn) Listen() <-chan Event {
|
|||
}
|
||||
|
||||
func (c *Conn) readLoop() {
|
||||
conn := c.Conn
|
||||
// Acquire the read lock throughout the span of the loop. This would still
|
||||
// allow Send to acquire another RLock, but wouldn't allow Close to
|
||||
// prematurely exit, as Close acquires a write lock.
|
||||
c.mut.RLock()
|
||||
defer c.mut.RUnlock()
|
||||
|
||||
go func() {
|
||||
defer close(c.events)
|
||||
// Clean up the events channel in the end.
|
||||
defer close(c.events)
|
||||
|
||||
for {
|
||||
b, err := readAll(conn, context.Background())
|
||||
if err != nil {
|
||||
// Is the error an EOF?
|
||||
if stderr.Is(err, io.EOF) {
|
||||
// Yes it is, exit.
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the error is a fatal one
|
||||
if code := websocket.CloseStatus(err); code > -1 {
|
||||
// Is the exit normal?
|
||||
if code == websocket.StatusNormalClosure {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Unusual error; log:
|
||||
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
for {
|
||||
b, err := c.handle()
|
||||
if err != nil {
|
||||
// Is the error an EOF?
|
||||
if stderr.Is(err, io.EOF) {
|
||||
// Yes it is, exit.
|
||||
return
|
||||
}
|
||||
|
||||
c.events <- Event{b, nil}
|
||||
// Check if the error is a normal one:
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||
return
|
||||
}
|
||||
|
||||
// Unusual error; log and exit:
|
||||
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// If nil bytes, then it's an incomplete payload.
|
||||
if b == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
c.events <- Event{b, nil}
|
||||
}
|
||||
}
|
||||
|
||||
func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) {
|
||||
t, r, err := c.Reader(ctx)
|
||||
func (c *Conn) handle() ([]byte, error) {
|
||||
// skip message type
|
||||
t, r, err := c.Conn.NextReader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t == websocket.MessageBinary {
|
||||
if t == websocket.BinaryMessage {
|
||||
// Probably a zlib payload
|
||||
z, err := zlib.NewReader(r)
|
||||
if err != nil {
|
||||
c.CloseRead(ctx)
|
||||
return nil,
|
||||
errors.Wrap(err, "Failed to create a zlib reader")
|
||||
return nil, errors.Wrap(err, "Failed to create a zlib reader")
|
||||
}
|
||||
|
||||
defer z.Close()
|
||||
r = z
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
c.CloseRead(ctx)
|
||||
return nil, err
|
||||
}
|
||||
return readAll(&c.buf, r)
|
||||
|
||||
return b, nil
|
||||
// if t is a text message, then handle it normally.
|
||||
// if t == websocket.TextMessage {
|
||||
// return readAll(&c.buf, r)
|
||||
// }
|
||||
|
||||
// // Write to the zlib writer.
|
||||
// c.zlib.Write(r)
|
||||
// // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil {
|
||||
// // return nil, errors.Wrap(err, "Failed to write to zlib")
|
||||
// // }
|
||||
|
||||
// if !c.zlib.CanFlush() {
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
// // Flush and get the uncompressed payload.
|
||||
// b, err := c.zlib.Flush()
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "Failed to flush zlib")
|
||||
// }
|
||||
|
||||
// return nil, errors.New("Unexpected binary message.")
|
||||
}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
// TODO: zlib stream
|
||||
return c.Conn.Write(ctx, websocket.MessageText, b)
|
||||
func (c *Conn) Send(b []byte) error {
|
||||
c.mut.RLock()
|
||||
defer c.mut.RUnlock()
|
||||
|
||||
if c.Conn == nil {
|
||||
return errors.New("Websocket is closed.")
|
||||
}
|
||||
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close(err error) error {
|
||||
// Wait for the read loop to exit after exiting.
|
||||
defer c.close()
|
||||
func (c *Conn) Close(code int) error {
|
||||
// Wait for the read loop to exit at the end.
|
||||
err := c.writeClose(code)
|
||||
c.close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return c.Conn.Close(websocket.StatusNormalClosure, "")
|
||||
}
|
||||
func (c *Conn) writeClose(code int) error {
|
||||
c.mut.RLock()
|
||||
defer c.mut.RUnlock()
|
||||
|
||||
var msg = err.Error()
|
||||
if len(msg) > 125 {
|
||||
msg = msg[:125] // truncate
|
||||
}
|
||||
// Quick deadline:
|
||||
deadline := time.Now().Add(CloseDeadline)
|
||||
|
||||
return c.Conn.Close(websocket.StatusProtocolError, msg)
|
||||
// Make a closure message:
|
||||
msg := websocket.FormatCloseMessage(code, "")
|
||||
|
||||
// Send a close message before closing the connection. We're not error
|
||||
// checking this because it's not important.
|
||||
c.Conn.WriteControl(websocket.TextMessage, msg, deadline)
|
||||
|
||||
// Safe to close now.
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) close() {
|
||||
// Flush all events:
|
||||
for range c.events {
|
||||
}
|
||||
|
||||
// This blocks until the events channel is dead.
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
<-c.events
|
||||
// Clean up.
|
||||
c.events = nil
|
||||
|
||||
// Set the connection to nil.
|
||||
c.Conn = nil
|
||||
}
|
||||
|
||||
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
||||
// buffer.
|
||||
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
|
||||
defer buf.Reset()
|
||||
if _, err := buf.ReadFrom(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy the bytes so we could empty the buffer for reuse.
|
||||
p := buf.Bytes()
|
||||
cpy := make([]byte, len(p))
|
||||
copy(cpy, p)
|
||||
|
||||
return cpy, nil
|
||||
}
|
||||
|
|
|
@ -4,9 +4,11 @@ package wsutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/internal/json"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
@ -26,26 +28,21 @@ type Websocket struct {
|
|||
|
||||
SendLimiter *rate.Limiter
|
||||
DialLimiter *rate.Limiter
|
||||
|
||||
listener <-chan Event
|
||||
dialed bool
|
||||
}
|
||||
|
||||
func New(addr string) (*Websocket, error) {
|
||||
func New(addr string) *Websocket {
|
||||
return NewCustom(NewConn(json.Default{}), addr)
|
||||
}
|
||||
|
||||
// NewCustom creates a new undialed Websocket.
|
||||
func NewCustom(conn Connection, addr string) (*Websocket, error) {
|
||||
ws := &Websocket{
|
||||
func NewCustom(conn Connection, addr string) *Websocket {
|
||||
return &Websocket{
|
||||
Conn: conn,
|
||||
Addr: addr,
|
||||
|
||||
SendLimiter: NewSendLimiter(),
|
||||
DialLimiter: NewDialLimiter(),
|
||||
}
|
||||
|
||||
return ws, nil
|
||||
}
|
||||
|
||||
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||
|
@ -68,14 +65,31 @@ func (ws *Websocket) Listen() <-chan Event {
|
|||
return ws.Conn.Listen()
|
||||
}
|
||||
|
||||
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(ctx); err != nil {
|
||||
func (ws *Websocket) Send(b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(context.Background()); err != nil {
|
||||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
return ws.Conn.Send(ctx, b)
|
||||
return ws.Conn.Send(b)
|
||||
}
|
||||
|
||||
func (ws *Websocket) Close(err error) error {
|
||||
return ws.Conn.Close(err)
|
||||
func (ws *Websocket) Close() error {
|
||||
return ws.Conn.Close(websocket.CloseGoingAway)
|
||||
}
|
||||
|
||||
func InjectValues(rawurl string, values url.Values) string {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
// Unknown URL, return as-is.
|
||||
return rawurl
|
||||
}
|
||||
|
||||
// Append additional parameters:
|
||||
var q = u.Query()
|
||||
for k, v := range values {
|
||||
q[k] = append(q[k], v...)
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
package zlib
|
||||
|
||||
import (
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Reader interface {
|
||||
io.ReadCloser
|
||||
zlib.Resetter
|
||||
}
|
||||
|
||||
func zlibStreamer(r flate.Reader) (Reader, error) {
|
||||
// verify header
|
||||
h := make([]byte, 2)
|
||||
|
||||
if _, err := io.ReadFull(r, h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// verify header
|
||||
if err := verifyHeader(h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flate.NewReader(r).(Reader), nil
|
||||
}
|
||||
|
||||
// https://golang.org/src/compress/zlib/reader.go#L35
|
||||
const zlibDeflate = 8
|
||||
|
||||
func verifyHeader(scratch []byte) error {
|
||||
h := uint(scratch[0])<<8 | uint(scratch[1])
|
||||
if (scratch[0]&0x0f != zlibDeflate) || (h%31 != 0) {
|
||||
return zlib.ErrHeader
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
// Package zlib provides abstractions on top of compress/zlib to work with
|
||||
// Discord's method of compressing websocket packets.
|
||||
package zlib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var Suffix = [4]byte{'\x00', '\x00', '\xff', '\xff'}
|
||||
|
||||
var ErrPartial = errors.New("only partial payload in buffer")
|
||||
|
||||
type Inflator struct {
|
||||
zlib Reader
|
||||
wbuf bytes.Buffer // write buffer for writing compressed bytes
|
||||
rbuf bytes.Buffer // read buffer for writing uncompressed bytes
|
||||
}
|
||||
|
||||
func NewInflator() *Inflator {
|
||||
return &Inflator{
|
||||
wbuf: bytes.Buffer{},
|
||||
rbuf: bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Inflator) Write(p []byte) (n int, err error) {
|
||||
log.Println(p)
|
||||
// Write to buffer normally.
|
||||
return i.wbuf.Write(p)
|
||||
}
|
||||
|
||||
// CanFlush returns if Flush() should be called.
|
||||
func (i *Inflator) CanFlush() bool {
|
||||
if i.wbuf.Len() < 4 {
|
||||
return false
|
||||
}
|
||||
p := i.wbuf.Bytes()
|
||||
return bytes.Equal(p[len(p)-4:], Suffix[:])
|
||||
}
|
||||
|
||||
func (i *Inflator) Flush() ([]byte, error) {
|
||||
// Check if close frames are there:
|
||||
// if !i.CanFlush() {
|
||||
// return nil, ErrPartial
|
||||
// }
|
||||
|
||||
// log.Println(i.wbuf.Bytes())
|
||||
|
||||
// We should reset the write buffer after flushing.
|
||||
// defer i.wbuf.Reset()
|
||||
|
||||
// We can reset the read buffer while returning its byte slice. This works
|
||||
// as long as we copy the byte slice before resetting.
|
||||
defer i.rbuf.Reset()
|
||||
|
||||
// Guarantee there's a zlib writer. Since Discord streams zlib, we have to
|
||||
// reuse the same Reader. Only the first packet has the zlib header.
|
||||
if i.zlib == nil {
|
||||
r, err := zlibStreamer(&i.wbuf)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to make a FLATE reader")
|
||||
}
|
||||
// safe assertion
|
||||
i.zlib = r
|
||||
// } else {
|
||||
// // Reset the FLATE reader for future use:
|
||||
// if err := i.zlib.Reset(&i.wbuf, nil); err != nil {
|
||||
// return nil, errors.Wrap(err, "Failed to reset zlib reader")
|
||||
// }
|
||||
}
|
||||
|
||||
// We can ignore zlib.Read's error, as zlib.Close would return them.
|
||||
_, err := i.rbuf.ReadFrom(i.zlib)
|
||||
|
||||
// ErrUnexpectedEOF happens because zlib tries to find the last 4 bytes
|
||||
// to verify checksum. Discord doesn't send this.
|
||||
if err != nil {
|
||||
// Unexpected error, try and close.
|
||||
return nil, errors.Wrap(err, "Failed to read from FLATE reader")
|
||||
}
|
||||
|
||||
// if err := i.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF {
|
||||
// // Try and close anyway.
|
||||
// return nil, errors.Wrap(err, "Failed to read from zlib reader")
|
||||
// }
|
||||
|
||||
// Copy the bytes.
|
||||
return bytecopy(i.rbuf.Bytes()), nil
|
||||
}
|
||||
|
||||
// func (d *Deflator) TryFlush() ([]byte, error) {
|
||||
// // Check if the buffer ends with the zlib close suffix.
|
||||
// if d.wbuf.Len() < 4 {
|
||||
// return nil, nil
|
||||
// }
|
||||
// if p := d.wbuf.Bytes(); !bytes.Equal(p[len(p)-4:], Suffix[:]) {
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
// // Guarantee there's a zlib writer. Since Discord streams zlib, we have to
|
||||
// // reuse the same Reader. Only the first packet has the zlib header.
|
||||
// if d.zlib == nil {
|
||||
// r, err := zlib.NewReader(&d.wbuf)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "Failed to make a zlib reader")
|
||||
// }
|
||||
// // safe assertion
|
||||
// d.zlib = r
|
||||
// }
|
||||
|
||||
// // We can reset the read buffer while returning its byte slice. This works
|
||||
// // as long as we copy the byte slice before resetting.
|
||||
// defer d.rbuf.Reset()
|
||||
|
||||
// defer d.wbuf.Reset()
|
||||
|
||||
// // We can ignore zlib.Read's error, as zlib.Close would return them.
|
||||
// _, err := d.rbuf.ReadFrom(d.zlib)
|
||||
// log.Println("Read:", err, d.rbuf.String())
|
||||
|
||||
// // ErrUnexpectedEOF happens because zlib tries to find the last 4 bytes
|
||||
// // to verify checksum. Discord doesn't send this.
|
||||
// // if err != nil && err != io.ErrUnexpectedEOF {
|
||||
// // // Unexpected error, try and close.
|
||||
// // return nil, errors.Wrap(err, "Failed to read from zlib reader")
|
||||
// // }
|
||||
|
||||
// if err := d.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF {
|
||||
// // Try and close anyway.
|
||||
// return nil, errors.Wrap(err, "Failed to read from zlib reader")
|
||||
// }
|
||||
|
||||
// // Copy the bytes.
|
||||
// return bytecopy(d.rbuf.Bytes()), nil
|
||||
// }
|
||||
|
||||
func bytecopy(p []byte) []byte {
|
||||
cpy := make([]byte, len(p))
|
||||
copy(cpy, p)
|
||||
return cpy
|
||||
}
|
Loading…
Reference in New Issue