mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-12-11 07:54:58 +00:00
diamondburned
273fcf1418
This commit adds subcommand aliases as well as additional code in HelpGenerate to cover for both subcommand and command aliases. A breaking change is that {,Must}RegisterSubcommandCustom methods are now replaced with normal {,Must}RegisterSubcommand methods. This is because they now use variadic strings, which could take 0, 1 or more arguments. This commit also allows AddMiddleware and similar methods to be given a method directly: sub.Plumb(cmds.PlumbedHandler) sub.AddMiddleware(cmds.PlumbedHandler, cmds.plumbMiddleware) This change closes issue #146.
439 lines
9.9 KiB
Go
439 lines
9.9 KiB
Go
package bot
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/diamondburned/arikawa/v2/discord"
|
|
"github.com/diamondburned/arikawa/v2/gateway"
|
|
"github.com/diamondburned/arikawa/v2/state"
|
|
"github.com/diamondburned/arikawa/v2/utils/handler"
|
|
)
|
|
|
|
type testc struct {
|
|
Ctx *Context
|
|
Return chan interface{}
|
|
Counter uint64
|
|
Typed int8
|
|
}
|
|
|
|
func (t *testc) Setup(sub *Subcommand) {
|
|
sub.AddMiddleware([]string{"*", "GetCounter"}, func(v interface{}) {
|
|
t.Counter++
|
|
})
|
|
sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) {
|
|
t.Counter++
|
|
})
|
|
// stub middleware for testing
|
|
sub.AddMiddleware(t.OnTyping, func(*gateway.TypingStartEvent) {
|
|
t.Typed = 2
|
|
})
|
|
sub.Hide(t.Hidden)
|
|
}
|
|
func (t *testc) Hidden(*gateway.MessageCreateEvent) {}
|
|
func (t *testc) Noop(*gateway.MessageCreateEvent) {}
|
|
func (t *testc) GetCounter(*gateway.MessageCreateEvent) {
|
|
t.Return <- strconv.FormatUint(t.Counter, 10)
|
|
}
|
|
func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error {
|
|
t.Return <- args
|
|
return errors.New("oh no")
|
|
}
|
|
func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *ArgumentParts) {
|
|
t.Return <- []string(*c)
|
|
}
|
|
func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) {
|
|
t.Return <- c[len(c)-1]
|
|
}
|
|
func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, _ string, c ArgumentParts) {
|
|
t.Return <- c
|
|
}
|
|
func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) {
|
|
t.Return <- c
|
|
}
|
|
func (t *testc) NoArgs(*gateway.MessageCreateEvent) error {
|
|
return errors.New("passed")
|
|
}
|
|
func (t *testc) OnTyping(*gateway.TypingStartEvent) {
|
|
t.Typed--
|
|
}
|
|
|
|
func TestNewContext(t *testing.T) {
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
c, err := New(s, &testc{})
|
|
if err != nil {
|
|
t.Fatal("Failed to create new context:", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(c.Subcommands(), c.subcommands) {
|
|
t.Fatal("Subcommands mismatch.")
|
|
}
|
|
}
|
|
|
|
func TestContext(t *testing.T) {
|
|
var given = &testc{}
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
Handler: handler.New(),
|
|
}
|
|
|
|
sub, err := NewSubcommand(given)
|
|
if err != nil {
|
|
t.Fatal("Failed to create subcommand:", err)
|
|
}
|
|
|
|
var ctx = &Context{
|
|
Name: "arikawa/bot test",
|
|
Description: "Just a test.",
|
|
|
|
Subcommand: sub,
|
|
State: s,
|
|
ParseArgs: DefaultArgsParser(),
|
|
}
|
|
|
|
t.Run("init commands", func(t *testing.T) {
|
|
if err := ctx.Subcommand.InitCommands(ctx); err != nil {
|
|
t.Fatal("Failed to init commands:", err)
|
|
}
|
|
|
|
if given.Ctx == nil {
|
|
t.Fatal("given'sub Context field is nil")
|
|
}
|
|
|
|
if given.Ctx.State.Store == nil {
|
|
t.Fatal("given'sub State is nil")
|
|
}
|
|
})
|
|
|
|
t.Run("find commands", func(t *testing.T) {
|
|
cmd := ctx.FindCommand("", "NoArgs")
|
|
if cmd == nil {
|
|
t.Fatal("Failed to find NoArgs")
|
|
}
|
|
})
|
|
|
|
t.Run("middleware", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("pls do ")
|
|
|
|
// This should trigger the middleware first.
|
|
if err := expect(ctx, given, "3", "pls do getCounter"); err != nil {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("derive intents", func(t *testing.T) {
|
|
intents := ctx.DeriveIntents()
|
|
|
|
assertIntents := func(target gateway.Intents, name string) {
|
|
if !intents.Has(target) {
|
|
t.Error("Derived intents do not have", name)
|
|
}
|
|
}
|
|
|
|
assertIntents(gateway.IntentGuildMessages, "guild messages")
|
|
assertIntents(gateway.IntentDirectMessages, "direct messages")
|
|
assertIntents(gateway.IntentGuildMessageTyping, "guild typing")
|
|
assertIntents(gateway.IntentDirectMessageTyping, "direct message typing")
|
|
})
|
|
|
|
t.Run("typing event", func(t *testing.T) {
|
|
typing := &gateway.TypingStartEvent{}
|
|
|
|
if err := ctx.callCmd(typing); err != nil {
|
|
t.Fatal("Failed to call with TypingStart:", err)
|
|
}
|
|
|
|
// -1 none ran
|
|
if given.Typed != 1 {
|
|
t.Fatal("Typed bool is false")
|
|
}
|
|
})
|
|
|
|
t.Run("call command", func(t *testing.T) {
|
|
// Set a custom prefix
|
|
ctx.HasPrefix = NewPrefix("~")
|
|
|
|
var (
|
|
send = "hacka doll no. 3"
|
|
expects = []string{"hacka", "doll", "no.", "3"}
|
|
)
|
|
|
|
if err := expect(ctx, given, expects, "~send "+send); err.Error() != "oh no" {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command rawarguments", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := RawArguments("just things")
|
|
|
|
if err := expect(ctx, given, expects, "!content just things"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command custom manual parser", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := []string{"arg1", ":)"}
|
|
|
|
if err := expect(ctx, given, expects, "!custom arg1 :)"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command custom variadic parser", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := &customParsed{true}
|
|
|
|
if err := expect(ctx, given, expects, "!variadic bruh moment"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command custom trailing manual parser", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := ArgumentParts{"arikawa"}
|
|
|
|
if err := sendMsg(ctx, given, &expects, "!trailCustom hime arikawa"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
|
|
if expects.Length() != 1 {
|
|
t.Fatal("Unexpected ArgumentParts length.")
|
|
}
|
|
if expects.After(1)+expects.After(2)+expects.After(-1) != "" {
|
|
t.Fatal("Unexpected ArgumentsParts after.")
|
|
}
|
|
if expects.String() != "arikawa" {
|
|
t.Fatal("Unexpected ArgumentsParts string.")
|
|
}
|
|
if expects.Arg(0) != "arikawa" {
|
|
t.Fatal("Unexpected ArgumentParts arg 0")
|
|
}
|
|
if expects.Arg(1) != "" {
|
|
t.Fatal("Unexpected ArgumentParts arg 1")
|
|
}
|
|
})
|
|
|
|
testMessage := func(content string) error {
|
|
// Mock a messageCreate event
|
|
m := &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: content,
|
|
},
|
|
}
|
|
|
|
return ctx.callCmd(m)
|
|
}
|
|
|
|
t.Run("call command without args", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("")
|
|
|
|
if err := testMessage("noArgs"); err.Error() != "passed" {
|
|
t.Fatal("unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
// Test error cases
|
|
|
|
t.Run("call unknown command", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("joe pls ")
|
|
|
|
err := testMessage("joe pls no")
|
|
|
|
if err == nil || !strings.HasPrefix(err.Error(), "unknown command:") {
|
|
t.Fatal("unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
// Test subcommands
|
|
|
|
t.Run("register subcommand", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("run ")
|
|
|
|
sub := &testc{}
|
|
ctx.MustRegisterSubcommand(sub)
|
|
|
|
if err := testMessage("run testc noop"); err != nil {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
|
|
expects := RawArguments("hackadoll no. 3")
|
|
|
|
if err := expect(ctx, sub, expects, "run testc content hackadoll no. 3"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
|
|
if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil {
|
|
t.Fatal("Failed to find subcommand Noop")
|
|
}
|
|
})
|
|
|
|
t.Run("register subcommand custom", func(t *testing.T) {
|
|
ctx.MustRegisterSubcommand(&testc{}, "arikawa", "a")
|
|
})
|
|
|
|
t.Run("duplicate subcommand", func(t *testing.T) {
|
|
_, err := ctx.RegisterSubcommand(&testc{}, "arikawa")
|
|
if err := err.Error(); !strings.Contains(err, "duplicate") {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
|
|
_, err = ctx.RegisterSubcommand(&testc{}, "a")
|
|
if err := err.Error(); !strings.Contains(err, "duplicate") {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("help", func(t *testing.T) {
|
|
ctx.MustRegisterSubcommand(&testc{}, "helper")
|
|
|
|
h := ctx.Help()
|
|
if h == "" {
|
|
t.Fatal("Empty help?")
|
|
}
|
|
|
|
if strings.Contains(h, "hidden") {
|
|
t.Fatal("Hidden command shown in help.")
|
|
}
|
|
|
|
if !strings.Contains(h, "arikawa/bot test") {
|
|
t.Fatal("Name not found.")
|
|
}
|
|
if !strings.Contains(h, "Just a test.") {
|
|
t.Fatal("Description not found.")
|
|
}
|
|
if !strings.Contains(h, "**a**") {
|
|
t.Fatal("arikawa alias `a' not found.")
|
|
}
|
|
})
|
|
|
|
t.Run("start", func(t *testing.T) {
|
|
cancel := ctx.Start()
|
|
defer cancel()
|
|
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
given.Return = make(chan interface{})
|
|
|
|
ctx.Handler.Call(&gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: "!content hime arikawa best trap",
|
|
},
|
|
})
|
|
|
|
if c := (<-given.Return).(RawArguments); c != "hime arikawa best trap" {
|
|
t.Fatal("Unexpected content:", c)
|
|
}
|
|
})
|
|
}
|
|
|
|
func expect(ctx *Context, given *testc, expects interface{}, content string) (call error) {
|
|
var v interface{}
|
|
if call = sendMsg(ctx, given, &v, content); call != nil {
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(v, expects) {
|
|
return fmt.Errorf("returned argument is invalid: %v", v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sendMsg(ctx *Context, given *testc, into interface{}, content string) (call error) {
|
|
// Return channel for testing
|
|
ret := make(chan interface{})
|
|
given.Return = ret
|
|
|
|
// Mock a messageCreate event
|
|
m := &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: content,
|
|
},
|
|
}
|
|
|
|
var callCh = make(chan error)
|
|
go func() {
|
|
callCh <- ctx.Call(m)
|
|
}()
|
|
|
|
select {
|
|
case arg := <-ret:
|
|
call = <-callCh
|
|
reflect.ValueOf(into).Elem().Set(reflect.ValueOf(arg))
|
|
return
|
|
|
|
case call = <-callCh:
|
|
return fmt.Errorf("expected return before error: %w", call)
|
|
|
|
case <-time.After(time.Second):
|
|
return errors.New("timed out while waiting")
|
|
}
|
|
}
|
|
|
|
func BenchmarkConstructor(b *testing.B) {
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
_, _ = New(s, &testc{})
|
|
}
|
|
}
|
|
|
|
func BenchmarkCall(b *testing.B) {
|
|
var given = &testc{}
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
sub, _ := NewSubcommand(given)
|
|
|
|
var ctx = &Context{
|
|
Subcommand: sub,
|
|
State: s,
|
|
HasPrefix: NewPrefix("~"),
|
|
ParseArgs: DefaultArgsParser(),
|
|
}
|
|
|
|
m := &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: "~noop",
|
|
},
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
ctx.callCmd(m)
|
|
}
|
|
}
|
|
|
|
func BenchmarkHelp(b *testing.B) {
|
|
var given = &testc{}
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
sub, _ := NewSubcommand(given)
|
|
|
|
var ctx = &Context{
|
|
Subcommand: sub,
|
|
State: s,
|
|
HasPrefix: NewPrefix("~"),
|
|
ParseArgs: DefaultArgsParser(),
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
_ = ctx.Help()
|
|
}
|
|
}
|