Fixed some race conditions

This commit is contained in:
diamondburned (Forefront) 2020-02-02 14:12:54 -08:00
parent 16cbbb4f25
commit 85b793a1a7
7 changed files with 41 additions and 48 deletions

View File

@ -160,6 +160,7 @@ func (g *Gateway) Close() error {
if g.done != nil {
// Wait for the event handler to fully exit
<-g.done
g.done = nil
}
// Stop the Websocket
@ -198,6 +199,7 @@ func (g *Gateway) Open() error {
if err := g.WS.Dial(ctx); err != nil {
// Save the error, retry again
Lerr = errors.Wrap(err, "Failed to reconnect")
g.ErrorLog(err)
continue
}
@ -295,7 +297,6 @@ func (g *Gateway) handleWS() {
defer func() {
g.done <- struct{}{}
g.done = nil
}()
for {
@ -306,6 +307,8 @@ func (g *Gateway) handleWS() {
return
}
g.ErrorLog(errors.Wrap(err, "Pacemaker died"))
// Pacemaker died, pretty fatal. We'll reconnect though.
if err := g.Reconnect(); err != nil {
// Very fatal if this fails. We'll warn the user.

View File

@ -16,7 +16,7 @@ func TestIntegration(t *testing.T) {
}
WSError = func(err error) {
log.Println(err)
t.Fatal(err)
}
var gateway *Gateway
@ -49,14 +49,20 @@ func TestIntegration(t *testing.T) {
t.Fatal("Failed to reconnect:", err)
}
/* TODO: We're not testing this, as Discord will replay events before it
* sends the Resumed event.
timeout := time.After(10 * time.Second)
resumed, ok := (<-gateway.Events).(*ResumedEvent)
if !ok {
t.Fatal("Event received is not of type Resumed:", resumed)
Main:
for {
select {
case ev := <-gateway.Events:
switch ev.(type) {
case *ResumedEvent, *ReadyEvent:
break Main
}
case <-timeout:
t.Fatal("Timed out waiting for ResumedEvent")
}
}
*/
if err := g.Close(); err != nil {
t.Fatal("Failed to close Gateway:", err)

View File

@ -99,7 +99,7 @@ func HandleEvent(g *Gateway, data []byte) error {
// Parse the raw data into an OP struct
var op *OP
if err := g.Driver.Unmarshal(data, &op); err != nil {
return errors.Wrap(err, "OP error")
return errors.Wrap(err, "OP error: "+string(data))
}
return HandleOP(g, op)

View File

@ -2,13 +2,11 @@ package gateway
import "sync/atomic"
type Sequence struct {
seq int64
}
type Sequence int64
func NewSequence() *Sequence {
return &Sequence{0}
return (*Sequence)(new(int64))
}
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) }
func (s *Sequence) Get() int64 { return atomic.LoadInt64(&s.seq) }
func (s *Sequence) Set(seq int64) { atomic.StoreInt64((*int64)(s), seq) }
func (s *Sequence) Get() int64 { return atomic.LoadInt64((*int64)(s)) }

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/diamondburned/arikawa
go 1.13
require (
github.com/davecgh/go-spew v1.1.1
github.com/gorilla/schema v1.1.0
github.com/pkg/errors v0.8.1
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa

1
go.sum
View File

@ -37,6 +37,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -15,7 +15,6 @@ import (
"nhooyr.io/websocket"
)
var WSBuffer = 12
var WSReadLimit int64 = 8192000 // 8 MiB
// Connection is an interface that abstracts around a generic Websocket driver.
@ -55,7 +54,7 @@ var _ Connection = (*Conn)(nil)
func NewConn(driver json.Driver) *Conn {
return &Conn{
Driver: driver,
events: make(chan Event, WSBuffer),
events: make(chan Event),
}
}
@ -73,7 +72,7 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
})
c.Conn.SetReadLimit(WSReadLimit)
c.events = make(chan Event, WSBuffer)
c.events = make(chan Event)
c.readLoop()
return err
}
@ -83,11 +82,13 @@ func (c *Conn) Listen() <-chan Event {
}
func (c *Conn) readLoop() {
conn := c.Conn
go func() {
defer close(c.events)
for {
b, err := c.readAll(context.Background())
b, err := readAll(conn, context.Background())
if err != nil {
// Is the error an EOF?
if stderr.Is(err, io.EOF) {
@ -97,18 +98,15 @@ func (c *Conn) readLoop() {
// Check if the error is a fatal one
if code := websocket.CloseStatus(err); code > -1 {
// Is the exit unusual?
if code != websocket.StatusNormalClosure {
// Unusual error, log
c.events <- Event{nil, errors.Wrap(err, "WS fatal")}
// Is the exit normal?
if code == websocket.StatusNormalClosure {
return
}
return
}
// or it's not fatal, we just log and continue
// Unusual error; log:
c.events <- Event{nil, errors.Wrap(err, "WS error")}
continue
return
}
c.events <- Event{b, nil}
@ -116,8 +114,8 @@ func (c *Conn) readLoop() {
}()
}
func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
t, r, err := c.Conn.Reader(ctx)
func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) {
t, r, err := c.Reader(ctx)
if err != nil {
return nil, err
}
@ -126,7 +124,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
// Probably a zlib payload
z, err := zlib.NewReader(r)
if err != nil {
c.Conn.CloseRead(ctx)
c.CloseRead(ctx)
return nil,
errors.Wrap(err, "Failed to create a zlib reader")
}
@ -137,7 +135,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
c.Conn.CloseRead(ctx)
c.CloseRead(ctx)
return nil, err
}
@ -145,9 +143,6 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
}
func (c *Conn) Send(ctx context.Context, b []byte) error {
c.mut.Lock()
defer c.mut.Unlock()
// TODO: zlib stream
return c.Conn.Write(ctx, websocket.MessageText, b)
}
@ -155,14 +150,14 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
func (c *Conn) Close(err error) error {
// Wait for the read loop to exit after exiting.
defer func() {
c.mut.Lock()
defer c.mut.Unlock()
<-c.events
c.events = nil
// Set the connection to nil.
c.Conn = nil
// Flush all events.
c.flush()
}()
if err == nil {
@ -176,14 +171,3 @@ func (c *Conn) Close(err error) error {
return c.Conn.Close(websocket.StatusProtocolError, msg)
}
func (c *Conn) flush() {
for {
select {
case <-c.events:
continue
default:
return
}
}
}