diff --git a/gateway/gateway.go b/gateway/gateway.go index 6c01da1..a11b32b 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -86,7 +86,8 @@ type Gateway struct { ErrorLog func(err error) // default to log.Println // FatalError is where Reconnect errors will go to. When an error is sent - // here, the Gateway is already dead. This channel is buffered once. + // here, the Gateway is already dead, so Close() shouldn't be called. + // This channel is buffered once. FatalError <-chan error fatalError chan error @@ -250,7 +251,8 @@ func (g *Gateway) Start() error { } // Wait blocks until the Gateway fatally exits when it couldn't reconnect -// anymore. To use this withh other channels, check out g.FatalError. +// anymore. To use this withh other channels, check out g.FatalError. If a +// non-nil error is returned, Close() shouldn't be called again. func (g *Gateway) Wait() error { return <-g.FatalError } @@ -281,17 +283,23 @@ func (g *Gateway) start() error { } } - // Expect at least one event - ev := <-ch + // Expect either READY or RESUMED before continuing. + WSDebug("Waiting for either READY or RESUMED.") - // Check for error - if ev.Error != nil { - return errors.Wrap(ev.Error, "First error") - } + err := WaitForEvent(g, ch, func(op *OP) bool { + switch op.EventName { + case "READY": + WSDebug("Found READY event.") + return true + case "RESUMED": + WSDebug("Found RESUMED event.") + return true + } + return false + }) - // Handle the event - if err := HandleEvent(g, ev.Data); err != nil { - return errors.Wrap(err, "WS handler error on first event") + if err != nil { + return errors.Wrap(err, "First error") } // Start the pacemaker with the heartrate received from Hello, after @@ -345,17 +353,8 @@ func (g *Gateway) eventLoop() error { return errors.Wrap(err, "Pacemaker died, reconnecting") case ev := <-ch: - // Check for error - if ev.Error != nil { - return ev.Error - } - - if len(ev.Data) == 0 { - return errors.New("Event data is empty, reconnecting.") - } - // Handle the event - if err := HandleEvent(g, ev.Data); err != nil { + if err := HandleEvent(g, ev); err != nil { g.ErrorLog(errors.Wrap(err, "WS handler error")) } } diff --git a/gateway/op.go b/gateway/op.go index 1f56db1..3e6c8ef 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -40,30 +40,7 @@ type OP struct { var ErrInvalidSession = errors.New("Invalid session") -func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { - if ev.Error != nil { - return nil, ev.Error - } - - if len(ev.Data) == 0 { - return nil, errors.New("Empty payload") - } - - var op *OP - if err := driver.Unmarshal(ev.Data, &op); err != nil { - return nil, errors.Wrap(err, "Failed to decode payload") - } - - if op.Code == InvalidSessionOP { - return op, ErrInvalidSession - } - - return op, nil -} - -func DecodeEvent(driver json.Driver, - ev wsutil.Event, v interface{}) (OPCode, error) { - +func DecodeEvent(driver json.Driver, ev wsutil.Event, v interface{}) (OPCode, error) { op, err := DecodeOP(driver, ev) if err != nil { return 0, err @@ -76,9 +53,7 @@ func DecodeEvent(driver json.Driver, return op.Code, nil } -func AssertEvent(driver json.Driver, - ev wsutil.Event, code OPCode, v interface{}) (*OP, error) { - +func AssertEvent(driver json.Driver, ev wsutil.Event, code OPCode, v interface{}) (*OP, error) { op, err := DecodeOP(driver, ev) if err != nil { return nil, err @@ -98,18 +73,60 @@ func AssertEvent(driver json.Driver, return op, nil } -func HandleEvent(g *Gateway, data []byte) error { - if len(data) == 0 { - return ErrInvalidSession +func HandleEvent(g *Gateway, ev wsutil.Event) error { + o, err := DecodeOP(g.Driver, ev) + if err != nil { + return err + } + + return HandleOP(g, o) +} + +// WaitForEvent blocks until fn() returns true. All incoming events are handled +// regardless. +func WaitForEvent(g *Gateway, ch <-chan wsutil.Event, fn func(*OP) bool) error { + for ev := range ch { + o, err := DecodeOP(g.Driver, ev) + if err != nil { + return err + } + + // Are these events what we're looking for? + found := fn(o) + + // Handle the *OP anyway. + if err := HandleOP(g, o); err != nil { + return err + } + + // If we found the event, return. + if found { + return nil + } + } + + return errors.New("Event not found and event channel is closed.") +} + +func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { + if ev.Error != nil { + return nil, ev.Error + } + + if len(ev.Data) == 0 { + return nil, errors.New("Empty payload") } - // Parse the raw data into an OP struct var op *OP - if err := g.Driver.Unmarshal(data, &op); err != nil { - return errors.Wrap(err, "OP error: "+string(data)) + if err := driver.Unmarshal(ev.Data, &op); err != nil { + return nil, errors.Wrap(err, "OP error: "+string(ev.Data)) } - return HandleOP(g, op) + if op.Code == InvalidSessionOP { + return op, ErrInvalidSession + } + + return op, nil } func HandleOP(g *Gateway, op *OP) error { diff --git a/handler/handler.go b/handler/handler.go index eb264a2..bb0e858 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -121,6 +121,8 @@ func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, ca return } +// AddHandler adds the handler, returning a function that would remove this +// handler when called. func (h *Handler) AddHandler(handler interface{}) (rm func()) { rm, err := h.addHandler(handler) if err != nil { diff --git a/session/session.go b/session/session.go index 7c21d22..bf4768f 100644 --- a/session/session.go +++ b/session/session.go @@ -87,14 +87,15 @@ func NewWithGateway(gw *gateway.Gateway) *Session { } func (s *Session) Open() error { - if err := s.Gateway.Open(); err != nil { - return errors.Wrap(err, "Failed to start gateway") - } - + // Start the handler beforehand so no events are missed. stop := make(chan struct{}) s.hstop = stop go s.startHandler(stop) + if err := s.Gateway.Open(); err != nil { + return errors.Wrap(err, "Failed to start gateway") + } + return nil }