diff --git a/api/webhook/interactionserver.go b/api/webhook/interactionserver.go new file mode 100644 index 0000000..02741b8 --- /dev/null +++ b/api/webhook/interactionserver.go @@ -0,0 +1,189 @@ +package webhook + +import ( + "bytes" + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "io" + "log" + "mime/multipart" + "net/http" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/discord" + "github.com/pkg/errors" +) + +func writeError(w http.ResponseWriter, code int, err error) { + var resp struct { + Error string `json:"error"` + } + + if err != nil { + resp.Error = err.Error() + } else { + resp.Error = http.StatusText(code) + } + + b, err := json.Marshal(resp) + if err != nil { + log.Panicln("cannot marshal error response:", err) + } + + w.Write(b) +} + +// InteractionHandler is a type whose method is called on every incoming event. +type InteractionHandler interface { + // HandleInteraction is expected to return a response synchronously, either + // to be followed-up later by deferring the response or to be responded + // immediately. + HandleInteraction(*discord.InteractionEvent) *api.InteractionResponse +} + +type alwaysDeferInteraction struct { + f func(*discord.InteractionEvent) + flags discord.MessageFlags +} + +// AlwaysDeferInteraction always returns a DeferredMessageInteractionWithSource +// then invokes f in the background. This allows f to always use the follow-up +// functions. +func AlwaysDeferInteraction(flags discord.MessageFlags, f func(*discord.InteractionEvent)) InteractionHandler { + return alwaysDeferInteraction{f, flags} +} + +func (f alwaysDeferInteraction) HandleInteraction(ev *discord.InteractionEvent) *api.InteractionResponse { + go f.f(ev) + return &api.InteractionResponse{ + Type: api.DeferredMessageInteractionWithSource, + Data: &api.InteractionResponseData{ + Flags: f.flags, + }, + } +} + +// InteractionErrorFunc is called to write an error. err may be nil with a +// non-2xx code. +type InteractionErrorFunc func(w http.ResponseWriter, r *http.Request, code int, err error) + +// InteractionServer provides a HTTP handler to verify and handle Interaction +// Create events sent by Discord into a HTTP endpoint.. +type InteractionServer struct { + ErrorFunc InteractionErrorFunc + + interactionHandler InteractionHandler + httpHandler http.Handler + pubkey ed25519.PublicKey +} + +// NewInteractionServer creates a new InteractionServer instance. pubkey should +// be hex-encoded. +func NewInteractionServer(pubkey string, handler InteractionHandler) (*InteractionServer, error) { + pubkeyB, err := hex.DecodeString(pubkey) + if err != nil { + return nil, errors.Wrap(err, "cannot decode hex pubkey") + } + + s := InteractionServer{ + ErrorFunc: func(w http.ResponseWriter, r *http.Request, code int, err error) { + writeError(w, code, err) + }, + interactionHandler: handler, + httpHandler: nil, + pubkey: pubkeyB, + } + + s.httpHandler = http.HandlerFunc(s.handle) + s.httpHandler = s.withVerification(s.httpHandler) + + return &s, nil +} + +// ServeHTTP implements http.Handler. +func (s *InteractionServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.withVerification(http.HandlerFunc(s.handle)).ServeHTTP(w, r) +} + +func (s *InteractionServer) handle(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "POST": + var ev discord.InteractionEvent + + if err := json.NewDecoder(r.Body).Decode(&ev); err != nil { + s.ErrorFunc(w, r, 400, errors.Wrap(err, "cannot decode interaction body")) + return + } + + switch ev.Data.(type) { + case *discord.PingInteraction: + json.NewEncoder(w).Encode(api.InteractionResponse{ + Type: api.PongInteraction, + }) + } + + resp := s.interactionHandler.HandleInteraction(&ev) + if resp != nil && resp.Type != api.PongInteraction { + if resp.NeedsMultipart() { + body := multipart.NewWriter(w) + w.Header().Set("Content-Type", body.FormDataContentType()) + resp.WriteMultipart(body) + } else { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } + } + default: + s.ErrorFunc(w, r, http.StatusMethodNotAllowed, errors.New("method not allowed")) + } +} + +// withVerification was written thanks to @bsdlp and their code +// https://github.com/bsdlp/discord-interactions-go/blob/a2ba844/interactions/verify_example_test.go#L63. +func (s *InteractionServer) withVerification(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + signature := r.Header.Get("X-Signature-Ed25519") + if signature == "" { + s.ErrorFunc(w, r, 401, errors.New("missing header X-Signature-Ed25519")) + return + } + + sig, err := hex.DecodeString(signature) + if err != nil { + s.ErrorFunc(w, r, 400, errors.Wrap(err, "X-Signature-Ed25519 is not valid hex-encoded")) + return + } + + if len(sig) != ed25519.SignatureSize || sig[63]&224 != 0 { + s.ErrorFunc(w, r, 400, errors.New("invalid X-Signature-Ed25519 data")) + return + } + + timestamp := r.Header.Get("X-Signature-Timestamp") + if timestamp == "" { + s.ErrorFunc(w, r, 401, errors.New("missing header X-Signature-Timestamp")) + return + } + + var msg bytes.Buffer + msg.Grow(int(r.ContentLength+1) + len(timestamp)) + msg.WriteString(timestamp) + + if _, err := io.Copy(&msg, r.Body); err != nil { + s.ErrorFunc(w, r, 500, errors.Wrap(err, "cannot read body")) + return + } + + if !ed25519.Verify(s.pubkey, msg.Bytes(), sig) { + s.ErrorFunc(w, r, 401, errors.New("signature mismatch")) + return + } + + // Return the request body for use. + body := msg.Bytes()[len(timestamp):] + r.Body = io.NopCloser(bytes.NewReader(body)) + + next.ServeHTTP(w, r) + }) +}