mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-03-23 10:29:30 +00:00
Bot: Allow func(T), added more documentation, minor fixes
This commit is contained in:
parent
922c32c0eb
commit
cc530ce7a2
15
bot/ctx.go
15
bot/ctx.go
|
@ -91,7 +91,8 @@ type Context struct {
|
||||||
// Message Create.
|
// Message Create.
|
||||||
ErrorLogger func(error)
|
ErrorLogger func(error)
|
||||||
|
|
||||||
// ReplyError when true replies to the user the error.
|
// ReplyError when true replies to the user the error. This only applies to
|
||||||
|
// MessageCreate events.
|
||||||
ReplyError bool
|
ReplyError bool
|
||||||
|
|
||||||
// Subcommands contains all the registered subcommands. This is not
|
// Subcommands contains all the registered subcommands. This is not
|
||||||
|
@ -135,6 +136,7 @@ func Start(token string, cmd interface{},
|
||||||
}
|
}
|
||||||
|
|
||||||
return func() error {
|
return func() error {
|
||||||
|
// Run cancel() last to remove handlers when the context exits.
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return s.Wait()
|
return s.Wait()
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -291,8 +293,11 @@ func (ctx *Context) Start() func() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log the main error if reply is disabled.
|
mc, isMessage := v.(*gateway.MessageCreateEvent)
|
||||||
if !ctx.ReplyError {
|
|
||||||
|
// Log the main error if reply is disabled or if the event isn't a
|
||||||
|
// message.
|
||||||
|
if !ctx.ReplyError || !isMessage {
|
||||||
// Ignore trivial errors:
|
// Ignore trivial errors:
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case *ErrInvalidUsage, *ErrUnknownCommand:
|
case *ErrInvalidUsage, *ErrUnknownCommand:
|
||||||
|
@ -304,8 +309,8 @@ func (ctx *Context) Start() func() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mc, ok := v.(*gateway.MessageCreateEvent)
|
// Only reply if the event is not a message.
|
||||||
if !ok {
|
if !isMessage {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,38 +10,44 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NonFatal is an interface that a method can implement to ignore all errors.
|
||||||
|
// This works similarly to Break.
|
||||||
|
type NonFatal interface {
|
||||||
|
error
|
||||||
|
IgnoreError() // noop method
|
||||||
|
}
|
||||||
|
|
||||||
|
func onlyFatal(err error) error {
|
||||||
|
if _, ok := err.(NonFatal); ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type _Break struct{ error }
|
||||||
|
|
||||||
|
// implement NonFatal.
|
||||||
|
func (_Break) IgnoreError() {}
|
||||||
|
|
||||||
|
// Break is a non-fatal error that could be returned from middlewares or
|
||||||
|
// handlers to stop the chain of execution.
|
||||||
|
//
|
||||||
|
// Middlewares are guaranteed to be executed before handlers, but the exact
|
||||||
|
// order of each are undefined. Main handlers are also guaranteed to be executed
|
||||||
|
// before all subcommands. If a main middleware cancels, no subcommand
|
||||||
|
// middlewares will be called.
|
||||||
|
//
|
||||||
|
// Break implements the NonFatal interface, which causes an error to be ignored.
|
||||||
|
var Break NonFatal = _Break{errors.New("break middleware chain, non-fatal")}
|
||||||
|
|
||||||
func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
|
func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
|
||||||
var callers []*CommandContext
|
var callers []*CommandContext
|
||||||
var middles []*CommandContext
|
var middles []*CommandContext
|
||||||
var found bool
|
var found bool
|
||||||
|
|
||||||
for _, cmd := range ctx.Events {
|
find := func(sub *Subcommand) {
|
||||||
// Check if middleware
|
|
||||||
if cmd.Flag.Is(Middleware) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.event == evT {
|
|
||||||
callers = append(callers, cmd)
|
|
||||||
found = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
// Search for middlewares with the same type:
|
|
||||||
for _, mw := range ctx.mwMethods {
|
|
||||||
if mw.event == evT {
|
|
||||||
middles = append(middles, mw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sub := range ctx.subcommands {
|
|
||||||
// Reset found status
|
|
||||||
found = false
|
|
||||||
|
|
||||||
for _, cmd := range sub.Events {
|
for _, cmd := range sub.Events {
|
||||||
// Check if middleware
|
// Search only for callers, so skip middlewares.
|
||||||
if cmd.Flag.Is(Middleware) {
|
if cmd.Flag.Is(Middleware) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -52,6 +58,7 @@ func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only get middlewares if we found handlers for that same event.
|
||||||
if found {
|
if found {
|
||||||
// Search for middlewares with the same type:
|
// Search for middlewares with the same type:
|
||||||
for _, mw := range sub.mwMethods {
|
for _, mw := range sub.mwMethods {
|
||||||
|
@ -62,6 +69,16 @@ func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find the main context first.
|
||||||
|
find(ctx.Subcommand)
|
||||||
|
|
||||||
|
for _, sub := range ctx.subcommands {
|
||||||
|
// Reset found status
|
||||||
|
found = false
|
||||||
|
// Find subcommands second.
|
||||||
|
find(sub)
|
||||||
|
}
|
||||||
|
|
||||||
return append(middles, callers...)
|
return append(middles, callers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +115,10 @@ func (ctx *Context) callCmd(ev interface{}) error {
|
||||||
for _, c := range filtered {
|
for _, c := range filtered {
|
||||||
_, err := callWith(c.value, ev)
|
_, err := callWith(c.value, ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.ErrorLogger(err)
|
if err = onlyFatal(err); err != nil {
|
||||||
|
ctx.ErrorLogger(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +126,8 @@ func (ctx *Context) callCmd(ev interface{}) error {
|
||||||
// slice, but we don't want to ignore those handlers either.
|
// slice, but we don't want to ignore those handlers either.
|
||||||
if evT == typeMessageCreate {
|
if evT == typeMessageCreate {
|
||||||
// safe assertion always
|
// safe assertion always
|
||||||
return ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent))
|
err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent))
|
||||||
|
return onlyFatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -402,16 +423,28 @@ func callWith(
|
||||||
}
|
}
|
||||||
|
|
||||||
func errorReturns(returns []reflect.Value) (interface{}, error) {
|
func errorReturns(returns []reflect.Value) (interface{}, error) {
|
||||||
// assume first is always error, since we checked for this in parseCommands
|
// Handlers may return nothing.
|
||||||
|
if len(returns) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// assume first return is always error, since we checked for this in
|
||||||
|
// parseCommands.
|
||||||
v := returns[len(returns)-1].Interface()
|
v := returns[len(returns)-1].Interface()
|
||||||
|
// If the last return (error) is nil.
|
||||||
if v == nil {
|
if v == nil {
|
||||||
|
// If we only have 1 returns, that return must be the error. The error
|
||||||
|
// is nil, so nil is returned.
|
||||||
if len(returns) == 1 {
|
if len(returns) == 1 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the first argument as-is. The above returns[-1] check assumes
|
||||||
|
// 2 return values (T, error), meaning returns[0] is the T value.
|
||||||
return returns[0].Interface(), nil
|
return returns[0].Interface(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Treat the last return as an error.
|
||||||
return nil, v.(error)
|
return nil, v.(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,24 @@ var (
|
||||||
}()
|
}()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Subcommand is any form of command, which could be a top-level command or a
|
||||||
|
// subcommand.
|
||||||
|
//
|
||||||
|
// Allowed method signatures
|
||||||
|
//
|
||||||
|
// These are the acceptable function signatures that would be parsed as commands
|
||||||
|
// or events. A return type <T> implies that return value will be ignored.
|
||||||
|
//
|
||||||
|
// func(*gateway.MessageCreateEvent, ...) (string, error)
|
||||||
|
// func(*gateway.MessageCreateEvent, ...) (*discord.Embed, error)
|
||||||
|
// func(*gateway.MessageCreateEvent, ...) (*api.SendMessageData, error)
|
||||||
|
// func(*gateway.MessageCreateEvent, ...) (T, error)
|
||||||
|
// func(*gateway.MessageCreateEvent, ...) error
|
||||||
|
// func(*gateway.MessageCreateEvent, ...)
|
||||||
|
// func(<AnyEvent>) (T, error)
|
||||||
|
// func(<AnyEvent>) error
|
||||||
|
// func(<AnyEvent>)
|
||||||
|
//
|
||||||
type Subcommand struct {
|
type Subcommand struct {
|
||||||
Description string
|
Description string
|
||||||
|
|
||||||
|
@ -92,9 +110,6 @@ type CommandContext struct {
|
||||||
event reflect.Type // gateway.*Event
|
event reflect.Type // gateway.*Event
|
||||||
method reflect.Method
|
method reflect.Method
|
||||||
|
|
||||||
// return type
|
|
||||||
retType reflect.Type
|
|
||||||
|
|
||||||
Arguments []Argument
|
Arguments []Argument
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,6 +134,8 @@ func (cctx *CommandContext) Usage() []string {
|
||||||
return arguments
|
return arguments
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewSubcommand is used to make a new subcommand. You usually wouldn't call
|
||||||
|
// this function, but instead use (*Context).RegisterSubcommand().
|
||||||
func NewSubcommand(cmd interface{}) (*Subcommand, error) {
|
func NewSubcommand(cmd interface{}) (*Subcommand, error) {
|
||||||
var sub = Subcommand{
|
var sub = Subcommand{
|
||||||
command: cmd,
|
command: cmd,
|
||||||
|
@ -333,14 +350,21 @@ func (sub *Subcommand) parseCommands() error {
|
||||||
|
|
||||||
// Check number of returns:
|
// Check number of returns:
|
||||||
numOut := methodT.NumOut()
|
numOut := methodT.NumOut()
|
||||||
if numOut == 0 || numOut > 2 {
|
|
||||||
|
// Returns can either be:
|
||||||
|
// Nothing - func()
|
||||||
|
// An error - func() error
|
||||||
|
// An error and something else - func() (T, error)
|
||||||
|
if numOut > 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the last return's type:
|
// Check the last return's type if the method returns anything.
|
||||||
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
|
if numOut > 0 {
|
||||||
// Invalid, skip.
|
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
|
||||||
continue
|
// Invalid, skip.
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var command = CommandContext{
|
var command = CommandContext{
|
||||||
|
|
|
@ -58,8 +58,6 @@ func (g *Gateway) Resume() error {
|
||||||
type HeartbeatData int
|
type HeartbeatData int
|
||||||
|
|
||||||
func (g *Gateway) Heartbeat() error {
|
func (g *Gateway) Heartbeat() error {
|
||||||
// g.available.RLock()
|
|
||||||
// defer g.available.RUnlock()
|
|
||||||
return g.Send(HeartbeatOP, g.Sequence.Get())
|
return g.Send(HeartbeatOP, g.Sequence.Get())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -179,12 +179,11 @@ func (g *Gateway) Close() error {
|
||||||
// would also exit our event loop. Both would be 2.
|
// would also exit our event loop. Both would be 2.
|
||||||
g.waitGroup.Wait()
|
g.waitGroup.Wait()
|
||||||
|
|
||||||
WSDebug("WaitGroup is done.")
|
|
||||||
|
|
||||||
// Mark g.waitGroup as empty:
|
// Mark g.waitGroup as empty:
|
||||||
g.waitGroup = nil
|
g.waitGroup = nil
|
||||||
|
|
||||||
// Stop the Websocket
|
WSDebug("WaitGroup is done. Closing the websocket.")
|
||||||
|
|
||||||
err := g.WS.Close()
|
err := g.WS.Close()
|
||||||
g.AfterClose(err)
|
g.AfterClose(err)
|
||||||
return err
|
return err
|
||||||
|
@ -199,7 +198,7 @@ func (g *Gateway) Reconnect() error {
|
||||||
return errors.Wrap(err, "Failed to close Gateway before reconnecting")
|
return errors.Wrap(err, "Failed to close Gateway before reconnecting")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < WSRetries; i++ {
|
for i := 0; WSRetries < 0 || i < WSRetries; i++ {
|
||||||
WSDebug("Trying to dial, attempt", i)
|
WSDebug("Trying to dial, attempt", i)
|
||||||
|
|
||||||
// Condition: err == ErrInvalidSession:
|
// Condition: err == ErrInvalidSession:
|
||||||
|
|
|
@ -61,6 +61,9 @@ func (p *Pacemaker) Dead() bool {
|
||||||
func (p *Pacemaker) Stop() {
|
func (p *Pacemaker) Stop() {
|
||||||
if p.stop != nil {
|
if p.stop != nil {
|
||||||
p.stop <- struct{}{}
|
p.stop <- struct{}{}
|
||||||
|
WSDebug("(*Pacemaker).stop was sent a stop signal.")
|
||||||
|
} else {
|
||||||
|
WSDebug("(*Pacemaker).stop is nil, skipping.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,10 +104,10 @@ func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
p.death <- p.start()
|
p.death <- p.start()
|
||||||
// Mark the pacemaker loop as done.
|
|
||||||
wg.Done()
|
|
||||||
// Mark the stop channel as nil, so later Close() calls won't block forever.
|
// Mark the stop channel as nil, so later Close() calls won't block forever.
|
||||||
p.stop = nil
|
p.stop = nil
|
||||||
|
// Mark the pacemaker loop as done.
|
||||||
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return p.death
|
return p.death
|
||||||
|
|
|
@ -71,9 +71,6 @@ func NewConn(driver json.Driver) *Conn {
|
||||||
HandshakeTimeout: DefaultTimeout,
|
HandshakeTimeout: DefaultTimeout,
|
||||||
EnableCompression: true,
|
EnableCompression: true,
|
||||||
},
|
},
|
||||||
events: make(chan Event),
|
|
||||||
writes: make(chan []byte),
|
|
||||||
errors: make(chan error),
|
|
||||||
// zlib: zlib.NewInflator(),
|
// zlib: zlib.NewInflator(),
|
||||||
// buf: make([]byte, CopyBufferSize),
|
// buf: make([]byte, CopyBufferSize),
|
||||||
}
|
}
|
||||||
|
@ -143,8 +140,8 @@ func (c *Conn) readLoop() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If nil bytes, then it's an incomplete payload.
|
// If the payload length is 0, skip it.
|
||||||
if b == nil {
|
if len(b) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue