From cc530ce7a22c2e631564e23a87eba0d858723ea1 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Thu, 9 Apr 2020 13:49:12 -0700 Subject: [PATCH] Bot: Allow func(T), added more documentation, minor fixes --- bot/ctx.go | 15 +++++--- bot/ctx_call.go | 91 ++++++++++++++++++++++++++++++-------------- bot/subcommand.go | 40 +++++++++++++++---- gateway/commands.go | 2 - gateway/gateway.go | 7 ++-- gateway/pacemaker.go | 7 +++- utils/wsutil/conn.go | 7 +--- 7 files changed, 114 insertions(+), 55 deletions(-) diff --git a/bot/ctx.go b/bot/ctx.go index 60ffa88..18e8141 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -91,7 +91,8 @@ type Context struct { // Message Create. ErrorLogger func(error) - // ReplyError when true replies to the user the error. + // ReplyError when true replies to the user the error. This only applies to + // MessageCreate events. ReplyError bool // Subcommands contains all the registered subcommands. This is not @@ -135,6 +136,7 @@ func Start(token string, cmd interface{}, } return func() error { + // Run cancel() last to remove handlers when the context exits. defer cancel() return s.Wait() }, nil @@ -291,8 +293,11 @@ func (ctx *Context) Start() func() { return } - // Log the main error if reply is disabled. - if !ctx.ReplyError { + mc, isMessage := v.(*gateway.MessageCreateEvent) + + // Log the main error if reply is disabled or if the event isn't a + // message. + if !ctx.ReplyError || !isMessage { // Ignore trivial errors: switch err.(type) { case *ErrInvalidUsage, *ErrUnknownCommand: @@ -304,8 +309,8 @@ func (ctx *Context) Start() func() { return } - mc, ok := v.(*gateway.MessageCreateEvent) - if !ok { + // Only reply if the event is not a message. + if !isMessage { return } diff --git a/bot/ctx_call.go b/bot/ctx_call.go index e5c59c0..eb503c9 100644 --- a/bot/ctx_call.go +++ b/bot/ctx_call.go @@ -10,38 +10,44 @@ import ( "github.com/pkg/errors" ) +// NonFatal is an interface that a method can implement to ignore all errors. +// This works similarly to Break. +type NonFatal interface { + error + IgnoreError() // noop method +} + +func onlyFatal(err error) error { + if _, ok := err.(NonFatal); ok { + return nil + } + return err +} + +type _Break struct{ error } + +// implement NonFatal. +func (_Break) IgnoreError() {} + +// Break is a non-fatal error that could be returned from middlewares or +// handlers to stop the chain of execution. +// +// Middlewares are guaranteed to be executed before handlers, but the exact +// order of each are undefined. Main handlers are also guaranteed to be executed +// before all subcommands. If a main middleware cancels, no subcommand +// middlewares will be called. +// +// Break implements the NonFatal interface, which causes an error to be ignored. +var Break NonFatal = _Break{errors.New("break middleware chain, non-fatal")} + func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext { var callers []*CommandContext var middles []*CommandContext var found bool - for _, cmd := range ctx.Events { - // Check if middleware - if cmd.Flag.Is(Middleware) { - continue - } - - if cmd.event == evT { - callers = append(callers, cmd) - found = true - } - } - - if found { - // Search for middlewares with the same type: - for _, mw := range ctx.mwMethods { - if mw.event == evT { - middles = append(middles, mw) - } - } - } - - for _, sub := range ctx.subcommands { - // Reset found status - found = false - + find := func(sub *Subcommand) { for _, cmd := range sub.Events { - // Check if middleware + // Search only for callers, so skip middlewares. if cmd.Flag.Is(Middleware) { continue } @@ -52,6 +58,7 @@ func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext { } } + // Only get middlewares if we found handlers for that same event. if found { // Search for middlewares with the same type: for _, mw := range sub.mwMethods { @@ -62,6 +69,16 @@ func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext { } } + // Find the main context first. + find(ctx.Subcommand) + + for _, sub := range ctx.subcommands { + // Reset found status + found = false + // Find subcommands second. + find(sub) + } + return append(middles, callers...) } @@ -98,7 +115,10 @@ func (ctx *Context) callCmd(ev interface{}) error { for _, c := range filtered { _, err := callWith(c.value, ev) if err != nil { - ctx.ErrorLogger(err) + if err = onlyFatal(err); err != nil { + ctx.ErrorLogger(err) + } + return err } } @@ -106,7 +126,8 @@ func (ctx *Context) callCmd(ev interface{}) error { // slice, but we don't want to ignore those handlers either. if evT == typeMessageCreate { // safe assertion always - return ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent)) + err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent)) + return onlyFatal(err) } return nil @@ -402,16 +423,28 @@ func callWith( } func errorReturns(returns []reflect.Value) (interface{}, error) { - // assume first is always error, since we checked for this in parseCommands + // Handlers may return nothing. + if len(returns) == 0 { + return nil, nil + } + + // assume first return is always error, since we checked for this in + // parseCommands. v := returns[len(returns)-1].Interface() + // If the last return (error) is nil. if v == nil { + // If we only have 1 returns, that return must be the error. The error + // is nil, so nil is returned. if len(returns) == 1 { return nil, nil } + // Return the first argument as-is. The above returns[-1] check assumes + // 2 return values (T, error), meaning returns[0] is the T value. return returns[0].Interface(), nil } + // Treat the last return as an error. return nil, v.(error) } diff --git a/bot/subcommand.go b/bot/subcommand.go index 486c79c..1cf2cc5 100644 --- a/bot/subcommand.go +++ b/bot/subcommand.go @@ -32,6 +32,24 @@ var ( }() ) +// Subcommand is any form of command, which could be a top-level command or a +// subcommand. +// +// Allowed method signatures +// +// These are the acceptable function signatures that would be parsed as commands +// or events. A return type implies that return value will be ignored. +// +// func(*gateway.MessageCreateEvent, ...) (string, error) +// func(*gateway.MessageCreateEvent, ...) (*discord.Embed, error) +// func(*gateway.MessageCreateEvent, ...) (*api.SendMessageData, error) +// func(*gateway.MessageCreateEvent, ...) (T, error) +// func(*gateway.MessageCreateEvent, ...) error +// func(*gateway.MessageCreateEvent, ...) +// func() (T, error) +// func() error +// func() +// type Subcommand struct { Description string @@ -92,9 +110,6 @@ type CommandContext struct { event reflect.Type // gateway.*Event method reflect.Method - // return type - retType reflect.Type - Arguments []Argument } @@ -119,6 +134,8 @@ func (cctx *CommandContext) Usage() []string { return arguments } +// NewSubcommand is used to make a new subcommand. You usually wouldn't call +// this function, but instead use (*Context).RegisterSubcommand(). func NewSubcommand(cmd interface{}) (*Subcommand, error) { var sub = Subcommand{ command: cmd, @@ -333,14 +350,21 @@ func (sub *Subcommand) parseCommands() error { // Check number of returns: numOut := methodT.NumOut() - if numOut == 0 || numOut > 2 { + + // Returns can either be: + // Nothing - func() + // An error - func() error + // An error and something else - func() (T, error) + if numOut > 2 { continue } - // Check the last return's type: - if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) { - // Invalid, skip. - continue + // Check the last return's type if the method returns anything. + if numOut > 0 { + if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) { + // Invalid, skip. + continue + } } var command = CommandContext{ diff --git a/gateway/commands.go b/gateway/commands.go index dac4079..a0fe52e 100644 --- a/gateway/commands.go +++ b/gateway/commands.go @@ -58,8 +58,6 @@ func (g *Gateway) Resume() error { type HeartbeatData int func (g *Gateway) Heartbeat() error { - // g.available.RLock() - // defer g.available.RUnlock() return g.Send(HeartbeatOP, g.Sequence.Get()) } diff --git a/gateway/gateway.go b/gateway/gateway.go index 2d545e0..a1e134f 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -179,12 +179,11 @@ func (g *Gateway) Close() error { // would also exit our event loop. Both would be 2. g.waitGroup.Wait() - WSDebug("WaitGroup is done.") - // Mark g.waitGroup as empty: g.waitGroup = nil - // Stop the Websocket + WSDebug("WaitGroup is done. Closing the websocket.") + err := g.WS.Close() g.AfterClose(err) return err @@ -199,7 +198,7 @@ func (g *Gateway) Reconnect() error { return errors.Wrap(err, "Failed to close Gateway before reconnecting") } - for i := 0; i < WSRetries; i++ { + for i := 0; WSRetries < 0 || i < WSRetries; i++ { WSDebug("Trying to dial, attempt", i) // Condition: err == ErrInvalidSession: diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go index 593bac0..6503f2c 100644 --- a/gateway/pacemaker.go +++ b/gateway/pacemaker.go @@ -61,6 +61,9 @@ func (p *Pacemaker) Dead() bool { func (p *Pacemaker) Stop() { if p.stop != nil { p.stop <- struct{}{} + WSDebug("(*Pacemaker).stop was sent a stop signal.") + } else { + WSDebug("(*Pacemaker).stop is nil, skipping.") } } @@ -101,10 +104,10 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) { go func() { p.death <- p.start() - // Mark the pacemaker loop as done. - wg.Done() // Mark the stop channel as nil, so later Close() calls won't block forever. p.stop = nil + // Mark the pacemaker loop as done. + wg.Done() }() return p.death diff --git a/utils/wsutil/conn.go b/utils/wsutil/conn.go index 524d48f..0c1ee1d 100644 --- a/utils/wsutil/conn.go +++ b/utils/wsutil/conn.go @@ -71,9 +71,6 @@ func NewConn(driver json.Driver) *Conn { HandshakeTimeout: DefaultTimeout, EnableCompression: true, }, - events: make(chan Event), - writes: make(chan []byte), - errors: make(chan error), // zlib: zlib.NewInflator(), // buf: make([]byte, CopyBufferSize), } @@ -143,8 +140,8 @@ func (c *Conn) readLoop() { return } - // If nil bytes, then it's an incomplete payload. - if b == nil { + // If the payload length is 0, skip it. + if len(b) == 0 { continue }