diff --git a/session/session.go b/session/session.go index a0c9195..e24010a 100644 --- a/session/session.go +++ b/session/session.go @@ -19,6 +19,10 @@ import ( // ErrMFA is returned if the account requires a 2FA code to log in. var ErrMFA = errors.New("account has 2FA enabled") +// ErrClosed is returned if the Session is closed, either because it's already +// closed (and Close is being called again) or it was never started. +var ErrClosed = errors.New("Session is closed") + // Session manages both the API and Gateway. As such, Session inherits all of // API's methods, as well has the Handler used for Gateway. type Session struct { @@ -189,8 +193,64 @@ func (s *Session) gatewayIsAlive() bool { } } +// Connect opens the Discord gateway and waits until an unrecoverable error +// occurs. Always prefer this method over Open. Note that Connect will return +// when ctx is done or when s.Close is called. +// +// As an odd case, when ctx is done and if the gateway is already finished +// connecting, then a nil error will be returned (unless the gateway has an +// error). This is contrary to the common behavior of a ctx function returning +// ctx.Err(). +func (s *Session) Connect(ctx context.Context) error { + if err := s.Open(ctx); err != nil { + return err + } + + if err := s.Wait(ctx); err != nil && ctx.Err() == nil { + return err + } + + return nil +} + +func (s *Session) initConnect(ctx context.Context) (<-chan struct{}, error) { + evCh := make(chan interface{}) + + s.state.Lock() + defer s.state.Unlock() + + if s.state.cancel != nil { + if err := s.close(); err != nil { + return nil, err + } + } + + if s.state.gateway == nil { + g, err := gateway.NewWithIdentifier(ctx, s.state.id) + if err != nil { + return nil, err + } + s.state.gateway = g + } + + ctx, cancel := context.WithCancel(context.Background()) + s.state.ctx = ctx + s.state.cancel = cancel + + // TODO: change this to AddSyncHandler. + rm := s.AddHandler(evCh) + defer rm() + + opCh := s.state.gateway.Connect(s.state.ctx) + + doneCh := ophandler.Loop(opCh, s.Handler) + s.state.doneCh = doneCh + + return doneCh, nil +} + // Open opens the Discord gateway and its handler, then waits until either the -// Ready or Resumed event gets through. +// Ready or Resumed event gets through. Prefer using Connect instead of Open. func (s *Session) Open(ctx context.Context) error { evCh := make(chan interface{}) @@ -198,7 +258,7 @@ func (s *Session) Open(ctx context.Context) error { defer s.state.Unlock() if s.state.cancel != nil { - if err := s.close(ctx); err != nil { + if err := s.close(); err != nil { return err } } @@ -225,7 +285,7 @@ func (s *Session) Open(ctx context.Context) error { for { select { case <-ctx.Done(): - s.close(ctx) + s.close() return ctx.Err() case <-s.state.doneCh: @@ -241,6 +301,34 @@ func (s *Session) Open(ctx context.Context) error { } } +// Wait blocks until either ctx is done or the gateway stumbles on an +// unrecoverable error. +func (s *Session) Wait(ctx context.Context) error { + s.state.Lock() + doneCh := s.state.doneCh + s.state.Unlock() + + if doneCh == nil { + return ErrClosed + } + + for { + select { + case <-ctx.Done(): + s.Close() + // Prefer gateway errors over context errors. + if err := s.GatewayError(); err != nil { + return err + } + return ctx.Err() + + case <-doneCh: + // Event loop died. + return s.GatewayError() + } + } +} + // WithContext returns a shallow copy of Session with the context replaced in // the API client. All methods called on the returned Session will use this // given context. @@ -261,26 +349,19 @@ func (s *Session) Close() error { s.state.Lock() defer s.state.Unlock() - return s.close(context.Background()) + return s.close() } -func (s *Session) close(ctx context.Context) error { +func (s *Session) close() error { if s.state.cancel == nil { - return errors.New("Session is already closed") + return ErrClosed } s.state.cancel() s.state.cancel = nil s.state.ctx = nil - // Wait until we've successfully disconnected. - select { - case <-ctx.Done(): - return errors.Wrap(ctx.Err(), "cannot wait for gateway exit") - case <-s.state.doneCh: - // ok - } - + <-s.state.doneCh s.state.doneCh = nil return s.state.gateway.LastError() diff --git a/session/session_test.go b/session/session_test.go index a7b9ceb..6b5fe05 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -34,7 +34,7 @@ func TestSession(t *testing.T) { } if ready, ok := <-readyCh; !ok { - t.Fatal("ready not received") + t.Error("ready not received") } else { now := time.Now() t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username) @@ -48,3 +48,47 @@ func TestSession(t *testing.T) { time.Sleep(time.Second) } } + +func TestSessionConnect(t *testing.T) { + attempts := 1 + timeout := 15 * time.Second + + if !testing.Short() { + attempts = 5 + timeout = time.Minute // 5s-10s each reconnection + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + + env := testenv.Must(t) + + readyCh := make(chan *gateway.ReadyEvent, 1) + + s := NewWithIntents(env.BotToken, gateway.IntentGuilds) + s.AddHandler(readyCh) + + for i := 0; i < attempts; i++ { + ctx, cancel := context.WithCancel(ctx) + + go func() { + select { + case ready := <-readyCh: + now := time.Now() + t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username) + cancel() + case <-ctx.Done(): + t.Error("ready not received") + } + }() + + if err := s.Connect(ctx); err != nil { + t.Fatal("failed to open:", err) + } + + cancel() + + // Hold for an additional one second. + time.Sleep(time.Second) + } +}