diff --git a/api/send.go b/api/send.go index a7688af..ef1990b 100644 --- a/api/send.go +++ b/api/send.go @@ -1,16 +1,13 @@ package api import ( - "io" "mime/multipart" - "strconv" "github.com/pkg/errors" "github.com/diamondburned/arikawa/v2/discord" - "github.com/diamondburned/arikawa/v2/utils/httputil" - "github.com/diamondburned/arikawa/v2/utils/json" "github.com/diamondburned/arikawa/v2/utils/json/option" + "github.com/diamondburned/arikawa/v2/utils/sendpart" ) const AttachmentSpoilerPrefix = "SPOILER_" @@ -93,12 +90,6 @@ func (am AllowedMentions) Verify() error { // ExecuteWebhookData has both an empty Content and no Embed(s). var ErrEmptyMessage = errors.New("message is empty") -// SendMessageFile represents a file to be uploaded to Discord. -type SendMessageFile struct { - Name string - Reader io.Reader -} - // SendMessageData is the full structure to send a new message to Discord with. type SendMessageData struct { // Content are the message contents (up to 2000 characters). @@ -111,7 +102,7 @@ type SendMessageData struct { // Embed is embedded rich content. Embed *discord.Embed `json:"embed,omitempty"` - Files []SendMessageFile `json:"-"` + Files []sendpart.File `json:"-"` // AllowedMentions are the allowed mentions for a message. AllowedMentions *AllowedMentions `json:"allowed_mentions,omitempty"` @@ -124,8 +115,13 @@ type SendMessageData struct { Reference *discord.MessageReference `json:"message_reference,omitempty"` } -func (data *SendMessageData) WriteMultipart(body *multipart.Writer) error { - return writeMultipart(body, data, data.Files) +// NeedsMultipart returns true if the SendMessageData has files. +func (data SendMessageData) NeedsMultipart() bool { + return len(data.Files) > 0 +} + +func (data SendMessageData) WriteMultipart(body *multipart.Writer) error { + return sendpart.Write(body, data, data.Files) } // SendMessageComplex posts a message to a guild text or DM channel. If @@ -168,77 +164,5 @@ func (c *Client) SendMessageComplex( var URL = EndpointChannels + channelID.String() + "/messages" var msg *discord.Message - - if len(data.Files) == 0 { - // No files, so no need for streaming. - return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(data)) - } - - resp, err := c.MeanwhileMultipart(data.WriteMultipart, "POST", URL) - if err != nil { - return nil, err - } - - var body = resp.GetBody() - defer body.Close() - - return msg, json.DecodeStream(body, &msg) -} - -// https://discord.com/developers/docs/resources/webhook#execute-webhook-jsonform-params -type ExecuteWebhookData struct { - // Content are the message contents (up to 2000 characters). - // - // Required: one of content, file, embeds - Content string `json:"content,omitempty"` - - // Username overrides the default username of the webhook - Username string `json:"username,omitempty"` - // AvatarURL overrides the default avatar of the webhook. - AvatarURL discord.URL `json:"avatar_url,omitempty"` - - // TTS is true if this is a TTS message. - TTS bool `json:"tts,omitempty"` - // Embeds contains embedded rich content. - // - // Required: one of content, file, embeds - Embeds []discord.Embed `json:"embeds,omitempty"` - - Files []SendMessageFile `json:"-"` - - // AllowedMentions are the allowed mentions for the message. - AllowedMentions *AllowedMentions `json:"allowed_mentions,omitempty"` -} - -func (data *ExecuteWebhookData) WriteMultipart(body *multipart.Writer) error { - return writeMultipart(body, data, data.Files) -} - -func writeMultipart(body *multipart.Writer, item interface{}, files []SendMessageFile) error { - defer body.Close() - - // Encode the JSON body first - w, err := body.CreateFormField("payload_json") - if err != nil { - return errors.Wrap(err, "failed to create bodypart for JSON") - } - - if err := json.EncodeStream(w, item); err != nil { - return errors.Wrap(err, "failed to encode JSON") - } - - for i, file := range files { - num := strconv.Itoa(i) - - w, err := body.CreateFormFile("file"+num, file.Name) - if err != nil { - return errors.Wrap(err, "failed to create bodypart for "+num) - } - - if _, err := io.Copy(w, file.Reader); err != nil { - return errors.Wrap(err, "failed to write for file "+num) - } - } - - return nil + return msg, sendpart.POST(c.Client, data, &msg, URL) } diff --git a/api/send_test.go b/api/send_test.go index 516f7cc..11c863d 100644 --- a/api/send_test.go +++ b/api/send_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/diamondburned/arikawa/v2/discord" + "github.com/diamondburned/arikawa/v2/utils/sendpart" ) func TestMarshalAllowedMentions(t *testing.T) { @@ -112,7 +113,7 @@ func TestSendMessage(t *testing.T) { t.Run("files only", func(t *testing.T) { var empty = SendMessageData{ - Files: []SendMessageFile{{Name: "test.jpg"}}, + Files: []sendpart.File{{Name: "test.jpg"}}, } if err := send(empty); err != nil { diff --git a/api/webhook/webhook.go b/api/webhook/webhook.go index 4b91b18..7d46580 100644 --- a/api/webhook/webhook.go +++ b/api/webhook/webhook.go @@ -3,6 +3,7 @@ package webhook import ( + "mime/multipart" "net/url" "strconv" @@ -11,8 +12,8 @@ import ( "github.com/diamondburned/arikawa/v2/api" "github.com/diamondburned/arikawa/v2/discord" "github.com/diamondburned/arikawa/v2/utils/httputil" - "github.com/diamondburned/arikawa/v2/utils/json" "github.com/diamondburned/arikawa/v2/utils/json/option" + "github.com/diamondburned/arikawa/v2/utils/sendpart" ) // Client is the client used to interact with a webhook. @@ -61,21 +62,59 @@ func (c *Client) Delete() error { return c.FastRequest("DELETE", api.EndpointWebhooks+c.ID.String()+"/"+c.Token) } +// https://discord.com/developers/docs/resources/webhook#execute-webhook-jsonform-params +type ExecuteWebhookData struct { + // Content are the message contents (up to 2000 characters). + // + // Required: one of content, file, embeds + Content string `json:"content,omitempty"` + + // Username overrides the default username of the webhook + Username string `json:"username,omitempty"` + // AvatarURL overrides the default avatar of the webhook. + AvatarURL discord.URL `json:"avatar_url,omitempty"` + + // TTS is true if this is a TTS message. + TTS bool `json:"tts,omitempty"` + // Embeds contains embedded rich content. + // + // Required: one of content, file, embeds + Embeds []discord.Embed `json:"embeds,omitempty"` + + // Files represents a list of files to upload. This will not be JSON-encoded + // and will only be available through WriteMultipart. + Files []sendpart.File `json:"-"` + + // AllowedMentions are the allowed mentions for the message. + AllowedMentions *api.AllowedMentions `json:"allowed_mentions,omitempty"` +} + +// NeedsMultipart returns true if the ExecuteWebhookData has files. +func (data ExecuteWebhookData) NeedsMultipart() bool { + return len(data.Files) > 0 +} + +// WriteMultipart writes the webhook data into the given multipart body. It does +// not close body. +func (data ExecuteWebhookData) WriteMultipart(body *multipart.Writer) error { + return sendpart.Write(body, data, data.Files) +} + // Execute sends a message to the webhook, but doesn't wait for the message to // get created. This is generally faster, but only applicable if no further // interaction is required. -func (c *Client) Execute(data api.ExecuteWebhookData) (err error) { +func (c *Client) Execute(data ExecuteWebhookData) (err error) { _, err = c.execute(data, false) return } // ExecuteAndWait executes the webhook, and waits for the generated // discord.Message to be returned. -func (c *Client) ExecuteAndWait(data api.ExecuteWebhookData) (*discord.Message, error) { +func (c *Client) ExecuteAndWait(data ExecuteWebhookData) (*discord.Message, error) { return c.execute(data, true) } -func (c *Client) execute(data api.ExecuteWebhookData, wait bool) (*discord.Message, error) { +func (c *Client) execute(data ExecuteWebhookData, wait bool) (*discord.Message, error) { if data.Content == "" && len(data.Embeds) == 0 && len(data.Files) == 0 { return nil, api.ErrEmptyMessage } @@ -92,36 +131,20 @@ func (c *Client) execute(data api.ExecuteWebhookData, wait bool) (*discord.Messa } } - var param = url.Values{} + var param url.Values if wait { - param.Set("wait", "true") + param = url.Values{"wait": {"true"}} } var URL = api.EndpointWebhooks + c.ID.String() + "/" + c.Token + "?" + param.Encode() + var msg *discord.Message - - if len(data.Files) == 0 { - // No files, so no need for streaming. - return msg, c.RequestJSON(&msg, "POST", URL, - httputil.WithJSONBody(data)) + var ptr interface{} + if wait { + ptr = &msg } - writer := data.WriteMultipart - - resp, err := c.MeanwhileMultipart(writer, "POST", URL) - if err != nil { - return nil, err - } - - var body = resp.GetBody() - defer body.Close() - - if !wait { - // Since we didn't tell Discord to wait, we have nothing to parse. - return nil, nil - } - - return msg, json.DecodeStream(body, &msg) + return msg, sendpart.POST(c.Client, data, ptr, URL) } // https://discord.com/developers/docs/resources/webhook#edit-webhook-message-jsonform-params diff --git a/utils/httputil/client.go b/utils/httputil/client.go index b070490..1c32ca7 100644 --- a/utils/httputil/client.go +++ b/utils/httputil/client.go @@ -91,14 +91,34 @@ func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) (e er return } +// MultipartWriter is the interface for a data structure that can write into a +// multipart writer. +type MultipartWriter interface { + WriteMultipart(body *multipart.Writer) error +} + +// MeanwhileMultipart concurrently encodes and writes the given multipart writer +// at the same time. The writer will be called in another goroutine, but the +// writer will be closed when MeanwhileMultipart returns. func (c *Client) MeanwhileMultipart( - writer func(*multipart.Writer) error, + writer MultipartWriter, method, url string, opts ...RequestOption) (httpdriver.Response, error) { r, w := io.Pipe() body := multipart.NewWriter(w) - go func() { w.CloseWithError(writer(body)) }() + // Ensure the writer is closed by the time this function exits, so + // WriteMultipart will exit. + defer w.Close() + + go func() { + err := writer.WriteMultipart(body) + if err != nil { + err = body.Close() + } + + w.CloseWithError(err) + }() // Prepend the multipart writer and the correct Content-Type header options. opts = PrependOptions( @@ -135,6 +155,10 @@ func (c *Client) RequestJSON(to interface{}, method, url string, opts ...Request if status == httpdriver.NoContent { return nil } + // to is nil for some reason. Ignore. + if to == nil { + return nil + } if err := json.DecodeStream(body, to); err != nil { return JSONError{err} diff --git a/utils/sendpart/sendpart.go b/utils/sendpart/sendpart.go new file mode 100644 index 0000000..47936ce --- /dev/null +++ b/utils/sendpart/sendpart.go @@ -0,0 +1,80 @@ +package sendpart + +import ( + "io" + "mime/multipart" + "strconv" + + "github.com/diamondburned/arikawa/v2/utils/httputil" + "github.com/diamondburned/arikawa/v2/utils/json" + "github.com/pkg/errors" +) + +// File represents a file to be uploaded to Discord. +type File struct { + Name string + Reader io.Reader +} + +// DataMultipartWriter is a MultipartWriter that also contains data that's +// JSON-marshalable. +type DataMultipartWriter interface { + // NeedsMultipart returns true if the data interface must be sent using + // multipart form. + NeedsMultipart() bool + + httputil.MultipartWriter +} + +// POST sends a POST request using client to the given URL and unmarshal the +// body into v if it's not nil. It will only send using multipart if files is +// true. +func POST(c *httputil.Client, data DataMultipartWriter, v interface{}, url string) error { + if !data.NeedsMultipart() { + // No files, so no need for streaming. + return c.RequestJSON(v, "POST", url, httputil.WithJSONBody(data)) + } + + resp, err := c.MeanwhileMultipart(data, "POST", url) + if err != nil { + return err + } + + var body = resp.GetBody() + defer body.Close() + + if v == nil { + return nil + } + + return json.DecodeStream(body, v) +} + +// Write writes the item into payload_json and the list of files into the +// multipart writer. Write does not close the body. +func Write(body *multipart.Writer, item interface{}, files []File) error { + // Encode the JSON body first + w, err := body.CreateFormField("payload_json") + if err != nil { + return errors.Wrap(err, "failed to create bodypart for JSON") + } + + if err := json.EncodeStream(w, item); err != nil { + return errors.Wrap(err, "failed to encode JSON") + } + + for i, file := range files { + num := strconv.Itoa(i) + + w, err := body.CreateFormFile("file"+num, file.Name) + if err != nil { + return errors.Wrap(err, "failed to create bodypart for "+num) + } + + if _, err := io.Copy(w, file.Reader); err != nil { + return errors.Wrap(err, "failed to write for file "+num) + } + } + + return nil +}