From 443ec791af61a50da9ed18b39485da3206e5c097 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Fri, 24 Apr 2020 15:09:05 -0700 Subject: [PATCH] Heart: Moved PacemakerLoop to wsutil, changed Gateway abstractions to generic ones --- gateway/gateway.go | 62 +++---------- gateway/op.go | 102 +-------------------- utils/heart/heart.go | 93 ------------------- utils/moreatomic/bool.go | 19 ++++ utils/moreatomic/mutex.go | 33 +++++++ utils/moreatomic/serial.go | 16 ++++ utils/moreatomic/snowflake.go | 17 ++++ utils/moreatomic/string.go | 18 ++++ utils/moreatomic/time.go | 46 ++++++++++ utils/wsutil/heart.go | 116 ++++++++++++++++++++++++ utils/wsutil/op.go | 164 ++++++++++++++++++++++++++++++++++ 11 files changed, 444 insertions(+), 242 deletions(-) create mode 100644 utils/moreatomic/bool.go create mode 100644 utils/moreatomic/mutex.go create mode 100644 utils/moreatomic/serial.go create mode 100644 utils/moreatomic/snowflake.go create mode 100644 utils/moreatomic/string.go create mode 100644 utils/moreatomic/time.go create mode 100644 utils/wsutil/heart.go create mode 100644 utils/wsutil/op.go diff --git a/gateway/gateway.go b/gateway/gateway.go index b48c685..889e2e2 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -88,8 +88,8 @@ type Gateway struct { SessionID string Identifier *Identifier - Pacemaker *Pacemaker Sequence *Sequence + PacerLoop *wsutil.PacemakerLoop ErrorLog func(err error) // default to log.Println @@ -102,7 +102,7 @@ type Gateway struct { // If this channel is non-nil, all incoming OP packets will also be sent // here. This should be buffered, so to not block the main loop. - OP chan *OP + OP chan *wsutil.OP // Mutex to hold off calls when the WS is not available. Doesn't block if // Start() is not called or Close() is called. Also doesn't block for @@ -110,7 +110,6 @@ type Gateway struct { // available sync.RWMutex // Filled by methods, internal use - paceDeath chan error waitGroup *sync.WaitGroup } @@ -148,7 +147,7 @@ func NewCustomGateway(gatewayURL, token string) *Gateway { // Close closes the underlying Websocket connection. func (g *Gateway) Close() error { // Check if the WS is already closed: - if g.waitGroup == nil && g.paceDeath == nil { + if g.waitGroup == nil && g.PacerLoop.Stopped() { wsutil.WSDebug("Gateway is already closed.") g.AfterClose(nil) @@ -156,11 +155,11 @@ func (g *Gateway) Close() error { } // If the pacemaker is running: - if g.paceDeath != nil { + if !g.PacerLoop.Stopped() { wsutil.WSDebug("Stopping pacemaker...") // Stop the pacemaker and the event handler - g.Pacemaker.Stop() + g.PacerLoop.Stop() wsutil.WSDebug("Stopped pacemaker.") } @@ -254,12 +253,6 @@ func (g *Gateway) Start() error { return nil } -// Wait is deprecated. The gateway will reconnect forever. This function will -// panic. -func (g *Gateway) Wait() error { - panic("Wait is deprecated. defer (*Gateway).Close() is required.") -} - func (g *Gateway) start() error { // This is where we'll get our events ch := g.WS.Listen() @@ -269,7 +262,7 @@ func (g *Gateway) start() error { // Wait for an OP 10 Hello var hello HelloEvent - if _, err := AssertEvent(<-ch, HelloOP, &hello); err != nil { + if _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello); err != nil { return errors.Wrap(err, "Error at Hello") } @@ -290,7 +283,7 @@ func (g *Gateway) start() error { wsutil.WSDebug("Waiting for either READY or RESUMED.") // WaitForEvent should - err := WaitForEvent(g, ch, func(op *OP) bool { + err := wsutil.WaitForEvent(g, ch, func(op *wsutil.OP) bool { switch op.EventName { case "READY": wsutil.WSDebug("Found READY event.") @@ -306,15 +299,8 @@ func (g *Gateway) start() error { return errors.Wrap(err, "First error") } - // Start the pacemaker with the heartrate received from Hello, after - // initializing everything. This ensures we only heartbeat if the websocket - // is authenticated. - g.Pacemaker = &Pacemaker{ - Heartrate: hello.HeartbeatInterval.Duration(), - Pace: g.Heartbeat, - } - // Pacemaker dies here, only when it's fatal. - g.paceDeath = g.Pacemaker.StartAsync(g.waitGroup) + // Use the pacemaker loop. + g.PacerLoop = wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, g) // Start the event handler, which also handles the pacemaker death signal. g.waitGroup.Add(1) @@ -327,7 +313,7 @@ func (g *Gateway) start() error { // handleWS uses the Websocket and parses them into g.Events. func (g *Gateway) handleWS() { - err := g.eventLoop() + err := g.PacerLoop.Run() g.waitGroup.Done() // mark so Close() can exit. wsutil.WSDebug("Event loop stopped.") @@ -338,34 +324,8 @@ func (g *Gateway) handleWS() { } } -func (g *Gateway) eventLoop() error { - ch := g.WS.Listen() - - for { - select { - case err := <-g.paceDeath: - // Got a paceDeath, we're exiting from here on out. - g.paceDeath = nil // mark - - if err == nil { - wsutil.WSDebug("Pacemaker stopped without errors.") - // No error, just exit normally. - return nil - } - - return errors.Wrap(err, "Pacemaker died, reconnecting") - - case ev := <-ch: - // Handle the event - if err := HandleEvent(g, ev); err != nil { - g.ErrorLog(errors.Wrap(err, "WS handler error")) - } - } - } -} - func (g *Gateway) Send(code OPCode, v interface{}) error { - var op = OP{ + var op = wsutil.OP{ Code: code, } diff --git a/gateway/op.go b/gateway/op.go index 219e973..aa6b138 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -10,7 +10,7 @@ import ( "github.com/pkg/errors" ) -type OPCode uint8 +type OPCode = wsutil.OPCode const ( DispatchOP OPCode = 0 // recv @@ -29,101 +29,7 @@ const ( GuildSubscriptionsOP OPCode = 14 ) -type OP struct { - Code OPCode `json:"op"` - Data json.Raw `json:"d,omitempty"` - - // Only for Dispatch (op 0) - Sequence int64 `json:"s,omitempty"` - EventName string `json:"t,omitempty"` -} - -func DecodeEvent(ev wsutil.Event, v interface{}) (OPCode, error) { - op, err := DecodeOP(ev) - if err != nil { - return 0, err - } - - if err := json.Unmarshal(op.Data, v); err != nil { - return 0, errors.Wrap(err, "Failed to decode data") - } - - return op.Code, nil -} - -func AssertEvent(ev wsutil.Event, code OPCode, v interface{}) (*OP, error) { - op, err := DecodeOP(ev) - if err != nil { - return nil, err - } - - if op.Code != code { - return op, fmt.Errorf( - "Unexpected OP Code: %d, expected %d (%s)", - op.Code, code, op.Data, - ) - } - - if err := json.Unmarshal(op.Data, v); err != nil { - return op, errors.Wrap(err, "Failed to decode data") - } - - return op, nil -} - -func HandleEvent(g *Gateway, ev wsutil.Event) error { - o, err := DecodeOP(ev) - if err != nil { - return err - } - - return HandleOP(g, o) -} - -// WaitForEvent blocks until fn() returns true. All incoming events are handled -// regardless. -func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error { - for ev := range ch { - o, err := DecodeOP(ev) - if err != nil { - return err - } - - // Handle the *OP first, in case it's an Invalid Session. This should - // also prevent a race condition with things that need Ready after - // Open(). - if err := HandleOP(g, o); err != nil { - return err - } - - // Are these events what we're looking for? If we've found the event, - // return. - if fn(o) { - return nil - } - } - - return errors.New("Event not found and event channel is closed.") -} - -func DecodeOP(ev wsutil.Event) (*OP, error) { - if ev.Error != nil { - return nil, ev.Error - } - - if len(ev.Data) == 0 { - return nil, errors.New("Empty payload") - } - - var op *OP - if err := json.Unmarshal(ev.Data, &op); err != nil { - return nil, errors.Wrap(err, "OP error: "+string(ev.Data)) - } - - return op, nil -} - -func HandleOP(g *Gateway, op *OP) error { +func (g *Gateway) HandleOP(op *wsutil.OP) error { if g.OP != nil { g.OP <- op } @@ -131,11 +37,11 @@ func HandleOP(g *Gateway, op *OP) error { switch op.Code { case HeartbeatAckOP: // Heartbeat from the server? - g.Pacemaker.Echo() + g.PacerLoop.Echo() case HeartbeatOP: // Server requesting a heartbeat. - return g.Pacemaker.Pace() + return g.PacerLoop.Pace() case ReconnectOP: // Server requests to reconnect, die and retry. diff --git a/utils/heart/heart.go b/utils/heart/heart.go index c1f91ae..c21f39b 100644 --- a/utils/heart/heart.go +++ b/utils/heart/heart.go @@ -7,7 +7,6 @@ import ( "sync/atomic" "time" - "github.com/diamondburned/arikawa/utils/wsutil" "github.com/pkg/errors" ) @@ -151,95 +150,3 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) { return p.death } - -// TODO API -type EventLoop interface { - Heartbeat() error - HandleEvent(ev wsutil.Event) error -} - -// PacemakerLoop provides an event loop with a pacemaker. -type PacemakerLoop struct { - pacemaker *Pacemaker // let's not copy this - pacedeath chan error - - events <-chan wsutil.Event - handler func(wsutil.Event) error - - ErrorLog func(error) -} - -func NewLoop(heartrate time.Duration, evs <-chan wsutil.Event, evl EventLoop) *PacemakerLoop { - pacemaker := NewPacemaker(heartrate, evl.Heartbeat) - - return &PacemakerLoop{ - pacemaker: pacemaker, - events: evs, - handler: evl.HandleEvent, - } -} - -func (p *PacemakerLoop) errorLog(err error) { - if p.ErrorLog == nil { - Debug("Uncaught error:", err) - return - } - - p.ErrorLog(err) -} - -func (p *PacemakerLoop) Echo() { - p.pacemaker.Echo() -} - -func (p *PacemakerLoop) Stop() { - p.pacemaker.Stop() -} - -func (p *PacemakerLoop) Stopped() bool { - return p.pacedeath == nil -} - -func (p *PacemakerLoop) Run() error { - // If the event loop is already running. - if p.pacedeath != nil { - return nil - } - // callers should explicitly handle waitgroups. - p.pacedeath = p.pacemaker.StartAsync(nil) - - defer func() { - // mark pacedeath once done - p.pacedeath = nil - - Debug("Pacemaker loop has exited.") - }() - - for { - select { - case err := <-p.pacedeath: - // Got a paceDeath, we're exiting from here on out. - p.pacedeath = nil // mark - - if err == nil { - // No error, just exit normally. - return nil - } - - return errors.Wrap(err, "Pacemaker died, reconnecting") - - case ev, ok := <-p.events: - if !ok { - // Events channel is closed. Kill the pacemaker manually and - // die. - p.pacemaker.Stop() - return <-p.pacedeath - } - - // Handle the event - if err := p.handler(ev); err != nil { - p.errorLog(errors.Wrap(err, "WS handler error")) - } - } - } -} diff --git a/utils/moreatomic/bool.go b/utils/moreatomic/bool.go new file mode 100644 index 0000000..0f89556 --- /dev/null +++ b/utils/moreatomic/bool.go @@ -0,0 +1,19 @@ +package moreatomic + +import "sync/atomic" + +type Bool struct { + val uint32 +} + +func (b *Bool) Get() bool { + return atomic.LoadUint32(&b.val) == 1 +} + +func (b *Bool) Set(val bool) { + var x = uint32(0) + if val { + x = 1 + } + atomic.StoreUint32(&b.val, x) +} diff --git a/utils/moreatomic/mutex.go b/utils/moreatomic/mutex.go new file mode 100644 index 0000000..4893997 --- /dev/null +++ b/utils/moreatomic/mutex.go @@ -0,0 +1,33 @@ +package moreatomic + +import "github.com/sasha-s/go-deadlock" + +type BusyMutex struct { + busy Bool + mut deadlock.Mutex +} + +func (m *BusyMutex) TryLock() bool { + if m.busy.Get() { + return false + } + + m.mut.Lock() + m.busy.Set(true) + + return true +} + +func (m *BusyMutex) IsBusy() bool { + return m.busy.Get() +} + +func (m *BusyMutex) Lock() { + m.mut.Lock() + m.busy.Set(true) +} + +func (m *BusyMutex) Unlock() { + m.busy.Set(false) + m.mut.Unlock() +} diff --git a/utils/moreatomic/serial.go b/utils/moreatomic/serial.go new file mode 100644 index 0000000..a2b9806 --- /dev/null +++ b/utils/moreatomic/serial.go @@ -0,0 +1,16 @@ +package moreatomic + +import "sync/atomic" + +type Serial struct { + serial uint32 +} + +func (s *Serial) Get() int { + return int(atomic.LoadUint32(&s.serial)) +} + +func (s *Serial) Incr() int { + atomic.AddUint32(&s.serial, 1) + return s.Get() +} diff --git a/utils/moreatomic/snowflake.go b/utils/moreatomic/snowflake.go new file mode 100644 index 0000000..6b8e762 --- /dev/null +++ b/utils/moreatomic/snowflake.go @@ -0,0 +1,17 @@ +package moreatomic + +import ( + "sync/atomic" + + "github.com/diamondburned/arikawa/discord" +) + +type Snowflake int64 + +func (s *Snowflake) Get() discord.Snowflake { + return discord.Snowflake(atomic.LoadInt64((*int64)(s))) +} + +func (s *Snowflake) Set(id discord.Snowflake) { + atomic.StoreInt64((*int64)(s), int64(id)) +} diff --git a/utils/moreatomic/string.go b/utils/moreatomic/string.go new file mode 100644 index 0000000..f0ac0c4 --- /dev/null +++ b/utils/moreatomic/string.go @@ -0,0 +1,18 @@ +package moreatomic + +import "sync/atomic" + +type String struct { + v atomic.Value +} + +func (s *String) Get() string { + if v, ok := s.v.Load().(string); ok { + return v + } + return "" +} + +func (s *String) Set(str string) { + s.v.Store(str) +} diff --git a/utils/moreatomic/time.go b/utils/moreatomic/time.go new file mode 100644 index 0000000..349010b --- /dev/null +++ b/utils/moreatomic/time.go @@ -0,0 +1,46 @@ +package moreatomic + +import ( + "sync/atomic" + "time" +) + +type Time struct { + unixnano int64 +} + +func Now() *Time { + return &Time{ + unixnano: time.Now().UnixNano(), + } +} + +func (t *Time) Get() time.Time { + nano := atomic.LoadInt64(&t.unixnano) + return time.Unix(0, nano) +} + +func (t *Time) Set(time time.Time) { + atomic.StoreInt64(&t.unixnano, time.UnixNano()) +} + +// HasBeen checks if it has been this long since the last time. If yes, it will +// set the time. +func (t *Time) HasBeen(dura time.Duration) bool { + now := time.Now() + nano := atomic.LoadInt64(&t.unixnano) + + // We have to be careful of zero values. + if nano != 0 { + // Subtract the duration to now. If subtracted now is before the stored + // time, that means it hasn't been that long yet. We also have to be careful + // of an unitialized time. + if now.Add(-dura).Before(time.Unix(0, nano)) { + return false + } + } + + // It has been that long, so store the variable. + t.Set(now) + return true +} diff --git a/utils/wsutil/heart.go b/utils/wsutil/heart.go new file mode 100644 index 0000000..432b5a4 --- /dev/null +++ b/utils/wsutil/heart.go @@ -0,0 +1,116 @@ +package wsutil + +import ( + "time" + + "github.com/diamondburned/arikawa/utils/heart" + "github.com/pkg/errors" +) + +// TODO API +type EventLoop interface { + Heartbeat() error + HandleOP(*OP) error + // HandleEvent(ev Event) error +} + +// PacemakerLoop provides an event loop with a pacemaker. +type PacemakerLoop struct { + pacemaker *heart.Pacemaker // let's not copy this + pacedeath chan error + + events <-chan Event + handler func(*OP) error + + Extras ExtraHandlers + + ErrorLog func(error) +} + +func NewLoop(heartrate time.Duration, evs <-chan Event, evl EventLoop) *PacemakerLoop { + pacemaker := heart.NewPacemaker(heartrate, evl.Heartbeat) + + return &PacemakerLoop{ + pacemaker: pacemaker, + events: evs, + handler: evl.HandleOP, + } +} + +func (p *PacemakerLoop) errorLog(err error) { + if p.ErrorLog == nil { + WSDebug("Uncaught error:", err) + return + } + + p.ErrorLog(err) +} + +func (p *PacemakerLoop) Pace() error { + return p.pacemaker.Pace() +} + +func (p *PacemakerLoop) Echo() { + p.pacemaker.Echo() +} + +func (p *PacemakerLoop) Stop() { + p.pacemaker.Stop() +} + +func (p *PacemakerLoop) Stopped() bool { + return p.pacedeath == nil +} + +func (p *PacemakerLoop) Run() error { + // If the event loop is already running. + if p.pacedeath != nil { + return nil + } + // callers should explicitly handle waitgroups. + p.pacedeath = p.pacemaker.StartAsync(nil) + + defer func() { + // mark pacedeath once done + p.pacedeath = nil + + WSDebug("Pacemaker loop has exited.") + }() + + for { + select { + case err := <-p.pacedeath: + // Got a paceDeath, we're exiting from here on out. + p.pacedeath = nil // mark + + if err == nil { + // No error, just exit normally. + return nil + } + + return errors.Wrap(err, "Pacemaker died, reconnecting") + + case ev, ok := <-p.events: + if !ok { + // Events channel is closed. Kill the pacemaker manually and + // die. + p.pacemaker.Stop() + return <-p.pacedeath + } + + o, err := DecodeOP(ev) + if err != nil { + p.errorLog(errors.Wrap(err, "Failed to decode OP")) + return err + } + + // Check the events before handling. + p.Extras.Check(o) + + // Handle the event + if err := p.handler(o); err != nil { + p.errorLog(errors.Wrap(err, "Handler failed")) + } + } + } +} diff --git a/utils/wsutil/op.go b/utils/wsutil/op.go new file mode 100644 index 0000000..ed080dc --- /dev/null +++ b/utils/wsutil/op.go @@ -0,0 +1,164 @@ +package wsutil + +import ( + "fmt" + "sync" + + "github.com/diamondburned/arikawa/utils/json" + "github.com/diamondburned/arikawa/utils/moreatomic" + "github.com/pkg/errors" +) + +var ErrEmptyPayload = errors.New("Empty payload") + +// OPCode is a generic type for websocket OP codes. +type OPCode uint8 + +type OP struct { + Code OPCode `json:"op"` + Data json.Raw `json:"d,omitempty"` + + // Only for Gateway Dispatch (op 0) + Sequence int64 `json:"s,omitempty"` + EventName string `json:"t,omitempty"` +} + +func (op *OP) UnmarshalData(v interface{}) error { + return json.Unmarshal(op.Data, v) +} + +func DecodeOP(ev Event) (*OP, error) { + if ev.Error != nil { + return nil, ev.Error + } + + if len(ev.Data) == 0 { + return nil, ErrEmptyPayload + } + + var op *OP + if err := json.Unmarshal(ev.Data, &op); err != nil { + return nil, errors.Wrap(err, "OP error: "+string(ev.Data)) + } + + return op, nil +} + +func AssertEvent(ev Event, code OPCode, v interface{}) (*OP, error) { + op, err := DecodeOP(ev) + if err != nil { + return nil, err + } + + if op.Code != code { + return op, fmt.Errorf( + "Unexpected OP Code: %d, expected %d (%s)", + op.Code, code, op.Data, + ) + } + + if err := json.Unmarshal(op.Data, v); err != nil { + return op, errors.Wrap(err, "Failed to decode data") + } + + return op, nil +} + +type EventHandler interface { + HandleOP(op *OP) error +} + +func HandleEvent(h EventHandler, ev Event) error { + o, err := DecodeOP(ev) + if err != nil { + return err + } + + return h.HandleOP(o) +} + +// WaitForEvent blocks until fn() returns true. All incoming events are handled +// regardless. +func WaitForEvent(h EventHandler, ch <-chan Event, fn func(*OP) bool) error { + for ev := range ch { + o, err := DecodeOP(ev) + if err != nil { + return err + } + + // Handle the *OP first, in case it's an Invalid Session. This should + // also prevent a race condition with things that need Ready after + // Open(). + if err := h.HandleOP(o); err != nil { + return err + } + + // Are these events what we're looking for? If we've found the event, + // return. + if fn(o) { + return nil + } + } + + return errors.New("Event not found and event channel is closed.") +} + +type ExtraHandlers struct { + mutex sync.Mutex + handlers map[uint32]*ExtraHandler + serial uint32 +} + +type ExtraHandler struct { + Check func(*OP) bool + send chan *OP + + closed moreatomic.Bool +} + +func (ex *ExtraHandlers) Add(check func(*OP) bool) (<-chan *OP, func()) { + handler := &ExtraHandler{ + Check: check, + send: make(chan *OP), + } + + ex.mutex.Lock() + defer ex.mutex.Unlock() + + i := ex.serial + ex.serial++ + + ex.handlers[i] = handler + + return handler.send, func() { + // Check the atomic bool before acquiring the mutex. Might help a bit in + // performance. + if handler.closed.Get() { + return + } + + ex.mutex.Lock() + defer ex.mutex.Unlock() + + delete(ex.handlers, i) + } +} + +// Check runs and sends OP data. It is not thread-safe. +func (ex *ExtraHandlers) Check(op *OP) { + ex.mutex.Lock() + defer ex.mutex.Unlock() + + for i, handler := range ex.handlers { + if handler.Check(op) { + // Attempt to send. + handler.send <- op + + // Mark the handler as closed. + handler.closed.Set(true) + + // Delete the handler. + delete(ex.handlers, i) + } + } +}