Bot: Added variadic arguments support

This commit is contained in:
diamondburned (Forefront) 2020-05-03 15:59:10 -07:00
parent af682d3f35
commit 3b903dce68
7 changed files with 162 additions and 66 deletions

View File

@ -93,7 +93,7 @@ func (c *Content) CustomParse(content string) error {
type Argument struct {
String string
// Rule: pointer for structs, direct for primitives
Type reflect.Type
rtype reflect.Type
// indicates if the type is referenced, meaning it's a pointer but not the
// original call.
@ -105,6 +105,10 @@ type Argument struct {
custom *reflect.Method
}
func (a *Argument) Type() reflect.Type {
return a.rtype
}
var ShellwordsEscaper = strings.NewReplacer(
"\\", "\\\\",
)
@ -116,7 +120,12 @@ var ParseArgs = func(args string) ([]string, error) {
// nilV, only used to return an error
var nilV = reflect.Value{}
func getArgumentValueFn(t reflect.Type) (*Argument, error) {
func getArgumentValueFn(t reflect.Type, variadic bool) (*Argument, error) {
// Allow array types if varidic is true.
if variadic && t.Kind() == reflect.Slice {
t = t.Elem()
}
var typeI = t
var ptr = false
@ -152,7 +161,7 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) {
return &Argument{
String: fromUsager(typeI),
Type: typeI,
rtype: typeI,
pointer: ptr,
fn: avfn,
}, nil
@ -166,17 +175,13 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) {
return reflect.ValueOf(s), nil
}
case reflect.Int, reflect.Int8,
reflect.Int16, reflect.Int32, reflect.Int64:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fn = func(s string) (reflect.Value, error) {
i, err := strconv.ParseInt(s, 10, 64)
return quickRet(i, err, t)
}
case reflect.Uint, reflect.Uint8,
reflect.Uint16, reflect.Uint32, reflect.Uint64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
fn = func(s string) (reflect.Value, error) {
u, err := strconv.ParseUint(s, 10, 64)
return quickRet(u, err, t)
@ -196,7 +201,7 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) {
case "False", "FALSE", "false", "F", "f", "no", "n", "N", "0":
return reflect.ValueOf(false), nil
default:
return nilV, errors.New("invalid bool [true/false]")
return nilV, errors.New("invalid bool [true|false]")
}
}
}
@ -207,7 +212,7 @@ func getArgumentValueFn(t reflect.Type) (*Argument, error) {
return &Argument{
String: t.String(),
Type: t,
rtype: t,
fn: fn,
}, nil
}
@ -232,9 +237,11 @@ func fromUsager(typeI reflect.Type) string {
if !ok {
panic("BUG: type IUsager does not implement Usage")
}
vs := mt.Func.Call([]reflect.Value{reflect.New(typeI.Elem())})
return vs[0].String()
}
s := strings.Split(typeI.String(), ".")
return s[len(s)-1]
}

View File

@ -27,15 +27,14 @@ func TestArguments(t *testing.T) {
testArgs(t, mockParse("testString"), "testString")
testArgs(t, *mockParse("testString"), "testString")
_, err := getArgumentValueFn(reflect.TypeOf(struct{}{}))
_, err := getArgumentValueFn(reflect.TypeOf(struct{}{}), false)
if !strings.HasPrefix(err.Error(), "invalid type: ") {
t.Fatal("Unexpected error:", err)
}
}
func testArgs(t *testing.T, expect interface{}, input string) {
f, err := getArgumentValueFn(reflect.TypeOf(expect))
f, err := getArgumentValueFn(reflect.TypeOf(expect), false)
if err != nil {
t.Fatal("Failed to get argument value function:", err)
}
@ -49,3 +48,23 @@ func testArgs(t *testing.T, expect interface{}, input string) {
t.Fatal("Value :", v, "\nExpects:", expect)
}
}
// used for ctx_test.go
type customManualParsed struct {
args []string
}
func (c *customManualParsed) ParseContent(args []string) error {
c.args = args
return nil
}
type customParsed struct {
parsed bool
}
func (c *customParsed) Parse(string) error {
c.parsed = true
return nil
}

View File

@ -265,7 +265,7 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
// Check manual or parser
if cmd.Arguments[0].fn == nil {
// Create a zero value instance of this:
v := reflect.New(cmd.Arguments[0].Type)
v := reflect.New(cmd.Arguments[0].rtype)
ret := []reflect.Value{}
switch {
@ -313,35 +313,78 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
goto Call
}
// Not enough arguments given
if delta := len(args[start:]) - len(cmd.Arguments); delta != 0 {
var err = "Not enough arguments given"
if delta > 0 {
err = "Too many arguments given"
// Argument count check.
if argdelta := len(args[start:]) - len(cmd.Arguments); argdelta != 0 {
var err error // no err if nil
switch {
// If there aren't enough arguments given.
case argdelta < 0:
err = ErrNotEnoughArgs
// If there are too many arguments, then check if the command supports
// variadic arguments. We already did a length check above.
case argdelta > 0 && !cmd.Variadic:
// If it's not variadic, then we can't accept it.
err = ErrTooManyArgs
}
return &ErrInvalidUsage{
Args: args,
Index: len(args) - 1,
Err: err,
Ctx: cmd,
if err != nil {
return &ErrInvalidUsage{
Prefix: pf,
Args: args,
Index: len(args) - 1,
Wrap: err,
Ctx: cmd,
}
}
}
// Allocate a new slice the length of function arguments.
argv = make([]reflect.Value, len(cmd.Arguments))
for i := start; i < len(args); i++ {
v, err := cmd.Arguments[i-start].fn(args[i])
for i := 0; i < len(argv); i++ {
v, err := cmd.Arguments[i].fn(args[start+i])
if err != nil {
return &ErrInvalidUsage{
Args: args,
Index: i,
Err: err.Error(),
Ctx: cmd,
Prefix: pf,
Args: args,
Index: i,
Wrap: err,
Ctx: cmd,
}
}
argv[i-start] = v
argv[i] = v
}
// Parse the rest with variadic arguments. Go's reflect states that varidic
// parameters will automatically be copied, which is good.
if len(args) > len(argv) {
// The location to continue parsing from args.
argc := len(argv)
// Allocate a new slice to append into. We start 1-off from the start,
// as the first argument of the variadic slice is already parsed.
vars := make([]reflect.Value, len(args)-len(argv)-1)
last := cmd.Arguments[len(cmd.Arguments)-1]
// Continue the above loop, where i stops before len(argv).
for i := 0; i < len(vars); i++ {
v, err := last.fn(args[argc+i+1])
if err != nil {
return &ErrInvalidUsage{
Prefix: pf,
Args: args,
Index: i,
Wrap: err,
Ctx: cmd,
}
}
vars[i] = v
}
argv = append(argv, vars...)
}
Call:
@ -426,10 +469,12 @@ func callWith(
caller reflect.Value,
ev interface{}, values ...reflect.Value) (interface{}, error) {
return errorReturns(caller.Call(append(
values = append(
[]reflect.Value{reflect.ValueOf(ev)},
values...,
)))
)
return errorReturns(caller.Call(values))
}
func errorReturns(returns []reflect.Value) (interface{}, error) {

View File

@ -29,16 +29,21 @@ func (t *testCommands) GetCounter(_ *gateway.MessageCreateEvent) error {
return nil
}
func (t *testCommands) Send(_ *gateway.MessageCreateEvent, arg string) error {
t.Return <- arg
func (t *testCommands) Send(_ *gateway.MessageCreateEvent, args ...string) error {
t.Return <- args
return errors.New("oh no")
}
func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *customParseable) error {
func (t *testCommands) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) error {
t.Return <- c.args
return nil
}
func (t *testCommands) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) error {
t.Return <- c[len(c)-1]
return nil
}
func (t *testCommands) NoArgs(_ *gateway.MessageCreateEvent) error {
return errors.New("passed")
}
@ -52,15 +57,6 @@ func (t *testCommands) OnTyping(_ *gateway.TypingStartEvent) error {
return nil
}
type customParseable struct {
args []string
}
func (c *customParseable) ParseContent(args []string) error {
c.args = args
return nil
}
func TestNewContext(t *testing.T) {
var state = &state.State{
Store: state.NewDefaultStore(nil),
@ -181,12 +177,17 @@ func TestContext(t *testing.T) {
// Set a custom prefix
ctx.HasPrefix = NewPrefix("~")
if err := testReturn("test", "~send test"); err.Error() != "oh no" {
var (
strings = "hacka doll no. 3"
expects = []string{"hacka", "doll", "no.", "3"}
)
if err := testReturn(expects, "~send "+strings); err.Error() != "oh no" {
t.Fatal("Unexpected error:", err)
}
})
t.Run("call command custom parser", func(t *testing.T) {
t.Run("call command custom manual parser", func(t *testing.T) {
ctx.HasPrefix = NewPrefix("!")
expects := []string{"custom", "arg1", ":)"}
@ -195,6 +196,15 @@ func TestContext(t *testing.T) {
}
})
t.Run("call command custom variadic parser", func(t *testing.T) {
ctx.HasPrefix = NewPrefix("!")
expects := &customParsed{true}
if err := testReturn(expects, "!variadic bruh moment"); err != nil {
t.Fatal("Unexpected call error:", err)
}
})
testMessage := func(content string) error {
// Mock a messageCreate event
m := &gateway.MessageCreateEvent{

View File

@ -1,10 +1,12 @@
package bot
import (
"errors"
"strings"
)
type ErrUnknownCommand struct {
Prefix string
Command string
Parent string
@ -18,7 +20,7 @@ func (err *ErrUnknownCommand) Error() string {
}
var UnknownCommandString = func(err *ErrUnknownCommand) string {
var header = "Unknown command: "
var header = "Unknown command: " + err.Prefix
if err.Parent != "" {
header += err.Parent + " " + err.Command
} else {
@ -28,10 +30,16 @@ var UnknownCommandString = func(err *ErrUnknownCommand) string {
return header
}
var (
ErrTooManyArgs = errors.New("Too many arguments given")
ErrNotEnoughArgs = errors.New("Not enough arguments given")
)
type ErrInvalidUsage struct {
Args []string
Index int
Err string
Prefix string
Args []string
Index int
Wrap error
// TODO: usage generator?
// Here, as a reminder
@ -42,9 +50,13 @@ func (err *ErrInvalidUsage) Error() string {
return InvalidUsageString(err)
}
func (err *ErrInvalidUsage) Unwrap() error {
return err.Wrap
}
var InvalidUsageString = func(err *ErrInvalidUsage) string {
if err.Index == 0 {
return "Invalid usage, error: " + err.Err
return "Invalid usage, error: " + err.Wrap.Error() + "."
}
if len(err.Args) == 0 {
@ -52,6 +64,8 @@ var InvalidUsageString = func(err *ErrInvalidUsage) string {
}
body := "Invalid usage at " +
// Write the prefix.
err.Prefix +
// Write the first part
strings.Join(err.Args[:err.Index], " ") +
// Write the wrong part
@ -59,8 +73,8 @@ var InvalidUsageString = func(err *ErrInvalidUsage) string {
// Write the last part
strings.Join(err.Args[err.Index+1:], " ")
if err.Err != "" {
body += "\nError: " + err.Err
if err.Wrap != nil {
body += "\nError: " + err.Wrap.Error() + "."
}
return body

View File

@ -120,6 +120,9 @@ type CommandContext struct {
// Hidden is true if the method has a hidden nameflag.
Hidden bool
// Variadic is true if the function is a variadic one.
Variadic bool
value reflect.Value // Func
event reflect.Type // gateway.*Event
method reflect.Method
@ -389,9 +392,10 @@ func (sub *Subcommand) parseCommands() error {
}
var command = CommandContext{
method: sub.ptrType.Method(i),
value: method,
event: methodT.In(0), // parse event
method: sub.ptrType.Method(i),
value: method,
event: methodT.In(0), // parse event
Variadic: methodT.IsVariadic(),
}
// Parse the method name
@ -460,7 +464,7 @@ func (sub *Subcommand) parseCommands() error {
command.Arguments = []Argument{{
String: t.String(),
Type: t,
rtype: t,
pointer: ptr,
custom: &mt,
}}
@ -478,7 +482,7 @@ func (sub *Subcommand) parseCommands() error {
command.Arguments = []Argument{{
String: t.String(),
Type: t,
rtype: t,
pointer: ptr,
manual: &mt,
}}
@ -491,7 +495,7 @@ func (sub *Subcommand) parseCommands() error {
// Fill up arguments
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
a, err := getArgumentValueFn(t)
a, err := getArgumentValueFn(t, command.Variadic)
if err != nil {
return errors.Wrap(err, "Error parsing argument "+t.String())
}

View File

@ -27,7 +27,7 @@ func TestSubcommand(t *testing.T) {
}
// !!! CHANGE ME
if len(sub.Commands) != 5 {
if len(sub.Commands) != 6 {
t.Fatal("invalid ctx.commands len", len(sub.Commands))
}
@ -57,11 +57,8 @@ func TestSubcommand(t *testing.T) {
t.Fatal("expected 0 arguments, got non-zero")
}
case "noop", "getCounter":
case "noop", "getCounter", "variadic":
// Found, but whatever
default:
t.Fatal("Unexpected command:", this.Command)
}
if this.event != typeMessageCreate {