1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-02-08 04:28:32 +00:00

MeanwhileBody tests

This commit is contained in:
diamondburned (Forefront) 2020-01-18 18:27:30 -08:00
parent a82d71ad3c
commit e98c533114
8 changed files with 301 additions and 28 deletions

View file

@ -16,7 +16,6 @@ const (
Endpoint = BaseEndpoint + "/v" + APIVersion + "/"
EndpointGateway = Endpoint + "gateway"
EndpointGatewayBot = EndpointGateway + "/bot"
EndpointWebhooks = Endpoint + "webhooks/"
)
var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)"

View file

@ -123,11 +123,11 @@ func (c *Client) SendMessageComplex(
}
}
var URL = EndpointChannels + channelID.String()
var URL = EndpointChannels + channelID.String() + "/messages"
var msg *discord.Message
if len(data.Files) == 0 {
// No files, no need for streaming
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL,
httputil.WithJSONBody(c, data))
}
@ -136,7 +136,8 @@ func (c *Client) SendMessageComplex(
return data.WriteMultipart(c, w)
}
resp, err := c.MeanwhileBody(writer, "POST", URL)
resp, err := c.MeanwhileBody(writer, "POST", URL,
httputil.MultipartRequest)
if err != nil {
return nil, err
}

View file

@ -3,6 +3,7 @@ package api
import (
"fmt"
"io"
"log"
"mime/multipart"
"net/http"
"net/textproto"
@ -14,15 +15,7 @@ import (
"github.com/pkg/errors"
)
type SendMessageData struct {
Content string `json:"content"`
Nonce string `json:"nonce"`
TTS bool `json:"tts"`
Embed *discord.Embed `json:"embed"`
Files []SendMessageFile `json:"-"`
}
var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
type SendMessageFile struct {
Name string
@ -30,9 +23,39 @@ type SendMessageFile struct {
Reader io.Reader
}
var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
type SendMessageData struct {
Content string `json:"content,omitempty"`
Nonce string `json:"nonce,omitempty"`
TTS bool `json:"tts"`
Embed *discord.Embed `json:"embed,omitempty"`
Files []SendMessageFile `json:"-"`
}
func (data *SendMessageData) WriteMultipart(
c json.Driver, w io.Writer) error {
return writeMultipart(c, w, data, data.Files)
}
type ExecuteWebhookData struct {
SendMessageData
Username string `json:"username,omitempty"`
AvatarURL discord.URL `json:"avatar_url,omitempty"`
}
func (data *ExecuteWebhookData) WriteMultipart(
c json.Driver, w io.Writer) error {
return writeMultipart(c, w, data, data.Files)
}
func writeMultipart(
c json.Driver, w io.Writer,
item interface{}, files []SendMessageFile) error {
func (data *SendMessageData) WriteMultipart(c json.Driver, w io.Writer) error {
body := multipart.NewWriter(w)
// Encode the JSON body first
@ -45,25 +68,24 @@ func (data *SendMessageData) WriteMultipart(c json.Driver, w io.Writer) error {
return errors.Wrap(err, "Failed to create bodypart for JSON")
}
if err := c.EncodeStream(w, data); err != nil {
j, err := c.Marshal(item)
log.Println(string(j), err)
if err := c.EncodeStream(w, item); err != nil {
return errors.Wrap(err, "Failed to encode JSON")
}
// Content-Type buffer
var buf []byte
for i, file := range data.Files {
for i, file := range files {
h := textproto.MIMEHeader{}
h.Set("Content-Disposition", fmt.Sprintf(
`form-data; name="file%d"; filename="%s"`,
i, quoteEscaper.Replace(file.Name),
))
w, err := body.CreatePart(h)
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for "+
strconv.Itoa(i))
}
var bufUsed int
if file.ContentType == "" {
if buf == nil {
@ -77,18 +99,24 @@ func (data *SendMessageData) WriteMultipart(c json.Driver, w io.Writer) error {
}
file.ContentType = http.DetectContentType(buf[:n])
data.Files[i] = file
files[i] = file
bufUsed = n
}
h.Set("Content-Type", file.ContentType)
h.Set("Content-Type", file.ContentType)
w, err := body.CreatePart(h)
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for "+
strconv.Itoa(i))
}
if bufUsed > 0 {
// Prematurely write
if _, err := w.Write(buf[:n]); err != nil {
if _, err := w.Write(buf[:bufUsed]); err != nil {
return errors.Wrap(err, "Failed to write buffer for "+
strconv.Itoa(i))
}
} else {
h.Set("Content-Type", file.ContentType)
}
if _, err := io.Copy(w, file.Reader); err != nil {

140
api/webhook.go Normal file
View file

@ -0,0 +1,140 @@
package api
import (
"io"
"net/url"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/internal/httputil"
"github.com/pkg/errors"
)
const EndpointWebhooks = Endpoint + "webhooks/"
// CreateWebhook creates a new webhook; avatar hash is optional. Requires
// MANAGE_WEBHOOKS.
func (c *Client) CreateWebhook(
channelID discord.Snowflake,
name string, avatar discord.Hash) (*discord.Webhook, error) {
var param struct {
Name string `json:"name"`
Avatar discord.Hash `json:"avatar"`
}
param.Name = name
param.Avatar = avatar
var w *discord.Webhook
return w, c.RequestJSON(
&w, "POST",
EndpointChannels+channelID.String()+"/webhooks",
httputil.WithJSONBody(c, param),
)
}
// Webhooks requires MANAGE_WEBHOOKS.
func (c *Client) Webhooks(
guildID discord.Snowflake) ([]discord.Webhook, error) {
var ws []discord.Webhook
return ws, c.RequestJSON(&ws, "GET",
EndpointGuilds+guildID.String()+"/webhooks")
}
func (c *Client) Webhook(
webhookID discord.Snowflake) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "GET",
EndpointWebhooks+webhookID.String())
}
func (c *Client) WebhookWithToken(
webhookID discord.Snowflake, token string) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "GET",
EndpointWebhooks+webhookID.String()+"/"+token)
}
type ModifyWebhookData struct {
Name string `json:"name,omitempty"`
Avatar discord.Hash `json:"avatar,omitempty"` // TODO: clear avatar how?
ChannelID discord.Snowflake `json:"channel_id,omitempty"`
}
func (c *Client) ModifyWebhook(
webhookID discord.Snowflake,
data ModifyWebhookData) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "PATCH",
EndpointWebhooks+webhookID.String())
}
func (c *Client) ModifyWebhookWithToken(
webhookID discord.Snowflake,
data ModifyWebhookData, token string) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "PATCH",
EndpointWebhooks+webhookID.String()+"/"+token)
}
func (c *Client) DeleteWebhook(webhookID discord.Snowflake) error {
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String())
}
func (c *Client) DeleteWebhookWithToken(
webhookID discord.Snowflake, token string) error {
return c.FastRequest("DELETE",
EndpointWebhooks+webhookID.String()+"/"+token)
}
// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will
// wait for the message to be delivered and will return the message body. This
// also means the returned message will only be there if wait is true.
func (c *Client) ExecuteWebhook(
webhookID discord.Snowflake, token string, wait bool,
data ExecuteWebhookData) (*discord.Message, error) {
if data.Embed != nil {
if err := data.Embed.Validate(); err != nil {
return nil, errors.Wrap(err, "Embed error")
}
}
var param = url.Values{}
if wait {
param.Set("wait", "true")
}
var URL = EndpointWebhooks + webhookID.String() + "?" + param.Encode()
var msg *discord.Message
if len(data.Files) == 0 {
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL,
httputil.WithJSONBody(c, data))
}
writer := func(w io.Writer) error {
return data.WriteMultipart(c, w)
}
resp, err := c.MeanwhileBody(writer, "POST", URL)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if !wait {
// Since we didn't tell Discord to wait, we have nothing to parse.
return nil, nil
}
return msg, c.DecodeStream(resp.Body, &msg)
}

22
discord/webhook.go Normal file
View file

@ -0,0 +1,22 @@
package discord
type Webhook struct {
ID Snowflake `json:"id"`
Type WebhookType `json:"type"`
User User `json:"user"` // creator
GuildID Snowflake `json:"guild_id,omitempty"`
ChannelID Snowflake `json:"channel_id"`
Name string `json:"name"`
Avatar Hash `json:"avatar"`
Token string `json:"token"` // incoming webhooks only
}
type WebhookType uint8
const (
_ WebhookType = iota
IncomingWebhook
ChannelFollowerWebhook
)

View file

@ -6,6 +6,7 @@ import (
"context"
"io"
"io/ioutil"
"log"
"net/http"
"time"
@ -35,6 +36,8 @@ func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error,
// We want to cancel the request if our bodyWriter fails
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r, w := io.Pipe()
var bgErr error
@ -44,6 +47,9 @@ func (c *Client) MeanwhileBody(bodyWriter func(io.Writer) error,
bgErr = err
cancel()
}
// Close the writer so the body gets flushed to the HTTP reader.
w.Close()
}()
resp, err := c.RequestCtx(ctx, method, url,
@ -87,6 +93,7 @@ func (c *Client) RequestCtx(ctx context.Context,
r, err := c.Client.Do(req)
if err != nil {
log.Println("Do error", url, err)
return nil, RequestError{err}
}
@ -120,6 +127,11 @@ func (c *Client) RequestCtxJSON(ctx context.Context,
defer r.Body.Close()
// No content, working as intended (tm)
if r.StatusCode == http.StatusNoContent {
return nil
}
if err := c.DecodeStream(r.Body, to); err != nil {
return JSONError{err}
}

View file

@ -0,0 +1,63 @@
package httputil
import (
"io"
"io/ioutil"
"net"
"net/http"
"testing"
)
func TestMeanwhileBody(t *testing.T) {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal("Can't read body:", err)
}
if s := string(b); s != "Hime" {
t.Fatal("Unexpected body:", s)
}
w.Write([]byte("Arikawa"))
})
addr := startHTTP(t)
c := NewClient()
w := func(w io.Writer) error {
w.Write([]byte("Hime"))
return nil
}
r, err := c.MeanwhileBody(w, "GET", "http://"+addr)
if err != nil {
t.Fatal("Failed to send request:", err)
}
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal("Can't read body:", err)
}
if s := string(b); s != "Arikawa" {
t.Fatal("Unexpected body:", s)
}
}
func startHTTP(t *testing.T) string {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("TCP error:", err)
}
go func() {
if err := http.Serve(listener, nil); err != nil {
t.Fatal("HTTP error:", err)
}
}()
return listener.Addr().(*net.TCPAddr).String()
}

View file

@ -16,6 +16,11 @@ func JSONRequest(r *http.Request) error {
return nil
}
func MultipartRequest(r *http.Request) error {
r.Header.Set("Content-Type", "multipart/form-data")
return nil
}
func WithSchema(schema SchemaEncoder, v interface{}) RequestOption {
return func(r *http.Request) error {
params, err := schema.Encode(v)
@ -35,7 +40,10 @@ func WithSchema(schema SchemaEncoder, v interface{}) RequestOption {
func WithBody(body io.ReadCloser) RequestOption {
return func(r *http.Request) error {
// tee := io.TeeReader(body, os.Stderr)
// r.Body = ioutil.NopCloser(tee)
r.Body = body
r.ContentLength = -1
return nil
}
}