mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-02-08 20:47:13 +00:00
Gateway: Added a retry limit
State: Event handlers now handle all of Ready's Guilds field Session: Added Wait, which blocks until SIGINT or Gateway error
This commit is contained in:
parent
220eb5ff42
commit
f0102d765f
|
@ -17,7 +17,7 @@ func main() {
|
|||
|
||||
commands := &Bot{}
|
||||
|
||||
stop, err := bot.Start(token, commands, func(ctx *bot.Context) error {
|
||||
wait, err := bot.Start(token, commands, func(ctx *bot.Context) error {
|
||||
ctx.Prefix = "!"
|
||||
|
||||
// Subcommand demo, but this can be in another package.
|
||||
|
@ -30,10 +30,13 @@ func main() {
|
|||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
defer stop()
|
||||
|
||||
log.Println("Bot started")
|
||||
|
||||
// Automatically block until SIGINT.
|
||||
bot.Wait()
|
||||
// As of this commit, wait() will block until SIGINT or fatal. The past
|
||||
// versions close on call, but this one will block.
|
||||
// If for some reason you want the Cancel() function, manually make a new
|
||||
// context.
|
||||
if err := wait(); err != nil {
|
||||
log.Fatalln("Gateway fatal error:", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/diamondburned/arikawa/bot"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/session"
|
||||
)
|
||||
|
@ -39,6 +38,8 @@ func main() {
|
|||
|
||||
log.Println("Started as", u.Username)
|
||||
|
||||
// Block until SIGINT. Optional.
|
||||
bot.Wait()
|
||||
// Block until a fatal error or SIGINT.
|
||||
if err := s.Wait(); err != nil {
|
||||
log.Fatalln("Gateway fatal error:", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/diamondburned/arikawa/bot"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/handler"
|
||||
"github.com/diamondburned/arikawa/state"
|
||||
|
@ -49,6 +48,8 @@ func main() {
|
|||
|
||||
log.Println("Started as", u.Username)
|
||||
|
||||
// Block until SIGINT. Optional.
|
||||
bot.Wait()
|
||||
// Block until a fatal error or SIGINT.
|
||||
if err := s.Wait(); err != nil {
|
||||
log.Fatalln("Gateway fatal error:", err)
|
||||
}
|
||||
}
|
||||
|
|
12
bot/ctx.go
12
bot/ctx.go
|
@ -88,7 +88,7 @@ type Context struct {
|
|||
// Start quickly starts a bot with the given command. It will prepend "Bot"
|
||||
// into the token automatically. Refer to example/ for usage.
|
||||
func Start(token string, cmd interface{},
|
||||
opts func(*Context) error) (stop func() error, err error) {
|
||||
opts func(*Context) error) (wait func() error, err error) {
|
||||
|
||||
s, err := state.New("Bot " + token)
|
||||
if err != nil {
|
||||
|
@ -118,11 +118,11 @@ func Start(token string, cmd interface{},
|
|||
|
||||
return func() error {
|
||||
cancel()
|
||||
return s.Close()
|
||||
return s.Wait()
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Wait is a convenient function that blocks until a SIGINT is sent.
|
||||
// Wait is deprecated. Use (*Context).Wait().
|
||||
func Wait() {
|
||||
sigs := make(chan os.Signal)
|
||||
signal.Notify(sigs, os.Interrupt)
|
||||
|
@ -170,6 +170,12 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
|
|||
return ctx, nil
|
||||
}
|
||||
|
||||
// Wait blocks until either the Gateway fatally exits or a SIGINT is received.
|
||||
// Check the Gateway documentation for more information.
|
||||
func (ctx *Context) Wait() error {
|
||||
return ctx.Session.Wait()
|
||||
}
|
||||
|
||||
func (ctx *Context) Subcommands() []*Subcommand {
|
||||
// Getter is not useless, refer to the struct doc for reason.
|
||||
return ctx.subcommands
|
||||
|
|
|
@ -26,7 +26,7 @@ func (g *Gateway) Identify() error {
|
|||
return errors.Wrap(err, "Can't wait for identify()")
|
||||
}
|
||||
|
||||
return g.Send(IdentifyOP, g.Identifier)
|
||||
return g.send(false, IdentifyOP, g.Identifier)
|
||||
}
|
||||
|
||||
type ResumeData struct {
|
||||
|
@ -47,7 +47,7 @@ func (g *Gateway) Resume() error {
|
|||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
return g.Send(ResumeOP, ResumeData{
|
||||
return g.send(false, ResumeOP, ResumeData{
|
||||
Token: g.Identifier.Token,
|
||||
SessionID: ses,
|
||||
Sequence: seq,
|
||||
|
@ -58,6 +58,8 @@ func (g *Gateway) Resume() error {
|
|||
type HeartbeatData int
|
||||
|
||||
func (g *Gateway) Heartbeat() error {
|
||||
g.available.RLock()
|
||||
defer g.available.RUnlock()
|
||||
return g.Send(HeartbeatOP, g.Sequence.Get())
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"context"
|
||||
"log"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -45,6 +44,8 @@ var (
|
|||
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
||||
// timeout for the websocket.
|
||||
WSExtraReadTimeout = time.Second
|
||||
// WSRetries controls the number of Reconnects before erroring out.
|
||||
WSRetries = 3
|
||||
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
)
|
||||
|
@ -64,13 +65,6 @@ func GatewayURL() (string, error) {
|
|||
&Gateway, "GET", EndpointGateway)
|
||||
}
|
||||
|
||||
// Identity is used as the default identity when initializing a new Gateway.
|
||||
var Identity = IdentifyProperties{
|
||||
OS: runtime.GOOS,
|
||||
Browser: "Arikawa",
|
||||
Device: "Arikawa",
|
||||
}
|
||||
|
||||
type Gateway struct {
|
||||
WS *wsutil.Websocket
|
||||
json.Driver
|
||||
|
@ -91,7 +85,11 @@ type Gateway struct {
|
|||
Sequence *Sequence
|
||||
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
FatalLog func(err error) // called when the WS can't reconnect and resume
|
||||
|
||||
// FatalError is where Reconnect errors will go to. When an error is sent
|
||||
// here, the Gateway is already dead. This channel is buffered once.
|
||||
FatalError <-chan error
|
||||
fatalError chan error
|
||||
|
||||
// Only use for debugging
|
||||
|
||||
|
@ -99,6 +97,11 @@ type Gateway struct {
|
|||
// here. This should be buffered, so to not block the main loop.
|
||||
OP chan *OP
|
||||
|
||||
// Mutex to hold off calls when the WS is not available. Doesn't block if
|
||||
// Start() is not called or Close() is called. Also doesn't block for
|
||||
// Identify or Resume.
|
||||
available sync.RWMutex
|
||||
|
||||
// Filled by methods, internal use
|
||||
paceDeath chan error
|
||||
waitGroup *sync.WaitGroup
|
||||
|
@ -124,8 +127,9 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|||
Identifier: DefaultIdentifier(token),
|
||||
Sequence: NewSequence(),
|
||||
ErrorLog: WSError,
|
||||
FatalLog: WSFatal,
|
||||
fatalError: make(chan error, 1),
|
||||
}
|
||||
g.FatalError = g.fatalError
|
||||
|
||||
// Parameters for the gateway
|
||||
param := url.Values{}
|
||||
|
@ -170,6 +174,9 @@ func (g *Gateway) Close() error {
|
|||
func (g *Gateway) Reconnect() error {
|
||||
WSDebug("Reconnecting...")
|
||||
|
||||
g.available.Lock()
|
||||
defer g.available.Unlock()
|
||||
|
||||
// If the event loop is not dead:
|
||||
if g.paceDeath != nil {
|
||||
WSDebug("Gateway is not closed, closing before reconnecting...")
|
||||
|
@ -177,7 +184,7 @@ func (g *Gateway) Reconnect() error {
|
|||
WSDebug("Gateway is closed asynchronously. Goroutine may not be exited.")
|
||||
}
|
||||
|
||||
for i := 0; ; i++ {
|
||||
for i := 0; i < WSRetries; i++ {
|
||||
WSDebug("Trying to dial, attempt", i)
|
||||
|
||||
// Condition: err == ErrInvalidSession:
|
||||
|
@ -190,10 +197,10 @@ func (g *Gateway) Reconnect() error {
|
|||
}
|
||||
|
||||
WSDebug("Started after attempt:", i)
|
||||
break
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return ErrWSMaxTries
|
||||
}
|
||||
|
||||
func (g *Gateway) Open() error {
|
||||
|
@ -218,6 +225,9 @@ func (g *Gateway) Open() error {
|
|||
// Start authenticates with the websocket, or resume from a dead Websocket
|
||||
// connection. This function doesn't block.
|
||||
func (g *Gateway) Start() error {
|
||||
g.available.Lock()
|
||||
defer g.available.Unlock()
|
||||
|
||||
if err := g.start(); err != nil {
|
||||
WSDebug("Start failed:", err)
|
||||
if err := g.Close(); err != nil {
|
||||
|
@ -228,6 +238,12 @@ func (g *Gateway) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Wait blocks until the Gateway fatally exits when it couldn't reconnect
|
||||
// anymore. To use this withh other channels, check out g.FatalError.
|
||||
func (g *Gateway) Wait() error {
|
||||
return <-g.FatalError
|
||||
}
|
||||
|
||||
func (g *Gateway) start() error {
|
||||
// This is where we'll get our events
|
||||
ch := g.WS.Listen()
|
||||
|
@ -291,10 +307,9 @@ func (g *Gateway) handleWS() {
|
|||
g.waitGroup.Done()
|
||||
|
||||
if err != nil {
|
||||
if err := g.Reconnect(); err != nil {
|
||||
g.FatalLog(errors.Wrap(err, "Failed to reconnect"))
|
||||
}
|
||||
g.ErrorLog(err)
|
||||
|
||||
g.fatalError <- errors.Wrap(g.Reconnect(), "Failed to reconnect")
|
||||
// Reconnect should spawn another eventLoop in its Start function.
|
||||
}
|
||||
}
|
||||
|
@ -319,8 +334,7 @@ func (g *Gateway) eventLoop() error {
|
|||
case ev := <-ch:
|
||||
// Check for error
|
||||
if ev.Error != nil {
|
||||
g.ErrorLog(ev.Error)
|
||||
continue
|
||||
return ev.Error
|
||||
}
|
||||
|
||||
if len(ev.Data) == 0 {
|
||||
|
@ -336,6 +350,10 @@ func (g *Gateway) eventLoop() error {
|
|||
}
|
||||
|
||||
func (g *Gateway) Send(code OPCode, v interface{}) error {
|
||||
return g.send(true, code, v)
|
||||
}
|
||||
|
||||
func (g *Gateway) send(lock bool, code OPCode, v interface{}) error {
|
||||
var op = OP{
|
||||
Code: code,
|
||||
}
|
||||
|
@ -357,5 +375,10 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
if lock {
|
||||
g.available.RLock()
|
||||
defer g.available.RUnlock()
|
||||
}
|
||||
|
||||
return g.WS.Send(ctx, b)
|
||||
}
|
||||
|
|
|
@ -2,12 +2,23 @@ package gateway
|
|||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Identity is used as the default identity when initializing a new Gateway.
|
||||
var Identity = IdentifyProperties{
|
||||
OS: runtime.GOOS,
|
||||
Browser: "Arikawa",
|
||||
Device: "Arikawa",
|
||||
}
|
||||
|
||||
// Presence is used as the default presence when initializing a new Gateway.
|
||||
var Presence *UpdateStatusData
|
||||
|
||||
type IdentifyProperties struct {
|
||||
// Required
|
||||
OS string `json:"os"` // GOOS
|
||||
|
@ -71,6 +82,7 @@ func DefaultIdentifier(token string) *Identifier {
|
|||
Token: token,
|
||||
Properties: Identity,
|
||||
Shard: DefaultShard(),
|
||||
Presence: Presence,
|
||||
|
||||
Compress: true,
|
||||
LargeThreshold: 50,
|
||||
|
|
|
@ -8,8 +8,8 @@ type ReadyEvent struct {
|
|||
User discord.User `json:"user"`
|
||||
SessionID string `json:"session_id"`
|
||||
|
||||
PrivateChannels []discord.Channel `json:"private_channels"`
|
||||
Guilds []discord.Guild `json:"guilds"`
|
||||
PrivateChannels []discord.Channel `json:"private_channels"`
|
||||
Guilds []GuildCreateEvent `json:"guilds"`
|
||||
|
||||
Shard *Shard `json:"shard"`
|
||||
|
||||
|
|
|
@ -4,6 +4,9 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/diamondburned/arikawa/api"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/handler"
|
||||
|
@ -108,10 +111,29 @@ func (s *Session) startHandler(stop <-chan struct{}) {
|
|||
|
||||
func (s *Session) Close() error {
|
||||
// Stop the event handler
|
||||
if s.hstop != nil {
|
||||
close(s.hstop)
|
||||
}
|
||||
s.close()
|
||||
|
||||
// Close the websocket
|
||||
return s.Gateway.Close()
|
||||
}
|
||||
|
||||
// Wait blocks until either a SIGINT or a Gateway fatal error is received.
|
||||
// Check the Gateway documentation for more information.
|
||||
func (s *Session) Wait() error {
|
||||
sigint := make(chan os.Signal)
|
||||
signal.Notify(sigint, os.Interrupt)
|
||||
|
||||
select {
|
||||
case <-sigint:
|
||||
return s.Close()
|
||||
case err := <-s.Gateway.FatalError:
|
||||
s.close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) close() {
|
||||
if s.hstop != nil {
|
||||
close(s.hstop)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -435,22 +435,38 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
|
|||
|
||||
////
|
||||
|
||||
func (s *State) Presence(
|
||||
guildID, userID discord.Snowflake) (*discord.Presence, error) {
|
||||
// Presence checks the state for user presences. If no guildID is given, it will
|
||||
// look for the presence in all guilds.
|
||||
func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence, error) {
|
||||
p, err := s.Store.Presence(guildID, userID)
|
||||
if err == nil {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
return s.Store.Presence(guildID, userID)
|
||||
// If there's no guild ID, look in all guilds
|
||||
if !guildID.Valid() {
|
||||
g, err := s.Guilds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, g := range g {
|
||||
if p, err := s.Store.Presence(g.ID, userID); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *State) Presences(
|
||||
guildID discord.Snowflake) ([]discord.Presence, error) {
|
||||
|
||||
func (s *State) Presences(guildID discord.Snowflake) ([]discord.Presence, error) {
|
||||
return s.Store.Presences(guildID)
|
||||
}
|
||||
|
||||
////
|
||||
|
||||
func (s *State) Role(
|
||||
guildID, roleID discord.Snowflake) (*discord.Role, error) {
|
||||
func (s *State) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) {
|
||||
|
||||
r, err := s.Store.Role(guildID, roleID)
|
||||
if err == nil {
|
||||
|
|
|
@ -26,20 +26,45 @@ func (s *State) onEvent(iface interface{}) {
|
|||
// Set Ready to the state
|
||||
s.Ready = *ev
|
||||
|
||||
// Handle guilds
|
||||
for _, g := range ev.Guilds {
|
||||
g := g
|
||||
// Handle presences
|
||||
for _, p := range ev.Presences {
|
||||
p := p
|
||||
|
||||
if err := s.Store.GuildSet(&g); err != nil {
|
||||
s.stateErr(err, "Failed to set guild in state")
|
||||
if err := s.Store.PresenceSet(0, &p); err != nil {
|
||||
s.stateErr(err, "Failed to set global presence")
|
||||
}
|
||||
}
|
||||
|
||||
// Handle guilds
|
||||
for i := range ev.Guilds {
|
||||
guild := ev.Guilds[i]
|
||||
|
||||
if err := s.Store.GuildSet(&guild.Guild); err != nil {
|
||||
s.stateErr(err, "Failed to set guild in Ready")
|
||||
}
|
||||
|
||||
for i := range guild.Members {
|
||||
if err := s.Store.MemberSet(guild.ID, &guild.Members[i]); err != nil {
|
||||
s.stateErr(err, "Failed to set guild member in Ready")
|
||||
}
|
||||
}
|
||||
|
||||
for i := range guild.Channels {
|
||||
if err := s.Store.ChannelSet(&guild.Channels[i]); err != nil {
|
||||
s.stateErr(err, "Failed to set guild channel in Ready")
|
||||
}
|
||||
}
|
||||
|
||||
for i := range guild.Presences {
|
||||
if err := s.Store.PresenceSet(guild.ID, &guild.Presences[i]); err != nil {
|
||||
s.stateErr(err, "Failed to set guild presence in Ready")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle private channels
|
||||
for _, ch := range ev.PrivateChannels {
|
||||
ch := ch
|
||||
|
||||
if err := s.Store.ChannelSet(&ch); err != nil {
|
||||
for i := range ev.PrivateChannels {
|
||||
if err := s.Store.ChannelSet(&ev.PrivateChannels[i]); err != nil {
|
||||
s.stateErr(err, "Failed to set channel in state")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue