API: Added AllowedMentions and more tests

This commit is contained in:
diamondburned (Forefront) 2020-04-19 16:35:37 -07:00
parent 2afe683b7d
commit 748ad5641c
17 changed files with 438 additions and 220 deletions

View File

@ -14,7 +14,7 @@ stages:
unit_test:
stage: test
script:
- go test -tags unit -v -coverprofile $COV ./...
- go test -v -coverprofile $COV ./...
- go tool cover -func $COV
| grep -F 'total:'
| sed -E 's|total:\s+\(.*?\)\s+([0-9]+\.[0-9]+%)|TEST_COVERAGE=\1|'

View File

@ -8,9 +8,7 @@ import (
// Messages gets all mesesages, automatically paginating. Use with care, as
// this could get as many as hundred thousands of messages, making a lot of
// queries.
func (c *Client) Messages(
channelID discord.Snowflake, max uint) ([]discord.Message, error) {
func (c *Client) Messages(channelID discord.Snowflake, max uint) ([]discord.Message, error) {
var msgs []discord.Message
var after discord.Snowflake = 0
@ -64,8 +62,9 @@ func (c *Client) MessagesAfter(
return c.messagesRange(channelID, 0, after, 0, limit)
}
func (c *Client) messagesRange(channelID, before, after,
around discord.Snowflake, limit uint) ([]discord.Message, error) {
func (c *Client) messagesRange(
channelID, before, after, around discord.Snowflake,
limit uint) ([]discord.Message, error) {
switch {
case limit == 0:
@ -95,9 +94,7 @@ func (c *Client) messagesRange(channelID, before, after,
)
}
func (c *Client) Message(
channelID, messageID discord.Snowflake) (*discord.Message, error) {
func (c *Client) Message(channelID, messageID discord.Snowflake) (*discord.Message, error) {
var msg *discord.Message
return msg, c.RequestJSON(&msg, "GET",
EndpointChannels+channelID.String()+"/messages/"+messageID.String())
@ -146,9 +143,7 @@ func (c *Client) DeleteMessage(channelID, messageID discord.Snowflake) error {
// DeleteMessages only works for bots. It can't delete messages older than 2
// weeks, and will fail if tried. This endpoint requires MANAGE_MESSAGES.
func (c *Client) DeleteMessages(
channelID discord.Snowflake, messageIDs []discord.Snowflake) error {
func (c *Client) DeleteMessages(channelID discord.Snowflake, messageIDs []discord.Snowflake) error {
var param struct {
Messages []discord.Snowflake `json:"messages"`
}

View File

@ -1,120 +0,0 @@
package api
import (
"io"
"mime/multipart"
"strconv"
"strings"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/utils/httputil"
"github.com/diamondburned/arikawa/utils/json"
"github.com/pkg/errors"
)
func (c *Client) SendMessageComplex(
channelID discord.Snowflake,
data SendMessageData) (*discord.Message, error) {
if data.Embed != nil {
if err := data.Embed.Validate(); err != nil {
return nil, errors.Wrap(err, "Embed error")
}
}
var URL = EndpointChannels + channelID.String() + "/messages"
var msg *discord.Message
if len(data.Files) == 0 {
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(c, data))
}
writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(c, mw)
}
resp, err := c.MeanwhileMultipart(writer, "POST", URL)
if err != nil {
return nil, err
}
var body = resp.GetBody()
defer body.Close()
return msg, c.DecodeStream(body, &msg)
}
const AttachmentSpoilerPrefix = "SPOILER_"
var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
type SendMessageFile struct {
Name string
Reader io.Reader
}
type SendMessageData struct {
Content string `json:"content,omitempty"`
Nonce string `json:"nonce,omitempty"`
TTS bool `json:"tts"`
Embed *discord.Embed `json:"embed,omitempty"`
Files []SendMessageFile `json:"-"`
}
func (data *SendMessageData) WriteMultipart(
c json.Driver, body *multipart.Writer) error {
return writeMultipart(c, body, data, data.Files)
}
type ExecuteWebhookData struct {
Content string `json:"content,omitempty"`
Nonce string `json:"nonce,omitempty"`
TTS bool `json:"tts"`
Embeds []discord.Embed `json:"embeds,omitempty"`
Files []SendMessageFile `json:"-"`
Username string `json:"username,omitempty"`
AvatarURL discord.URL `json:"avatar_url,omitempty"`
}
func (data *ExecuteWebhookData) WriteMultipart(c json.Driver, body *multipart.Writer) error {
return writeMultipart(c, body, data, data.Files)
}
func writeMultipart(
c json.Driver, body *multipart.Writer,
item interface{}, files []SendMessageFile) error {
defer body.Close()
// Encode the JSON body first
w, err := body.CreateFormField("payload_json")
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for JSON")
}
if err := c.EncodeStream(w, item); err != nil {
return errors.Wrap(err, "Failed to encode JSON")
}
for i, file := range files {
num := strconv.Itoa(i)
w, err := body.CreateFormFile("file"+num, file.Name)
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for "+num)
}
if _, err := io.Copy(w, file.Reader); err != nil {
return errors.Wrap(err, "Failed to write for file "+num)
}
}
return nil
}

View File

@ -1,5 +1,3 @@
// +build unit
package rate
import "testing"

View File

@ -1,5 +1,3 @@
// +build unit
package rate
import "testing"

View File

@ -1,5 +1,3 @@
// +build unit
package rate
import (

266
api/send.go Normal file
View File

@ -0,0 +1,266 @@
package api
import (
"io"
"mime/multipart"
"net/url"
"strconv"
"strings"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/utils/httputil"
"github.com/diamondburned/arikawa/utils/json"
"github.com/pkg/errors"
)
const AttachmentSpoilerPrefix = "SPOILER_"
var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
// AllowedMentions is a whitelist of mentions for a message.
// https://discordapp.com/developers/docs/resources/channel#allowed-mentions-object
//
// Whitelists
//
// Roles and Users are slices that act as whitelists for IDs that are allowed to
// be mentioned. For example, if only 1 ID is provided in Users, then only that
// ID will be parsed in the message. No other IDs will be. The same example also
// applies for roles.
//
// If Parse is an empty slice and both Users and Roles are empty slices, then no
// mentions will be parsed.
//
// Constraints
//
// If the Users slice is not empty, then Parse must not have AllowUserMention.
// Likewise, if the Roles slice is not empty, then Parse must not have
// AllowRoleMention. This is because everything provided in Parse will make
// Discord parse it completely, meaning they would be mutually exclusive with
// whitelist slices, Roles and Users.
type AllowedMentions struct {
Parse []AllowedMentionType `json:"parse"`
Roles []discord.Snowflake `json:"roles,omitempty"` // max 100
Users []discord.Snowflake `json:"users,omitempty"` // max 100
}
// AllowedMentionType is a constant that tells Discord what is allowed to parse
// from a message content. This can help prevent things such as an unintentional
// @everyone mention.
type AllowedMentionType string
const (
// AllowRoleMention makes Discord parse roles in the content.
AllowRoleMention AllowedMentionType = "roles"
// AllowUserMention makes Discord parse user mentions in the content.
AllowUserMention AllowedMentionType = "users"
// AllowEveryoneMention makes Discord parse @everyone mentions.
AllowEveryoneMention AllowedMentionType = "everyone"
)
// Verify checks the AllowedMentions against constraints mentioned in
// AllowedMentions' documentation. This will be called on SendMessageComplex.
func (am AllowedMentions) Verify() error {
if len(am.Roles) > 100 {
return errors.Errorf("Roles slice length %d is over 100", len(am.Roles))
}
if len(am.Users) > 100 {
return errors.Errorf("Users slice length %d is over 100", len(am.Users))
}
for _, allowed := range am.Parse {
switch allowed {
case AllowRoleMention:
if len(am.Roles) > 0 {
return errors.New(`Parse has AllowRoleMention and Roles slice is not empty`)
}
case AllowUserMention:
if len(am.Users) > 0 {
return errors.New(`Parse has AllowUserMention and Users slice is not empty`)
}
}
}
return nil
}
// ErrEmptyMessage is returned if either a SendMessageData or an
// ExecuteWebhookData has both an empty Content and no Embed(s).
var ErrEmptyMessage = errors.New("Message is empty")
// SendMessageFile represents a file to be uploaded to Discord.
type SendMessageFile struct {
Name string
Reader io.Reader
}
// SendMessageData is the full structure to send a new message to Discord with.
type SendMessageData struct {
// Either of these fields must not be empty.
Content string `json:"content,omitempty"`
Nonce string `json:"nonce,omitempty"`
TTS bool `json:"tts,omitempty"`
Embed *discord.Embed `json:"embed,omitempty"`
Files []SendMessageFile `json:"-"`
AllowedMentions *AllowedMentions `json:"allowed_mentions,omitempty"`
}
func (data *SendMessageData) WriteMultipart(c json.Driver, body *multipart.Writer) error {
return writeMultipart(c, body, data, data.Files)
}
func (c *Client) SendMessageComplex(
channelID discord.Snowflake, data SendMessageData) (*discord.Message, error) {
if data.Content == "" && data.Embed == nil {
return nil, ErrEmptyMessage
}
if data.AllowedMentions != nil {
if err := data.AllowedMentions.Verify(); err != nil {
return nil, errors.Wrap(err, "AllowedMentions error")
}
}
if data.Embed != nil {
if err := data.Embed.Validate(); err != nil {
return nil, errors.Wrap(err, "Embed error")
}
}
var URL = EndpointChannels + channelID.String() + "/messages"
var msg *discord.Message
if len(data.Files) == 0 {
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(c, data))
}
writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(c, mw)
}
resp, err := c.MeanwhileMultipart(writer, "POST", URL)
if err != nil {
return nil, err
}
var body = resp.GetBody()
defer body.Close()
return msg, c.DecodeStream(body, &msg)
}
type ExecuteWebhookData struct {
// Either of these fields must not be empty.
Content string `json:"content,omitempty"`
Nonce string `json:"nonce,omitempty"`
TTS bool `json:"tts,omitempty"`
Embeds []discord.Embed `json:"embeds,omitempty"`
Files []SendMessageFile `json:"-"`
AllowedMentions *AllowedMentions `json:"allowed_mentions,omitempty"`
// Optional fields specific to Webhooks.
Username string `json:"username,omitempty"`
AvatarURL discord.URL `json:"avatar_url,omitempty"`
}
func (data *ExecuteWebhookData) WriteMultipart(c json.Driver, body *multipart.Writer) error {
return writeMultipart(c, body, data, data.Files)
}
// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will
// wait for the message to be delivered and will return the message body. This
// also means the returned message will only be there if wait is true.
func (c *Client) ExecuteWebhook(
webhookID discord.Snowflake,
token string,
wait bool, // if false, then nil returned for *Message.
data ExecuteWebhookData) (*discord.Message, error) {
if data.Content == "" && len(data.Embeds) == 0 {
return nil, ErrEmptyMessage
}
if data.AllowedMentions != nil {
if err := data.AllowedMentions.Verify(); err != nil {
return nil, errors.Wrap(err, "AllowedMentions error")
}
}
for i, embed := range data.Embeds {
if err := embed.Validate(); err != nil {
return nil, errors.Wrap(err, "Embed error at "+strconv.Itoa(i))
}
}
var param = url.Values{}
if wait {
param.Set("wait", "true")
}
var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode()
var msg *discord.Message
if len(data.Files) == 0 {
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL,
httputil.WithJSONBody(c, data))
}
writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(c, mw)
}
resp, err := c.MeanwhileMultipart(writer, "POST", URL)
if err != nil {
return nil, err
}
var body = resp.GetBody()
defer body.Close()
if !wait {
// Since we didn't tell Discord to wait, we have nothing to parse.
return nil, nil
}
return msg, c.DecodeStream(body, &msg)
}
func writeMultipart(
c json.Driver, body *multipart.Writer,
item interface{}, files []SendMessageFile) error {
defer body.Close()
// Encode the JSON body first
w, err := body.CreateFormField("payload_json")
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for JSON")
}
if err := c.EncodeStream(w, item); err != nil {
return errors.Wrap(err, "Failed to encode JSON")
}
for i, file := range files {
num := strconv.Itoa(i)
w, err := body.CreateFormFile("file"+num, file.Name)
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for "+num)
}
if _, err := io.Copy(w, file.Reader); err != nil {
return errors.Wrap(err, "Failed to write for file "+num)
}
}
return nil
}

156
api/send_test.go Normal file
View File

@ -0,0 +1,156 @@
package api
import (
"encoding/json"
"strings"
"testing"
"github.com/diamondburned/arikawa/discord"
)
func TestMarshalAllowedMentions(t *testing.T) {
t.Run("parse nothing", func(t *testing.T) {
var data = SendMessageData{
AllowedMentions: &AllowedMentions{
Parse: []AllowedMentionType{},
},
}
if j := mustMarshal(t, data); j != `{"allowed_mentions":{"parse":[]}}` {
t.Fatal("Unexpected JSON:", j)
}
})
t.Run("allow everything", func(t *testing.T) {
var data = SendMessageData{
Content: "a",
}
if j := mustMarshal(t, data); j != `{"content":"a"}` {
t.Fatal("Unexpected JSON:", j)
}
})
t.Run("allow certain user IDs", func(t *testing.T) {
var data = SendMessageData{
AllowedMentions: &AllowedMentions{
Users: []discord.Snowflake{1, 2},
},
}
if j := mustMarshal(t, data); j != `{"allowed_mentions":{"parse":null,"users":["1","2"]}}` {
t.Fatal("Unexpected JSON:", j)
}
})
}
func TestVerifyAllowedMentions(t *testing.T) {
t.Run("invalid", func(t *testing.T) {
var am = AllowedMentions{
Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention},
Users: []discord.Snowflake{69, 420},
}
err := am.Verify()
errMustContain(t, err, "Users slice is not empty")
})
t.Run("users too long", func(t *testing.T) {
var am = AllowedMentions{
Users: make([]discord.Snowflake, 101),
}
err := am.Verify()
errMustContain(t, err, "Users slice length 101 is over 100")
})
t.Run("roles too long", func(t *testing.T) {
var am = AllowedMentions{
Roles: make([]discord.Snowflake, 101),
}
err := am.Verify()
errMustContain(t, err, "Roles slice length 101 is over 100")
})
t.Run("valid", func(t *testing.T) {
var am = AllowedMentions{
Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention},
Roles: []discord.Snowflake{1337},
Users: []discord.Snowflake{},
}
if err := am.Verify(); err != nil {
t.Fatal("Unexpected error:", err)
}
})
}
func TestSendMessage(t *testing.T) {
send := func(data SendMessageData) error {
// shouldn't matter
client := (*Client)(nil)
_, err := client.SendMessageComplex(0, data)
return err
}
t.Run("empty", func(t *testing.T) {
var empty = SendMessageData{
Content: "",
Embed: nil,
}
if err := send(empty); err != ErrEmptyMessage {
t.Fatal("Unexpected error:", err)
}
})
t.Run("invalid allowed mentions", func(t *testing.T) {
var data = SendMessageData{
Content: "hime arikawa",
AllowedMentions: &AllowedMentions{
Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention},
Users: []discord.Snowflake{69, 420},
},
}
err := send(data)
errMustContain(t, err, "AllowedMentions error")
})
t.Run("invalid embed", func(t *testing.T) {
var data = SendMessageData{
Embed: &discord.Embed{
// max 256
Title: spaces(257),
},
}
err := send(data)
errMustContain(t, err, "Embed error")
})
}
func errMustContain(t *testing.T, err error, contains string) {
// mark function as helper so line traces are accurate.
t.Helper()
if err != nil && strings.Contains(err.Error(), contains) {
return
}
t.Fatal("Unexpected error:", err)
}
func spaces(length int) string {
return strings.Repeat(" ", length)
}
func mustMarshal(t *testing.T, v interface{}) string {
t.Helper()
j, err := json.Marshal(v)
if err != nil {
t.Fatal("Failed to marshal data:", err)
}
return string(j)
}

View File

@ -1,13 +1,8 @@
package api
import (
"mime/multipart"
"net/url"
"strconv"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/utils/httputil"
"github.com/pkg/errors"
)
var EndpointWebhooks = Endpoint + "webhooks/"
@ -81,52 +76,3 @@ func (c *Client) DeleteWebhook(webhookID discord.Snowflake) error {
func (c *Client) DeleteWebhookWithToken(webhookID discord.Snowflake, token string) error {
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String()+"/"+token)
}
// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will
// wait for the message to be delivered and will return the message body. This
// also means the returned message will only be there if wait is true.
func (c *Client) ExecuteWebhook(
webhookID discord.Snowflake,
token string,
wait bool,
data ExecuteWebhookData) (*discord.Message, error) {
for i, embed := range data.Embeds {
if err := embed.Validate(); err != nil {
return nil, errors.Wrap(err, "Embed error at "+strconv.Itoa(i))
}
}
var param = url.Values{}
if wait {
param.Set("wait", "true")
}
var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode()
var msg *discord.Message
if len(data.Files) == 0 {
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL,
httputil.WithJSONBody(c, data))
}
writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(c, mw)
}
resp, err := c.MeanwhileMultipart(writer, "POST", URL)
if err != nil {
return nil, err
}
var body = resp.GetBody()
defer body.Close()
if !wait {
// Since we didn't tell Discord to wait, we have nothing to parse.
return nil, nil
}
return msg, c.DecodeStream(body, &msg)
}

View File

@ -1,5 +1,3 @@
// +build unit
package bot
import (

View File

@ -1,5 +1,3 @@
// +build unit
package bot
import (

View File

@ -1,5 +1,3 @@
// +build unit
package bot
import "testing"

View File

@ -1,5 +1,3 @@
// +build unit
package bot
import "testing"

View File

@ -11,13 +11,6 @@ import (
// Identify structure is at identify.go
func (i *IdentifyData) SetShard(id, num int) {
if i.Shard == nil {
i.Shard = new(Shard)
}
i.Shard[0], i.Shard[1] = id, num
}
func (g *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
@ -26,7 +19,7 @@ func (g *Gateway) Identify() error {
return errors.Wrap(err, "Can't wait for identify()")
}
return g.send(false, IdentifyOP, g.Identifier)
return g.Send(IdentifyOP, g.Identifier)
}
type ResumeData struct {
@ -47,7 +40,7 @@ func (g *Gateway) Resume() error {
return ErrMissingForResume
}
return g.send(false, ResumeOP, ResumeData{
return g.Send(ResumeOP, ResumeData{
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,

View File

@ -389,10 +389,6 @@ func (g *Gateway) eventLoop() error {
}
func (g *Gateway) Send(code OPCode, v interface{}) error {
return g.send(true, code, v)
}
func (g *Gateway) send(lock bool, code OPCode, v interface{}) error {
var op = OP{
Code: code,
}
@ -411,11 +407,6 @@ func (g *Gateway) send(lock bool, code OPCode, v interface{}) error {
return errors.Wrap(err, "Failed to encode payload")
}
// if lock {
// g.available.RLock()
// defer g.available.RUnlock()
// }
// WS should already be thread-safe.
return g.WS.Send(b)
}

View File

@ -48,6 +48,13 @@ type IdentifyData struct {
Intents Intents `json:"intents,omitempty"`
}
func (i *IdentifyData) SetShard(id, num int) {
if i.Shard == nil {
i.Shard = new(Shard)
}
i.Shard[0], i.Shard[1] = id, num
}
// Intents is a new Discord API feature that's documented at
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
type Intents uint32

View File

@ -1,5 +1,3 @@
// +build unit
package handler
import (