1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-12-02 01:29:47 +00:00

Fixed SendComplex

This commit is contained in:
diamondburned (Forefront) 2020-01-18 19:12:08 -08:00
parent e98c533114
commit 05c8932166
5 changed files with 39 additions and 78 deletions

View file

@ -1,7 +1,7 @@
package api package api
import ( import (
"io" "mime/multipart"
"github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/internal/httputil" "github.com/diamondburned/arikawa/internal/httputil"
@ -132,12 +132,11 @@ func (c *Client) SendMessageComplex(
httputil.WithJSONBody(c, data)) httputil.WithJSONBody(c, data))
} }
writer := func(w io.Writer) error { writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(c, w) return data.WriteMultipart(c, mw)
} }
resp, err := c.MeanwhileBody(writer, "POST", URL, resp, err := c.MeanwhileMultipart(writer, "POST", URL)
httputil.MultipartRequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,12 +1,8 @@
package api package api
import ( import (
"fmt"
"io" "io"
"log"
"mime/multipart" "mime/multipart"
"net/http"
"net/textproto"
"strconv" "strconv"
"strings" "strings"
@ -15,12 +11,13 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const AttachmentSpoilerPrefix = "SPOILER_"
var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`) var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
type SendMessageFile struct { type SendMessageFile struct {
Name string Name string
ContentType string // auto-detect if empty Reader io.Reader
Reader io.Reader
} }
type SendMessageData struct { type SendMessageData struct {
@ -34,9 +31,9 @@ type SendMessageData struct {
} }
func (data *SendMessageData) WriteMultipart( func (data *SendMessageData) WriteMultipart(
c json.Driver, w io.Writer) error { c json.Driver, body *multipart.Writer) error {
return writeMultipart(c, w, data, data.Files) return writeMultipart(c, body, data, data.Files)
} }
type ExecuteWebhookData struct { type ExecuteWebhookData struct {
@ -47,87 +44,39 @@ type ExecuteWebhookData struct {
} }
func (data *ExecuteWebhookData) WriteMultipart( func (data *ExecuteWebhookData) WriteMultipart(
c json.Driver, w io.Writer) error { c json.Driver, body *multipart.Writer) error {
return writeMultipart(c, w, data, data.Files) return writeMultipart(c, body, data, data.Files)
} }
func writeMultipart( func writeMultipart(
c json.Driver, w io.Writer, c json.Driver, body *multipart.Writer,
item interface{}, files []SendMessageFile) error { item interface{}, files []SendMessageFile) error {
body := multipart.NewWriter(w) defer body.Close()
// Encode the JSON body first // Encode the JSON body first
h := textproto.MIMEHeader{} w, err := body.CreateFormField("payload_json")
h.Set("Content-Disposition", `form-data; name="payload_json"`)
h.Set("Content-Type", "application/json")
w, err := body.CreatePart(h)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to create bodypart for JSON") return errors.Wrap(err, "Failed to create bodypart for JSON")
} }
j, err := c.Marshal(item)
log.Println(string(j), err)
if err := c.EncodeStream(w, item); err != nil { if err := c.EncodeStream(w, item); err != nil {
return errors.Wrap(err, "Failed to encode JSON") return errors.Wrap(err, "Failed to encode JSON")
} }
// Content-Type buffer
var buf []byte
for i, file := range files { for i, file := range files {
h := textproto.MIMEHeader{} num := strconv.Itoa(i)
h.Set("Content-Disposition", fmt.Sprintf(
`form-data; name="file%d"; filename="%s"`,
i, quoteEscaper.Replace(file.Name),
))
var bufUsed int w, err := body.CreateFormFile("file"+num, file.Name)
if file.ContentType == "" {
if buf == nil {
buf = make([]byte, 512)
}
n, err := file.Reader.Read(buf)
if err != nil {
return errors.Wrap(err, "Failed to read first 512 bytes for "+
strconv.Itoa(i))
}
file.ContentType = http.DetectContentType(buf[:n])
files[i] = file
bufUsed = n
}
h.Set("Content-Type", file.ContentType)
w, err := body.CreatePart(h)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to create bodypart for "+ return errors.Wrap(err, "Failed to create bodypart for "+num)
strconv.Itoa(i))
}
if bufUsed > 0 {
// Prematurely write
if _, err := w.Write(buf[:bufUsed]); err != nil {
return errors.Wrap(err, "Failed to write buffer for "+
strconv.Itoa(i))
}
} }
if _, err := io.Copy(w, file.Reader); err != nil { if _, err := io.Copy(w, file.Reader); err != nil {
return errors.Wrap(err, "Failed to write file for "+ return errors.Wrap(err, "Failed to write for file "+num)
strconv.Itoa(i))
} }
} }
if err := body.Close(); err != nil {
return errors.Wrap(err, "Failed to close body writer")
}
return nil return nil
} }

View file

@ -1,7 +1,7 @@
package api package api
import ( import (
"io" "mime/multipart"
"net/url" "net/url"
"github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/discord"
@ -120,11 +120,11 @@ func (c *Client) ExecuteWebhook(
httputil.WithJSONBody(c, data)) httputil.WithJSONBody(c, data))
} }
writer := func(w io.Writer) error { writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(c, w) return data.WriteMultipart(c, mw)
} }
resp, err := c.MeanwhileBody(writer, "POST", URL) resp, err := c.MeanwhileMultipart(writer, "POST", URL)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"mime/multipart"
"net/http" "net/http"
"time" "time"
@ -31,7 +32,8 @@ func NewClient() Client {
} }
} }
func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error, func (c *Client) MeanwhileMultipart(
multipartWriter func(*multipart.Writer) error,
method, url string, opts ...RequestOption) (*http.Response, error) { method, url string, opts ...RequestOption) (*http.Response, error) {
// We want to cancel the request if our bodyWriter fails // We want to cancel the request if our bodyWriter fails
@ -39,11 +41,12 @@ func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error,
defer cancel() defer cancel()
r, w := io.Pipe() r, w := io.Pipe()
body := multipart.NewWriter(w)
var bgErr error var bgErr error
go func() { go func() {
if err := bodyWriter(w); err != nil { if err := multipartWriter(body); err != nil {
bgErr = err bgErr = err
cancel() cancel()
} }
@ -53,7 +56,10 @@ func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error,
}() }()
resp, err := c.RequestCtx(ctx, method, url, resp, err := c.RequestCtx(ctx, method, url,
append([]RequestOption{WithBody(r)}, opts...)...) append([]RequestOption{
WithBody(r),
WithContentType(body.FormDataContentType()),
}, opts...)...)
if err != nil && bgErr != nil { if err != nil && bgErr != nil {
if resp.Body != nil { if resp.Body != nil {

View file

@ -21,6 +21,13 @@ func MultipartRequest(r *http.Request) error {
return nil return nil
} }
func WithContentType(ctype string) RequestOption {
return func(r *http.Request) error {
r.Header.Set("Content-Type", ctype)
return nil
}
}
func WithSchema(schema SchemaEncoder, v interface{}) RequestOption { func WithSchema(schema SchemaEncoder, v interface{}) RequestOption {
return func(r *http.Request) error { return func(r *http.Request) error {
params, err := schema.Encode(v) params, err := schema.Encode(v)