diff --git a/gateway/commands.go b/gateway/commands.go index 9eed32a..8fb1ebb 100644 --- a/gateway/commands.go +++ b/gateway/commands.go @@ -58,8 +58,8 @@ func (g *Gateway) Resume() error { type HeartbeatData int func (g *Gateway) Heartbeat() error { - g.available.RLock() - defer g.available.RUnlock() + // g.available.RLock() + // defer g.available.RUnlock() return g.Send(HeartbeatOP, g.Sequence.Get()) } diff --git a/gateway/gateway.go b/gateway/gateway.go index 0959bdb..6c01da1 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -99,7 +99,7 @@ type Gateway struct { // Mutex to hold off calls when the WS is not available. Doesn't block if // Start() is not called or Close() is called. Also doesn't block for // Identify or Resume. - available sync.RWMutex + // available sync.RWMutex // Filled by methods, internal use paceDeath chan error @@ -131,19 +131,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { g.FatalError = g.fatalError // Parameters for the gateway - param := url.Values{} - param.Set("v", Version) - param.Set("encoding", Encoding) - // param.Set("compress", Compress) + param := url.Values{ + "v": {Version}, + "encoding": {Encoding}, + } + // Append the form to the URL URL += "?" + param.Encode() // Create a new undialed Websocket. - ws, err := wsutil.NewCustom(wsutil.NewConn(driver), URL) - if err != nil { - return nil, errors.Wrap(err, "Failed to connect to Gateway "+URL) - } - g.WS = ws + g.WS = wsutil.NewCustom(wsutil.NewConn(driver), URL) // Try and dial it return g, nil @@ -151,6 +148,12 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { // Close closes the underlying Websocket connection. func (g *Gateway) Close() error { + // Check if the WS is already closed: + if g.waitGroup == nil && g.paceDeath == nil { + WSDebug("Gateway is already closed.") + return nil + } + // If the pacemaker is running: if g.paceDeath != nil { WSDebug("Stopping pacemaker...") @@ -167,22 +170,22 @@ func (g *Gateway) Close() error { // would also exit our event loop. Both would be 2. g.waitGroup.Wait() + WSDebug("WaitGroup is done.") + // Mark g.waitGroup as empty: g.waitGroup = nil // Stop the Websocket - return g.WS.Close(nil) + return g.WS.Close() } // Reconnects and resumes. func (g *Gateway) Reconnect() error { WSDebug("Reconnecting...") - // If the event loop is not dead: - if g.paceDeath != nil { - WSDebug("Gateway is not closed, closing before reconnecting...") - g.Close() - WSDebug("Gateway is closed asynchronously. Goroutine may not be exited.") + // Guarantee the gateway is already closed: + if err := g.Close(); err != nil { + return errors.Wrap(err, "Failed to close Gateway before reconnecting") } for i := 0; i < WSRetries; i++ { @@ -204,8 +207,11 @@ func (g *Gateway) Reconnect() error { return ErrWSMaxTries } +// Open connects to the Websocket and authenticate it. You should usually use +// this function over Start(). func (g *Gateway) Open() error { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) + defer cancel() // Reconnect to the Gateway if err := g.WS.Dial(ctx); err != nil { @@ -224,13 +230,17 @@ func (g *Gateway) Open() error { } // Start authenticates with the websocket, or resume from a dead Websocket -// connection. This function doesn't block. +// connection. This function doesn't block. You wouldn't usually use this +// function, but Open() instead. func (g *Gateway) Start() error { - g.available.Lock() - defer g.available.Unlock() + // g.available.Lock() + // defer g.available.Unlock() if err := g.start(); err != nil { WSDebug("Start failed:", err) + + // Close can be called with the mutex still acquired here, as the + // pacemaker hasn't started yet. if err := g.Close(); err != nil { WSDebug("Failed to close after start fail:", err) } @@ -375,14 +385,11 @@ func (g *Gateway) send(lock bool, code OPCode, v interface{}) error { return errors.Wrap(err, "Failed to encode payload") } - // ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - // defer cancel() - ctx := context.Background() + // if lock { + // g.available.RLock() + // defer g.available.RUnlock() + // } - if lock { - g.available.RLock() - defer g.available.RUnlock() - } - - return g.WS.Send(ctx, b) + // WS should already be thread-safe. + return g.WS.Send(b) } diff --git a/gateway/integration_test.go b/gateway/integration_test.go index a4fb6ce..6522639 100644 --- a/gateway/integration_test.go +++ b/gateway/integration_test.go @@ -5,12 +5,17 @@ package gateway import ( "log" "os" + "strings" "testing" "time" - - "nhooyr.io/websocket" ) +func init() { + WSDebug = func(v ...interface{}) { + log.Println(append([]interface{}{"Debug:"}, v...)...) + } +} + func TestInvalidToken(t *testing.T) { g, err := NewGateway("bad token") if err != nil { @@ -23,7 +28,7 @@ func TestInvalidToken(t *testing.T) { } // 4004 Authentication Failed. - if websocket.CloseStatus(err) == 4004 { + if strings.Contains(err.Error(), "4004") { return } @@ -65,6 +70,9 @@ func TestIntegration(t *testing.T) { log.Println("Bot's username is", ready.User.Username) + // Sleep past the rate limiter before reconnecting: + time.Sleep(5 * time.Second) + // Try and reconnect if err := gateway.Reconnect(); err != nil { t.Fatal("Failed to reconnect:", err) @@ -77,8 +85,11 @@ Main: select { case ev := <-gateway.Events: switch ev.(type) { - case *ResumedEvent, *ReadyEvent: + // Accept only a Resumed event. + case *ResumedEvent: break Main + case *ReadyEvent: + t.Fatal("Ready event received instead of Resumed.") } case <-timeout: t.Fatal("Timed out waiting for ResumedEvent") diff --git a/gateway/op.go b/gateway/op.go index 563bfc1..1f56db1 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -45,6 +45,10 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { return nil, ev.Error } + if len(ev.Data) == 0 { + return nil, errors.New("Empty payload") + } + var op *OP if err := driver.Unmarshal(ev.Data, &op); err != nil { return nil, errors.Wrap(err, "Failed to decode payload") @@ -170,8 +174,7 @@ func HandleOP(g *Gateway, op *OP) error { return nil default: - return fmt.Errorf( - "Unknown OP code %d (event %s)", op.Code, op.EventName) + return fmt.Errorf("Unknown OP code %d (event %s)", op.Code, op.EventName) } return nil diff --git a/go.mod b/go.mod index 3e5a646..bc87594 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,9 @@ go 1.13 require ( github.com/gorilla/schema v1.1.0 - github.com/klauspost/compress v1.10.3 // indirect + github.com/gorilla/websocket v1.4.2 github.com/pkg/errors v0.9.1 github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa golang.org/x/net v0.0.0-20200202094626-16171245cfb2 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 - nhooyr.io/websocket v1.7.4 ) diff --git a/go.sum b/go.sum index 08efd2a..5830f11 100644 --- a/go.sum +++ b/go.sum @@ -1,54 +1,15 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= -github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= -github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= -github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY= github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= -github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= -github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= -github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa h1:xiD6U6h+QMkAwI195dFwdku2N+enlCy9XwFTnEXaCQo= github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa/go.mod h1:KKzWrLiWu6EpzxZBPmPisPgq6oL+do2yLa0C0BTx5fA= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -nhooyr.io/websocket v1.7.4 h1:w/LGB2sZT0RV8lZYR7nfyaYz4PUbYZ5oF7NBon2M0NY= -nhooyr.io/websocket v1.7.4/go.mod h1:PxYxCwFdFYQ0yRvtQz3s/dC+VEm7CSuC/4b9t8MQQxw= diff --git a/internal/wsutil/conn.go b/internal/wsutil/conn.go index 68af118..3fc9342 100644 --- a/internal/wsutil/conn.go +++ b/internal/wsutil/conn.go @@ -1,24 +1,29 @@ package wsutil import ( + "bytes" "compress/zlib" "context" "io" - "io/ioutil" "net/http" "sync" + "time" stderr "errors" "github.com/diamondburned/arikawa/internal/json" + "github.com/gorilla/websocket" "github.com/pkg/errors" - "nhooyr.io/websocket" ) -var WSReadLimit int64 = 8192000 // 8 MiB +const CopyBufferSize = 2048 + +// CloseDeadline controls the deadline to wait for sending the Close frame. +var CloseDeadline = time.Second // Connection is an interface that abstracts around a generic Websocket driver. -// This connection expects the driver to handle compression by itself. +// This connection expects the driver to handle compression by itself, including +// modifying the connection URL. type Connection interface { // Dial dials the address (string). Context needs to be passed in for // timeout. This method should also be re-usable after Close is called. @@ -28,15 +33,12 @@ type Connection interface { // nil, so check for Error first. Listen() <-chan Event - // Send allows the caller to send bytes. Context needs to be passed in order - // to re-use the context that's already used for the limiter. - Send(context.Context, []byte) error + // Send allows the caller to send bytes. Thread safety is a requirement. + Send([]byte) error // Close should close the websocket connection. The connection will not be - // reused. - // If error is nil, the connection should close with a StatusNormalClosure - // (1000). If not, it should close with a StatusProtocolError (1002). - Close(err error) error + // reused. Code should be sent as the status code for the close frame. + Close(code int) error } // Conn is the default Websocket connection. It compresses all payloads using @@ -45,8 +47,14 @@ type Conn struct { Conn *websocket.Conn json.Driver - mut sync.Mutex + dialer *websocket.Dialer + mut sync.RWMutex events chan Event + + buf bytes.Buffer + + // zlib *zlib.Inflator // zlib.NewReader + // buf []byte // io.Copy buffer } var _ Connection = (*Conn)(nil) @@ -54,30 +62,40 @@ var _ Connection = (*Conn)(nil) func NewConn(driver json.Driver) *Conn { return &Conn{ Driver: driver, + dialer: &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: DefaultTimeout, + EnableCompression: true, + }, events: make(chan Event), + // zlib: zlib.NewInflator(), + // buf: make([]byte, CopyBufferSize), } } func (c *Conn) Dial(ctx context.Context, addr string) error { var err error + // Enable compression: headers := http.Header{} - headers.Set("Accept-Encoding", "zlib") // enable + headers.Set("Accept-Encoding", "zlib") + + // BUG: https://github.com/golang/go/issues/31514 + // // Enable stream compression: + // addr = InjectValues(addr, url.Values{ + // "compress": {"zlib-stream"}, + // }) c.mut.Lock() defer c.mut.Unlock() - c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{ - HTTPHeader: headers, - }) + c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers) if err != nil { return errors.Wrap(err, "Failed to dial WS") } - c.Conn.SetReadLimit(WSReadLimit) - c.events = make(chan Event) - c.readLoop() + go c.readLoop() return err } @@ -86,94 +104,149 @@ func (c *Conn) Listen() <-chan Event { } func (c *Conn) readLoop() { - conn := c.Conn + // Acquire the read lock throughout the span of the loop. This would still + // allow Send to acquire another RLock, but wouldn't allow Close to + // prematurely exit, as Close acquires a write lock. + c.mut.RLock() + defer c.mut.RUnlock() - go func() { - defer close(c.events) + // Clean up the events channel in the end. + defer close(c.events) - for { - b, err := readAll(conn, context.Background()) - if err != nil { - // Is the error an EOF? - if stderr.Is(err, io.EOF) { - // Yes it is, exit. - return - } - - // Check if the error is a fatal one - if code := websocket.CloseStatus(err); code > -1 { - // Is the exit normal? - if code == websocket.StatusNormalClosure { - return - } - } - - // Unusual error; log: - c.events <- Event{nil, errors.Wrap(err, "WS error")} + for { + b, err := c.handle() + if err != nil { + // Is the error an EOF? + if stderr.Is(err, io.EOF) { + // Yes it is, exit. return } - c.events <- Event{b, nil} + // Check if the error is a normal one: + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + return + } + + // Unusual error; log and exit: + c.events <- Event{nil, errors.Wrap(err, "WS error")} + return } - }() + + // If nil bytes, then it's an incomplete payload. + if b == nil { + continue + } + + c.events <- Event{b, nil} + } } -func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) { - t, r, err := c.Reader(ctx) +func (c *Conn) handle() ([]byte, error) { + // skip message type + t, r, err := c.Conn.NextReader() if err != nil { return nil, err } - if t == websocket.MessageBinary { + if t == websocket.BinaryMessage { // Probably a zlib payload z, err := zlib.NewReader(r) if err != nil { - c.CloseRead(ctx) - return nil, - errors.Wrap(err, "Failed to create a zlib reader") + return nil, errors.Wrap(err, "Failed to create a zlib reader") } defer z.Close() r = z } - b, err := ioutil.ReadAll(r) - if err != nil { - c.CloseRead(ctx) - return nil, err - } + return readAll(&c.buf, r) - return b, nil + // if t is a text message, then handle it normally. + // if t == websocket.TextMessage { + // return readAll(&c.buf, r) + // } + + // // Write to the zlib writer. + // c.zlib.Write(r) + // // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil { + // // return nil, errors.Wrap(err, "Failed to write to zlib") + // // } + + // if !c.zlib.CanFlush() { + // return nil, nil + // } + + // // Flush and get the uncompressed payload. + // b, err := c.zlib.Flush() + // if err != nil { + // return nil, errors.Wrap(err, "Failed to flush zlib") + // } + + // return nil, errors.New("Unexpected binary message.") } -func (c *Conn) Send(ctx context.Context, b []byte) error { - // TODO: zlib stream - return c.Conn.Write(ctx, websocket.MessageText, b) +func (c *Conn) Send(b []byte) error { + c.mut.RLock() + defer c.mut.RUnlock() + + if c.Conn == nil { + return errors.New("Websocket is closed.") + } + + return c.Conn.WriteMessage(websocket.TextMessage, b) } -func (c *Conn) Close(err error) error { - // Wait for the read loop to exit after exiting. - defer c.close() +func (c *Conn) Close(code int) error { + // Wait for the read loop to exit at the end. + err := c.writeClose(code) + c.close() + return err +} - if err == nil { - return c.Conn.Close(websocket.StatusNormalClosure, "") - } +func (c *Conn) writeClose(code int) error { + c.mut.RLock() + defer c.mut.RUnlock() - var msg = err.Error() - if len(msg) > 125 { - msg = msg[:125] // truncate - } + // Quick deadline: + deadline := time.Now().Add(CloseDeadline) - return c.Conn.Close(websocket.StatusProtocolError, msg) + // Make a closure message: + msg := websocket.FormatCloseMessage(code, "") + + // Send a close message before closing the connection. We're not error + // checking this because it's not important. + c.Conn.WriteControl(websocket.TextMessage, msg, deadline) + + // Safe to close now. + return c.Conn.Close() } func (c *Conn) close() { + // Flush all events: + for range c.events { + } + + // This blocks until the events channel is dead. c.mut.Lock() defer c.mut.Unlock() - <-c.events + // Clean up. c.events = nil - - // Set the connection to nil. c.Conn = nil } + +// readAll reads bytes into an existing buffer, copy it over, then wipe the old +// buffer. +func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) { + defer buf.Reset() + if _, err := buf.ReadFrom(r); err != nil { + return nil, err + } + + // Copy the bytes so we could empty the buffer for reuse. + p := buf.Bytes() + cpy := make([]byte, len(p)) + copy(cpy, p) + + return cpy, nil +} diff --git a/internal/wsutil/ws.go b/internal/wsutil/ws.go index ab1f3e3..2677c6c 100644 --- a/internal/wsutil/ws.go +++ b/internal/wsutil/ws.go @@ -4,9 +4,11 @@ package wsutil import ( "context" + "net/url" "time" "github.com/diamondburned/arikawa/internal/json" + "github.com/gorilla/websocket" "github.com/pkg/errors" "golang.org/x/time/rate" ) @@ -26,26 +28,21 @@ type Websocket struct { SendLimiter *rate.Limiter DialLimiter *rate.Limiter - - listener <-chan Event - dialed bool } -func New(addr string) (*Websocket, error) { +func New(addr string) *Websocket { return NewCustom(NewConn(json.Default{}), addr) } // NewCustom creates a new undialed Websocket. -func NewCustom(conn Connection, addr string) (*Websocket, error) { - ws := &Websocket{ +func NewCustom(conn Connection, addr string) *Websocket { + return &Websocket{ Conn: conn, Addr: addr, SendLimiter: NewSendLimiter(), DialLimiter: NewDialLimiter(), } - - return ws, nil } func (ws *Websocket) Dial(ctx context.Context) error { @@ -68,14 +65,31 @@ func (ws *Websocket) Listen() <-chan Event { return ws.Conn.Listen() } -func (ws *Websocket) Send(ctx context.Context, b []byte) error { - if err := ws.SendLimiter.Wait(ctx); err != nil { +func (ws *Websocket) Send(b []byte) error { + if err := ws.SendLimiter.Wait(context.Background()); err != nil { return errors.Wrap(err, "SendLimiter failed") } - return ws.Conn.Send(ctx, b) + return ws.Conn.Send(b) } -func (ws *Websocket) Close(err error) error { - return ws.Conn.Close(err) +func (ws *Websocket) Close() error { + return ws.Conn.Close(websocket.CloseGoingAway) +} + +func InjectValues(rawurl string, values url.Values) string { + u, err := url.Parse(rawurl) + if err != nil { + // Unknown URL, return as-is. + return rawurl + } + + // Append additional parameters: + var q = u.Query() + for k, v := range values { + q[k] = append(q[k], v...) + } + + u.RawQuery = q.Encode() + return u.String() } diff --git a/internal/zlib/flate.go b/internal/zlib/flate.go new file mode 100644 index 0000000..313ac0c --- /dev/null +++ b/internal/zlib/flate.go @@ -0,0 +1,39 @@ +package zlib + +import ( + "compress/flate" + "compress/zlib" + "io" +) + +type Reader interface { + io.ReadCloser + zlib.Resetter +} + +func zlibStreamer(r flate.Reader) (Reader, error) { + // verify header + h := make([]byte, 2) + + if _, err := io.ReadFull(r, h); err != nil { + return nil, err + } + + // verify header + if err := verifyHeader(h); err != nil { + return nil, err + } + + return flate.NewReader(r).(Reader), nil +} + +// https://golang.org/src/compress/zlib/reader.go#L35 +const zlibDeflate = 8 + +func verifyHeader(scratch []byte) error { + h := uint(scratch[0])<<8 | uint(scratch[1]) + if (scratch[0]&0x0f != zlibDeflate) || (h%31 != 0) { + return zlib.ErrHeader + } + return nil +} diff --git a/internal/zlib/zlib.go b/internal/zlib/zlib.go new file mode 100644 index 0000000..606c8d7 --- /dev/null +++ b/internal/zlib/zlib.go @@ -0,0 +1,144 @@ +// Package zlib provides abstractions on top of compress/zlib to work with +// Discord's method of compressing websocket packets. +package zlib + +import ( + "bytes" + "log" + + "github.com/pkg/errors" +) + +var Suffix = [4]byte{'\x00', '\x00', '\xff', '\xff'} + +var ErrPartial = errors.New("only partial payload in buffer") + +type Inflator struct { + zlib Reader + wbuf bytes.Buffer // write buffer for writing compressed bytes + rbuf bytes.Buffer // read buffer for writing uncompressed bytes +} + +func NewInflator() *Inflator { + return &Inflator{ + wbuf: bytes.Buffer{}, + rbuf: bytes.Buffer{}, + } +} + +func (i *Inflator) Write(p []byte) (n int, err error) { + log.Println(p) + // Write to buffer normally. + return i.wbuf.Write(p) +} + +// CanFlush returns if Flush() should be called. +func (i *Inflator) CanFlush() bool { + if i.wbuf.Len() < 4 { + return false + } + p := i.wbuf.Bytes() + return bytes.Equal(p[len(p)-4:], Suffix[:]) +} + +func (i *Inflator) Flush() ([]byte, error) { + // Check if close frames are there: + // if !i.CanFlush() { + // return nil, ErrPartial + // } + + // log.Println(i.wbuf.Bytes()) + + // We should reset the write buffer after flushing. + // defer i.wbuf.Reset() + + // We can reset the read buffer while returning its byte slice. This works + // as long as we copy the byte slice before resetting. + defer i.rbuf.Reset() + + // Guarantee there's a zlib writer. Since Discord streams zlib, we have to + // reuse the same Reader. Only the first packet has the zlib header. + if i.zlib == nil { + r, err := zlibStreamer(&i.wbuf) + if err != nil { + return nil, errors.Wrap(err, "Failed to make a FLATE reader") + } + // safe assertion + i.zlib = r + // } else { + // // Reset the FLATE reader for future use: + // if err := i.zlib.Reset(&i.wbuf, nil); err != nil { + // return nil, errors.Wrap(err, "Failed to reset zlib reader") + // } + } + + // We can ignore zlib.Read's error, as zlib.Close would return them. + _, err := i.rbuf.ReadFrom(i.zlib) + + // ErrUnexpectedEOF happens because zlib tries to find the last 4 bytes + // to verify checksum. Discord doesn't send this. + if err != nil { + // Unexpected error, try and close. + return nil, errors.Wrap(err, "Failed to read from FLATE reader") + } + + // if err := i.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF { + // // Try and close anyway. + // return nil, errors.Wrap(err, "Failed to read from zlib reader") + // } + + // Copy the bytes. + return bytecopy(i.rbuf.Bytes()), nil +} + +// func (d *Deflator) TryFlush() ([]byte, error) { +// // Check if the buffer ends with the zlib close suffix. +// if d.wbuf.Len() < 4 { +// return nil, nil +// } +// if p := d.wbuf.Bytes(); !bytes.Equal(p[len(p)-4:], Suffix[:]) { +// return nil, nil +// } + +// // Guarantee there's a zlib writer. Since Discord streams zlib, we have to +// // reuse the same Reader. Only the first packet has the zlib header. +// if d.zlib == nil { +// r, err := zlib.NewReader(&d.wbuf) +// if err != nil { +// return nil, errors.Wrap(err, "Failed to make a zlib reader") +// } +// // safe assertion +// d.zlib = r +// } + +// // We can reset the read buffer while returning its byte slice. This works +// // as long as we copy the byte slice before resetting. +// defer d.rbuf.Reset() + +// defer d.wbuf.Reset() + +// // We can ignore zlib.Read's error, as zlib.Close would return them. +// _, err := d.rbuf.ReadFrom(d.zlib) +// log.Println("Read:", err, d.rbuf.String()) + +// // ErrUnexpectedEOF happens because zlib tries to find the last 4 bytes +// // to verify checksum. Discord doesn't send this. +// // if err != nil && err != io.ErrUnexpectedEOF { +// // // Unexpected error, try and close. +// // return nil, errors.Wrap(err, "Failed to read from zlib reader") +// // } + +// if err := d.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF { +// // Try and close anyway. +// return nil, errors.Wrap(err, "Failed to read from zlib reader") +// } + +// // Copy the bytes. +// return bytecopy(d.rbuf.Bytes()), nil +// } + +func bytecopy(p []byte) []byte { + cpy := make([]byte, len(p)) + copy(cpy, p) + return cpy +}