diff --git a/api/cmdroute/router.go b/api/cmdroute/router.go index 1141f2c..4a264cd 100644 --- a/api/cmdroute/router.go +++ b/api/cmdroute/router.go @@ -10,9 +10,10 @@ import ( // Router is a router for slash commands. A zero-value Router is a valid router. type Router struct { - nodes map[string]routeNode - mws []Middleware - stack []*Router + nodes map[string]routeNode + mws []Middleware + parent *Router // parent router, if any + groups []Router // next routers to check, if any } type routeNode interface { @@ -44,9 +45,6 @@ func NewRouter() *Router { } func (r *Router) init() { - if r.stack == nil { - r.stack = []*Router{r} - } if r.nodes == nil { r.nodes = make(map[string]routeNode, 4) } @@ -75,7 +73,7 @@ func (r *Router) Use(mws ...Middleware) { // parent command of the given name. func (r *Router) Sub(name string, f func(r *Router)) { sub := NewRouter() - sub.stack = append(append([]*Router(nil), r.stack...), sub) + sub.parent = r f(sub) r.add(name, routeNodeSub{sub}) @@ -92,6 +90,39 @@ func (r *Router) AddFunc(name string, f CommandHandlerFunc) { r.Add(name, f) } +// Group creates a subrouter that handles certain commands within the parent +// command. This is useful for assigning middlewares to a group of commands that +// belong to the same parent command. +// +// For example, consider the following: +// +// r := cmdroute.NewRouter() +// r.Group(func(r *cmdroute.Router) { +// r.Use(cmdroute.Deferrable(client, cmdroute.DeferOpts{})) +// r.Add("foo", handleFoo) +// }) +// r.Add("bar", handleBar) +// +// In this example, the middleware is only applied to the "foo" command, and not +// the "bar" command. +func (r *Router) Group(f func(r *Router)) { + f(r.With()) +} + +// With is similar to Group, but it returns a new router instead of calling a +// function with a new router. This is useful for chaining middlewares once, +// such as: +// +// r := cmdroute.NewRouter() +// r.With(cmdroute.Deferrable(client, cmdroute.DeferOpts{})).Add("foo", handleFoo) +func (r *Router) With(mws ...Middleware) *Router { + r.groups = append(r.groups, Router{}) + sub := &r.groups[len(r.groups)-1] + sub.parent = r + sub.mws = append(sub.mws, mws...) + return sub +} + // HandleInteraction implements webhook.InteractionHandler. It only handles // events of type CommandInteraction, otherwise nil is returned. func (r *Router) HandleInteraction(ev *discord.InteractionEvent) *api.InteractionResponse { @@ -113,11 +144,11 @@ func (r *Router) callHandler(ev *discord.InteractionEvent, fn InteractionHandler // Apply middlewares, parent last, first one added last. This ensures that // when we call the handler, the middlewares are applied in the order they // were added. - for i := len(r.stack) - 1; i >= 0; i-- { - r := r.stack[i] - for j := len(r.mws) - 1; j >= 0; j-- { - h = r.mws[j](h) + for r != nil { + for i := len(r.mws) - 1; i >= 0; i-- { + h = r.mws[i](h) } + r = r.parent } return h.HandleInteraction(context.Background(), ev) @@ -162,7 +193,27 @@ type handlerData struct { data discord.CommandInteractionOption } +// findCommandHandler finds the command handler for the given command name. +// It checks the current router and its groups. func (r *Router) findCommandHandler(ev *discord.InteractionEvent, data discord.CommandInteractionOption) (handlerData, bool) { + found, ok := r.findCommandHandlerOnce(ev, data) + if ok { + return found, true + } + + for _, sub := range r.groups { + found, ok = sub.findCommandHandlerOnce(ev, data) + if ok { + return found, true + } + } + + return handlerData{}, false +} + +// findCommandHandlerOnce finds the command handler for the given command name. +// It only checks the current router and not its groups. +func (r *Router) findCommandHandlerOnce(ev *discord.InteractionEvent, data discord.CommandInteractionOption) (handlerData, bool) { node, ok := r.nodes[data.Name] if !ok { return handlerData{}, false diff --git a/api/cmdroute/router_test.go b/api/cmdroute/router_test.go index fd588d7..88a349d 100644 --- a/api/cmdroute/router_test.go +++ b/api/cmdroute/router_test.go @@ -165,34 +165,45 @@ func TestRouter(t *testing.T) { }) t.Run("middlewares", func(t *testing.T) { - var stack []string - pushStack := func(s string) Middleware { - return func(next InteractionHandler) InteractionHandler { - return InteractionHandlerFunc(func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse { - stack = append(stack, s) - return next.HandleInteraction(ctx, ev) - }) - } - } + var stack middlewareStacker r := NewRouter() - r.Use(pushStack("root1")) - r.Use(pushStack("root2")) + r.Use(stack.pusher("root1")) + r.Use(stack.pusher("root2")) r.Sub("test", func(r *Router) { - r.Use(pushStack("sub1.1")) - r.Use(pushStack("sub1.2")) - r.Sub("sub1", func(r *Router) { - r.Use(pushStack("sub2.1")) - r.Use(pushStack("sub2.2")) - r.Add("sub2", assertHandler(t, mockOptions)) + // We put test 1 at the start, but test 2 at the end. + // The order should be preserved. + r.Use(stack.pusher("test.1")) + + // unused + r.Group(func(r *Router) { + r.Use(stack.pusher("test.3")) }) + + // unused + r.With(stack.pusher("test.4")) + + r.Group(func(r *Router) { + r.Use(stack.pusher("test.5")) + + r.Sub("sub", func(r *Router) { + r.Use(stack.pusher("test.sub.1")) + r.Use(stack.pusher("test.sub.2")) + + r.Add("sub2", assertHandler(t, mockOptions)) + }) + }) + + // Test 2 goes here. + r.Use(stack.pusher("test.2")) }) + r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "test", Options: []discord.CommandInteractionOption{ { - Name: "sub1", + Name: "sub", Type: discord.SubcommandGroupOptionType, Options: []discord.CommandInteractionOption{ { @@ -205,23 +216,15 @@ func TestRouter(t *testing.T) { }, })) - expects := []string{ + stack.expect(t, []string{ "root1", "root2", - "sub1.1", - "sub1.2", - "sub2.1", - "sub2.2", - } - if len(stack) != len(expects) { - t.Fatalf("expected stack to have %d elements, got %d", len(expects), len(stack)) - } - - for i := range expects { - if stack[i] != expects[i] { - t.Fatalf("expected stack[%d] to be %q, got %q", i, expects[i], stack[i]) - } - } + "test.1", + "test.2", + "test.5", + "test.sub.1", + "test.sub.2", + }) }) t.Run("deferred", func(t *testing.T) { @@ -335,6 +338,7 @@ var mockOptions = []discord.CommandInteractionOption{ }, } +// assertHandler asserts that the given options are equal to the expected options. func assertHandler(t *testing.T, opts discord.CommandInteractionOptions) CommandHandler { return CommandHandlerFunc(func(ctx context.Context, data CommandData) *api.InteractionResponseData { if len(data.Options) != len(opts) { @@ -410,3 +414,25 @@ func strInteractionResp(resp *api.InteractionResponse) string { } return fmt.Sprintf("%d:%#v", resp.Type, resp.Data) } + +type middlewareStacker []string + +func (m *middlewareStacker) pusher(s string) Middleware { + return func(next InteractionHandler) InteractionHandler { + return InteractionHandlerFunc(func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse { + *m = append(*m, s) + return next.HandleInteraction(ctx, ev) + }) + } +} + +func (m middlewareStacker) expect(t *testing.T, expects []string) { + if len(m) != len(expects) { + t.Fatalf("expected stack to have %d elements, got %d: %v", len(expects), len(m), m) + } + for i := range expects { + if m[i] != expects[i] { + t.Fatalf("expected stack[%d] to be %q, got %q", i, expects[i], m[i]) + } + } +}