Fixed some race conditions
This commit is contained in:
parent
16cbbb4f25
commit
85b793a1a7
|
@ -160,6 +160,7 @@ func (g *Gateway) Close() error {
|
|||
if g.done != nil {
|
||||
// Wait for the event handler to fully exit
|
||||
<-g.done
|
||||
g.done = nil
|
||||
}
|
||||
|
||||
// Stop the Websocket
|
||||
|
@ -198,6 +199,7 @@ func (g *Gateway) Open() error {
|
|||
if err := g.WS.Dial(ctx); err != nil {
|
||||
// Save the error, retry again
|
||||
Lerr = errors.Wrap(err, "Failed to reconnect")
|
||||
g.ErrorLog(err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -295,7 +297,6 @@ func (g *Gateway) handleWS() {
|
|||
|
||||
defer func() {
|
||||
g.done <- struct{}{}
|
||||
g.done = nil
|
||||
}()
|
||||
|
||||
for {
|
||||
|
@ -306,6 +307,8 @@ func (g *Gateway) handleWS() {
|
|||
return
|
||||
}
|
||||
|
||||
g.ErrorLog(errors.Wrap(err, "Pacemaker died"))
|
||||
|
||||
// Pacemaker died, pretty fatal. We'll reconnect though.
|
||||
if err := g.Reconnect(); err != nil {
|
||||
// Very fatal if this fails. We'll warn the user.
|
||||
|
|
|
@ -16,7 +16,7 @@ func TestIntegration(t *testing.T) {
|
|||
}
|
||||
|
||||
WSError = func(err error) {
|
||||
log.Println(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var gateway *Gateway
|
||||
|
@ -49,14 +49,20 @@ func TestIntegration(t *testing.T) {
|
|||
t.Fatal("Failed to reconnect:", err)
|
||||
}
|
||||
|
||||
/* TODO: We're not testing this, as Discord will replay events before it
|
||||
* sends the Resumed event.
|
||||
timeout := time.After(10 * time.Second)
|
||||
|
||||
resumed, ok := (<-gateway.Events).(*ResumedEvent)
|
||||
if !ok {
|
||||
t.Fatal("Event received is not of type Resumed:", resumed)
|
||||
Main:
|
||||
for {
|
||||
select {
|
||||
case ev := <-gateway.Events:
|
||||
switch ev.(type) {
|
||||
case *ResumedEvent, *ReadyEvent:
|
||||
break Main
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatal("Timed out waiting for ResumedEvent")
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
if err := g.Close(); err != nil {
|
||||
t.Fatal("Failed to close Gateway:", err)
|
||||
|
|
|
@ -99,7 +99,7 @@ func HandleEvent(g *Gateway, data []byte) error {
|
|||
// Parse the raw data into an OP struct
|
||||
var op *OP
|
||||
if err := g.Driver.Unmarshal(data, &op); err != nil {
|
||||
return errors.Wrap(err, "OP error")
|
||||
return errors.Wrap(err, "OP error: "+string(data))
|
||||
}
|
||||
|
||||
return HandleOP(g, op)
|
||||
|
|
|
@ -2,13 +2,11 @@ package gateway
|
|||
|
||||
import "sync/atomic"
|
||||
|
||||
type Sequence struct {
|
||||
seq int64
|
||||
}
|
||||
type Sequence int64
|
||||
|
||||
func NewSequence() *Sequence {
|
||||
return &Sequence{0}
|
||||
return (*Sequence)(new(int64))
|
||||
}
|
||||
|
||||
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) }
|
||||
func (s *Sequence) Get() int64 { return atomic.LoadInt64(&s.seq) }
|
||||
func (s *Sequence) Set(seq int64) { atomic.StoreInt64((*int64)(s), seq) }
|
||||
func (s *Sequence) Get() int64 { return atomic.LoadInt64((*int64)(s)) }
|
||||
|
|
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module github.com/diamondburned/arikawa
|
|||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/gorilla/schema v1.1.0
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
||||
|
|
1
go.sum
1
go.sum
|
@ -37,6 +37,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
var WSBuffer = 12
|
||||
var WSReadLimit int64 = 8192000 // 8 MiB
|
||||
|
||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||
|
@ -55,7 +54,7 @@ var _ Connection = (*Conn)(nil)
|
|||
func NewConn(driver json.Driver) *Conn {
|
||||
return &Conn{
|
||||
Driver: driver,
|
||||
events: make(chan Event, WSBuffer),
|
||||
events: make(chan Event),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,7 +72,7 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
|||
})
|
||||
c.Conn.SetReadLimit(WSReadLimit)
|
||||
|
||||
c.events = make(chan Event, WSBuffer)
|
||||
c.events = make(chan Event)
|
||||
c.readLoop()
|
||||
return err
|
||||
}
|
||||
|
@ -83,11 +82,13 @@ func (c *Conn) Listen() <-chan Event {
|
|||
}
|
||||
|
||||
func (c *Conn) readLoop() {
|
||||
conn := c.Conn
|
||||
|
||||
go func() {
|
||||
defer close(c.events)
|
||||
|
||||
for {
|
||||
b, err := c.readAll(context.Background())
|
||||
b, err := readAll(conn, context.Background())
|
||||
if err != nil {
|
||||
// Is the error an EOF?
|
||||
if stderr.Is(err, io.EOF) {
|
||||
|
@ -97,18 +98,15 @@ func (c *Conn) readLoop() {
|
|||
|
||||
// Check if the error is a fatal one
|
||||
if code := websocket.CloseStatus(err); code > -1 {
|
||||
// Is the exit unusual?
|
||||
if code != websocket.StatusNormalClosure {
|
||||
// Unusual error, log
|
||||
c.events <- Event{nil, errors.Wrap(err, "WS fatal")}
|
||||
// Is the exit normal?
|
||||
if code == websocket.StatusNormalClosure {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// or it's not fatal, we just log and continue
|
||||
// Unusual error; log:
|
||||
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
c.events <- Event{b, nil}
|
||||
|
@ -116,8 +114,8 @@ func (c *Conn) readLoop() {
|
|||
}()
|
||||
}
|
||||
|
||||
func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||
t, r, err := c.Conn.Reader(ctx)
|
||||
func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) {
|
||||
t, r, err := c.Reader(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -126,7 +124,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
// Probably a zlib payload
|
||||
z, err := zlib.NewReader(r)
|
||||
if err != nil {
|
||||
c.Conn.CloseRead(ctx)
|
||||
c.CloseRead(ctx)
|
||||
return nil,
|
||||
errors.Wrap(err, "Failed to create a zlib reader")
|
||||
}
|
||||
|
@ -137,7 +135,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
c.Conn.CloseRead(ctx)
|
||||
c.CloseRead(ctx)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -145,9 +143,6 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
// TODO: zlib stream
|
||||
return c.Conn.Write(ctx, websocket.MessageText, b)
|
||||
}
|
||||
|
@ -155,14 +150,14 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
|
|||
func (c *Conn) Close(err error) error {
|
||||
// Wait for the read loop to exit after exiting.
|
||||
defer func() {
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
<-c.events
|
||||
c.events = nil
|
||||
|
||||
// Set the connection to nil.
|
||||
c.Conn = nil
|
||||
|
||||
// Flush all events.
|
||||
c.flush()
|
||||
}()
|
||||
|
||||
if err == nil {
|
||||
|
@ -176,14 +171,3 @@ func (c *Conn) Close(err error) error {
|
|||
|
||||
return c.Conn.Close(websocket.StatusProtocolError, msg)
|
||||
}
|
||||
|
||||
func (c *Conn) flush() {
|
||||
for {
|
||||
select {
|
||||
case <-c.events:
|
||||
continue
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue