diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6653043..acc4a8a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -14,7 +14,7 @@ stages: unit_test: stage: test script: - - go test -tags unit -v -coverprofile $COV ./... + - go test -v -coverprofile $COV ./... - go tool cover -func $COV | grep -F 'total:' | sed -E 's|total:\s+\(.*?\)\s+([0-9]+\.[0-9]+%)|TEST_COVERAGE=\1|' diff --git a/api/message.go b/api/message.go index cdec488..fa4872e 100644 --- a/api/message.go +++ b/api/message.go @@ -8,9 +8,7 @@ import ( // Messages gets all mesesages, automatically paginating. Use with care, as // this could get as many as hundred thousands of messages, making a lot of // queries. -func (c *Client) Messages( - channelID discord.Snowflake, max uint) ([]discord.Message, error) { - +func (c *Client) Messages(channelID discord.Snowflake, max uint) ([]discord.Message, error) { var msgs []discord.Message var after discord.Snowflake = 0 @@ -64,8 +62,9 @@ func (c *Client) MessagesAfter( return c.messagesRange(channelID, 0, after, 0, limit) } -func (c *Client) messagesRange(channelID, before, after, - around discord.Snowflake, limit uint) ([]discord.Message, error) { +func (c *Client) messagesRange( + channelID, before, after, around discord.Snowflake, + limit uint) ([]discord.Message, error) { switch { case limit == 0: @@ -95,9 +94,7 @@ func (c *Client) messagesRange(channelID, before, after, ) } -func (c *Client) Message( - channelID, messageID discord.Snowflake) (*discord.Message, error) { - +func (c *Client) Message(channelID, messageID discord.Snowflake) (*discord.Message, error) { var msg *discord.Message return msg, c.RequestJSON(&msg, "GET", EndpointChannels+channelID.String()+"/messages/"+messageID.String()) @@ -146,9 +143,7 @@ func (c *Client) DeleteMessage(channelID, messageID discord.Snowflake) error { // DeleteMessages only works for bots. It can't delete messages older than 2 // weeks, and will fail if tried. This endpoint requires MANAGE_MESSAGES. -func (c *Client) DeleteMessages( - channelID discord.Snowflake, messageIDs []discord.Snowflake) error { - +func (c *Client) DeleteMessages(channelID discord.Snowflake, messageIDs []discord.Snowflake) error { var param struct { Messages []discord.Snowflake `json:"messages"` } diff --git a/api/message_send.go b/api/message_send.go deleted file mode 100644 index 8739ecb..0000000 --- a/api/message_send.go +++ /dev/null @@ -1,120 +0,0 @@ -package api - -import ( - "io" - "mime/multipart" - "strconv" - "strings" - - "github.com/diamondburned/arikawa/discord" - "github.com/diamondburned/arikawa/utils/httputil" - "github.com/diamondburned/arikawa/utils/json" - "github.com/pkg/errors" -) - -func (c *Client) SendMessageComplex( - channelID discord.Snowflake, - data SendMessageData) (*discord.Message, error) { - - if data.Embed != nil { - if err := data.Embed.Validate(); err != nil { - return nil, errors.Wrap(err, "Embed error") - } - } - - var URL = EndpointChannels + channelID.String() + "/messages" - var msg *discord.Message - - if len(data.Files) == 0 { - // No files, so no need for streaming. - return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(c, data)) - } - - writer := func(mw *multipart.Writer) error { - return data.WriteMultipart(c, mw) - } - - resp, err := c.MeanwhileMultipart(writer, "POST", URL) - if err != nil { - return nil, err - } - - var body = resp.GetBody() - defer body.Close() - - return msg, c.DecodeStream(body, &msg) -} - -const AttachmentSpoilerPrefix = "SPOILER_" - -var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`) - -type SendMessageFile struct { - Name string - Reader io.Reader -} - -type SendMessageData struct { - Content string `json:"content,omitempty"` - Nonce string `json:"nonce,omitempty"` - TTS bool `json:"tts"` - - Embed *discord.Embed `json:"embed,omitempty"` - - Files []SendMessageFile `json:"-"` -} - -func (data *SendMessageData) WriteMultipart( - c json.Driver, body *multipart.Writer) error { - - return writeMultipart(c, body, data, data.Files) -} - -type ExecuteWebhookData struct { - Content string `json:"content,omitempty"` - Nonce string `json:"nonce,omitempty"` - TTS bool `json:"tts"` - - Embeds []discord.Embed `json:"embeds,omitempty"` - - Files []SendMessageFile `json:"-"` - - Username string `json:"username,omitempty"` - AvatarURL discord.URL `json:"avatar_url,omitempty"` -} - -func (data *ExecuteWebhookData) WriteMultipart(c json.Driver, body *multipart.Writer) error { - return writeMultipart(c, body, data, data.Files) -} - -func writeMultipart( - c json.Driver, body *multipart.Writer, - item interface{}, files []SendMessageFile) error { - - defer body.Close() - - // Encode the JSON body first - w, err := body.CreateFormField("payload_json") - if err != nil { - return errors.Wrap(err, "Failed to create bodypart for JSON") - } - - if err := c.EncodeStream(w, item); err != nil { - return errors.Wrap(err, "Failed to encode JSON") - } - - for i, file := range files { - num := strconv.Itoa(i) - - w, err := body.CreateFormFile("file"+num, file.Name) - if err != nil { - return errors.Wrap(err, "Failed to create bodypart for "+num) - } - - if _, err := io.Copy(w, file.Reader); err != nil { - return errors.Wrap(err, "Failed to write for file "+num) - } - } - - return nil -} diff --git a/api/rate/emoji_test.go b/api/rate/emoji_test.go index dd5ec51..22b4aef 100644 --- a/api/rate/emoji_test.go +++ b/api/rate/emoji_test.go @@ -1,5 +1,3 @@ -// +build unit - package rate import "testing" diff --git a/api/rate/majors_test.go b/api/rate/majors_test.go index b6ec7ce..e8d1e61 100644 --- a/api/rate/majors_test.go +++ b/api/rate/majors_test.go @@ -1,5 +1,3 @@ -// +build unit - package rate import "testing" diff --git a/api/rate/rate_test.go b/api/rate/rate_test.go index c1a0837..ea65c23 100644 --- a/api/rate/rate_test.go +++ b/api/rate/rate_test.go @@ -1,5 +1,3 @@ -// +build unit - package rate import ( diff --git a/api/send.go b/api/send.go new file mode 100644 index 0000000..457ad80 --- /dev/null +++ b/api/send.go @@ -0,0 +1,266 @@ +package api + +import ( + "io" + "mime/multipart" + "net/url" + "strconv" + "strings" + + "github.com/diamondburned/arikawa/discord" + "github.com/diamondburned/arikawa/utils/httputil" + "github.com/diamondburned/arikawa/utils/json" + "github.com/pkg/errors" +) + +const AttachmentSpoilerPrefix = "SPOILER_" + +var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +// AllowedMentions is a whitelist of mentions for a message. +// https://discordapp.com/developers/docs/resources/channel#allowed-mentions-object +// +// Whitelists +// +// Roles and Users are slices that act as whitelists for IDs that are allowed to +// be mentioned. For example, if only 1 ID is provided in Users, then only that +// ID will be parsed in the message. No other IDs will be. The same example also +// applies for roles. +// +// If Parse is an empty slice and both Users and Roles are empty slices, then no +// mentions will be parsed. +// +// Constraints +// +// If the Users slice is not empty, then Parse must not have AllowUserMention. +// Likewise, if the Roles slice is not empty, then Parse must not have +// AllowRoleMention. This is because everything provided in Parse will make +// Discord parse it completely, meaning they would be mutually exclusive with +// whitelist slices, Roles and Users. +type AllowedMentions struct { + Parse []AllowedMentionType `json:"parse"` + Roles []discord.Snowflake `json:"roles,omitempty"` // max 100 + Users []discord.Snowflake `json:"users,omitempty"` // max 100 +} + +// AllowedMentionType is a constant that tells Discord what is allowed to parse +// from a message content. This can help prevent things such as an unintentional +// @everyone mention. +type AllowedMentionType string + +const ( + // AllowRoleMention makes Discord parse roles in the content. + AllowRoleMention AllowedMentionType = "roles" + // AllowUserMention makes Discord parse user mentions in the content. + AllowUserMention AllowedMentionType = "users" + // AllowEveryoneMention makes Discord parse @everyone mentions. + AllowEveryoneMention AllowedMentionType = "everyone" +) + +// Verify checks the AllowedMentions against constraints mentioned in +// AllowedMentions' documentation. This will be called on SendMessageComplex. +func (am AllowedMentions) Verify() error { + if len(am.Roles) > 100 { + return errors.Errorf("Roles slice length %d is over 100", len(am.Roles)) + } + if len(am.Users) > 100 { + return errors.Errorf("Users slice length %d is over 100", len(am.Users)) + } + + for _, allowed := range am.Parse { + switch allowed { + case AllowRoleMention: + if len(am.Roles) > 0 { + return errors.New(`Parse has AllowRoleMention and Roles slice is not empty`) + } + case AllowUserMention: + if len(am.Users) > 0 { + return errors.New(`Parse has AllowUserMention and Users slice is not empty`) + } + } + } + + return nil +} + +// ErrEmptyMessage is returned if either a SendMessageData or an +// ExecuteWebhookData has both an empty Content and no Embed(s). +var ErrEmptyMessage = errors.New("Message is empty") + +// SendMessageFile represents a file to be uploaded to Discord. +type SendMessageFile struct { + Name string + Reader io.Reader +} + +// SendMessageData is the full structure to send a new message to Discord with. +type SendMessageData struct { + // Either of these fields must not be empty. + Content string `json:"content,omitempty"` + Nonce string `json:"nonce,omitempty"` + + TTS bool `json:"tts,omitempty"` + Embed *discord.Embed `json:"embed,omitempty"` + + Files []SendMessageFile `json:"-"` + + AllowedMentions *AllowedMentions `json:"allowed_mentions,omitempty"` +} + +func (data *SendMessageData) WriteMultipart(c json.Driver, body *multipart.Writer) error { + return writeMultipart(c, body, data, data.Files) +} + +func (c *Client) SendMessageComplex( + channelID discord.Snowflake, data SendMessageData) (*discord.Message, error) { + + if data.Content == "" && data.Embed == nil { + return nil, ErrEmptyMessage + } + + if data.AllowedMentions != nil { + if err := data.AllowedMentions.Verify(); err != nil { + return nil, errors.Wrap(err, "AllowedMentions error") + } + } + + if data.Embed != nil { + if err := data.Embed.Validate(); err != nil { + return nil, errors.Wrap(err, "Embed error") + } + } + + var URL = EndpointChannels + channelID.String() + "/messages" + var msg *discord.Message + + if len(data.Files) == 0 { + // No files, so no need for streaming. + return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(c, data)) + } + + writer := func(mw *multipart.Writer) error { + return data.WriteMultipart(c, mw) + } + + resp, err := c.MeanwhileMultipart(writer, "POST", URL) + if err != nil { + return nil, err + } + + var body = resp.GetBody() + defer body.Close() + + return msg, c.DecodeStream(body, &msg) +} + +type ExecuteWebhookData struct { + // Either of these fields must not be empty. + Content string `json:"content,omitempty"` + Nonce string `json:"nonce,omitempty"` + + TTS bool `json:"tts,omitempty"` + Embeds []discord.Embed `json:"embeds,omitempty"` + + Files []SendMessageFile `json:"-"` + + AllowedMentions *AllowedMentions `json:"allowed_mentions,omitempty"` + + // Optional fields specific to Webhooks. + Username string `json:"username,omitempty"` + AvatarURL discord.URL `json:"avatar_url,omitempty"` +} + +func (data *ExecuteWebhookData) WriteMultipart(c json.Driver, body *multipart.Writer) error { + return writeMultipart(c, body, data, data.Files) +} + +// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will +// wait for the message to be delivered and will return the message body. This +// also means the returned message will only be there if wait is true. +func (c *Client) ExecuteWebhook( + webhookID discord.Snowflake, + token string, + wait bool, // if false, then nil returned for *Message. + data ExecuteWebhookData) (*discord.Message, error) { + + if data.Content == "" && len(data.Embeds) == 0 { + return nil, ErrEmptyMessage + } + + if data.AllowedMentions != nil { + if err := data.AllowedMentions.Verify(); err != nil { + return nil, errors.Wrap(err, "AllowedMentions error") + } + } + + for i, embed := range data.Embeds { + if err := embed.Validate(); err != nil { + return nil, errors.Wrap(err, "Embed error at "+strconv.Itoa(i)) + } + } + + var param = url.Values{} + if wait { + param.Set("wait", "true") + } + + var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode() + var msg *discord.Message + + if len(data.Files) == 0 { + // No files, so no need for streaming. + return msg, c.RequestJSON(&msg, "POST", URL, + httputil.WithJSONBody(c, data)) + } + + writer := func(mw *multipart.Writer) error { + return data.WriteMultipart(c, mw) + } + + resp, err := c.MeanwhileMultipart(writer, "POST", URL) + if err != nil { + return nil, err + } + + var body = resp.GetBody() + defer body.Close() + + if !wait { + // Since we didn't tell Discord to wait, we have nothing to parse. + return nil, nil + } + + return msg, c.DecodeStream(body, &msg) +} + +func writeMultipart( + c json.Driver, body *multipart.Writer, + item interface{}, files []SendMessageFile) error { + + defer body.Close() + + // Encode the JSON body first + w, err := body.CreateFormField("payload_json") + if err != nil { + return errors.Wrap(err, "Failed to create bodypart for JSON") + } + + if err := c.EncodeStream(w, item); err != nil { + return errors.Wrap(err, "Failed to encode JSON") + } + + for i, file := range files { + num := strconv.Itoa(i) + + w, err := body.CreateFormFile("file"+num, file.Name) + if err != nil { + return errors.Wrap(err, "Failed to create bodypart for "+num) + } + + if _, err := io.Copy(w, file.Reader); err != nil { + return errors.Wrap(err, "Failed to write for file "+num) + } + } + + return nil +} diff --git a/api/send_test.go b/api/send_test.go new file mode 100644 index 0000000..2f200b6 --- /dev/null +++ b/api/send_test.go @@ -0,0 +1,156 @@ +package api + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/diamondburned/arikawa/discord" +) + +func TestMarshalAllowedMentions(t *testing.T) { + t.Run("parse nothing", func(t *testing.T) { + var data = SendMessageData{ + AllowedMentions: &AllowedMentions{ + Parse: []AllowedMentionType{}, + }, + } + + if j := mustMarshal(t, data); j != `{"allowed_mentions":{"parse":[]}}` { + t.Fatal("Unexpected JSON:", j) + } + }) + + t.Run("allow everything", func(t *testing.T) { + var data = SendMessageData{ + Content: "a", + } + + if j := mustMarshal(t, data); j != `{"content":"a"}` { + t.Fatal("Unexpected JSON:", j) + } + }) + + t.Run("allow certain user IDs", func(t *testing.T) { + var data = SendMessageData{ + AllowedMentions: &AllowedMentions{ + Users: []discord.Snowflake{1, 2}, + }, + } + + if j := mustMarshal(t, data); j != `{"allowed_mentions":{"parse":null,"users":["1","2"]}}` { + t.Fatal("Unexpected JSON:", j) + } + }) +} + +func TestVerifyAllowedMentions(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + var am = AllowedMentions{ + Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention}, + Users: []discord.Snowflake{69, 420}, + } + + err := am.Verify() + errMustContain(t, err, "Users slice is not empty") + }) + + t.Run("users too long", func(t *testing.T) { + var am = AllowedMentions{ + Users: make([]discord.Snowflake, 101), + } + + err := am.Verify() + errMustContain(t, err, "Users slice length 101 is over 100") + }) + + t.Run("roles too long", func(t *testing.T) { + var am = AllowedMentions{ + Roles: make([]discord.Snowflake, 101), + } + + err := am.Verify() + errMustContain(t, err, "Roles slice length 101 is over 100") + }) + + t.Run("valid", func(t *testing.T) { + var am = AllowedMentions{ + Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention}, + Roles: []discord.Snowflake{1337}, + Users: []discord.Snowflake{}, + } + + if err := am.Verify(); err != nil { + t.Fatal("Unexpected error:", err) + } + }) +} + +func TestSendMessage(t *testing.T) { + send := func(data SendMessageData) error { + // shouldn't matter + client := (*Client)(nil) + _, err := client.SendMessageComplex(0, data) + return err + } + + t.Run("empty", func(t *testing.T) { + var empty = SendMessageData{ + Content: "", + Embed: nil, + } + + if err := send(empty); err != ErrEmptyMessage { + t.Fatal("Unexpected error:", err) + } + }) + + t.Run("invalid allowed mentions", func(t *testing.T) { + var data = SendMessageData{ + Content: "hime arikawa", + AllowedMentions: &AllowedMentions{ + Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention}, + Users: []discord.Snowflake{69, 420}, + }, + } + + err := send(data) + errMustContain(t, err, "AllowedMentions error") + }) + + t.Run("invalid embed", func(t *testing.T) { + var data = SendMessageData{ + Embed: &discord.Embed{ + // max 256 + Title: spaces(257), + }, + } + + err := send(data) + errMustContain(t, err, "Embed error") + }) +} + +func errMustContain(t *testing.T, err error, contains string) { + // mark function as helper so line traces are accurate. + t.Helper() + + if err != nil && strings.Contains(err.Error(), contains) { + return + } + t.Fatal("Unexpected error:", err) +} + +func spaces(length int) string { + return strings.Repeat(" ", length) +} + +func mustMarshal(t *testing.T, v interface{}) string { + t.Helper() + + j, err := json.Marshal(v) + if err != nil { + t.Fatal("Failed to marshal data:", err) + } + return string(j) +} diff --git a/api/webhook.go b/api/webhook.go index f6b344d..817e759 100644 --- a/api/webhook.go +++ b/api/webhook.go @@ -1,13 +1,8 @@ package api import ( - "mime/multipart" - "net/url" - "strconv" - "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/utils/httputil" - "github.com/pkg/errors" ) var EndpointWebhooks = Endpoint + "webhooks/" @@ -81,52 +76,3 @@ func (c *Client) DeleteWebhook(webhookID discord.Snowflake) error { func (c *Client) DeleteWebhookWithToken(webhookID discord.Snowflake, token string) error { return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String()+"/"+token) } - -// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will -// wait for the message to be delivered and will return the message body. This -// also means the returned message will only be there if wait is true. -func (c *Client) ExecuteWebhook( - webhookID discord.Snowflake, - token string, - wait bool, - data ExecuteWebhookData) (*discord.Message, error) { - - for i, embed := range data.Embeds { - if err := embed.Validate(); err != nil { - return nil, errors.Wrap(err, "Embed error at "+strconv.Itoa(i)) - } - } - - var param = url.Values{} - if wait { - param.Set("wait", "true") - } - - var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode() - var msg *discord.Message - - if len(data.Files) == 0 { - // No files, so no need for streaming. - return msg, c.RequestJSON(&msg, "POST", URL, - httputil.WithJSONBody(c, data)) - } - - writer := func(mw *multipart.Writer) error { - return data.WriteMultipart(c, mw) - } - - resp, err := c.MeanwhileMultipart(writer, "POST", URL) - if err != nil { - return nil, err - } - - var body = resp.GetBody() - defer body.Close() - - if !wait { - // Since we didn't tell Discord to wait, we have nothing to parse. - return nil, nil - } - - return msg, c.DecodeStream(body, &msg) -} diff --git a/bot/ctx_plumb_test.go b/bot/ctx_plumb_test.go index ee08666..79a7476 100644 --- a/bot/ctx_plumb_test.go +++ b/bot/ctx_plumb_test.go @@ -1,5 +1,3 @@ -// +build unit - package bot import ( diff --git a/bot/ctx_test.go b/bot/ctx_test.go index 1bdfe94..641923c 100644 --- a/bot/ctx_test.go +++ b/bot/ctx_test.go @@ -1,5 +1,3 @@ -// +build unit - package bot import ( diff --git a/bot/nameflag_test.go b/bot/nameflag_test.go index 29b863b..915a0b9 100644 --- a/bot/nameflag_test.go +++ b/bot/nameflag_test.go @@ -1,5 +1,3 @@ -// +build unit - package bot import "testing" diff --git a/bot/subcommand_test.go b/bot/subcommand_test.go index 3df5230..ad8736c 100644 --- a/bot/subcommand_test.go +++ b/bot/subcommand_test.go @@ -1,5 +1,3 @@ -// +build unit - package bot import "testing" diff --git a/gateway/commands.go b/gateway/commands.go index a0fe52e..bfe0d42 100644 --- a/gateway/commands.go +++ b/gateway/commands.go @@ -11,13 +11,6 @@ import ( // Identify structure is at identify.go -func (i *IdentifyData) SetShard(id, num int) { - if i.Shard == nil { - i.Shard = new(Shard) - } - i.Shard[0], i.Shard[1] = id, num -} - func (g *Gateway) Identify() error { ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) defer cancel() @@ -26,7 +19,7 @@ func (g *Gateway) Identify() error { return errors.Wrap(err, "Can't wait for identify()") } - return g.send(false, IdentifyOP, g.Identifier) + return g.Send(IdentifyOP, g.Identifier) } type ResumeData struct { @@ -47,7 +40,7 @@ func (g *Gateway) Resume() error { return ErrMissingForResume } - return g.send(false, ResumeOP, ResumeData{ + return g.Send(ResumeOP, ResumeData{ Token: g.Identifier.Token, SessionID: ses, Sequence: seq, diff --git a/gateway/gateway.go b/gateway/gateway.go index b12cbbb..ce8cc3c 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -389,10 +389,6 @@ func (g *Gateway) eventLoop() error { } func (g *Gateway) Send(code OPCode, v interface{}) error { - return g.send(true, code, v) -} - -func (g *Gateway) send(lock bool, code OPCode, v interface{}) error { var op = OP{ Code: code, } @@ -411,11 +407,6 @@ func (g *Gateway) send(lock bool, code OPCode, v interface{}) error { return errors.Wrap(err, "Failed to encode payload") } - // if lock { - // g.available.RLock() - // defer g.available.RUnlock() - // } - // WS should already be thread-safe. return g.WS.Send(b) } diff --git a/gateway/identify.go b/gateway/identify.go index 72fa406..916795e 100644 --- a/gateway/identify.go +++ b/gateway/identify.go @@ -48,6 +48,13 @@ type IdentifyData struct { Intents Intents `json:"intents,omitempty"` } +func (i *IdentifyData) SetShard(id, num int) { + if i.Shard == nil { + i.Shard = new(Shard) + } + i.Shard[0], i.Shard[1] = id, num +} + // Intents is a new Discord API feature that's documented at // https://discordapp.com/developers/docs/topics/gateway#gateway-intents. type Intents uint32 diff --git a/handler/handler_test.go b/handler/handler_test.go index 79b5c47..2b0e3c1 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -1,5 +1,3 @@ -// +build unit - package handler import (