mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-12-12 08:25:10 +00:00
diamondburned
9899f6073b
This commit adds automatic Intents detection into package bot. When the Start function is used, the intents will be OR'd after running the options callback. This commit also breaks the old "AddIntent" methods to rename them to "AddIntents" for correctness.
431 lines
9.7 KiB
Go
431 lines
9.7 KiB
Go
package bot
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/diamondburned/arikawa/v2/discord"
|
|
"github.com/diamondburned/arikawa/v2/gateway"
|
|
"github.com/diamondburned/arikawa/v2/state"
|
|
"github.com/diamondburned/arikawa/v2/utils/handler"
|
|
)
|
|
|
|
type testc struct {
|
|
Ctx *Context
|
|
Return chan interface{}
|
|
Counter uint64
|
|
Typed int8
|
|
}
|
|
|
|
func (t *testc) Setup(sub *Subcommand) {
|
|
sub.AddMiddleware("*,GetCounter", func(v interface{}) {
|
|
t.Counter++
|
|
})
|
|
sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) {
|
|
t.Counter++
|
|
})
|
|
// stub middleware for testing
|
|
sub.AddMiddleware("OnTyping", func(*gateway.TypingStartEvent) {
|
|
t.Typed = 2
|
|
})
|
|
sub.Hide("Hidden")
|
|
}
|
|
func (t *testc) Hidden(*gateway.MessageCreateEvent) {}
|
|
func (t *testc) Noop(*gateway.MessageCreateEvent) {}
|
|
func (t *testc) GetCounter(*gateway.MessageCreateEvent) {
|
|
t.Return <- strconv.FormatUint(t.Counter, 10)
|
|
}
|
|
func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error {
|
|
t.Return <- args
|
|
return errors.New("oh no")
|
|
}
|
|
func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *ArgumentParts) {
|
|
t.Return <- []string(*c)
|
|
}
|
|
func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) {
|
|
t.Return <- c[len(c)-1]
|
|
}
|
|
func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, _ string, c ArgumentParts) {
|
|
t.Return <- c
|
|
}
|
|
func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) {
|
|
t.Return <- c
|
|
}
|
|
func (t *testc) NoArgs(*gateway.MessageCreateEvent) error {
|
|
return errors.New("passed")
|
|
}
|
|
func (t *testc) OnTyping(*gateway.TypingStartEvent) {
|
|
t.Typed--
|
|
}
|
|
|
|
func TestNewContext(t *testing.T) {
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
c, err := New(s, &testc{})
|
|
if err != nil {
|
|
t.Fatal("Failed to create new context:", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(c.Subcommands(), c.subcommands) {
|
|
t.Fatal("Subcommands mismatch.")
|
|
}
|
|
}
|
|
|
|
func TestContext(t *testing.T) {
|
|
var given = &testc{}
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
Handler: handler.New(),
|
|
}
|
|
|
|
sub, err := NewSubcommand(given)
|
|
if err != nil {
|
|
t.Fatal("Failed to create subcommand:", err)
|
|
}
|
|
|
|
var ctx = &Context{
|
|
Name: "arikawa/bot test",
|
|
Description: "Just a test.",
|
|
|
|
Subcommand: sub,
|
|
State: s,
|
|
ParseArgs: DefaultArgsParser(),
|
|
}
|
|
|
|
t.Run("init commands", func(t *testing.T) {
|
|
if err := ctx.Subcommand.InitCommands(ctx); err != nil {
|
|
t.Fatal("Failed to init commands:", err)
|
|
}
|
|
|
|
if given.Ctx == nil {
|
|
t.Fatal("given'sub Context field is nil")
|
|
}
|
|
|
|
if given.Ctx.State.Store == nil {
|
|
t.Fatal("given'sub State is nil")
|
|
}
|
|
})
|
|
|
|
t.Run("find commands", func(t *testing.T) {
|
|
cmd := ctx.FindCommand("", "NoArgs")
|
|
if cmd == nil {
|
|
t.Fatal("Failed to find NoArgs")
|
|
}
|
|
})
|
|
|
|
t.Run("help", func(t *testing.T) {
|
|
ctx.MustRegisterSubcommandCustom(&testc{}, "helper")
|
|
|
|
h := ctx.Help()
|
|
if h == "" {
|
|
t.Fatal("Empty help?")
|
|
}
|
|
|
|
if strings.Contains(h, "hidden") {
|
|
t.Fatal("Hidden command shown in help.")
|
|
}
|
|
|
|
if !strings.Contains(h, "arikawa/bot test") {
|
|
t.Fatal("Name not found.")
|
|
}
|
|
if !strings.Contains(h, "Just a test.") {
|
|
t.Fatal("Description not found.")
|
|
}
|
|
})
|
|
|
|
t.Run("middleware", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("pls do ")
|
|
|
|
// This should trigger the middleware first.
|
|
if err := expect(ctx, given, "3", "pls do getCounter"); err != nil {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("derive intents", func(t *testing.T) {
|
|
intents := ctx.DeriveIntents()
|
|
|
|
assertIntents := func(target gateway.Intents, name string) {
|
|
if !intents.Has(target) {
|
|
t.Error("Derived intents do not have", name)
|
|
}
|
|
}
|
|
|
|
assertIntents(gateway.IntentGuildMessages, "guild messages")
|
|
assertIntents(gateway.IntentDirectMessages, "direct messages")
|
|
assertIntents(gateway.IntentGuildMessageTyping, "guild typing")
|
|
assertIntents(gateway.IntentDirectMessageTyping, "direct message typing")
|
|
})
|
|
|
|
t.Run("typing event", func(t *testing.T) {
|
|
typing := &gateway.TypingStartEvent{}
|
|
|
|
if err := ctx.callCmd(typing); err != nil {
|
|
t.Fatal("Failed to call with TypingStart:", err)
|
|
}
|
|
|
|
// -1 none ran
|
|
if given.Typed != 1 {
|
|
t.Fatal("Typed bool is false")
|
|
}
|
|
})
|
|
|
|
t.Run("call command", func(t *testing.T) {
|
|
// Set a custom prefix
|
|
ctx.HasPrefix = NewPrefix("~")
|
|
|
|
var (
|
|
send = "hacka doll no. 3"
|
|
expects = []string{"hacka", "doll", "no.", "3"}
|
|
)
|
|
|
|
if err := expect(ctx, given, expects, "~send "+send); err.Error() != "oh no" {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command rawarguments", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := RawArguments("just things")
|
|
|
|
if err := expect(ctx, given, expects, "!content just things"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command custom manual parser", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := []string{"arg1", ":)"}
|
|
|
|
if err := expect(ctx, given, expects, "!custom arg1 :)"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command custom variadic parser", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := &customParsed{true}
|
|
|
|
if err := expect(ctx, given, expects, "!variadic bruh moment"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("call command custom trailing manual parser", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
expects := ArgumentParts{"arikawa"}
|
|
|
|
if err := sendMsg(ctx, given, &expects, "!trailCustom hime arikawa"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
|
|
if expects.Length() != 1 {
|
|
t.Fatal("Unexpected ArgumentParts length.")
|
|
}
|
|
if expects.After(1)+expects.After(2)+expects.After(-1) != "" {
|
|
t.Fatal("Unexpected ArgumentsParts after.")
|
|
}
|
|
if expects.String() != "arikawa" {
|
|
t.Fatal("Unexpected ArgumentsParts string.")
|
|
}
|
|
if expects.Arg(0) != "arikawa" {
|
|
t.Fatal("Unexpected ArgumentParts arg 0")
|
|
}
|
|
if expects.Arg(1) != "" {
|
|
t.Fatal("Unexpected ArgumentParts arg 1")
|
|
}
|
|
})
|
|
|
|
testMessage := func(content string) error {
|
|
// Mock a messageCreate event
|
|
m := &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: content,
|
|
},
|
|
}
|
|
|
|
return ctx.callCmd(m)
|
|
}
|
|
|
|
t.Run("call command without args", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("")
|
|
|
|
if err := testMessage("noArgs"); err.Error() != "passed" {
|
|
t.Fatal("unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
// Test error cases
|
|
|
|
t.Run("call unknown command", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("joe pls ")
|
|
|
|
err := testMessage("joe pls no")
|
|
|
|
if err == nil || !strings.HasPrefix(err.Error(), "unknown command:") {
|
|
t.Fatal("unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
// Test subcommands
|
|
|
|
t.Run("register subcommand", func(t *testing.T) {
|
|
ctx.HasPrefix = NewPrefix("run ")
|
|
|
|
sub := &testc{}
|
|
ctx.MustRegisterSubcommand(sub)
|
|
|
|
if err := testMessage("run testc noop"); err != nil {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
|
|
expects := RawArguments("hackadoll no. 3")
|
|
|
|
if err := expect(ctx, sub, expects, "run testc content hackadoll no. 3"); err != nil {
|
|
t.Fatal("Unexpected call error:", err)
|
|
}
|
|
|
|
if cmd := ctx.FindCommand("testc", "Noop"); cmd == nil {
|
|
t.Fatal("Failed to find subcommand Noop")
|
|
}
|
|
})
|
|
|
|
t.Run("register subcommand custom", func(t *testing.T) {
|
|
ctx.MustRegisterSubcommandCustom(&testc{}, "arikawa")
|
|
})
|
|
|
|
t.Run("duplicate subcommand", func(t *testing.T) {
|
|
_, err := ctx.RegisterSubcommandCustom(&testc{}, "arikawa")
|
|
if err := err.Error(); !strings.Contains(err, "duplicate") {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
})
|
|
|
|
t.Run("start", func(t *testing.T) {
|
|
cancel := ctx.Start()
|
|
defer cancel()
|
|
|
|
ctx.HasPrefix = NewPrefix("!")
|
|
given.Return = make(chan interface{})
|
|
|
|
ctx.Handler.Call(&gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: "!content hime arikawa best trap",
|
|
},
|
|
})
|
|
|
|
if c := (<-given.Return).(RawArguments); c != "hime arikawa best trap" {
|
|
t.Fatal("Unexpected content:", c)
|
|
}
|
|
})
|
|
}
|
|
|
|
func expect(ctx *Context, given *testc, expects interface{}, content string) (call error) {
|
|
var v interface{}
|
|
if call = sendMsg(ctx, given, &v, content); call != nil {
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(v, expects) {
|
|
return fmt.Errorf("returned argument is invalid: %v", v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sendMsg(ctx *Context, given *testc, into interface{}, content string) (call error) {
|
|
// Return channel for testing
|
|
ret := make(chan interface{})
|
|
given.Return = ret
|
|
|
|
// Mock a messageCreate event
|
|
m := &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: content,
|
|
},
|
|
}
|
|
|
|
var callCh = make(chan error)
|
|
go func() {
|
|
callCh <- ctx.Call(m)
|
|
}()
|
|
|
|
select {
|
|
case arg := <-ret:
|
|
call = <-callCh
|
|
reflect.ValueOf(into).Elem().Set(reflect.ValueOf(arg))
|
|
return
|
|
|
|
case call = <-callCh:
|
|
return fmt.Errorf("expected return before error: %w", call)
|
|
|
|
case <-time.After(time.Second):
|
|
return errors.New("timed out while waiting")
|
|
}
|
|
}
|
|
|
|
func BenchmarkConstructor(b *testing.B) {
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
_, _ = New(s, &testc{})
|
|
}
|
|
}
|
|
|
|
func BenchmarkCall(b *testing.B) {
|
|
var given = &testc{}
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
sub, _ := NewSubcommand(given)
|
|
|
|
var ctx = &Context{
|
|
Subcommand: sub,
|
|
State: s,
|
|
HasPrefix: NewPrefix("~"),
|
|
ParseArgs: DefaultArgsParser(),
|
|
}
|
|
|
|
m := &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
Content: "~noop",
|
|
},
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
ctx.callCmd(m)
|
|
}
|
|
}
|
|
|
|
func BenchmarkHelp(b *testing.B) {
|
|
var given = &testc{}
|
|
var s = &state.State{
|
|
Store: state.NewDefaultStore(nil),
|
|
}
|
|
|
|
sub, _ := NewSubcommand(given)
|
|
|
|
var ctx = &Context{
|
|
Subcommand: sub,
|
|
State: s,
|
|
HasPrefix: NewPrefix("~"),
|
|
ParseArgs: DefaultArgsParser(),
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
_ = ctx.Help()
|
|
}
|
|
}
|