From f0102d765f0993867408cea10f7ff774a41fd0e6 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sat, 29 Feb 2020 18:13:58 -0800 Subject: [PATCH] Gateway: Added a retry limit State: Event handlers now handle all of Ready's Guilds field Session: Added Wait, which blocks until SIGINT or Gateway error --- _example/advanced_bot/main.go | 13 +++++--- _example/simple/main.go | 7 +++-- _example/undeleter/main.go | 7 +++-- bot/ctx.go | 12 +++++-- gateway/commands.go | 6 ++-- gateway/gateway.go | 59 ++++++++++++++++++++++++----------- gateway/identify.go | 12 +++++++ gateway/ready.go | 4 +-- session/session.go | 28 +++++++++++++++-- state/state.go | 32 ++++++++++++++----- state/state_events.go | 43 +++++++++++++++++++------ 11 files changed, 167 insertions(+), 56 deletions(-) diff --git a/_example/advanced_bot/main.go b/_example/advanced_bot/main.go index 35b887b..937e426 100644 --- a/_example/advanced_bot/main.go +++ b/_example/advanced_bot/main.go @@ -17,7 +17,7 @@ func main() { commands := &Bot{} - stop, err := bot.Start(token, commands, func(ctx *bot.Context) error { + wait, err := bot.Start(token, commands, func(ctx *bot.Context) error { ctx.Prefix = "!" // Subcommand demo, but this can be in another package. @@ -30,10 +30,13 @@ func main() { log.Fatalln(err) } - defer stop() - log.Println("Bot started") - // Automatically block until SIGINT. - bot.Wait() + // As of this commit, wait() will block until SIGINT or fatal. The past + // versions close on call, but this one will block. + // If for some reason you want the Cancel() function, manually make a new + // context. + if err := wait(); err != nil { + log.Fatalln("Gateway fatal error:", err) + } } diff --git a/_example/simple/main.go b/_example/simple/main.go index 2d7bb75..0a32044 100644 --- a/_example/simple/main.go +++ b/_example/simple/main.go @@ -4,7 +4,6 @@ import ( "log" "os" - "github.com/diamondburned/arikawa/bot" "github.com/diamondburned/arikawa/gateway" "github.com/diamondburned/arikawa/session" ) @@ -39,6 +38,8 @@ func main() { log.Println("Started as", u.Username) - // Block until SIGINT. Optional. - bot.Wait() + // Block until a fatal error or SIGINT. + if err := s.Wait(); err != nil { + log.Fatalln("Gateway fatal error:", err) + } } diff --git a/_example/undeleter/main.go b/_example/undeleter/main.go index 2db6ced..b268ff2 100644 --- a/_example/undeleter/main.go +++ b/_example/undeleter/main.go @@ -4,7 +4,6 @@ import ( "log" "os" - "github.com/diamondburned/arikawa/bot" "github.com/diamondburned/arikawa/gateway" "github.com/diamondburned/arikawa/handler" "github.com/diamondburned/arikawa/state" @@ -49,6 +48,8 @@ func main() { log.Println("Started as", u.Username) - // Block until SIGINT. Optional. - bot.Wait() + // Block until a fatal error or SIGINT. + if err := s.Wait(); err != nil { + log.Fatalln("Gateway fatal error:", err) + } } diff --git a/bot/ctx.go b/bot/ctx.go index 71e609e..398c8fc 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -88,7 +88,7 @@ type Context struct { // Start quickly starts a bot with the given command. It will prepend "Bot" // into the token automatically. Refer to example/ for usage. func Start(token string, cmd interface{}, - opts func(*Context) error) (stop func() error, err error) { + opts func(*Context) error) (wait func() error, err error) { s, err := state.New("Bot " + token) if err != nil { @@ -118,11 +118,11 @@ func Start(token string, cmd interface{}, return func() error { cancel() - return s.Close() + return s.Wait() }, nil } -// Wait is a convenient function that blocks until a SIGINT is sent. +// Wait is deprecated. Use (*Context).Wait(). func Wait() { sigs := make(chan os.Signal) signal.Notify(sigs, os.Interrupt) @@ -170,6 +170,12 @@ func New(s *state.State, cmd interface{}) (*Context, error) { return ctx, nil } +// Wait blocks until either the Gateway fatally exits or a SIGINT is received. +// Check the Gateway documentation for more information. +func (ctx *Context) Wait() error { + return ctx.Session.Wait() +} + func (ctx *Context) Subcommands() []*Subcommand { // Getter is not useless, refer to the struct doc for reason. return ctx.subcommands diff --git a/gateway/commands.go b/gateway/commands.go index 1a19e73..5b30f3b 100644 --- a/gateway/commands.go +++ b/gateway/commands.go @@ -26,7 +26,7 @@ func (g *Gateway) Identify() error { return errors.Wrap(err, "Can't wait for identify()") } - return g.Send(IdentifyOP, g.Identifier) + return g.send(false, IdentifyOP, g.Identifier) } type ResumeData struct { @@ -47,7 +47,7 @@ func (g *Gateway) Resume() error { return ErrMissingForResume } - return g.Send(ResumeOP, ResumeData{ + return g.send(false, ResumeOP, ResumeData{ Token: g.Identifier.Token, SessionID: ses, Sequence: seq, @@ -58,6 +58,8 @@ func (g *Gateway) Resume() error { type HeartbeatData int func (g *Gateway) Heartbeat() error { + g.available.RLock() + defer g.available.RUnlock() return g.Send(HeartbeatOP, g.Sequence.Get()) } diff --git a/gateway/gateway.go b/gateway/gateway.go index 7cd957a..bb05e6a 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -11,7 +11,6 @@ import ( "context" "log" "net/url" - "runtime" "sync" "time" @@ -45,6 +44,8 @@ var ( // WSExtraReadTimeout is the duration to be added to Hello, as a read // timeout for the websocket. WSExtraReadTimeout = time.Second + // WSRetries controls the number of Reconnects before erroring out. + WSRetries = 3 WSDebug = func(v ...interface{}) {} ) @@ -64,13 +65,6 @@ func GatewayURL() (string, error) { &Gateway, "GET", EndpointGateway) } -// Identity is used as the default identity when initializing a new Gateway. -var Identity = IdentifyProperties{ - OS: runtime.GOOS, - Browser: "Arikawa", - Device: "Arikawa", -} - type Gateway struct { WS *wsutil.Websocket json.Driver @@ -91,7 +85,11 @@ type Gateway struct { Sequence *Sequence ErrorLog func(err error) // default to log.Println - FatalLog func(err error) // called when the WS can't reconnect and resume + + // FatalError is where Reconnect errors will go to. When an error is sent + // here, the Gateway is already dead. This channel is buffered once. + FatalError <-chan error + fatalError chan error // Only use for debugging @@ -99,6 +97,11 @@ type Gateway struct { // here. This should be buffered, so to not block the main loop. OP chan *OP + // Mutex to hold off calls when the WS is not available. Doesn't block if + // Start() is not called or Close() is called. Also doesn't block for + // Identify or Resume. + available sync.RWMutex + // Filled by methods, internal use paceDeath chan error waitGroup *sync.WaitGroup @@ -124,8 +127,9 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { Identifier: DefaultIdentifier(token), Sequence: NewSequence(), ErrorLog: WSError, - FatalLog: WSFatal, + fatalError: make(chan error, 1), } + g.FatalError = g.fatalError // Parameters for the gateway param := url.Values{} @@ -170,6 +174,9 @@ func (g *Gateway) Close() error { func (g *Gateway) Reconnect() error { WSDebug("Reconnecting...") + g.available.Lock() + defer g.available.Unlock() + // If the event loop is not dead: if g.paceDeath != nil { WSDebug("Gateway is not closed, closing before reconnecting...") @@ -177,7 +184,7 @@ func (g *Gateway) Reconnect() error { WSDebug("Gateway is closed asynchronously. Goroutine may not be exited.") } - for i := 0; ; i++ { + for i := 0; i < WSRetries; i++ { WSDebug("Trying to dial, attempt", i) // Condition: err == ErrInvalidSession: @@ -190,10 +197,10 @@ func (g *Gateway) Reconnect() error { } WSDebug("Started after attempt:", i) - break + return nil } - return nil + return ErrWSMaxTries } func (g *Gateway) Open() error { @@ -218,6 +225,9 @@ func (g *Gateway) Open() error { // Start authenticates with the websocket, or resume from a dead Websocket // connection. This function doesn't block. func (g *Gateway) Start() error { + g.available.Lock() + defer g.available.Unlock() + if err := g.start(); err != nil { WSDebug("Start failed:", err) if err := g.Close(); err != nil { @@ -228,6 +238,12 @@ func (g *Gateway) Start() error { return nil } +// Wait blocks until the Gateway fatally exits when it couldn't reconnect +// anymore. To use this withh other channels, check out g.FatalError. +func (g *Gateway) Wait() error { + return <-g.FatalError +} + func (g *Gateway) start() error { // This is where we'll get our events ch := g.WS.Listen() @@ -291,10 +307,9 @@ func (g *Gateway) handleWS() { g.waitGroup.Done() if err != nil { - if err := g.Reconnect(); err != nil { - g.FatalLog(errors.Wrap(err, "Failed to reconnect")) - } + g.ErrorLog(err) + g.fatalError <- errors.Wrap(g.Reconnect(), "Failed to reconnect") // Reconnect should spawn another eventLoop in its Start function. } } @@ -319,8 +334,7 @@ func (g *Gateway) eventLoop() error { case ev := <-ch: // Check for error if ev.Error != nil { - g.ErrorLog(ev.Error) - continue + return ev.Error } if len(ev.Data) == 0 { @@ -336,6 +350,10 @@ func (g *Gateway) eventLoop() error { } func (g *Gateway) Send(code OPCode, v interface{}) error { + return g.send(true, code, v) +} + +func (g *Gateway) send(lock bool, code OPCode, v interface{}) error { var op = OP{ Code: code, } @@ -357,5 +375,10 @@ func (g *Gateway) Send(code OPCode, v interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) defer cancel() + if lock { + g.available.RLock() + defer g.available.RUnlock() + } + return g.WS.Send(ctx, b) } diff --git a/gateway/identify.go b/gateway/identify.go index dcc1013..72fa406 100644 --- a/gateway/identify.go +++ b/gateway/identify.go @@ -2,12 +2,23 @@ package gateway import ( "context" + "runtime" "time" "github.com/pkg/errors" "golang.org/x/time/rate" ) +// Identity is used as the default identity when initializing a new Gateway. +var Identity = IdentifyProperties{ + OS: runtime.GOOS, + Browser: "Arikawa", + Device: "Arikawa", +} + +// Presence is used as the default presence when initializing a new Gateway. +var Presence *UpdateStatusData + type IdentifyProperties struct { // Required OS string `json:"os"` // GOOS @@ -71,6 +82,7 @@ func DefaultIdentifier(token string) *Identifier { Token: token, Properties: Identity, Shard: DefaultShard(), + Presence: Presence, Compress: true, LargeThreshold: 50, diff --git a/gateway/ready.go b/gateway/ready.go index 29cf9fc..680af8a 100644 --- a/gateway/ready.go +++ b/gateway/ready.go @@ -8,8 +8,8 @@ type ReadyEvent struct { User discord.User `json:"user"` SessionID string `json:"session_id"` - PrivateChannels []discord.Channel `json:"private_channels"` - Guilds []discord.Guild `json:"guilds"` + PrivateChannels []discord.Channel `json:"private_channels"` + Guilds []GuildCreateEvent `json:"guilds"` Shard *Shard `json:"shard"` diff --git a/session/session.go b/session/session.go index 89a2804..7c21d22 100644 --- a/session/session.go +++ b/session/session.go @@ -4,6 +4,9 @@ package session import ( + "os" + "os/signal" + "github.com/diamondburned/arikawa/api" "github.com/diamondburned/arikawa/gateway" "github.com/diamondburned/arikawa/handler" @@ -108,10 +111,29 @@ func (s *Session) startHandler(stop <-chan struct{}) { func (s *Session) Close() error { // Stop the event handler - if s.hstop != nil { - close(s.hstop) - } + s.close() // Close the websocket return s.Gateway.Close() } + +// Wait blocks until either a SIGINT or a Gateway fatal error is received. +// Check the Gateway documentation for more information. +func (s *Session) Wait() error { + sigint := make(chan os.Signal) + signal.Notify(sigint, os.Interrupt) + + select { + case <-sigint: + return s.Close() + case err := <-s.Gateway.FatalError: + s.close() + return err + } +} + +func (s *Session) close() { + if s.hstop != nil { + close(s.hstop) + } +} diff --git a/state/state.go b/state/state.go index a62b129..1594c8d 100644 --- a/state/state.go +++ b/state/state.go @@ -435,22 +435,38 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) //// -func (s *State) Presence( - guildID, userID discord.Snowflake) (*discord.Presence, error) { +// Presence checks the state for user presences. If no guildID is given, it will +// look for the presence in all guilds. +func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence, error) { + p, err := s.Store.Presence(guildID, userID) + if err == nil { + return p, nil + } - return s.Store.Presence(guildID, userID) + // If there's no guild ID, look in all guilds + if !guildID.Valid() { + g, err := s.Guilds() + if err != nil { + return nil, err + } + + for _, g := range g { + if p, err := s.Store.Presence(g.ID, userID); err == nil { + return p, nil + } + } + } + + return nil, err } -func (s *State) Presences( - guildID discord.Snowflake) ([]discord.Presence, error) { - +func (s *State) Presences(guildID discord.Snowflake) ([]discord.Presence, error) { return s.Store.Presences(guildID) } //// -func (s *State) Role( - guildID, roleID discord.Snowflake) (*discord.Role, error) { +func (s *State) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) { r, err := s.Store.Role(guildID, roleID) if err == nil { diff --git a/state/state_events.go b/state/state_events.go index 3bef83f..d8c991e 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -26,20 +26,45 @@ func (s *State) onEvent(iface interface{}) { // Set Ready to the state s.Ready = *ev - // Handle guilds - for _, g := range ev.Guilds { - g := g + // Handle presences + for _, p := range ev.Presences { + p := p - if err := s.Store.GuildSet(&g); err != nil { - s.stateErr(err, "Failed to set guild in state") + if err := s.Store.PresenceSet(0, &p); err != nil { + s.stateErr(err, "Failed to set global presence") + } + } + + // Handle guilds + for i := range ev.Guilds { + guild := ev.Guilds[i] + + if err := s.Store.GuildSet(&guild.Guild); err != nil { + s.stateErr(err, "Failed to set guild in Ready") + } + + for i := range guild.Members { + if err := s.Store.MemberSet(guild.ID, &guild.Members[i]); err != nil { + s.stateErr(err, "Failed to set guild member in Ready") + } + } + + for i := range guild.Channels { + if err := s.Store.ChannelSet(&guild.Channels[i]); err != nil { + s.stateErr(err, "Failed to set guild channel in Ready") + } + } + + for i := range guild.Presences { + if err := s.Store.PresenceSet(guild.ID, &guild.Presences[i]); err != nil { + s.stateErr(err, "Failed to set guild presence in Ready") + } } } // Handle private channels - for _, ch := range ev.PrivateChannels { - ch := ch - - if err := s.Store.ChannelSet(&ch); err != nil { + for i := range ev.PrivateChannels { + if err := s.Store.ChannelSet(&ev.PrivateChannels[i]); err != nil { s.stateErr(err, "Failed to set channel in state") } }