mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-10 00:45:48 +00:00
e058e7c7b7
* feat: add support for disabling signature verification * feat: refactor solution to not break api * feat: make InteractionServer.ServeHTTP use InteractionServer.httpHandler
204 lines
5.7 KiB
Go
204 lines
5.7 KiB
Go
package webhook
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ed25519"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"io"
|
|
"log"
|
|
"mime/multipart"
|
|
"net/http"
|
|
|
|
"github.com/diamondburned/arikawa/v3/api"
|
|
"github.com/diamondburned/arikawa/v3/discord"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func writeError(w http.ResponseWriter, code int, err error) {
|
|
var resp struct {
|
|
Error string `json:"error"`
|
|
}
|
|
|
|
if err != nil {
|
|
resp.Error = err.Error()
|
|
} else {
|
|
resp.Error = http.StatusText(code)
|
|
}
|
|
|
|
b, err := json.Marshal(resp)
|
|
if err != nil {
|
|
log.Panicln("cannot marshal error response:", err)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
w.Write(b)
|
|
}
|
|
|
|
// InteractionHandler is a type whose method is called on every incoming event.
|
|
type InteractionHandler interface {
|
|
// HandleInteraction is expected to return a response synchronously, either
|
|
// to be followed-up later by deferring the response or to be responded
|
|
// immediately.
|
|
HandleInteraction(*discord.InteractionEvent) *api.InteractionResponse
|
|
}
|
|
|
|
// InteractionHandlerFunc is a function type that implements the interface.
|
|
type InteractionHandlerFunc func(*discord.InteractionEvent) *api.InteractionResponse
|
|
|
|
var _ InteractionHandler = InteractionHandlerFunc(nil)
|
|
|
|
func (f InteractionHandlerFunc) HandleInteraction(ev *discord.InteractionEvent) *api.InteractionResponse {
|
|
return f(ev)
|
|
}
|
|
|
|
type alwaysDeferInteraction struct {
|
|
f func(*discord.InteractionEvent)
|
|
flags discord.MessageFlags
|
|
}
|
|
|
|
// AlwaysDeferInteraction always returns a DeferredMessageInteractionWithSource
|
|
// then invokes f in the background. This allows f to always use the follow-up
|
|
// functions.
|
|
func AlwaysDeferInteraction(flags discord.MessageFlags, f func(*discord.InteractionEvent)) InteractionHandler {
|
|
return alwaysDeferInteraction{f, flags}
|
|
}
|
|
|
|
func (f alwaysDeferInteraction) HandleInteraction(ev *discord.InteractionEvent) *api.InteractionResponse {
|
|
go f.f(ev)
|
|
return &api.InteractionResponse{
|
|
Type: api.DeferredMessageInteractionWithSource,
|
|
Data: &api.InteractionResponseData{
|
|
Flags: f.flags,
|
|
},
|
|
}
|
|
}
|
|
|
|
// InteractionErrorFunc is called to write an error. err may be nil with a
|
|
// non-2xx code.
|
|
type InteractionErrorFunc func(w http.ResponseWriter, r *http.Request, code int, err error)
|
|
|
|
// InteractionServer provides a HTTP handler to verify and handle Interaction
|
|
// Create events sent by Discord into a HTTP endpoint..
|
|
type InteractionServer struct {
|
|
ErrorFunc InteractionErrorFunc
|
|
|
|
interactionHandler InteractionHandler
|
|
httpHandler http.Handler
|
|
pubkey ed25519.PublicKey
|
|
}
|
|
|
|
// NewInteractionServer creates a new InteractionServer instance. pubkey should
|
|
// be hex-encoded.
|
|
func NewInteractionServer(pubkey string, handler InteractionHandler) (*InteractionServer, error) {
|
|
pubkeyB, err := hex.DecodeString(pubkey)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "cannot decode hex pubkey")
|
|
}
|
|
|
|
s := InteractionServer{
|
|
ErrorFunc: func(w http.ResponseWriter, r *http.Request, code int, err error) {
|
|
writeError(w, code, err)
|
|
},
|
|
interactionHandler: handler,
|
|
httpHandler: nil,
|
|
pubkey: pubkeyB,
|
|
}
|
|
|
|
s.httpHandler = http.HandlerFunc(s.handle)
|
|
if len(s.pubkey) != 0 {
|
|
s.httpHandler = s.withVerification(s.httpHandler)
|
|
}
|
|
|
|
return &s, nil
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler.
|
|
func (s *InteractionServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.httpHandler.ServeHTTP(w, r)
|
|
}
|
|
|
|
func (s *InteractionServer) handle(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case "POST":
|
|
var ev discord.InteractionEvent
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&ev); err != nil {
|
|
s.ErrorFunc(w, r, 400, errors.Wrap(err, "cannot decode interaction body"))
|
|
return
|
|
}
|
|
|
|
switch ev.Data.(type) {
|
|
case *discord.PingInteraction:
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(api.InteractionResponse{
|
|
Type: api.PongInteraction,
|
|
})
|
|
}
|
|
|
|
resp := s.interactionHandler.HandleInteraction(&ev)
|
|
if resp != nil && resp.Type != api.PongInteraction {
|
|
if resp.NeedsMultipart() {
|
|
body := multipart.NewWriter(w)
|
|
w.Header().Set("Content-Type", body.FormDataContentType())
|
|
resp.WriteMultipart(body)
|
|
} else {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
}
|
|
default:
|
|
s.ErrorFunc(w, r, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
|
}
|
|
}
|
|
|
|
// withVerification was written thanks to @bsdlp and their code
|
|
// https://github.com/bsdlp/discord-interactions-go/blob/a2ba844/interactions/verify_example_test.go#L63.
|
|
func (s *InteractionServer) withVerification(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
signature := r.Header.Get("X-Signature-Ed25519")
|
|
if signature == "" {
|
|
s.ErrorFunc(w, r, 401, errors.New("missing header X-Signature-Ed25519"))
|
|
return
|
|
}
|
|
|
|
sig, err := hex.DecodeString(signature)
|
|
if err != nil {
|
|
s.ErrorFunc(w, r, 400, errors.Wrap(err, "X-Signature-Ed25519 is not valid hex-encoded"))
|
|
return
|
|
}
|
|
|
|
if len(sig) != ed25519.SignatureSize || sig[63]&224 != 0 {
|
|
s.ErrorFunc(w, r, 400, errors.New("invalid X-Signature-Ed25519 data"))
|
|
return
|
|
}
|
|
|
|
timestamp := r.Header.Get("X-Signature-Timestamp")
|
|
if timestamp == "" {
|
|
s.ErrorFunc(w, r, 401, errors.New("missing header X-Signature-Timestamp"))
|
|
return
|
|
}
|
|
|
|
var msg bytes.Buffer
|
|
msg.Grow(int(r.ContentLength+1) + len(timestamp))
|
|
msg.WriteString(timestamp)
|
|
|
|
if _, err := io.Copy(&msg, r.Body); err != nil {
|
|
s.ErrorFunc(w, r, 500, errors.Wrap(err, "cannot read body"))
|
|
return
|
|
}
|
|
|
|
if !ed25519.Verify(s.pubkey, msg.Bytes(), sig) {
|
|
s.ErrorFunc(w, r, 401, errors.New("signature mismatch"))
|
|
return
|
|
}
|
|
|
|
// Return the request body for use.
|
|
body := msg.Bytes()[len(timestamp):]
|
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|