1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-16 11:54:29 +00:00
arikawa/api/webhook/interactionserver.go
diamondburned 17c26bf488
webhook: Add InteractionServer
This commit adds Interaction webhook server support directly into the
library.

Bots can now support both receiving events through the Discord gateway
and the Interaction webhook handler within the same library.
2022-08-22 02:18:00 -07:00

190 lines
5.3 KiB
Go

package webhook
import (
"bytes"
"crypto/ed25519"
"encoding/hex"
"encoding/json"
"io"
"log"
"mime/multipart"
"net/http"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/pkg/errors"
)
func writeError(w http.ResponseWriter, code int, err error) {
var resp struct {
Error string `json:"error"`
}
if err != nil {
resp.Error = err.Error()
} else {
resp.Error = http.StatusText(code)
}
b, err := json.Marshal(resp)
if err != nil {
log.Panicln("cannot marshal error response:", err)
}
w.Write(b)
}
// InteractionHandler is a type whose method is called on every incoming event.
type InteractionHandler interface {
// HandleInteraction is expected to return a response synchronously, either
// to be followed-up later by deferring the response or to be responded
// immediately.
HandleInteraction(*discord.InteractionEvent) *api.InteractionResponse
}
type alwaysDeferInteraction struct {
f func(*discord.InteractionEvent)
flags discord.MessageFlags
}
// AlwaysDeferInteraction always returns a DeferredMessageInteractionWithSource
// then invokes f in the background. This allows f to always use the follow-up
// functions.
func AlwaysDeferInteraction(flags discord.MessageFlags, f func(*discord.InteractionEvent)) InteractionHandler {
return alwaysDeferInteraction{f, flags}
}
func (f alwaysDeferInteraction) HandleInteraction(ev *discord.InteractionEvent) *api.InteractionResponse {
go f.f(ev)
return &api.InteractionResponse{
Type: api.DeferredMessageInteractionWithSource,
Data: &api.InteractionResponseData{
Flags: f.flags,
},
}
}
// InteractionErrorFunc is called to write an error. err may be nil with a
// non-2xx code.
type InteractionErrorFunc func(w http.ResponseWriter, r *http.Request, code int, err error)
// InteractionServer provides a HTTP handler to verify and handle Interaction
// Create events sent by Discord into a HTTP endpoint..
type InteractionServer struct {
ErrorFunc InteractionErrorFunc
interactionHandler InteractionHandler
httpHandler http.Handler
pubkey ed25519.PublicKey
}
// NewInteractionServer creates a new InteractionServer instance. pubkey should
// be hex-encoded.
func NewInteractionServer(pubkey string, handler InteractionHandler) (*InteractionServer, error) {
pubkeyB, err := hex.DecodeString(pubkey)
if err != nil {
return nil, errors.Wrap(err, "cannot decode hex pubkey")
}
s := InteractionServer{
ErrorFunc: func(w http.ResponseWriter, r *http.Request, code int, err error) {
writeError(w, code, err)
},
interactionHandler: handler,
httpHandler: nil,
pubkey: pubkeyB,
}
s.httpHandler = http.HandlerFunc(s.handle)
s.httpHandler = s.withVerification(s.httpHandler)
return &s, nil
}
// ServeHTTP implements http.Handler.
func (s *InteractionServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.withVerification(http.HandlerFunc(s.handle)).ServeHTTP(w, r)
}
func (s *InteractionServer) handle(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
var ev discord.InteractionEvent
if err := json.NewDecoder(r.Body).Decode(&ev); err != nil {
s.ErrorFunc(w, r, 400, errors.Wrap(err, "cannot decode interaction body"))
return
}
switch ev.Data.(type) {
case *discord.PingInteraction:
json.NewEncoder(w).Encode(api.InteractionResponse{
Type: api.PongInteraction,
})
}
resp := s.interactionHandler.HandleInteraction(&ev)
if resp != nil && resp.Type != api.PongInteraction {
if resp.NeedsMultipart() {
body := multipart.NewWriter(w)
w.Header().Set("Content-Type", body.FormDataContentType())
resp.WriteMultipart(body)
} else {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}
default:
s.ErrorFunc(w, r, http.StatusMethodNotAllowed, errors.New("method not allowed"))
}
}
// withVerification was written thanks to @bsdlp and their code
// https://github.com/bsdlp/discord-interactions-go/blob/a2ba844/interactions/verify_example_test.go#L63.
func (s *InteractionServer) withVerification(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
signature := r.Header.Get("X-Signature-Ed25519")
if signature == "" {
s.ErrorFunc(w, r, 401, errors.New("missing header X-Signature-Ed25519"))
return
}
sig, err := hex.DecodeString(signature)
if err != nil {
s.ErrorFunc(w, r, 400, errors.Wrap(err, "X-Signature-Ed25519 is not valid hex-encoded"))
return
}
if len(sig) != ed25519.SignatureSize || sig[63]&224 != 0 {
s.ErrorFunc(w, r, 400, errors.New("invalid X-Signature-Ed25519 data"))
return
}
timestamp := r.Header.Get("X-Signature-Timestamp")
if timestamp == "" {
s.ErrorFunc(w, r, 401, errors.New("missing header X-Signature-Timestamp"))
return
}
var msg bytes.Buffer
msg.Grow(int(r.ContentLength+1) + len(timestamp))
msg.WriteString(timestamp)
if _, err := io.Copy(&msg, r.Body); err != nil {
s.ErrorFunc(w, r, 500, errors.Wrap(err, "cannot read body"))
return
}
if !ed25519.Verify(s.pubkey, msg.Bytes(), sig) {
s.ErrorFunc(w, r, 401, errors.New("signature mismatch"))
return
}
// Return the request body for use.
body := msg.Bytes()[len(timestamp):]
r.Body = io.NopCloser(bytes.NewReader(body))
next.ServeHTTP(w, r)
})
}