From 3cb993aff9de3e6ebd89fec8b86691d7784121dc Mon Sep 17 00:00:00 2001 From: diamondburned Date: Fri, 4 Aug 2023 14:14:00 -0700 Subject: [PATCH] cmdroute: Allow routing ComponentInteraction --- api/cmdroute/fntypes.go | 29 +++++ api/cmdroute/router.go | 210 ++++++++++++++++++++++-------------- api/cmdroute/router_test.go | 42 ++++++-- 3 files changed, 192 insertions(+), 89 deletions(-) diff --git a/api/cmdroute/fntypes.go b/api/cmdroute/fntypes.go index e48cd43..813e266 100644 --- a/api/cmdroute/fntypes.go +++ b/api/cmdroute/fntypes.go @@ -90,3 +90,32 @@ var _ Autocompleter = (AutocompleterFunc)(nil) func (f AutocompleterFunc) Autocomplete(ctx context.Context, data AutocompleteData) api.AutocompleteChoices { return f(ctx, data) } + +/* + * Component + */ + +// ComponentData is passed to a ComponentHandler's HandleComponent method. +type ComponentData struct { + discord.ComponentInteraction + Event *discord.InteractionEvent +} + +// ComponentHandler is a type for a component handler. +type ComponentHandler interface { + // HandleComponent is expected to return a response synchronously, either + // to be followed-up later by deferring the response or to be responded + // immediately. + HandleComponent(ctx context.Context, data ComponentData) *api.InteractionResponse +} + +// ComponentHandlerFunc is a function that implements the ComponentHandler +// interface. +type ComponentHandlerFunc func(ctx context.Context, data ComponentData) *api.InteractionResponse + +var _ ComponentHandler = (ComponentHandlerFunc)(nil) + +// HandleComponent implements ComponentHandler. +func (f ComponentHandlerFunc) HandleComponent(ctx context.Context, data ComponentData) *api.InteractionResponse { + return f(ctx, data) +} diff --git a/api/cmdroute/router.go b/api/cmdroute/router.go index 6e56971..70239fb 100644 --- a/api/cmdroute/router.go +++ b/api/cmdroute/router.go @@ -15,12 +15,25 @@ type Router struct { stack []*Router } -type routeNode struct { - sub *Router - cmd CommandHandler - com Autocompleter +type routeNode interface { + isRouteNode() } +type routeNodeSub struct{ *Router } + +type routeNodeCommand struct { + command CommandHandler + autocomplete Autocompleter +} + +type routeNodeComponent struct { + component ComponentHandler +} + +func (routeNodeSub) isRouteNode() {} +func (routeNodeCommand) isRouteNode() {} +func (routeNodeComponent) isRouteNode() {} + var _ webhook.InteractionHandler = (*Router)(nil) // NewRouter creates a new Router. @@ -39,6 +52,17 @@ func (r *Router) init() { } } +func (r *Router) add(name string, node routeNode) { + r.init() + + _, ok := r.nodes[name] + if ok { + panic("cmdroute: node " + name + " already exists") + } + + r.nodes[name] = node +} + // Use adds a middleware to the router. The middleware is applied to all // subcommands and subrouters. Middlewares are applied in the order they are // added, with the middlewares in the parent router being applied first. @@ -50,31 +74,16 @@ func (r *Router) Use(mws ...Middleware) { // Sub creates a subrouter that handles all subcommands that are under the // parent command of the given name. func (r *Router) Sub(name string, f func(r *Router)) { - r.init() - - node, ok := r.nodes[name] - if ok && node.sub == nil { - panic("cmdroute: command " + name + " already exists") - } - sub := NewRouter() sub.stack = append(append([]*Router(nil), r.stack...), sub) f(sub) - r.nodes[name] = routeNode{sub: sub} + r.add(name, routeNodeSub{sub}) } // Add registers a slash command handler for the given command name. func (r *Router) Add(name string, h CommandHandler) { - r.init() - - node, ok := r.nodes[name] - if ok { - panic("cmdroute: command " + name + " already exists") - } - - node.cmd = h - r.nodes[name] = node + r.add(name, routeNodeCommand{command: h}) } // AddFunc is a convenience function that calls Handle with a @@ -91,12 +100,14 @@ func (r *Router) HandleInteraction(ev *discord.InteractionEvent) *api.Interactio return r.HandleCommand(ev, data) case *discord.AutocompleteInteraction: return r.HandleAutocompletion(ev, data) + case discord.ComponentInteraction: + return r.handleComponent(ev, data) default: return nil } } -func (r *Router) handleInteraction(ev *discord.InteractionEvent, fn InteractionHandlerFunc) *api.InteractionResponse { +func (r *Router) callHandler(ev *discord.InteractionEvent, fn InteractionHandlerFunc) *api.InteractionResponse { h := InteractionHandler(fn) // Apply middlewares, parent last, first one added last. This ensures that @@ -114,13 +125,16 @@ func (r *Router) handleInteraction(ev *discord.InteractionEvent, fn InteractionH // HandleCommand implements CommandHandler. It applies middlewares onto the // handler to be executed. +// +// Deprecated: This function should not be used directly. Use HandleInteraction +// instead. func (r *Router) HandleCommand(ev *discord.InteractionEvent, data *discord.CommandInteraction) *api.InteractionResponse { cmdType := discord.SubcommandOptionType if cmdIsGroup(data) { cmdType = discord.SubcommandGroupOptionType } - found, ok := r.findHandler(ev, discord.CommandInteractionOption{ + found, ok := r.findCommandHandler(ev, discord.CommandInteractionOption{ Type: cmdType, Name: data.Name, Options: data.Options, @@ -129,26 +143,7 @@ func (r *Router) HandleCommand(ev *discord.InteractionEvent, data *discord.Comma return nil } - return found.router.handleCommand(ev, found) -} - -func (r *Router) handleCommand(ev *discord.InteractionEvent, found handlerData) *api.InteractionResponse { - return r.handleInteraction(ev, - func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse { - data := found.handler.HandleCommand(ctx, CommandData{ - CommandInteractionOption: found.data, - Event: ev, - }) - if data == nil { - return nil - } - - return &api.InteractionResponse{ - Type: api.MessageInteractionWithSource, - Data: data, - } - }, - ) + return found.router.callCommandHandler(ev, found) } func cmdIsGroup(data *discord.CommandInteraction) bool { @@ -167,25 +162,25 @@ type handlerData struct { data discord.CommandInteractionOption } -func (r *Router) findHandler(ev *discord.InteractionEvent, data discord.CommandInteractionOption) (handlerData, bool) { +func (r *Router) findCommandHandler(ev *discord.InteractionEvent, data discord.CommandInteractionOption) (handlerData, bool) { node, ok := r.nodes[data.Name] if !ok { return handlerData{}, false } - switch { - case node.sub != nil: + switch node := node.(type) { + case routeNodeSub: if len(data.Options) != 1 || data.Type != discord.SubcommandGroupOptionType { break } - return node.sub.findHandler(ev, data.Options[0]) - case node.cmd != nil: + return node.findCommandHandler(ev, data.Options[0]) + case routeNodeCommand: if data.Type != discord.SubcommandOptionType { break } return handlerData{ router: r, - handler: node.cmd, + handler: node.command, data: data, }, true } @@ -193,16 +188,35 @@ func (r *Router) findHandler(ev *discord.InteractionEvent, data discord.CommandI return handlerData{}, false } +func (r *Router) callCommandHandler(ev *discord.InteractionEvent, found handlerData) *api.InteractionResponse { + return r.callHandler(ev, + func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse { + data := found.handler.HandleCommand(ctx, CommandData{ + CommandInteractionOption: found.data, + Event: ev, + }) + if data == nil { + return nil + } + + return &api.InteractionResponse{ + Type: api.MessageInteractionWithSource, + Data: data, + } + }, + ) +} + // AddAutocompleter registers an autocompleter for the given command name. func (r *Router) AddAutocompleter(name string, ac Autocompleter) { r.init() - node, ok := r.nodes[name] - if !ok || node.cmd == nil { - panic("cmdroute: command " + name + " does not exist or is not a (sub)command") + node, ok := r.nodes[name].(routeNodeCommand) + if !ok { + panic("cmdroute: cannot add autocompleter to unknown command " + name) } - node.com = ac + node.autocomplete = ac r.nodes[name] = node } @@ -213,6 +227,9 @@ func (r *Router) AddAutocompleterFunc(name string, f AutocompleterFunc) { } // HandleAutocompletion handles an autocompletion event. +// +// Deprecated: This function should not be used directly. Use HandleInteraction +// instead. func (r *Router) HandleAutocompletion(ev *discord.InteractionEvent, data *discord.AutocompleteInteraction) *api.InteractionResponse { cmdType := discord.SubcommandOptionType if autocompIsGroup(data) { @@ -228,28 +245,7 @@ func (r *Router) HandleAutocompletion(ev *discord.InteractionEvent, data *discor return nil } - return found.router.handleAutocompletion(ev, found) -} - -func (r *Router) handleAutocompletion(ev *discord.InteractionEvent, found autocompleterData) *api.InteractionResponse { - return r.handleInteraction(ev, - func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse { - choices := found.handler.Autocomplete(ctx, AutocompleteData{ - AutocompleteOption: found.data, - Event: ev, - }) - if choices == nil { - return nil - } - - return &api.InteractionResponse{ - Type: api.AutocompleteResult, - Data: &api.InteractionResponseData{ - Choices: choices, - }, - } - }, - ) + return found.router.callAutocompletion(ev, found) } func autocompIsGroup(data *discord.AutocompleteInteraction) bool { @@ -274,22 +270,76 @@ func (r *Router) findAutocompleter(ev *discord.InteractionEvent, data discord.Au return autocompleterData{}, false } - switch { - case node.sub != nil: + switch node := node.(type) { + case routeNodeSub: if len(data.Options) != 1 || data.Type != discord.SubcommandGroupOptionType { break } - return node.sub.findAutocompleter(ev, data.Options[0]) - case node.com != nil: + return node.findAutocompleter(ev, data.Options[0]) + case routeNodeCommand: + if node.autocomplete == nil { + break + } if data.Type != discord.SubcommandOptionType { break } return autocompleterData{ router: r, - handler: node.com, + handler: node.autocomplete, data: data, }, true } return autocompleterData{}, false } + +func (r *Router) callAutocompletion(ev *discord.InteractionEvent, found autocompleterData) *api.InteractionResponse { + return r.callHandler(ev, + func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse { + choices := found.handler.Autocomplete(ctx, AutocompleteData{ + AutocompleteOption: found.data, + Event: ev, + }) + if choices == nil { + return nil + } + + return &api.InteractionResponse{ + Type: api.AutocompleteResult, + Data: &api.InteractionResponseData{ + Choices: choices, + }, + } + }, + ) +} + +// AddComponent registers a component handler for the given component ID. +func (r *Router) AddComponent(id string, f ComponentHandler) { + r.add(id, routeNodeComponent{f}) +} + +// AddComponentFunc is a convenience function that calls Handle with a +// ComponentHandlerFunc. +func (r *Router) AddComponentFunc(id string, f ComponentHandlerFunc) { + r.AddComponent(id, f) +} + +func (r *Router) handleComponent(ev *discord.InteractionEvent, component discord.ComponentInteraction) *api.InteractionResponse { + node, ok := r.nodes[string(component.ID())].(routeNodeComponent) + if ok { + return r.callComponentHandler(ev, node.component) + } + return nil +} + +func (r *Router) callComponentHandler(ev *discord.InteractionEvent, handler ComponentHandler) *api.InteractionResponse { + return r.callHandler(ev, + func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse { + return handler.HandleComponent(ctx, ComponentData{ + Event: ev, + ComponentInteraction: ev.Data.(discord.ComponentInteraction), + }) + }, + ) +} diff --git a/api/cmdroute/router_test.go b/api/cmdroute/router_test.go index f8646be..fd588d7 100644 --- a/api/cmdroute/router_test.go +++ b/api/cmdroute/router_test.go @@ -20,7 +20,7 @@ func TestRouter(t *testing.T) { t.Run("command", func(t *testing.T) { r := NewRouter() r.Add("test", assertHandler(t, mockOptions)) - r.HandleInteraction(newInteractionEvent(discord.CommandInteraction{ + r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "test", Options: mockOptions, @@ -30,7 +30,7 @@ func TestRouter(t *testing.T) { t.Run("subcommand", func(t *testing.T) { r := NewRouter() r.Sub("test", func(r *Router) { r.Add("sub", assertHandler(t, mockOptions)) }) - r.HandleInteraction(newInteractionEvent(discord.CommandInteraction{ + r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "test", Options: []discord.CommandInteractionOption{ @@ -49,7 +49,7 @@ func TestRouter(t *testing.T) { t.Fatal("unexpected call") return nil }) - r.HandleInteraction(newInteractionEvent(discord.CommandInteraction{ + r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "unknown", })) @@ -64,7 +64,7 @@ func TestRouter(t *testing.T) { r.AddFunc("ping", func(_ context.Context, _ CommandData) *api.InteractionResponseData { return data }) - resp := r.HandleInteraction(newInteractionEvent(discord.CommandInteraction{ + resp := r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "ping", Options: mockOptions, @@ -140,6 +140,30 @@ func TestRouter(t *testing.T) { ) }) + t.Run("component", func(t *testing.T) { + r := NewRouter() + r.AddComponentFunc("ping", func(ctx context.Context, data ComponentData) *api.InteractionResponse { + button := data.ComponentInteraction.(*discord.ButtonInteraction) + return &api.InteractionResponse{ + Type: api.MessageInteractionWithSource, + Data: &api.InteractionResponseData{ + Content: option.NewNullableString(string(button.CustomID)), + }, + } + }) + resp := r.HandleInteraction(newInteractionEvent(&discord.ButtonInteraction{ + CustomID: "ping", + })) + if !reflect.DeepEqual(resp, &api.InteractionResponse{ + Type: api.MessageInteractionWithSource, + Data: &api.InteractionResponseData{ + Content: option.NewNullableString("ping"), + }, + }) { + t.Fatal("unexpected response") + } + }) + t.Run("middlewares", func(t *testing.T) { var stack []string pushStack := func(s string) Middleware { @@ -163,7 +187,7 @@ func TestRouter(t *testing.T) { r.Add("sub2", assertHandler(t, mockOptions)) }) }) - r.HandleInteraction(newInteractionEvent(discord.CommandInteraction{ + r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "test", Options: []discord.CommandInteractionOption{ @@ -255,7 +279,7 @@ func TestRouter(t *testing.T) { }) assertInteractionResp(t, - r.HandleInteraction(newInteractionEvent(discord.CommandInteraction{ + r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "ping", Options: mockOptions, @@ -271,7 +295,7 @@ func TestRouter(t *testing.T) { wg.Add(1) assertInteractionResp(t, - r.HandleInteraction(newInteractionEvent(discord.CommandInteraction{ + r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{ ID: 4, Name: "ping-defer", Options: mockOptions, @@ -288,13 +312,13 @@ func TestRouter(t *testing.T) { }) } -func newInteractionEvent(data discord.CommandInteraction) *discord.InteractionEvent { +func newInteractionEvent(data discord.InteractionData) *discord.InteractionEvent { return &discord.InteractionEvent{ ID: 100, AppID: 200, ChannelID: 300, Token: "mock token", - Data: &data, + Data: data, } }