1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-02-08 20:47:13 +00:00

Gateway: Added a retry limit

State: Event handlers now handle all of Ready's Guilds field
Session: Added Wait, which blocks until SIGINT or Gateway error
This commit is contained in:
diamondburned (Forefront) 2020-02-29 18:13:58 -08:00
parent 220eb5ff42
commit f0102d765f
11 changed files with 167 additions and 56 deletions

View file

@ -17,7 +17,7 @@ func main() {
commands := &Bot{}
stop, err := bot.Start(token, commands, func(ctx *bot.Context) error {
wait, err := bot.Start(token, commands, func(ctx *bot.Context) error {
ctx.Prefix = "!"
// Subcommand demo, but this can be in another package.
@ -30,10 +30,13 @@ func main() {
log.Fatalln(err)
}
defer stop()
log.Println("Bot started")
// Automatically block until SIGINT.
bot.Wait()
// As of this commit, wait() will block until SIGINT or fatal. The past
// versions close on call, but this one will block.
// If for some reason you want the Cancel() function, manually make a new
// context.
if err := wait(); err != nil {
log.Fatalln("Gateway fatal error:", err)
}
}

View file

@ -4,7 +4,6 @@ import (
"log"
"os"
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/session"
)
@ -39,6 +38,8 @@ func main() {
log.Println("Started as", u.Username)
// Block until SIGINT. Optional.
bot.Wait()
// Block until a fatal error or SIGINT.
if err := s.Wait(); err != nil {
log.Fatalln("Gateway fatal error:", err)
}
}

View file

@ -4,7 +4,6 @@ import (
"log"
"os"
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/handler"
"github.com/diamondburned/arikawa/state"
@ -49,6 +48,8 @@ func main() {
log.Println("Started as", u.Username)
// Block until SIGINT. Optional.
bot.Wait()
// Block until a fatal error or SIGINT.
if err := s.Wait(); err != nil {
log.Fatalln("Gateway fatal error:", err)
}
}

View file

@ -88,7 +88,7 @@ type Context struct {
// Start quickly starts a bot with the given command. It will prepend "Bot"
// into the token automatically. Refer to example/ for usage.
func Start(token string, cmd interface{},
opts func(*Context) error) (stop func() error, err error) {
opts func(*Context) error) (wait func() error, err error) {
s, err := state.New("Bot " + token)
if err != nil {
@ -118,11 +118,11 @@ func Start(token string, cmd interface{},
return func() error {
cancel()
return s.Close()
return s.Wait()
}, nil
}
// Wait is a convenient function that blocks until a SIGINT is sent.
// Wait is deprecated. Use (*Context).Wait().
func Wait() {
sigs := make(chan os.Signal)
signal.Notify(sigs, os.Interrupt)
@ -170,6 +170,12 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
return ctx, nil
}
// Wait blocks until either the Gateway fatally exits or a SIGINT is received.
// Check the Gateway documentation for more information.
func (ctx *Context) Wait() error {
return ctx.Session.Wait()
}
func (ctx *Context) Subcommands() []*Subcommand {
// Getter is not useless, refer to the struct doc for reason.
return ctx.subcommands

View file

@ -26,7 +26,7 @@ func (g *Gateway) Identify() error {
return errors.Wrap(err, "Can't wait for identify()")
}
return g.Send(IdentifyOP, g.Identifier)
return g.send(false, IdentifyOP, g.Identifier)
}
type ResumeData struct {
@ -47,7 +47,7 @@ func (g *Gateway) Resume() error {
return ErrMissingForResume
}
return g.Send(ResumeOP, ResumeData{
return g.send(false, ResumeOP, ResumeData{
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,
@ -58,6 +58,8 @@ func (g *Gateway) Resume() error {
type HeartbeatData int
func (g *Gateway) Heartbeat() error {
g.available.RLock()
defer g.available.RUnlock()
return g.Send(HeartbeatOP, g.Sequence.Get())
}

View file

@ -11,7 +11,6 @@ import (
"context"
"log"
"net/url"
"runtime"
"sync"
"time"
@ -45,6 +44,8 @@ var (
// WSExtraReadTimeout is the duration to be added to Hello, as a read
// timeout for the websocket.
WSExtraReadTimeout = time.Second
// WSRetries controls the number of Reconnects before erroring out.
WSRetries = 3
WSDebug = func(v ...interface{}) {}
)
@ -64,13 +65,6 @@ func GatewayURL() (string, error) {
&Gateway, "GET", EndpointGateway)
}
// Identity is used as the default identity when initializing a new Gateway.
var Identity = IdentifyProperties{
OS: runtime.GOOS,
Browser: "Arikawa",
Device: "Arikawa",
}
type Gateway struct {
WS *wsutil.Websocket
json.Driver
@ -91,7 +85,11 @@ type Gateway struct {
Sequence *Sequence
ErrorLog func(err error) // default to log.Println
FatalLog func(err error) // called when the WS can't reconnect and resume
// FatalError is where Reconnect errors will go to. When an error is sent
// here, the Gateway is already dead. This channel is buffered once.
FatalError <-chan error
fatalError chan error
// Only use for debugging
@ -99,6 +97,11 @@ type Gateway struct {
// here. This should be buffered, so to not block the main loop.
OP chan *OP
// Mutex to hold off calls when the WS is not available. Doesn't block if
// Start() is not called or Close() is called. Also doesn't block for
// Identify or Resume.
available sync.RWMutex
// Filled by methods, internal use
paceDeath chan error
waitGroup *sync.WaitGroup
@ -124,8 +127,9 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
Identifier: DefaultIdentifier(token),
Sequence: NewSequence(),
ErrorLog: WSError,
FatalLog: WSFatal,
fatalError: make(chan error, 1),
}
g.FatalError = g.fatalError
// Parameters for the gateway
param := url.Values{}
@ -170,6 +174,9 @@ func (g *Gateway) Close() error {
func (g *Gateway) Reconnect() error {
WSDebug("Reconnecting...")
g.available.Lock()
defer g.available.Unlock()
// If the event loop is not dead:
if g.paceDeath != nil {
WSDebug("Gateway is not closed, closing before reconnecting...")
@ -177,7 +184,7 @@ func (g *Gateway) Reconnect() error {
WSDebug("Gateway is closed asynchronously. Goroutine may not be exited.")
}
for i := 0; ; i++ {
for i := 0; i < WSRetries; i++ {
WSDebug("Trying to dial, attempt", i)
// Condition: err == ErrInvalidSession:
@ -190,10 +197,10 @@ func (g *Gateway) Reconnect() error {
}
WSDebug("Started after attempt:", i)
break
return nil
}
return nil
return ErrWSMaxTries
}
func (g *Gateway) Open() error {
@ -218,6 +225,9 @@ func (g *Gateway) Open() error {
// Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block.
func (g *Gateway) Start() error {
g.available.Lock()
defer g.available.Unlock()
if err := g.start(); err != nil {
WSDebug("Start failed:", err)
if err := g.Close(); err != nil {
@ -228,6 +238,12 @@ func (g *Gateway) Start() error {
return nil
}
// Wait blocks until the Gateway fatally exits when it couldn't reconnect
// anymore. To use this withh other channels, check out g.FatalError.
func (g *Gateway) Wait() error {
return <-g.FatalError
}
func (g *Gateway) start() error {
// This is where we'll get our events
ch := g.WS.Listen()
@ -291,10 +307,9 @@ func (g *Gateway) handleWS() {
g.waitGroup.Done()
if err != nil {
if err := g.Reconnect(); err != nil {
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
}
g.ErrorLog(err)
g.fatalError <- errors.Wrap(g.Reconnect(), "Failed to reconnect")
// Reconnect should spawn another eventLoop in its Start function.
}
}
@ -319,8 +334,7 @@ func (g *Gateway) eventLoop() error {
case ev := <-ch:
// Check for error
if ev.Error != nil {
g.ErrorLog(ev.Error)
continue
return ev.Error
}
if len(ev.Data) == 0 {
@ -336,6 +350,10 @@ func (g *Gateway) eventLoop() error {
}
func (g *Gateway) Send(code OPCode, v interface{}) error {
return g.send(true, code, v)
}
func (g *Gateway) send(lock bool, code OPCode, v interface{}) error {
var op = OP{
Code: code,
}
@ -357,5 +375,10 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
if lock {
g.available.RLock()
defer g.available.RUnlock()
}
return g.WS.Send(ctx, b)
}

View file

@ -2,12 +2,23 @@ package gateway
import (
"context"
"runtime"
"time"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
// Identity is used as the default identity when initializing a new Gateway.
var Identity = IdentifyProperties{
OS: runtime.GOOS,
Browser: "Arikawa",
Device: "Arikawa",
}
// Presence is used as the default presence when initializing a new Gateway.
var Presence *UpdateStatusData
type IdentifyProperties struct {
// Required
OS string `json:"os"` // GOOS
@ -71,6 +82,7 @@ func DefaultIdentifier(token string) *Identifier {
Token: token,
Properties: Identity,
Shard: DefaultShard(),
Presence: Presence,
Compress: true,
LargeThreshold: 50,

View file

@ -8,8 +8,8 @@ type ReadyEvent struct {
User discord.User `json:"user"`
SessionID string `json:"session_id"`
PrivateChannels []discord.Channel `json:"private_channels"`
Guilds []discord.Guild `json:"guilds"`
PrivateChannels []discord.Channel `json:"private_channels"`
Guilds []GuildCreateEvent `json:"guilds"`
Shard *Shard `json:"shard"`

View file

@ -4,6 +4,9 @@
package session
import (
"os"
"os/signal"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/handler"
@ -108,10 +111,29 @@ func (s *Session) startHandler(stop <-chan struct{}) {
func (s *Session) Close() error {
// Stop the event handler
if s.hstop != nil {
close(s.hstop)
}
s.close()
// Close the websocket
return s.Gateway.Close()
}
// Wait blocks until either a SIGINT or a Gateway fatal error is received.
// Check the Gateway documentation for more information.
func (s *Session) Wait() error {
sigint := make(chan os.Signal)
signal.Notify(sigint, os.Interrupt)
select {
case <-sigint:
return s.Close()
case err := <-s.Gateway.FatalError:
s.close()
return err
}
}
func (s *Session) close() {
if s.hstop != nil {
close(s.hstop)
}
}

View file

@ -435,22 +435,38 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
////
func (s *State) Presence(
guildID, userID discord.Snowflake) (*discord.Presence, error) {
// Presence checks the state for user presences. If no guildID is given, it will
// look for the presence in all guilds.
func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence, error) {
p, err := s.Store.Presence(guildID, userID)
if err == nil {
return p, nil
}
return s.Store.Presence(guildID, userID)
// If there's no guild ID, look in all guilds
if !guildID.Valid() {
g, err := s.Guilds()
if err != nil {
return nil, err
}
for _, g := range g {
if p, err := s.Store.Presence(g.ID, userID); err == nil {
return p, nil
}
}
}
return nil, err
}
func (s *State) Presences(
guildID discord.Snowflake) ([]discord.Presence, error) {
func (s *State) Presences(guildID discord.Snowflake) ([]discord.Presence, error) {
return s.Store.Presences(guildID)
}
////
func (s *State) Role(
guildID, roleID discord.Snowflake) (*discord.Role, error) {
func (s *State) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) {
r, err := s.Store.Role(guildID, roleID)
if err == nil {

View file

@ -26,20 +26,45 @@ func (s *State) onEvent(iface interface{}) {
// Set Ready to the state
s.Ready = *ev
// Handle guilds
for _, g := range ev.Guilds {
g := g
// Handle presences
for _, p := range ev.Presences {
p := p
if err := s.Store.GuildSet(&g); err != nil {
s.stateErr(err, "Failed to set guild in state")
if err := s.Store.PresenceSet(0, &p); err != nil {
s.stateErr(err, "Failed to set global presence")
}
}
// Handle guilds
for i := range ev.Guilds {
guild := ev.Guilds[i]
if err := s.Store.GuildSet(&guild.Guild); err != nil {
s.stateErr(err, "Failed to set guild in Ready")
}
for i := range guild.Members {
if err := s.Store.MemberSet(guild.ID, &guild.Members[i]); err != nil {
s.stateErr(err, "Failed to set guild member in Ready")
}
}
for i := range guild.Channels {
if err := s.Store.ChannelSet(&guild.Channels[i]); err != nil {
s.stateErr(err, "Failed to set guild channel in Ready")
}
}
for i := range guild.Presences {
if err := s.Store.PresenceSet(guild.ID, &guild.Presences[i]); err != nil {
s.stateErr(err, "Failed to set guild presence in Ready")
}
}
}
// Handle private channels
for _, ch := range ev.PrivateChannels {
ch := ch
if err := s.Store.ChannelSet(&ch); err != nil {
for i := range ev.PrivateChannels {
if err := s.Store.ChannelSet(&ev.PrivateChannels[i]); err != nil {
s.stateErr(err, "Failed to set channel in state")
}
}