mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-10-01 14:58:52 +00:00
Fixed a ridiculous race condition in the gateway
This commit is contained in:
parent
2cd9def778
commit
3ca1d352c9
|
@ -13,7 +13,7 @@ stages:
|
||||||
unit_test:
|
unit_test:
|
||||||
stage: test
|
stage: test
|
||||||
script:
|
script:
|
||||||
- go test ./...
|
- go test -race -v ./...
|
||||||
|
|
||||||
integration_test:
|
integration_test:
|
||||||
stage: test
|
stage: test
|
||||||
|
@ -21,5 +21,5 @@ integration_test:
|
||||||
variables:
|
variables:
|
||||||
- $BOT_TOKEN
|
- $BOT_TOKEN
|
||||||
script:
|
script:
|
||||||
- go test -tags integration ./...
|
- go test -tags integration -race -v ./...
|
||||||
|
|
||||||
|
|
|
@ -102,8 +102,8 @@ type Gateway struct {
|
||||||
OP chan Event
|
OP chan Event
|
||||||
|
|
||||||
// Filled by methods, internal use
|
// Filled by methods, internal use
|
||||||
paceDeath <-chan error
|
done chan struct{}
|
||||||
handler chan struct{}
|
paceDeath chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
||||||
|
@ -153,14 +153,16 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
||||||
|
|
||||||
// Close closes the underlying Websocket connection.
|
// Close closes the underlying Websocket connection.
|
||||||
func (g *Gateway) Close() error {
|
func (g *Gateway) Close() error {
|
||||||
// Stop the pacemaker
|
// If the pacemaker is running:
|
||||||
|
// Stop the pacemaker and the event handler
|
||||||
g.Pacemaker.Stop()
|
g.Pacemaker.Stop()
|
||||||
g.paceDeath = nil
|
|
||||||
|
|
||||||
// Stop the event handler
|
if g.done != nil {
|
||||||
if g.handler != nil {
|
// Wait for the event handler to fully exit
|
||||||
close(g.handler)
|
<-g.done
|
||||||
g.handler = nil
|
|
||||||
|
// Final clean-up
|
||||||
|
g.done = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the Websocket
|
// Stop the Websocket
|
||||||
|
@ -193,10 +195,7 @@ func (g *Gateway) Open() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reconnect to the Gateway
|
// Reconnect to the Gateway
|
||||||
if err := g.WS.Redial(ctx); err != nil {
|
if err := g.WS.Dial(ctx); err != nil {
|
||||||
// Close the connection
|
|
||||||
g.Close()
|
|
||||||
|
|
||||||
// Save the error, retry again
|
// Save the error, retry again
|
||||||
Lerr = errors.Wrap(err, "Failed to reconnect")
|
Lerr = errors.Wrap(err, "Failed to reconnect")
|
||||||
continue
|
continue
|
||||||
|
@ -204,9 +203,6 @@ func (g *Gateway) Open() error {
|
||||||
|
|
||||||
// Try to resume the connection
|
// Try to resume the connection
|
||||||
if err := g.Start(); err != nil {
|
if err := g.Start(); err != nil {
|
||||||
// Close the connection
|
|
||||||
g.Close()
|
|
||||||
|
|
||||||
// If the connection is rate limited (documented behavior):
|
// If the connection is rate limited (documented behavior):
|
||||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||||
if err == ErrInvalidSession {
|
if err == ErrInvalidSession {
|
||||||
|
@ -234,6 +230,14 @@ func (g *Gateway) Open() error {
|
||||||
// Start authenticates with the websocket, or resume from a dead Websocket
|
// Start authenticates with the websocket, or resume from a dead Websocket
|
||||||
// connection. This function doesn't block.
|
// connection. This function doesn't block.
|
||||||
func (g *Gateway) Start() error {
|
func (g *Gateway) Start() error {
|
||||||
|
if err := g.start(); err != nil {
|
||||||
|
g.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gateway) start() error {
|
||||||
// This is where we'll get our events
|
// This is where we'll get our events
|
||||||
ch := g.WS.Listen()
|
ch := g.WS.Listen()
|
||||||
|
|
||||||
|
@ -279,30 +283,35 @@ func (g *Gateway) Start() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the event handler
|
// Start the event handler
|
||||||
g.handler = make(chan struct{})
|
g.done = make(chan struct{})
|
||||||
go g.handleWS(g.handler)
|
go g.handleWS(g.done)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleWS uses the Websocket and parses them into g.Events.
|
// handleWS uses the Websocket and parses them into g.Events.
|
||||||
func (g *Gateway) handleWS(stop <-chan struct{}) {
|
func (g *Gateway) handleWS(done chan struct{}) {
|
||||||
ch := g.WS.Listen()
|
ch := g.WS.Listen()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-stop:
|
|
||||||
return
|
|
||||||
case err := <-g.paceDeath:
|
case err := <-g.paceDeath:
|
||||||
if err != nil {
|
if err == nil {
|
||||||
// Pacemaker died, pretty fatal. We'll reconnect though.
|
// No error, just exit normally.
|
||||||
if err := g.Reconnect(); err != nil {
|
return
|
||||||
// Very fatal if this fails. We'll warn the user.
|
}
|
||||||
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
|
|
||||||
|
|
||||||
// Then, we'll take the safe way and exit.
|
// Pacemaker died, pretty fatal. We'll reconnect though.
|
||||||
return
|
if err := g.Reconnect(); err != nil {
|
||||||
}
|
// Very fatal if this fails. We'll warn the user.
|
||||||
|
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
|
||||||
|
|
||||||
|
// Then, we'll take the safe way and exit.
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case ev := <-ch:
|
case ev := <-ch:
|
||||||
|
|
|
@ -24,7 +24,8 @@ type Pacemaker struct {
|
||||||
// Event
|
// Event
|
||||||
OnDead func() error
|
OnDead func() error
|
||||||
|
|
||||||
stop chan<- struct{}
|
stop chan<- struct{}
|
||||||
|
death chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pacemaker) Echo() {
|
func (p *Pacemaker) Echo() {
|
||||||
|
@ -92,16 +93,15 @@ func (p *Pacemaker) start(stop chan struct{}) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pacemaker) StartAsync() (death <-chan error) {
|
func (p *Pacemaker) StartAsync() (death chan error) {
|
||||||
var ch = make(chan error)
|
p.death = make(chan error)
|
||||||
|
|
||||||
stop := make(chan struct{})
|
stop := make(chan struct{})
|
||||||
p.stop = stop
|
p.stop = stop
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
ch <- p.start(stop)
|
p.death <- p.start(stop)
|
||||||
close(ch)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return ch
|
return p.death
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/internal/json"
|
"github.com/diamondburned/arikawa/internal/json"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -39,9 +40,11 @@ type Connection interface {
|
||||||
// Conn is the default Websocket connection. It compresses all payloads using
|
// Conn is the default Websocket connection. It compresses all payloads using
|
||||||
// zlib.
|
// zlib.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
*websocket.Conn
|
Conn *websocket.Conn
|
||||||
json.Driver
|
json.Driver
|
||||||
|
|
||||||
|
mut sync.Mutex
|
||||||
|
done chan struct{}
|
||||||
events chan Event
|
events chan Event
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,16 +63,15 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
headers.Set("Accept-Encoding", "zlib") // enable
|
headers.Set("Accept-Encoding", "zlib") // enable
|
||||||
|
|
||||||
|
c.mut.Lock()
|
||||||
|
defer c.mut.Unlock()
|
||||||
|
|
||||||
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
|
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
|
||||||
HTTPHeader: headers,
|
HTTPHeader: headers,
|
||||||
})
|
})
|
||||||
|
|
||||||
c.Conn.SetReadLimit(WSReadLimit)
|
c.Conn.SetReadLimit(WSReadLimit)
|
||||||
|
|
||||||
go func() {
|
c.readLoop(c.events)
|
||||||
c.readLoop(c.events)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,31 +80,36 @@ func (c *Conn) Listen() <-chan Event {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readLoop(ch chan Event) {
|
func (c *Conn) readLoop(ch chan Event) {
|
||||||
for {
|
c.done = make(chan struct{})
|
||||||
b, err := c.readAll(context.Background())
|
|
||||||
if err != nil {
|
go func() {
|
||||||
// Check if the error is a fatal one
|
for {
|
||||||
if code := websocket.CloseStatus(err); code > -1 {
|
b, err := c.readAll(context.Background())
|
||||||
// Is the exit unusual?
|
if err != nil {
|
||||||
if code != websocket.StatusNormalClosure {
|
// Check if the error is a fatal one
|
||||||
// Unusual error, log
|
if code := websocket.CloseStatus(err); code > -1 {
|
||||||
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
|
// Is the exit unusual?
|
||||||
|
if code != websocket.StatusNormalClosure {
|
||||||
|
// Unusual error, log
|
||||||
|
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.done <- struct{}{}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
// or it's not fatal, we just log and continue
|
||||||
|
ch <- Event{nil, errors.Wrap(err, "WS error")}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// or it's not fatal, we just log and continue
|
ch <- Event{b, nil}
|
||||||
ch <- Event{nil, errors.Wrap(err, "WS error")}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
ch <- Event{b, nil}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
t, r, err := c.Reader(ctx)
|
t, r, err := c.Conn.Reader(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -111,7 +118,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
// Probably a zlib payload
|
// Probably a zlib payload
|
||||||
z, err := zlib.NewReader(r)
|
z, err := zlib.NewReader(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.CloseRead(ctx)
|
c.Conn.CloseRead(ctx)
|
||||||
return nil,
|
return nil,
|
||||||
errors.Wrap(err, "Failed to create a zlib reader")
|
errors.Wrap(err, "Failed to create a zlib reader")
|
||||||
}
|
}
|
||||||
|
@ -122,7 +129,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
|
|
||||||
b, err := ioutil.ReadAll(r)
|
b, err := ioutil.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.CloseRead(ctx)
|
c.Conn.CloseRead(ctx)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,10 +138,22 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
|
|
||||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
// TODO: zlib stream
|
// TODO: zlib stream
|
||||||
return c.Write(ctx, websocket.MessageText, b)
|
return c.Conn.Write(ctx, websocket.MessageText, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close(err error) error {
|
func (c *Conn) Close(err error) error {
|
||||||
|
// Wait for the read loop to exit after exiting.
|
||||||
|
defer func() {
|
||||||
|
<-c.done
|
||||||
|
close(c.done)
|
||||||
|
|
||||||
|
// Set the connection to nil.
|
||||||
|
c.Conn = nil
|
||||||
|
|
||||||
|
// Flush all events.
|
||||||
|
c.flush()
|
||||||
|
}()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return c.Conn.Close(websocket.StatusNormalClosure, "")
|
return c.Conn.Close(websocket.StatusNormalClosure, "")
|
||||||
}
|
}
|
||||||
|
@ -146,3 +165,14 @@ func (c *Conn) Close(err error) error {
|
||||||
|
|
||||||
return c.Conn.Close(websocket.StatusProtocolError, msg)
|
return c.Conn.Close(websocket.StatusProtocolError, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) flush() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.events:
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -50,18 +50,12 @@ func NewCustom(
|
||||||
return ws, nil
|
return ws, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *Websocket) Redial(ctx context.Context) error {
|
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||||
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
||||||
// Expired, fatal error
|
// Expired, fatal error
|
||||||
return errors.Wrap(err, "Failed to wait")
|
return errors.Wrap(err, "Failed to wait")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the connection
|
|
||||||
if ws.dialed {
|
|
||||||
ws.Conn.Close(nil)
|
|
||||||
}
|
|
||||||
ws.dialed = true
|
|
||||||
|
|
||||||
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
|
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
|
||||||
return errors.Wrap(err, "Failed to dial")
|
return errors.Wrap(err, "Failed to dial")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue