Fixed some race conditions
This commit is contained in:
parent
16cbbb4f25
commit
85b793a1a7
|
@ -160,6 +160,7 @@ func (g *Gateway) Close() error {
|
||||||
if g.done != nil {
|
if g.done != nil {
|
||||||
// Wait for the event handler to fully exit
|
// Wait for the event handler to fully exit
|
||||||
<-g.done
|
<-g.done
|
||||||
|
g.done = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the Websocket
|
// Stop the Websocket
|
||||||
|
@ -198,6 +199,7 @@ func (g *Gateway) Open() error {
|
||||||
if err := g.WS.Dial(ctx); err != nil {
|
if err := g.WS.Dial(ctx); err != nil {
|
||||||
// Save the error, retry again
|
// Save the error, retry again
|
||||||
Lerr = errors.Wrap(err, "Failed to reconnect")
|
Lerr = errors.Wrap(err, "Failed to reconnect")
|
||||||
|
g.ErrorLog(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,7 +297,6 @@ func (g *Gateway) handleWS() {
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
g.done <- struct{}{}
|
g.done <- struct{}{}
|
||||||
g.done = nil
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -306,6 +307,8 @@ func (g *Gateway) handleWS() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
g.ErrorLog(errors.Wrap(err, "Pacemaker died"))
|
||||||
|
|
||||||
// Pacemaker died, pretty fatal. We'll reconnect though.
|
// Pacemaker died, pretty fatal. We'll reconnect though.
|
||||||
if err := g.Reconnect(); err != nil {
|
if err := g.Reconnect(); err != nil {
|
||||||
// Very fatal if this fails. We'll warn the user.
|
// Very fatal if this fails. We'll warn the user.
|
||||||
|
|
|
@ -16,7 +16,7 @@ func TestIntegration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
WSError = func(err error) {
|
WSError = func(err error) {
|
||||||
log.Println(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var gateway *Gateway
|
var gateway *Gateway
|
||||||
|
@ -49,14 +49,20 @@ func TestIntegration(t *testing.T) {
|
||||||
t.Fatal("Failed to reconnect:", err)
|
t.Fatal("Failed to reconnect:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* TODO: We're not testing this, as Discord will replay events before it
|
timeout := time.After(10 * time.Second)
|
||||||
* sends the Resumed event.
|
|
||||||
|
|
||||||
resumed, ok := (<-gateway.Events).(*ResumedEvent)
|
Main:
|
||||||
if !ok {
|
for {
|
||||||
t.Fatal("Event received is not of type Resumed:", resumed)
|
select {
|
||||||
|
case ev := <-gateway.Events:
|
||||||
|
switch ev.(type) {
|
||||||
|
case *ResumedEvent, *ReadyEvent:
|
||||||
|
break Main
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatal("Timed out waiting for ResumedEvent")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
if err := g.Close(); err != nil {
|
if err := g.Close(); err != nil {
|
||||||
t.Fatal("Failed to close Gateway:", err)
|
t.Fatal("Failed to close Gateway:", err)
|
||||||
|
|
|
@ -99,7 +99,7 @@ func HandleEvent(g *Gateway, data []byte) error {
|
||||||
// Parse the raw data into an OP struct
|
// Parse the raw data into an OP struct
|
||||||
var op *OP
|
var op *OP
|
||||||
if err := g.Driver.Unmarshal(data, &op); err != nil {
|
if err := g.Driver.Unmarshal(data, &op); err != nil {
|
||||||
return errors.Wrap(err, "OP error")
|
return errors.Wrap(err, "OP error: "+string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
return HandleOP(g, op)
|
return HandleOP(g, op)
|
||||||
|
|
|
@ -2,13 +2,11 @@ package gateway
|
||||||
|
|
||||||
import "sync/atomic"
|
import "sync/atomic"
|
||||||
|
|
||||||
type Sequence struct {
|
type Sequence int64
|
||||||
seq int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSequence() *Sequence {
|
func NewSequence() *Sequence {
|
||||||
return &Sequence{0}
|
return (*Sequence)(new(int64))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) }
|
func (s *Sequence) Set(seq int64) { atomic.StoreInt64((*int64)(s), seq) }
|
||||||
func (s *Sequence) Get() int64 { return atomic.LoadInt64(&s.seq) }
|
func (s *Sequence) Get() int64 { return atomic.LoadInt64((*int64)(s)) }
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module github.com/diamondburned/arikawa
|
||||||
go 1.13
|
go 1.13
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1
|
||||||
github.com/gorilla/schema v1.1.0
|
github.com/gorilla/schema v1.1.0
|
||||||
github.com/pkg/errors v0.8.1
|
github.com/pkg/errors v0.8.1
|
||||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
||||||
|
|
1
go.sum
1
go.sum
|
@ -37,6 +37,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
|
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
|
||||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
var WSBuffer = 12
|
|
||||||
var WSReadLimit int64 = 8192000 // 8 MiB
|
var WSReadLimit int64 = 8192000 // 8 MiB
|
||||||
|
|
||||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||||
|
@ -55,7 +54,7 @@ var _ Connection = (*Conn)(nil)
|
||||||
func NewConn(driver json.Driver) *Conn {
|
func NewConn(driver json.Driver) *Conn {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
Driver: driver,
|
Driver: driver,
|
||||||
events: make(chan Event, WSBuffer),
|
events: make(chan Event),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +72,7 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||||
})
|
})
|
||||||
c.Conn.SetReadLimit(WSReadLimit)
|
c.Conn.SetReadLimit(WSReadLimit)
|
||||||
|
|
||||||
c.events = make(chan Event, WSBuffer)
|
c.events = make(chan Event)
|
||||||
c.readLoop()
|
c.readLoop()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -83,11 +82,13 @@ func (c *Conn) Listen() <-chan Event {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readLoop() {
|
func (c *Conn) readLoop() {
|
||||||
|
conn := c.Conn
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(c.events)
|
defer close(c.events)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
b, err := c.readAll(context.Background())
|
b, err := readAll(conn, context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Is the error an EOF?
|
// Is the error an EOF?
|
||||||
if stderr.Is(err, io.EOF) {
|
if stderr.Is(err, io.EOF) {
|
||||||
|
@ -97,18 +98,15 @@ func (c *Conn) readLoop() {
|
||||||
|
|
||||||
// Check if the error is a fatal one
|
// Check if the error is a fatal one
|
||||||
if code := websocket.CloseStatus(err); code > -1 {
|
if code := websocket.CloseStatus(err); code > -1 {
|
||||||
// Is the exit unusual?
|
// Is the exit normal?
|
||||||
if code != websocket.StatusNormalClosure {
|
if code == websocket.StatusNormalClosure {
|
||||||
// Unusual error, log
|
return
|
||||||
c.events <- Event{nil, errors.Wrap(err, "WS fatal")}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// or it's not fatal, we just log and continue
|
// Unusual error; log:
|
||||||
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.events <- Event{b, nil}
|
c.events <- Event{b, nil}
|
||||||
|
@ -116,8 +114,8 @@ func (c *Conn) readLoop() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) {
|
||||||
t, r, err := c.Conn.Reader(ctx)
|
t, r, err := c.Reader(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -126,7 +124,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
// Probably a zlib payload
|
// Probably a zlib payload
|
||||||
z, err := zlib.NewReader(r)
|
z, err := zlib.NewReader(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Conn.CloseRead(ctx)
|
c.CloseRead(ctx)
|
||||||
return nil,
|
return nil,
|
||||||
errors.Wrap(err, "Failed to create a zlib reader")
|
errors.Wrap(err, "Failed to create a zlib reader")
|
||||||
}
|
}
|
||||||
|
@ -137,7 +135,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
|
|
||||||
b, err := ioutil.ReadAll(r)
|
b, err := ioutil.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Conn.CloseRead(ctx)
|
c.CloseRead(ctx)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,9 +143,6 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
c.mut.Lock()
|
|
||||||
defer c.mut.Unlock()
|
|
||||||
|
|
||||||
// TODO: zlib stream
|
// TODO: zlib stream
|
||||||
return c.Conn.Write(ctx, websocket.MessageText, b)
|
return c.Conn.Write(ctx, websocket.MessageText, b)
|
||||||
}
|
}
|
||||||
|
@ -155,14 +150,14 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
func (c *Conn) Close(err error) error {
|
func (c *Conn) Close(err error) error {
|
||||||
// Wait for the read loop to exit after exiting.
|
// Wait for the read loop to exit after exiting.
|
||||||
defer func() {
|
defer func() {
|
||||||
|
c.mut.Lock()
|
||||||
|
defer c.mut.Unlock()
|
||||||
|
|
||||||
<-c.events
|
<-c.events
|
||||||
c.events = nil
|
c.events = nil
|
||||||
|
|
||||||
// Set the connection to nil.
|
// Set the connection to nil.
|
||||||
c.Conn = nil
|
c.Conn = nil
|
||||||
|
|
||||||
// Flush all events.
|
|
||||||
c.flush()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -176,14 +171,3 @@ func (c *Conn) Close(err error) error {
|
||||||
|
|
||||||
return c.Conn.Close(websocket.StatusProtocolError, msg)
|
return c.Conn.Close(websocket.StatusProtocolError, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) flush() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.events:
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue