Fixed some race conditions

This commit is contained in:
diamondburned (Forefront) 2020-02-02 14:12:54 -08:00
parent 16cbbb4f25
commit 85b793a1a7
7 changed files with 41 additions and 48 deletions

View File

@ -160,6 +160,7 @@ func (g *Gateway) Close() error {
if g.done != nil { if g.done != nil {
// Wait for the event handler to fully exit // Wait for the event handler to fully exit
<-g.done <-g.done
g.done = nil
} }
// Stop the Websocket // Stop the Websocket
@ -198,6 +199,7 @@ func (g *Gateway) Open() error {
if err := g.WS.Dial(ctx); err != nil { if err := g.WS.Dial(ctx); err != nil {
// Save the error, retry again // Save the error, retry again
Lerr = errors.Wrap(err, "Failed to reconnect") Lerr = errors.Wrap(err, "Failed to reconnect")
g.ErrorLog(err)
continue continue
} }
@ -295,7 +297,6 @@ func (g *Gateway) handleWS() {
defer func() { defer func() {
g.done <- struct{}{} g.done <- struct{}{}
g.done = nil
}() }()
for { for {
@ -306,6 +307,8 @@ func (g *Gateway) handleWS() {
return return
} }
g.ErrorLog(errors.Wrap(err, "Pacemaker died"))
// Pacemaker died, pretty fatal. We'll reconnect though. // Pacemaker died, pretty fatal. We'll reconnect though.
if err := g.Reconnect(); err != nil { if err := g.Reconnect(); err != nil {
// Very fatal if this fails. We'll warn the user. // Very fatal if this fails. We'll warn the user.

View File

@ -16,7 +16,7 @@ func TestIntegration(t *testing.T) {
} }
WSError = func(err error) { WSError = func(err error) {
log.Println(err) t.Fatal(err)
} }
var gateway *Gateway var gateway *Gateway
@ -49,14 +49,20 @@ func TestIntegration(t *testing.T) {
t.Fatal("Failed to reconnect:", err) t.Fatal("Failed to reconnect:", err)
} }
/* TODO: We're not testing this, as Discord will replay events before it timeout := time.After(10 * time.Second)
* sends the Resumed event.
resumed, ok := (<-gateway.Events).(*ResumedEvent) Main:
if !ok { for {
t.Fatal("Event received is not of type Resumed:", resumed) select {
case ev := <-gateway.Events:
switch ev.(type) {
case *ResumedEvent, *ReadyEvent:
break Main
}
case <-timeout:
t.Fatal("Timed out waiting for ResumedEvent")
}
} }
*/
if err := g.Close(); err != nil { if err := g.Close(); err != nil {
t.Fatal("Failed to close Gateway:", err) t.Fatal("Failed to close Gateway:", err)

View File

@ -99,7 +99,7 @@ func HandleEvent(g *Gateway, data []byte) error {
// Parse the raw data into an OP struct // Parse the raw data into an OP struct
var op *OP var op *OP
if err := g.Driver.Unmarshal(data, &op); err != nil { if err := g.Driver.Unmarshal(data, &op); err != nil {
return errors.Wrap(err, "OP error") return errors.Wrap(err, "OP error: "+string(data))
} }
return HandleOP(g, op) return HandleOP(g, op)

View File

@ -2,13 +2,11 @@ package gateway
import "sync/atomic" import "sync/atomic"
type Sequence struct { type Sequence int64
seq int64
}
func NewSequence() *Sequence { func NewSequence() *Sequence {
return &Sequence{0} return (*Sequence)(new(int64))
} }
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) } func (s *Sequence) Set(seq int64) { atomic.StoreInt64((*int64)(s), seq) }
func (s *Sequence) Get() int64 { return atomic.LoadInt64(&s.seq) } func (s *Sequence) Get() int64 { return atomic.LoadInt64((*int64)(s)) }

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/diamondburned/arikawa
go 1.13 go 1.13
require ( require (
github.com/davecgh/go-spew v1.1.1
github.com/gorilla/schema v1.1.0 github.com/gorilla/schema v1.1.0
github.com/pkg/errors v0.8.1 github.com/pkg/errors v0.8.1
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa

1
go.sum
View File

@ -37,6 +37,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -15,7 +15,6 @@ import (
"nhooyr.io/websocket" "nhooyr.io/websocket"
) )
var WSBuffer = 12
var WSReadLimit int64 = 8192000 // 8 MiB var WSReadLimit int64 = 8192000 // 8 MiB
// Connection is an interface that abstracts around a generic Websocket driver. // Connection is an interface that abstracts around a generic Websocket driver.
@ -55,7 +54,7 @@ var _ Connection = (*Conn)(nil)
func NewConn(driver json.Driver) *Conn { func NewConn(driver json.Driver) *Conn {
return &Conn{ return &Conn{
Driver: driver, Driver: driver,
events: make(chan Event, WSBuffer), events: make(chan Event),
} }
} }
@ -73,7 +72,7 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
}) })
c.Conn.SetReadLimit(WSReadLimit) c.Conn.SetReadLimit(WSReadLimit)
c.events = make(chan Event, WSBuffer) c.events = make(chan Event)
c.readLoop() c.readLoop()
return err return err
} }
@ -83,11 +82,13 @@ func (c *Conn) Listen() <-chan Event {
} }
func (c *Conn) readLoop() { func (c *Conn) readLoop() {
conn := c.Conn
go func() { go func() {
defer close(c.events) defer close(c.events)
for { for {
b, err := c.readAll(context.Background()) b, err := readAll(conn, context.Background())
if err != nil { if err != nil {
// Is the error an EOF? // Is the error an EOF?
if stderr.Is(err, io.EOF) { if stderr.Is(err, io.EOF) {
@ -97,18 +98,15 @@ func (c *Conn) readLoop() {
// Check if the error is a fatal one // Check if the error is a fatal one
if code := websocket.CloseStatus(err); code > -1 { if code := websocket.CloseStatus(err); code > -1 {
// Is the exit unusual? // Is the exit normal?
if code != websocket.StatusNormalClosure { if code == websocket.StatusNormalClosure {
// Unusual error, log return
c.events <- Event{nil, errors.Wrap(err, "WS fatal")}
} }
return
} }
// or it's not fatal, we just log and continue // Unusual error; log:
c.events <- Event{nil, errors.Wrap(err, "WS error")} c.events <- Event{nil, errors.Wrap(err, "WS error")}
continue return
} }
c.events <- Event{b, nil} c.events <- Event{b, nil}
@ -116,8 +114,8 @@ func (c *Conn) readLoop() {
}() }()
} }
func (c *Conn) readAll(ctx context.Context) ([]byte, error) { func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) {
t, r, err := c.Conn.Reader(ctx) t, r, err := c.Reader(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -126,7 +124,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
// Probably a zlib payload // Probably a zlib payload
z, err := zlib.NewReader(r) z, err := zlib.NewReader(r)
if err != nil { if err != nil {
c.Conn.CloseRead(ctx) c.CloseRead(ctx)
return nil, return nil,
errors.Wrap(err, "Failed to create a zlib reader") errors.Wrap(err, "Failed to create a zlib reader")
} }
@ -137,7 +135,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
b, err := ioutil.ReadAll(r) b, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
c.Conn.CloseRead(ctx) c.CloseRead(ctx)
return nil, err return nil, err
} }
@ -145,9 +143,6 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
} }
func (c *Conn) Send(ctx context.Context, b []byte) error { func (c *Conn) Send(ctx context.Context, b []byte) error {
c.mut.Lock()
defer c.mut.Unlock()
// TODO: zlib stream // TODO: zlib stream
return c.Conn.Write(ctx, websocket.MessageText, b) return c.Conn.Write(ctx, websocket.MessageText, b)
} }
@ -155,14 +150,14 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
func (c *Conn) Close(err error) error { func (c *Conn) Close(err error) error {
// Wait for the read loop to exit after exiting. // Wait for the read loop to exit after exiting.
defer func() { defer func() {
c.mut.Lock()
defer c.mut.Unlock()
<-c.events <-c.events
c.events = nil c.events = nil
// Set the connection to nil. // Set the connection to nil.
c.Conn = nil c.Conn = nil
// Flush all events.
c.flush()
}() }()
if err == nil { if err == nil {
@ -176,14 +171,3 @@ func (c *Conn) Close(err error) error {
return c.Conn.Close(websocket.StatusProtocolError, msg) return c.Conn.Close(websocket.StatusProtocolError, msg)
} }
func (c *Conn) flush() {
for {
select {
case <-c.events:
continue
default:
return
}
}
}