mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-03-22 18:09:21 +00:00
session: Add Connect and Wait
This commit adds the Connect and Wait methods into session. This gives the user a way to block the program until the session runs into an error or the given ctx is done. In most cases, Connect is useful when combined with signal.NotifyContext, and so Connect is preferred over Open.
This commit is contained in:
parent
0ae6f6690a
commit
258b6149d7
|
@ -19,6 +19,10 @@ import (
|
|||
// ErrMFA is returned if the account requires a 2FA code to log in.
|
||||
var ErrMFA = errors.New("account has 2FA enabled")
|
||||
|
||||
// ErrClosed is returned if the Session is closed, either because it's already
|
||||
// closed (and Close is being called again) or it was never started.
|
||||
var ErrClosed = errors.New("Session is closed")
|
||||
|
||||
// Session manages both the API and Gateway. As such, Session inherits all of
|
||||
// API's methods, as well has the Handler used for Gateway.
|
||||
type Session struct {
|
||||
|
@ -189,8 +193,64 @@ func (s *Session) gatewayIsAlive() bool {
|
|||
}
|
||||
}
|
||||
|
||||
// Connect opens the Discord gateway and waits until an unrecoverable error
|
||||
// occurs. Always prefer this method over Open. Note that Connect will return
|
||||
// when ctx is done or when s.Close is called.
|
||||
//
|
||||
// As an odd case, when ctx is done and if the gateway is already finished
|
||||
// connecting, then a nil error will be returned (unless the gateway has an
|
||||
// error). This is contrary to the common behavior of a ctx function returning
|
||||
// ctx.Err().
|
||||
func (s *Session) Connect(ctx context.Context) error {
|
||||
if err := s.Open(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.Wait(ctx); err != nil && ctx.Err() == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) initConnect(ctx context.Context) (<-chan struct{}, error) {
|
||||
evCh := make(chan interface{})
|
||||
|
||||
s.state.Lock()
|
||||
defer s.state.Unlock()
|
||||
|
||||
if s.state.cancel != nil {
|
||||
if err := s.close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if s.state.gateway == nil {
|
||||
g, err := gateway.NewWithIdentifier(ctx, s.state.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.state.gateway = g
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.state.ctx = ctx
|
||||
s.state.cancel = cancel
|
||||
|
||||
// TODO: change this to AddSyncHandler.
|
||||
rm := s.AddHandler(evCh)
|
||||
defer rm()
|
||||
|
||||
opCh := s.state.gateway.Connect(s.state.ctx)
|
||||
|
||||
doneCh := ophandler.Loop(opCh, s.Handler)
|
||||
s.state.doneCh = doneCh
|
||||
|
||||
return doneCh, nil
|
||||
}
|
||||
|
||||
// Open opens the Discord gateway and its handler, then waits until either the
|
||||
// Ready or Resumed event gets through.
|
||||
// Ready or Resumed event gets through. Prefer using Connect instead of Open.
|
||||
func (s *Session) Open(ctx context.Context) error {
|
||||
evCh := make(chan interface{})
|
||||
|
||||
|
@ -198,7 +258,7 @@ func (s *Session) Open(ctx context.Context) error {
|
|||
defer s.state.Unlock()
|
||||
|
||||
if s.state.cancel != nil {
|
||||
if err := s.close(ctx); err != nil {
|
||||
if err := s.close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -225,7 +285,7 @@ func (s *Session) Open(ctx context.Context) error {
|
|||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.close(ctx)
|
||||
s.close()
|
||||
return ctx.Err()
|
||||
|
||||
case <-s.state.doneCh:
|
||||
|
@ -241,6 +301,34 @@ func (s *Session) Open(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Wait blocks until either ctx is done or the gateway stumbles on an
|
||||
// unrecoverable error.
|
||||
func (s *Session) Wait(ctx context.Context) error {
|
||||
s.state.Lock()
|
||||
doneCh := s.state.doneCh
|
||||
s.state.Unlock()
|
||||
|
||||
if doneCh == nil {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.Close()
|
||||
// Prefer gateway errors over context errors.
|
||||
if err := s.GatewayError(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.Err()
|
||||
|
||||
case <-doneCh:
|
||||
// Event loop died.
|
||||
return s.GatewayError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext returns a shallow copy of Session with the context replaced in
|
||||
// the API client. All methods called on the returned Session will use this
|
||||
// given context.
|
||||
|
@ -261,26 +349,19 @@ func (s *Session) Close() error {
|
|||
s.state.Lock()
|
||||
defer s.state.Unlock()
|
||||
|
||||
return s.close(context.Background())
|
||||
return s.close()
|
||||
}
|
||||
|
||||
func (s *Session) close(ctx context.Context) error {
|
||||
func (s *Session) close() error {
|
||||
if s.state.cancel == nil {
|
||||
return errors.New("Session is already closed")
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
s.state.cancel()
|
||||
s.state.cancel = nil
|
||||
s.state.ctx = nil
|
||||
|
||||
// Wait until we've successfully disconnected.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "cannot wait for gateway exit")
|
||||
case <-s.state.doneCh:
|
||||
// ok
|
||||
}
|
||||
|
||||
<-s.state.doneCh
|
||||
s.state.doneCh = nil
|
||||
|
||||
return s.state.gateway.LastError()
|
||||
|
|
|
@ -34,7 +34,7 @@ func TestSession(t *testing.T) {
|
|||
}
|
||||
|
||||
if ready, ok := <-readyCh; !ok {
|
||||
t.Fatal("ready not received")
|
||||
t.Error("ready not received")
|
||||
} else {
|
||||
now := time.Now()
|
||||
t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username)
|
||||
|
@ -48,3 +48,47 @@ func TestSession(t *testing.T) {
|
|||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionConnect(t *testing.T) {
|
||||
attempts := 1
|
||||
timeout := 15 * time.Second
|
||||
|
||||
if !testing.Short() {
|
||||
attempts = 5
|
||||
timeout = time.Minute // 5s-10s each reconnection
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
env := testenv.Must(t)
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent, 1)
|
||||
|
||||
s := NewWithIntents(env.BotToken, gateway.IntentGuilds)
|
||||
s.AddHandler(readyCh)
|
||||
|
||||
for i := 0; i < attempts; i++ {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case ready := <-readyCh:
|
||||
now := time.Now()
|
||||
t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username)
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Error("ready not received")
|
||||
}
|
||||
}()
|
||||
|
||||
if err := s.Connect(ctx); err != nil {
|
||||
t.Fatal("failed to open:", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
// Hold for an additional one second.
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue