diff --git a/api/message.go b/api/message.go index b9e07d1..ecaea9c 100644 --- a/api/message.go +++ b/api/message.go @@ -1,7 +1,7 @@ package api import ( - "io" + "mime/multipart" "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/internal/httputil" @@ -132,12 +132,11 @@ func (c *Client) SendMessageComplex( httputil.WithJSONBody(c, data)) } - writer := func(w io.Writer) error { - return data.WriteMultipart(c, w) + writer := func(mw *multipart.Writer) error { + return data.WriteMultipart(c, mw) } - resp, err := c.MeanwhileBody(writer, "POST", URL, - httputil.MultipartRequest) + resp, err := c.MeanwhileMultipart(writer, "POST", URL) if err != nil { return nil, err } diff --git a/api/message_send.go b/api/message_send.go index 87e2fe7..5afc8b3 100644 --- a/api/message_send.go +++ b/api/message_send.go @@ -1,12 +1,8 @@ package api import ( - "fmt" "io" - "log" "mime/multipart" - "net/http" - "net/textproto" "strconv" "strings" @@ -15,12 +11,13 @@ import ( "github.com/pkg/errors" ) +const AttachmentSpoilerPrefix = "SPOILER_" + var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`) type SendMessageFile struct { - Name string - ContentType string // auto-detect if empty - Reader io.Reader + Name string + Reader io.Reader } type SendMessageData struct { @@ -34,9 +31,9 @@ type SendMessageData struct { } func (data *SendMessageData) WriteMultipart( - c json.Driver, w io.Writer) error { + c json.Driver, body *multipart.Writer) error { - return writeMultipart(c, w, data, data.Files) + return writeMultipart(c, body, data, data.Files) } type ExecuteWebhookData struct { @@ -47,87 +44,39 @@ type ExecuteWebhookData struct { } func (data *ExecuteWebhookData) WriteMultipart( - c json.Driver, w io.Writer) error { + c json.Driver, body *multipart.Writer) error { - return writeMultipart(c, w, data, data.Files) + return writeMultipart(c, body, data, data.Files) } func writeMultipart( - c json.Driver, w io.Writer, + c json.Driver, body *multipart.Writer, item interface{}, files []SendMessageFile) error { - body := multipart.NewWriter(w) + defer body.Close() // Encode the JSON body first - h := textproto.MIMEHeader{} - h.Set("Content-Disposition", `form-data; name="payload_json"`) - h.Set("Content-Type", "application/json") - - w, err := body.CreatePart(h) + w, err := body.CreateFormField("payload_json") if err != nil { return errors.Wrap(err, "Failed to create bodypart for JSON") } - j, err := c.Marshal(item) - log.Println(string(j), err) - if err := c.EncodeStream(w, item); err != nil { return errors.Wrap(err, "Failed to encode JSON") } - // Content-Type buffer - var buf []byte - for i, file := range files { - h := textproto.MIMEHeader{} - h.Set("Content-Disposition", fmt.Sprintf( - `form-data; name="file%d"; filename="%s"`, - i, quoteEscaper.Replace(file.Name), - )) + num := strconv.Itoa(i) - var bufUsed int - - if file.ContentType == "" { - if buf == nil { - buf = make([]byte, 512) - } - - n, err := file.Reader.Read(buf) - if err != nil { - return errors.Wrap(err, "Failed to read first 512 bytes for "+ - strconv.Itoa(i)) - } - - file.ContentType = http.DetectContentType(buf[:n]) - files[i] = file - bufUsed = n - } - - h.Set("Content-Type", file.ContentType) - - w, err := body.CreatePart(h) + w, err := body.CreateFormFile("file"+num, file.Name) if err != nil { - return errors.Wrap(err, "Failed to create bodypart for "+ - strconv.Itoa(i)) - } - - if bufUsed > 0 { - // Prematurely write - if _, err := w.Write(buf[:bufUsed]); err != nil { - return errors.Wrap(err, "Failed to write buffer for "+ - strconv.Itoa(i)) - } + return errors.Wrap(err, "Failed to create bodypart for "+num) } if _, err := io.Copy(w, file.Reader); err != nil { - return errors.Wrap(err, "Failed to write file for "+ - strconv.Itoa(i)) + return errors.Wrap(err, "Failed to write for file "+num) } } - if err := body.Close(); err != nil { - return errors.Wrap(err, "Failed to close body writer") - } - return nil } diff --git a/api/webhook.go b/api/webhook.go index e645113..6f3c9c5 100644 --- a/api/webhook.go +++ b/api/webhook.go @@ -1,7 +1,7 @@ package api import ( - "io" + "mime/multipart" "net/url" "github.com/diamondburned/arikawa/discord" @@ -120,11 +120,11 @@ func (c *Client) ExecuteWebhook( httputil.WithJSONBody(c, data)) } - writer := func(w io.Writer) error { - return data.WriteMultipart(c, w) + writer := func(mw *multipart.Writer) error { + return data.WriteMultipart(c, mw) } - resp, err := c.MeanwhileBody(writer, "POST", URL) + resp, err := c.MeanwhileMultipart(writer, "POST", URL) if err != nil { return nil, err } diff --git a/internal/httputil/client.go b/internal/httputil/client.go index ac0e9ab..3c636e5 100644 --- a/internal/httputil/client.go +++ b/internal/httputil/client.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "log" + "mime/multipart" "net/http" "time" @@ -31,7 +32,8 @@ func NewClient() Client { } } -func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error, +func (c *Client) MeanwhileMultipart( + multipartWriter func(*multipart.Writer) error, method, url string, opts ...RequestOption) (*http.Response, error) { // We want to cancel the request if our bodyWriter fails @@ -39,11 +41,12 @@ func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error, defer cancel() r, w := io.Pipe() + body := multipart.NewWriter(w) var bgErr error go func() { - if err := bodyWriter(w); err != nil { + if err := multipartWriter(body); err != nil { bgErr = err cancel() } @@ -53,7 +56,10 @@ func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error, }() resp, err := c.RequestCtx(ctx, method, url, - append([]RequestOption{WithBody(r)}, opts...)...) + append([]RequestOption{ + WithBody(r), + WithContentType(body.FormDataContentType()), + }, opts...)...) if err != nil && bgErr != nil { if resp.Body != nil { diff --git a/internal/httputil/options.go b/internal/httputil/options.go index bc75f81..e612fe9 100644 --- a/internal/httputil/options.go +++ b/internal/httputil/options.go @@ -21,6 +21,13 @@ func MultipartRequest(r *http.Request) error { return nil } +func WithContentType(ctype string) RequestOption { + return func(r *http.Request) error { + r.Header.Set("Content-Type", ctype) + return nil + } +} + func WithSchema(schema SchemaEncoder, v interface{}) RequestOption { return func(r *http.Request) error { params, err := schema.Encode(v)