Compare commits

...

21 Commits

Author SHA1 Message Date
diamondburned 16a408bf30 wsutil: Refactored and decoupled structures for better thread safety 2020-10-28 10:19:22 -07:00
diamondburned 6c332ac145 {Voice,}Gateway: Fixed various race conditions
This commit fixes race conditions in both package voice, package
voicegateway and package gateway.

Originally, several race conditions exist when both the user's and the
pacemaker's goroutines both want to do several things to the websocket
connection. For example, the user's goroutine could be writing, and the
pacemaker's goroutine could trigger a reconnection. This is racey.

This issue is partially fixed by removing the pacer loop from package
heart and combining the ticker into the event (pacemaker) loop itself.

Technically, a race condition could still be triggered with care, but
the API itself never guaranteed any of those. As events are handled
using an internal loop into a channel, a race condition will not be
triggered just by handling events and writing to the websocket.
2020-10-22 10:47:27 -07:00
diamondburned 91ee92e9d5 Gateway: Fixed a race condition on ReconnectOP 2020-10-21 22:42:16 -07:00
diamondburned 86795e42a6 Session: Fixed a potential race condition on Close 2020-10-21 22:42:16 -07:00
Maximilian von Lindern 397d288927
API: fix errors in message pagination and streamline changes with other pagination methods (#150)
* API: fix faulty pagination behavior

This fix fixes a condition which lead to all messages getting fetched if the limit was a multiple of 100, instead of just the limit.

* API: add NewestMessages

* API: clarify MessageAfter docs

* API: adapt paginating methods for guild, member and message reaction to match the style of message's pagination methods

* API: return nil if no items were fetched

* API: remove Messages and Rename NewestMessages to Messages
2020-10-19 07:47:43 -07:00
diamondburned dec39c4c2d API: Fixed Messages{Before,After} fetching incorrectly beyond 100s 2020-10-18 22:14:49 -07:00
mavolin 6dabffb46c State: fix case where Role would return nil error, even though no role was found 2020-10-18 13:44:37 -07:00
diamondburned 1bec57523d Gateway: GuildSubscribeData should omit empty Channels map 2020-10-17 03:18:50 -07:00
diamondburned 86dd05da9e Gateway: Fixed empty Query on RequestGuildMembersData broken 2020-10-16 02:17:59 -07:00
mavolin 647efb8030 Discord: add Mention method to mentionable Snowflakes 2020-09-24 11:54:45 -07:00
diamondburned 64ab8c4f30 Bot: Fixed trailing backticks causing out of bound panic 2020-08-29 22:09:58 -07:00
mavolin 5acf9f3f22 Discord: fix invalid role mention generation 2020-08-24 16:32:51 -07:00
mavolin 7d5cc89ff0 API: add KickWithReason 2020-08-22 10:05:37 -07:00
diamondburned 6b4e26e839 wsutil: Improved internal code 2020-08-20 14:15:52 -07:00
diamondburned fd818e181e Gateway: GuildFolderID is now a signed int because Discord 2020-08-19 21:54:20 -07:00
diamondburned 87c648ae1d Discord: ParseSnowflake now uses ParseUint 2020-08-19 21:53:22 -07:00
diamondburned 3312c66515 Voice: Made EventLoop a valid struct value instead of nil pointer 2020-08-19 21:32:40 -07:00
diamondburned de61fd912d wsutil: Made PacemakerLoop valid as zero-value 2020-08-19 21:30:57 -07:00
diamondburned f0c73f4c99 State: Ready events now automatically reset the state 2020-08-18 10:20:48 -07:00
Maximilian von Lindern a7e9439109
Discord/API: implement changes to permission, allow and deny fields (#141) 2020-08-17 17:10:43 -07:00
diamondburned af7f413cea Gateway: Clarified GuildMemberListGroup.ID docs 2020-08-14 21:13:48 -07:00
31 changed files with 810 additions and 611 deletions

View File

@ -167,13 +167,14 @@ func (c *Client) DeleteChannel(channelID discord.ChannelID) error {
return c.FastRequest("DELETE", EndpointChannels+channelID.String())
}
// https://discord.com/developers/docs/resources/channel#edit-channel-permissions-json-params
type EditChannelPermissionData struct {
// Type is either "role" or "member".
Type discord.OverwriteType `json:"type"`
// Allow is a permission bit set for granted permissions.
Allow discord.Permissions `json:"allow"`
Allow discord.Permissions `json:"allow,string"`
// Deny is a permission bit set for denied permissions.
Deny discord.Permissions `json:"deny"`
Deny discord.Permissions `json:"deny,string"`
}
// EditChannelPermission edits the channel's permission overwrites for a user

View File

@ -9,6 +9,10 @@ import (
"github.com/diamondburned/arikawa/utils/json/option"
)
// maxGuildFetchLimit is the limit of max guilds per request, as imposed by
// Discord.
const maxGuildFetchLimit = 100
var EndpointGuilds = Endpoint + "guilds/"
// https://discordapp.com/developers/docs/resources/guild#create-guild-json-params
@ -105,8 +109,7 @@ func (c *Client) GuildWithCount(id discord.GuildID) (*discord.Guild, error) {
// Guilds returns a list of partial guild objects the current user is a member
// of. This method automatically paginates until it reaches the passed limit,
// or, if the limit is set to 0, has fetched all guilds within the passed
// range.
// or, if the limit is set to 0, has fetched all guilds the user has joined.
//
// As the underlying endpoint has a maximum of 100 guilds per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
@ -125,8 +128,8 @@ func (c *Client) Guilds(limit uint) ([]discord.Guild, error) {
// GuildsBefore returns a list of partial guild objects the current user is a
// member of. This method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all guilds within
// the passed range.
// passed limit, or, if the limit is set to 0, has fetched all guilds with an
// id smaller than before.
//
// As the underlying endpoint has a maximum of 100 guilds per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
@ -134,15 +137,16 @@ func (c *Client) Guilds(limit uint) ([]discord.Guild, error) {
//
// Requires the guilds OAuth2 scope.
func (c *Client) GuildsBefore(before discord.GuildID, limit uint) ([]discord.Guild, error) {
var guilds []discord.Guild
guilds := make([]discord.Guild, 0, limit)
// this is the limit of max guilds per request,as imposed by Discord
const hardLimit int = 100
fetch := uint(maxGuildFetchLimit)
unlimited := limit == 0
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
for limit > 0 || unlimited {
if limit > 0 {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if fetch > limit {
fetch = limit
}
@ -155,20 +159,24 @@ func (c *Client) GuildsBefore(before discord.GuildID, limit uint) ([]discord.Gui
}
guilds = append(g, guilds...)
if len(g) < hardLimit {
if len(g) < maxGuildFetchLimit {
break
}
before = g[0].ID
}
if len(guilds) == 0 {
return nil, nil
}
return guilds, nil
}
// GuildsAfter returns a list of partial guild objects the current user is a
// member of. This method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all guilds within
// the passed range.
// passed limit, or, if the limit is set to 0, has fetched all guilds with an
// id higher than after.
//
// As the underlying endpoint has a maximum of 100 guilds per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
@ -176,14 +184,15 @@ func (c *Client) GuildsBefore(before discord.GuildID, limit uint) ([]discord.Gui
//
// Requires the guilds OAuth2 scope.
func (c *Client) GuildsAfter(after discord.GuildID, limit uint) ([]discord.Guild, error) {
var guilds []discord.Guild
guilds := make([]discord.Guild, 0, limit)
// this is the limit of max guilds per request, as imposed by Discord
const hardLimit int = 100
fetch := uint(maxGuildFetchLimit)
unlimited := limit == 0
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
@ -197,13 +206,17 @@ func (c *Client) GuildsAfter(after discord.GuildID, limit uint) ([]discord.Guild
}
guilds = append(guilds, g...)
if len(g) < hardLimit {
if len(g) < maxGuildFetchLimit {
break
}
after = g[len(g)-1].ID
}
if len(guilds) == 0 {
return nil, nil
}
return guilds, nil
}

View File

@ -6,7 +6,9 @@ import (
"github.com/diamondburned/arikawa/utils/json/option"
)
// Member returns a guild member object for the specified user..
const maxMemberFetchLimit = 1000
// Member returns a guild member object for the specified user.
func (c *Client) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
var m *discord.Member
return m, c.RequestJSON(&m, "GET", EndpointGuilds+guildID.String()+"/members/"+userID.String())
@ -14,11 +16,11 @@ func (c *Client) Member(guildID discord.GuildID, userID discord.UserID) (*discor
// Members returns a list of members of the guild with the passed id. This
// method automatically paginates until it reaches the passed limit, or, if the
// limit is set to 0, has fetched all members within the passed range.
// limit is set to 0, has fetched all members in the guild.
//
// As the underlying endpoint has a maximum of 1000 members per request, at
// maximum a total of limit/1000 rounded up requests will be made, although
// they may be less, if no more members are available.
// they may be less if no more members are available.
//
// When fetching the members, those with the smallest ID will be fetched first.
func (c *Client) Members(guildID discord.GuildID, limit uint) ([]discord.Member, error) {
@ -27,7 +29,7 @@ func (c *Client) Members(guildID discord.GuildID, limit uint) ([]discord.Member,
// MembersAfter returns a list of members of the guild with the passed id. This
// method automatically paginates until it reaches the passed limit, or, if the
// limit is set to 0, has fetched all members within the passed range.
// limit is set to 0, has fetched all members with an id higher than after.
//
// As the underlying endpoint has a maximum of 1000 members per request, at
// maximum a total of limit/1000 rounded up requests will be made, although
@ -35,13 +37,15 @@ func (c *Client) Members(guildID discord.GuildID, limit uint) ([]discord.Member,
func (c *Client) MembersAfter(
guildID discord.GuildID, after discord.UserID, limit uint) ([]discord.Member, error) {
var mems []discord.Member
mems := make([]discord.Member, 0, limit)
const hardLimit int = 1000
fetch := uint(maxMemberFetchLimit)
unlimited := limit == 0
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
@ -56,13 +60,17 @@ func (c *Client) MembersAfter(
mems = append(mems, m...)
// There aren't any to fetch, even if this is less than limit.
if len(m) < hardLimit {
if len(m) < maxMemberFetchLimit {
break
}
after = mems[len(mems)-1].User.ID
}
if len(mems) == 0 {
return nil, nil
}
return mems, nil
}
@ -248,9 +256,27 @@ func (c *Client) Prune(guildID discord.GuildID, data PruneData) (uint, error) {
// Requires KICK_MEMBERS permission.
// Fires a Guild Member Remove Gateway event.
func (c *Client) Kick(guildID discord.GuildID, userID discord.UserID) error {
return c.KickWithReason(guildID, userID, "")
}
// KickWithReason removes a member from a guild.
// The reason, if non-empty, will be displayed in the audit log of the guild.
//
// Requires KICK_MEMBERS permission.
// Fires a Guild Member Remove Gateway event.
func (c *Client) KickWithReason(
guildID discord.GuildID, userID discord.UserID, reason string) error {
var data struct {
Reason string `schema:"reason,omitempty"`
}
data.Reason = reason
return c.FastRequest(
"DELETE",
EndpointGuilds+guildID.String()+"/members/"+userID.String(),
httputil.WithSchema(c, data),
)
}

View File

@ -8,18 +8,26 @@ import (
"github.com/diamondburned/arikawa/utils/json/option"
)
// Messages returns a list of messages sent in the channel with the passed ID.
// This method automatically paginates until it reaches the passed limit, or,
// if the limit is set to 0, has fetched all guilds within the passed ange.
// the limit of max messages per request, as imposed by Discord
const maxMessageFetchLimit = 100
// Messages returns a slice filled with the most recent messages sent in the
// channel with the passed ID. The method automatically paginates until it
// reaches the passed limit, or, if the limit is set to 0, has fetched all
// messages in the channel.
//
// As the underlying endpoint has a maximum of 100 messages per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more messages are available.
// As the underlying endpoint is capped at a maximum of 100 messages per
// request, at maximum a total of limit/100 rounded up requests will be made,
// although they may be less, if no more messages are available.
//
// When fetching the messages, those with the smallest ID will be fetched
// When fetching the messages, those with the highest ID, will be fetched
// first.
// The returned slice will be sorted from latest to oldest.
func (c *Client) Messages(channelID discord.ChannelID, limit uint) ([]discord.Message, error) {
return c.MessagesAfter(channelID, 0, limit)
// Since before is 0 it will be omitted by the http lib, which in turn
// will lead discord to send us the most recent messages without having to
// specify a Snowflake.
return c.MessagesBefore(channelID, 0, limit)
}
// MessagesAround returns messages around the ID, with a limit of 100.
@ -29,85 +37,112 @@ func (c *Client) MessagesAround(
return c.messagesRange(channelID, 0, 0, around, limit)
}
// MessagesBefore returns a list messages sent in the channel with the passed
// ID. This method automatically paginates until it reaches the passed limit,
// or, if the limit is set to 0, has fetched all guilds within the passed
// range.
// MessagesBefore returns a slice filled with the messages sent in the channel
// with the passed id. The method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all messages in the
// channel with an id smaller than before.
//
// As the underlying endpoint has a maximum of 100 messages per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more messages are available.
//
// The returned slice will be sorted from latest to oldest.
func (c *Client) MessagesBefore(
channelID discord.ChannelID, before discord.MessageID, limit uint) ([]discord.Message, error) {
var msgs []discord.Message
msgs := make([]discord.Message, 0, limit)
// this is the limit of max messages per request, as imposed by Discord
const hardLimit int = 100
fetch := uint(maxMessageFetchLimit)
// Check if we are truly fetching unlimited messages to avoid confusion
// later on, if the limit reaches 0.
unlimited := limit == 0
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
for limit > 0 || unlimited {
if limit > 0 {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if fetch > limit {
fetch = limit
}
limit -= fetch
limit -= maxMessageFetchLimit
}
m, err := c.messagesRange(channelID, before, 0, 0, fetch)
if err != nil {
return msgs, err
}
msgs = append(m, msgs...)
// Append the older messages into the list of newer messages.
msgs = append(msgs, m...)
if len(m) < hardLimit {
if len(m) < maxMessageFetchLimit {
break
}
before = m[0].ID
before = m[len(m)-1].ID
}
if len(msgs) == 0 {
return nil, nil
}
return msgs, nil
}
// MessagesAfter returns a list messages sent in the channel with the passed
// ID. This method automatically paginates until it reaches the passed limit,
// or, if the limit is set to 0, has fetched all guilds within the passed
// range.
// MessagesAfter returns a slice filled with the messages sent in the channel
// with the passed ID. The method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all messages in the
// channel with an id higher than after.
//
// As the underlying endpoint has a maximum of 100 messages per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more messages are available.
//
// The returned slice will be sorted from latest to oldest.
func (c *Client) MessagesAfter(
channelID discord.ChannelID, after discord.MessageID, limit uint) ([]discord.Message, error) {
// 0 is uint's zero value and will lead to the after param getting omitted,
// which in turn will lead to the most recent messages being returned.
// Setting this to 1 will prevent that.
if after == 0 {
after = 1
}
var msgs []discord.Message
// this is the limit of max messages per request, as imposed by Discord
const hardLimit int = 100
fetch := uint(maxMessageFetchLimit)
// Check if we are truly fetching unlimited messages to avoid confusion
// later on, if the limit reaches 0.
unlimited := limit == 0
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
for limit > 0 || unlimited {
if limit > 0 {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if fetch > limit {
fetch = limit
}
limit -= fetch
limit -= maxMessageFetchLimit
}
m, err := c.messagesRange(channelID, 0, after, 0, fetch)
if err != nil {
return msgs, err
}
msgs = append(msgs, m...)
// Prepend the older messages into the newly-fetched messages list.
msgs = append(m, msgs...)
if len(m) < hardLimit {
if len(m) < maxMessageFetchLimit {
break
}
after = m[len(m)-1].ID
after = m[0].ID
}
if len(msgs) == 0 {
return nil, nil
}
return msgs, nil

View File

@ -7,6 +7,8 @@ import (
"github.com/diamondburned/arikawa/utils/httputil"
)
const maxMessageReactionFetchLimit = 100
// React creates a reaction for the message.
//
// This endpoint requires the READ_MESSAGE_HISTORY permission to be present on
@ -42,7 +44,8 @@ func (c *Client) Reactions(
// ReactionsBefore returns a list of users that reacted with the passed Emoji.
// This method automatically paginates until it reaches the passed limit, or,
// if the limit is set to 0, has fetched all users within the passed range.
// if the limit is set to 0, has fetched all users with an id smaller than
// before.
//
// As the underlying endpoint has a maximum of 100 users per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
@ -51,13 +54,15 @@ func (c *Client) ReactionsBefore(
channelID discord.ChannelID, messageID discord.MessageID, before discord.UserID, emoji Emoji,
limit uint) ([]discord.User, error) {
var users []discord.User
users := make([]discord.User, 0, limit)
const hardLimit int = 100
fetch := uint(maxMessageReactionFetchLimit)
unlimited := limit == 0
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
@ -71,19 +76,24 @@ func (c *Client) ReactionsBefore(
}
users = append(r, users...)
if len(r) < hardLimit {
if len(r) < maxMessageReactionFetchLimit {
break
}
before = r[0].ID
}
if len(users) == 0 {
return nil, nil
}
return users, nil
}
// ReactionsAfter returns a list of users that reacted with the passed Emoji.
// This method automatically paginates until it reaches the passed limit, or,
// if the limit is set to 0, has fetched all users within the passed range.
// if the limit is set to 0, has fetched all users with an id higher than
// after.
//
// As the underlying endpoint has a maximum of 100 users per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
@ -92,13 +102,15 @@ func (c *Client) ReactionsAfter(
channelID discord.ChannelID, messageID discord.MessageID, after discord.UserID, emoji Emoji,
limit uint) ([]discord.User, error) {
var users []discord.User
users := make([]discord.User, 0, limit)
const hardLimit int = 100
fetch := uint(maxMessageReactionFetchLimit)
unlimited := limit == 0
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
@ -112,13 +124,17 @@ func (c *Client) ReactionsAfter(
}
users = append(users, r...)
if len(r) < hardLimit {
if len(r) < maxMessageReactionFetchLimit {
break
}
after = r[len(r)-1].ID
}
if len(users) == 0 {
return nil, nil
}
return users, nil
}

View File

@ -42,7 +42,7 @@ type CreateRoleData struct {
// Permissions is the bitwise value of the enabled/disabled permissions.
//
// Default: @everyone permissions in guild
Permissions discord.Permissions `json:"permissions,omitempty"`
Permissions discord.Permissions `json:"permissions,omitempty,string"`
// Color is the RGB color value of the role.
//
// Default: 0
@ -98,7 +98,7 @@ type ModifyRoleData struct {
// Name is the name of the role.
Name option.NullableString `json:"name,omitempty"`
// Permissions is the bitwise value of the enabled/disabled permissions.
Permissions *discord.Permissions `json:"permissions,omitempty"`
Permissions *discord.Permissions `json:"permissions,omitempty,string"`
// Permissions is the bitwise value of the enabled/disabled permissions.
Color option.NullableColor `json:"color,omitempty"`
// Hoist specifies whether the role should be displayed separately in the

View File

@ -5,18 +5,45 @@ import (
"strings"
)
// WordOffset is the offset from the position cursor to print on the error.
const WordOffset = 7
var escaper = strings.NewReplacer(
"`", "\\`",
"@", "\\@",
"\\", "\\\\",
)
type ErrParse struct {
Position int
ErrorStart,
ErrorPart,
ErrorEnd string
Words string // joined
}
func (e ErrParse) Error() string {
return fmt.Sprintf(
"Unexpected quote or escape: %s__%s__%s",
e.ErrorStart, e.ErrorPart, e.ErrorEnd,
// Magic number 5.
var a = max(0, e.Position-WordOffset)
var b = min(len(e.Words), e.Position+WordOffset)
var word = e.Words[a:b]
var uidx = e.Position - a
errstr := strings.Builder{}
errstr.WriteString("Unexpected quote or escape")
// Do a bound check.
if uidx+1 > len(word) {
// Invalid.
errstr.WriteString(".")
return errstr.String()
}
// Write the pre-underline part.
fmt.Fprintf(
&errstr, ": %s__%s__",
escaper.Replace(word[:uidx]),
escaper.Replace(string(word[uidx:])),
)
return errstr.String()
}
// Parse parses the given text to a slice of words.
@ -76,11 +103,6 @@ func Parse(line string) ([]string, error) {
got = true
}
// // If this is a backtick, then write it.
// if r == '`' {
// buf.WriteByte('`')
// }
singleQuoted = !singleQuoted
continue
}
@ -95,19 +117,9 @@ func Parse(line string) ([]string, error) {
}
if escaped || singleQuoted || doubleQuoted {
// the number of characters to highlight
var (
pos = cursor + 5
start = string(runes[max(cursor-100, 0) : pos-1])
end = string(runes[pos+1 : min(cursor+100, len(runes))])
part = string(runes[max(pos-1, 0):min(len(runes), pos+2)])
)
return args, &ErrParse{
Position: cursor,
ErrorStart: start,
ErrorPart: part,
ErrorEnd: end,
Position: cursor + buf.Len(),
Words: strings.Join(args, " "),
}
}

View File

@ -38,6 +38,16 @@ func TestParse(t *testing.T) {
[]string{"how", "about", "a", "go\npackage main\n", "go", "code?"},
false,
},
{
"this should not crash `",
[]string{"this", "should", "not", "crash"},
true,
},
{
"this should not crash '",
[]string{"this", "should", "not", "crash"},
true,
},
}
for _, test := range tests {

View File

@ -1,5 +1,7 @@
package discord
import "github.com/diamondburned/arikawa/utils/json"
// https://discord.com/developers/docs/resources/channel#channel-object
type Channel struct {
// ID is the id of this channel.
@ -55,7 +57,7 @@ type Channel struct {
// Mention returns a mention of the channel.
func (ch Channel) Mention() string {
return "<#" + ch.ID.String() + ">"
return ch.ID.Mention()
}
// IconURL returns the URL to the channel icon in the PNG format.
@ -103,13 +105,37 @@ var (
// https://discord.com/developers/docs/resources/channel#overwrite-object
type Overwrite struct {
// ID is the role or user id.
ID Snowflake `json:"id,string"`
ID Snowflake `json:"id"`
// Type is either "role" or "member".
Type OverwriteType `json:"type"`
// Allow is a permission bit set for granted permissions.
Allow Permissions `json:"allow"`
Allow Permissions `json:"allow,string"`
// Deny is a permission bit set for denied permissions.
Deny Permissions `json:"deny"`
Deny Permissions `json:"deny,string"`
}
// UnmarshalJSON unmarshals the passed json data into the Overwrite.
// This is necessary because Discord has different names for fields when
// sending than receiving.
func (o *Overwrite) UnmarshalJSON(data []byte) (err error) {
var recv struct {
ID Snowflake `json:"id"`
Type OverwriteType `json:"type"`
Allow Permissions `json:"allow_new,string"`
Deny Permissions `json:"deny_new,string"`
}
err = json.Unmarshal(data, &recv)
if err != nil {
return
}
o.ID = recv.ID
o.Type = recv.Type
o.Allow = recv.Allow
o.Deny = recv.Deny
return
}
type OverwriteType string

View File

@ -23,7 +23,7 @@ type Guild struct {
// Permissions are the total permissions for the user in the guild
// (excludes overrides).
Permissions Permissions `json:"permissions,omitempty"`
Permissions Permissions `json:"permissions_new,omitempty,string"`
// VoiceRegion is the voice region id for the guild.
VoiceRegion string `json:"region"`
@ -297,7 +297,7 @@ type Role struct {
Position int `json:"position"`
// Permissions is the permission bit set.
Permissions Permissions `json:"permissions"`
Permissions Permissions `json:"permissions_new,string"`
// Manages specifies whether this role is managed by an integration.
Managed bool `json:"managed"`
@ -307,7 +307,7 @@ type Role struct {
// Mention returns the mention of the Role.
func (r Role) Mention() string {
return "<&" + r.ID.String() + ">"
return r.ID.Mention()
}
// https://discord.com/developers/docs/topics/gateway#presence-update
@ -375,7 +375,7 @@ type Member struct {
// Mention returns the mention of the role.
func (m Member) Mention() string {
return "<@!" + m.User.ID.String() + ">"
return m.User.Mention()
}
// https://discord.com/developers/docs/resources/guild#ban-object

View File

@ -30,12 +30,12 @@ func ParseSnowflake(sf string) (Snowflake, error) {
return NullSnowflake, nil
}
i, err := strconv.ParseInt(sf, 10, 64)
u, err := strconv.ParseUint(sf, 10, 64)
if err != nil {
return 0, err
}
return Snowflake(i), nil
return Snowflake(u), nil
}
func (s *Snowflake) UnmarshalJSON(v []byte) error {
@ -149,6 +149,7 @@ func (s ChannelID) Time() time.Time { return Snowflake(s).Time() }
func (s ChannelID) Worker() uint8 { return Snowflake(s).Worker() }
func (s ChannelID) PID() uint8 { return Snowflake(s).PID() }
func (s ChannelID) Increment() uint16 { return Snowflake(s).Increment() }
func (s ChannelID) Mention() string { return "<#" + s.String() + ">" }
type EmojiID Snowflake
@ -219,6 +220,7 @@ func (s RoleID) Time() time.Time { return Snowflake(s).Time() }
func (s RoleID) Worker() uint8 { return Snowflake(s).Worker() }
func (s RoleID) PID() uint8 { return Snowflake(s).PID() }
func (s RoleID) Increment() uint16 { return Snowflake(s).Increment() }
func (s RoleID) Mention() string { return "<@&" + s.String() + ">" }
type UserID Snowflake
@ -233,6 +235,7 @@ func (s UserID) Time() time.Time { return Snowflake(s).Time() }
func (s UserID) Worker() uint8 { return Snowflake(s).Worker() }
func (s UserID) PID() uint8 { return Snowflake(s).PID() }
func (s UserID) Increment() uint16 { return Snowflake(s).Increment() }
func (s UserID) Mention() string { return "<@" + s.String() + ">" }
type WebhookID Snowflake

View File

@ -27,7 +27,7 @@ type User struct {
}
func (u User) Mention() string {
return "<@" + u.ID.String() + ">"
return u.ID.Mention()
}
// AvatarURL returns the URL of the Avatar Image. It automatically detects a

View File

@ -81,7 +81,7 @@ type RequestGuildMembersData struct {
GuildID []discord.GuildID `json:"guild_id"`
UserIDs []discord.UserID `json:"user_ids,omitempty"`
Query string `json:"query,omitempty"`
Query string `json:"query"`
Limit uint `json:"limit"`
Presences bool `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
@ -149,7 +149,7 @@ type GuildSubscribeData struct {
GuildID discord.GuildID `json:"guild_id"`
// Channels is not documented. It's used to fetch the right members sidebar.
Channels map[discord.ChannelID][][2]int `json:"channels"`
Channels map[discord.ChannelID][][2]int `json:"channels,omitempty"`
}
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {

View File

@ -133,7 +133,7 @@ type (
Ops []GuildMemberListOp `json:"ops"`
}
GuildMemberListGroup struct {
ID string `json:"id"` // either discord.RoleID or "online"
ID string `json:"id"` // either discord.RoleID, "online" or "offline"
Count uint64 `json:"count"`
}
GuildMemberListOp struct {

View File

@ -85,11 +85,14 @@ type Gateway struct {
// Session.
Events chan Event
// SessionID is used to store the session ID received after Ready. It is not
// thread-safe.
SessionID string
Identifier *Identifier
Sequence *Sequence
PacerLoop *wsutil.PacemakerLoop
PacerLoop wsutil.PacemakerLoop
ErrorLog func(err error) // default to log.Println
@ -98,11 +101,6 @@ type Gateway struct {
// reconnections or any type of connection interruptions.
AfterClose func(err error) // noop by default
// Mutex to hold off calls when the WS is not available. Doesn't block if
// Start() is not called or Close() is called. Also doesn't block for
// Identify or Resume.
// available sync.RWMutex
// Filled by methods, internal use
waitGroup *sync.WaitGroup
}
@ -163,40 +161,38 @@ func (g *Gateway) AddIntent(i Intents) {
}
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
func (g *Gateway) Close() (err error) {
wsutil.WSDebug("Trying to close.")
// Check if the WS is already closed:
if g.waitGroup == nil && g.PacerLoop.Stopped() {
if g.PacerLoop.Stopped() {
wsutil.WSDebug("Gateway is already closed.")
g.AfterClose(nil)
return nil
return err
}
// Trigger the close callback on exit.
defer func() { g.AfterClose(err) }()
// If the pacemaker is running:
if !g.PacerLoop.Stopped() {
wsutil.WSDebug("Stopping pacemaker...")
// Stop the pacemaker and the event handler
// Stop the pacemaker and the event handler.
g.PacerLoop.Stop()
wsutil.WSDebug("Stopped pacemaker.")
}
wsutil.WSDebug("Closing the websocket...")
err = g.WS.Close()
wsutil.WSDebug("Waiting for WaitGroup to be done.")
// This should work, since Pacemaker should signal its loop to stop, which
// would also exit our event loop. Both would be 2.
g.waitGroup.Wait()
// Mark g.waitGroup as empty:
g.waitGroup = nil
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
err := g.WS.Close()
g.AfterClose(err)
return err
}
@ -356,13 +352,11 @@ func (g *Gateway) start(ctx context.Context) error {
return errors.Wrap(err, "first error")
}
// Use the pacemaker loop.
g.PacerLoop = wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, g)
// Start the event handler, which also handles the pacemaker death signal.
g.waitGroup.Add(1)
g.PacerLoop.RunAsync(func(err error) {
// Use the pacemaker loop.
g.PacerLoop.RunAsync(hello.HeartbeatInterval.Duration(), ch, g, func(err error) {
g.waitGroup.Done() // mark so Close() can exit.
wsutil.WSDebug("Event loop stopped with error:", err)

View File

@ -30,6 +30,11 @@ const (
GuildSubscriptionsOP OPCode = 14
)
// ErrReconnectRequest is returned by HandleOP if a ReconnectOP is given. This
// is used mostly internally to signal the heartbeat loop to reconnect, if
// needed. It is not a fatal error.
var ErrReconnectRequest = errors.New("ReconnectOP received")
func (g *Gateway) HandleOP(op *wsutil.OP) error {
switch op.Code {
case HeartbeatAckOP:
@ -41,19 +46,17 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
defer cancel()
// Server requesting a heartbeat.
return g.PacerLoop.Pace(ctx)
if err := g.PacerLoop.Pace(ctx); err != nil {
return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace"))
}
case ReconnectOP:
// Server requests to reconnect, die and retry.
wsutil.WSDebug("ReconnectOP received.")
// We must reconnect in another goroutine, as running Reconnect
// synchronously would prevent the main event loop from exiting.
go g.Reconnect()
// Gracefully exit with a nil let the event handler take the signal from
// the pacemaker.
return nil
// Exit with the ReconnectOP error to force the heartbeat event loop to
// reconnect synchronously. Not really a fatal error.
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
case InvalidSessionOP:
// Discord expects us to sleep for no reason
@ -65,13 +68,12 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
// Invalid session, try and Identify.
if err := g.IdentifyCtx(ctx); err != nil {
// Can't identify, reconnect.
go g.Reconnect()
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
}
return nil
case HelloOP:
// What is this OP doing here???
return nil
case DispatchOP:
@ -102,7 +104,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
g.SessionID = ev.SessionID
}
// Throw the event into a channel, it's valid now.
// Throw the event into a channel; it's valid now.
g.Events <- ev
return nil

View File

@ -120,7 +120,7 @@ type GuildFolder struct {
// GuildFolderID is possibly a snowflake. It can also be 0 (null) or a low
// number of unknown significance.
type GuildFolderID uint64
type GuildFolderID int64
func (g *GuildFolderID) UnmarshalJSON(b []byte) error {
var body = string(b)
@ -130,7 +130,7 @@ func (g *GuildFolderID) UnmarshalJSON(b []byte) error {
body = strings.Trim(body, `"`)
u, err := strconv.ParseUint(body, 10, 64)
u, err := strconv.ParseInt(body, 10, 64)
if err != nil {
return err
}
@ -144,5 +144,5 @@ func (g GuildFolderID) MarshalJSON() ([]byte, error) {
return []byte("null"), nil
}
return []byte(strconv.FormatUint(uint64(g), 10)), nil
return []byte(strconv.FormatInt(int64(g), 10)), nil
}

1
go.sum
View File

@ -16,3 +16,4 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s=

View File

@ -3,7 +3,6 @@ package heart
import (
"context"
"sync"
"sync/atomic"
"time"
@ -36,23 +35,30 @@ type Pacemaker struct {
// Heartrate is the received duration between heartbeats.
Heartrate time.Duration
ticker time.Ticker
Ticks <-chan time.Time
// Time in nanoseconds, guarded by atomic read/writes.
SentBeat AtomicTime
EchoBeat AtomicTime
// Any callback that returns an error will stop the pacer.
Pace func(context.Context) error
stop chan struct{}
once sync.Once
death chan error
Pacer func(context.Context) error
}
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) *Pacemaker {
return &Pacemaker{
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker {
p := Pacemaker{
Heartrate: heartrate,
Pace: pacer,
Pacer: pacer,
ticker: *time.NewTicker(heartrate),
}
p.Ticks = p.ticker.C
// Reset states to its old position.
now := time.Now()
p.EchoBeat.Set(now)
p.SentBeat.Set(now)
return p
}
func (p *Pacemaker) Echo() {
@ -62,14 +68,6 @@ func (p *Pacemaker) Echo() {
// Dead, if true, will have Pace return an ErrDead.
func (p *Pacemaker) Dead() bool {
/* Deprecated
if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() {
return false
}
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
*/
var (
echo = p.EchoBeat.Get()
sent = p.SentBeat.Get()
@ -84,75 +82,84 @@ func (p *Pacemaker) Dead() bool {
// Stop stops the pacemaker, or it does nothing if the pacemaker is not started.
func (p *Pacemaker) Stop() {
Debug("(*Pacemaker).stop is trying sync.Once.")
p.once.Do(func() {
Debug("(*Pacemaker).stop closed the channel.")
close(p.stop)
})
p.ticker.Stop()
}
// pace sends a heartbeat with the appropriate timeout for the context.
func (p *Pacemaker) pace() error {
func (p *Pacemaker) Pace() error {
ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate)
defer cancel()
return p.Pace(ctx)
return p.PaceCtx(ctx)
}
func (p *Pacemaker) start() error {
// Reset states to its old position.
p.EchoBeat.Set(time.Time{})
p.SentBeat.Set(time.Time{})
// Create a new ticker.
tick := time.NewTicker(p.Heartrate)
defer tick.Stop()
// Echo at least once
p.Echo()
for {
if err := p.pace(); err != nil {
return errors.Wrap(err, "failed to pace")
}
// Paced, save:
p.SentBeat.Set(time.Now())
if p.Dead() {
return ErrDead
}
select {
case <-p.stop:
return nil
case <-tick.C:
}
}
}
// StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
p.death = make(chan error)
p.stop = make(chan struct{})
p.once = sync.Once{}
if wg != nil {
wg.Add(1)
func (p *Pacemaker) PaceCtx(ctx context.Context) error {
if err := p.Pacer(ctx); err != nil {
return err
}
go func() {
p.death <- p.start()
// Debug.
Debug("Pacemaker returned.")
p.SentBeat.Set(time.Now())
// Mark the pacemaker loop as done.
if wg != nil {
wg.Done()
}
}()
if p.Dead() {
return ErrDead
}
return p.death
return nil
}
// func (p *Pacemaker) start() error {
// // Reset states to its old position.
// p.EchoBeat.Set(time.Time{})
// p.SentBeat.Set(time.Time{})
// // Create a new ticker.
// tick := time.NewTicker(p.Heartrate)
// defer tick.Stop()
// // Echo at least once
// p.Echo()
// for {
// if err := p.pace(); err != nil {
// return errors.Wrap(err, "failed to pace")
// }
// // Paced, save:
// p.SentBeat.Set(time.Now())
// if p.Dead() {
// return ErrDead
// }
// select {
// case <-p.stop:
// return nil
// case <-tick.C:
// }
// }
// }
// // StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
// func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
// p.death = make(chan error)
// p.stop = make(chan struct{})
// p.once = sync.Once{}
// if wg != nil {
// wg.Add(1)
// }
// go func() {
// p.death <- p.start()
// // Debug.
// Debug("Pacemaker returned.")
// // Mark the pacemaker loop as done.
// if wg != nil {
// wg.Done()
// }
// }()
// return p.death
// }

View File

@ -4,6 +4,8 @@
package session
import (
"sync"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/api"
@ -39,6 +41,7 @@ type Session struct {
Ticket string
hstop chan struct{}
wstop sync.Once
}
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
@ -103,9 +106,9 @@ func NewWithGateway(gw *gateway.Gateway) *Session {
func (s *Session) Open() error {
// Start the handler beforehand so no events are missed.
stop := make(chan struct{})
s.hstop = stop
go s.startHandler(stop)
s.hstop = make(chan struct{})
s.wstop = sync.Once{}
go s.startHandler()
// Set the AfterClose's handler.
s.Gateway.AfterClose = func(err error) {
@ -121,10 +124,10 @@ func (s *Session) Open() error {
return nil
}
func (s *Session) startHandler(stop <-chan struct{}) {
func (s *Session) startHandler() {
for {
select {
case <-stop:
case <-s.hstop:
return
case ev := <-s.Gateway.Events:
s.Call(ev)
@ -134,14 +137,7 @@ func (s *Session) startHandler(stop <-chan struct{}) {
func (s *Session) Close() error {
// Stop the event handler
s.close()
s.wstop.Do(func() { s.hstop <- struct{}{} })
// Close the websocket
return s.Gateway.Close()
}
func (s *Session) close() {
if s.hstop != nil {
close(s.hstop)
}
}

View File

@ -636,6 +636,10 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.R
}
}
if role == nil {
return nil, ErrStoreNotFound
}
return role, nil
}

View File

@ -50,6 +50,13 @@ func (s *State) hookSession() {
func (s *State) onEvent(iface interface{}) {
switch ev := iface.(type) {
case *gateway.ReadyEvent:
// Reset the store before proceeding.
if resetter, ok := s.Store.(StoreResetter); ok {
if err := resetter.Reset(); err != nil {
s.stateErr(err, "Failed to reset state on READY")
}
}
// Set Ready to the state
s.Ready = *ev

View File

@ -92,6 +92,12 @@ type StoreModifier interface {
VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error
}
// StoreResetter is used by the state to reset the store on every Ready event.
type StoreResetter interface {
// Reset resets the store to a new valid instance.
Reset() error
}
// ErrStoreNotFound is an error that a store can use to return when something
// isn't in the storage. There is no strict restrictions on what uses this (the
// default one does, though), so be advised.

View File

@ -9,12 +9,17 @@ import (
"sync"
"time"
"github.com/diamondburned/arikawa/utils/json"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
const CopyBufferSize = 2048
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
// size is 4KB.
var CopyBufferSize = 4096
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
var MaxCapUntilReset = CopyBufferSize * 4
// CloseDeadline controls the deadline to wait for sending the Close frame.
var CloseDeadline = time.Second
@ -45,72 +50,54 @@ type Connection interface {
// Conn is the default Websocket connection. It compresses all payloads using
// zlib.
type Conn struct {
mutex sync.Mutex
Conn *websocket.Conn
json.Driver
dialer *websocket.Dialer
events chan Event
// write channels
writes chan []byte
errors chan error
buf bytes.Buffer
zlib io.ReadCloser // (compress/zlib).reader
// nil until Dial().
closeOnce *sync.Once
// zlib *zlib.Inflator // zlib.NewReader
// buf []byte // io.Copy buffer
}
var _ Connection = (*Conn)(nil)
// NewConn creates a new default websocket connection with a default dialer.
func NewConn() *Conn {
return NewConnWithDriver(json.Default)
return NewConnWithDialer(&websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
ReadBufferSize: CopyBufferSize,
WriteBufferSize: CopyBufferSize,
EnableCompression: true,
})
}
func NewConnWithDriver(driver json.Driver) *Conn {
return &Conn{
Driver: driver,
dialer: &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
EnableCompression: true,
},
// zlib: zlib.NewInflator(),
// buf: make([]byte, CopyBufferSize),
}
// NewConn creates a new default websocket connection with a custom dialer.
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
return &Conn{dialer: dialer}
}
func (c *Conn) Dial(ctx context.Context, addr string) error {
var err error
// Enable compression:
headers := http.Header{}
headers.Set("Accept-Encoding", "zlib")
headers := http.Header{
"Accept-Encoding": {"zlib"},
}
// BUG: https://github.com/golang/go/issues/31514
// // Enable stream compression:
// addr = InjectValues(addr, url.Values{
// "compress": {"zlib-stream"},
// })
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
conn, _, err := c.dialer.DialContext(ctx, addr, headers)
if err != nil {
return errors.Wrap(err, "failed to dial WS")
}
// Set up the closer.
c.closeOnce = &sync.Once{}
events := make(chan Event, WSBuffer)
go startReadLoop(conn, events)
c.events = make(chan Event)
go c.readLoop()
c.mutex.Lock()
defer c.mutex.Unlock()
c.writes = make(chan []byte)
c.errors = make(chan error)
go c.writeLoop()
c.Conn = conn
c.events = events
return err
}
@ -119,18 +106,88 @@ func (c *Conn) Listen() <-chan Event {
return c.events
}
func (c *Conn) readLoop() {
// Acquire the read lock throughout the span of the loop. This would still
// allow Send to acquire another RLock, but wouldn't allow Close to
// prematurely exit, as Close acquires a write lock.
// c.mut.RLock()
// defer c.mut.RUnlock()
// resetDeadline is used to reset the write deadline after using the context's.
var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
d, ok := ctx.Deadline()
if ok {
c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline)
}
return c.Conn.WriteMessage(websocket.TextMessage, b)
}
func (c *Conn) Close() error {
// Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls.
WSDebug("Conn: Acquiring write lock...")
// Acquire the write lock forever.
c.mutex.Lock()
defer c.mutex.Unlock()
WSDebug("Conn: Write lock acquired; closing.")
// Close the WS.
err := c.closeWS()
WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...")
// Flush all events before closing the channel. This will return as soon as
// c.events is closed, or after closed.
for range c.events {
}
WSDebug("Flushed events.")
// Mark c.Conn as empty.
c.Conn = nil
return err
}
func (c *Conn) closeWS() error {
// We can't close with a write control here, since it will invalidate the
// old session, breaking resumes.
// // Quick deadline:
// deadline := time.Now().Add(CloseDeadline)
// // Make a closure message:
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
// // Send a close message before closing the connection. We're not error
// // checking this because it's not important.
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
return c.Conn.Close()
}
// loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself.
type loopState struct {
conn *websocket.Conn
zlib io.ReadCloser
buf bytes.Buffer
}
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
// Clean up the events channel in the end.
defer close(c.events)
defer close(eventCh)
// Allocate the read loop its own private resources.
state := loopState{conn: conn}
state.buf.Grow(CopyBufferSize)
for {
b, err := c.handle()
b, err := state.handle()
if err != nil {
// Is the error an EOF?
if errors.Is(err, io.EOF) {
@ -144,7 +201,7 @@ func (c *Conn) readLoop() {
}
// Unusual error; log and exit:
c.events <- Event{nil, errors.Wrap(err, "WS error")}
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
return
}
@ -153,34 +210,13 @@ func (c *Conn) readLoop() {
continue
}
c.events <- Event{b, nil}
eventCh <- Event{b, nil}
}
}
func (c *Conn) writeLoop() {
// Closing c.writes would break the loop immediately.
for b := range c.writes {
c.errors <- c.Conn.WriteMessage(websocket.TextMessage, b)
}
// Quick deadline:
deadline := time.Now().Add(CloseDeadline)
// Make a closure message:
msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
// Send a close message before closing the connection. We're not error
// checking this because it's not important.
c.Conn.WriteControl(websocket.TextMessage, msg, deadline)
// Safe to close now.
c.errors <- c.Conn.Close()
close(c.errors)
}
func (c *Conn) handle() ([]byte, error) {
func (state *loopState) handle() ([]byte, error) {
// skip message type
t, r, err := c.Conn.NextReader()
t, r, err := state.conn.NextReader()
if err != nil {
return nil, err
}
@ -188,112 +224,43 @@ func (c *Conn) handle() ([]byte, error) {
if t == websocket.BinaryMessage {
// Probably a zlib payload
if c.zlib == nil {
if state.zlib == nil {
z, err := zlib.NewReader(r)
if err != nil {
return nil, errors.Wrap(err, "failed to create a zlib reader")
}
c.zlib = z
state.zlib = z
} else {
if err := c.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
return nil, errors.Wrap(err, "failed to reset zlib reader")
}
}
defer c.zlib.Close()
r = c.zlib
defer state.zlib.Close()
r = state.zlib
}
return readAll(&c.buf, r)
// if t is a text message, then handle it normally.
// if t == websocket.TextMessage {
// return readAll(&c.buf, r)
// }
// // Write to the zlib writer.
// c.zlib.Write(r)
// // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil {
// // return nil, errors.Wrap(err, "Failed to write to zlib")
// // }
// if !c.zlib.CanFlush() {
// return nil, nil
// }
// // Flush and get the uncompressed payload.
// b, err := c.zlib.Flush()
// if err != nil {
// return nil, errors.Wrap(err, "Failed to flush zlib")
// }
// return nil, errors.New("Unexpected binary message.")
}
func (c *Conn) Send(ctx context.Context, b []byte) error {
// If websocket is already closed.
if c.writes == nil {
return ErrWebsocketClosed
}
// Send the bytes.
select {
case c.writes <- b:
// continue
case <-ctx.Done():
return ctx.Err()
}
// Receive the error.
select {
case err := <-c.errors:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) Close() (err error) {
// Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls.
c.closeOnce.Do(func() {
// Close c.writes. This should trigger the websocket to close itself.
close(c.writes)
// Mark c.writes as empty.
c.writes = nil
// Wait for the write loop to exit by flusing the errors channel.
err = <-c.errors // get close error
for range c.errors { // then flush
}
// Flush all events before closing the channel. This will return as soon as
// c.events is closed, or after closed.
for range c.events {
}
// Mark c.events as empty.
c.events = nil
// Mark c.Conn as empty.
c.Conn = nil
})
return err
return state.readAll(r)
}
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
// buffer.
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
defer buf.Reset()
if _, err := buf.ReadFrom(r); err != nil {
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
defer state.buf.Reset()
if _, err := state.buf.ReadFrom(r); err != nil {
return nil, err
}
// Copy the bytes so we could empty the buffer for reuse.
p := buf.Bytes()
cpy := make([]byte, len(p))
copy(cpy, p)
cpy := make([]byte, state.buf.Len())
copy(cpy, state.buf.Bytes())
// If the buffer's capacity is over the limit, then re-allocate a new one.
if state.buf.Cap() > MaxCapUntilReset {
state.buf = bytes.Buffer{}
state.buf.Grow(CopyBufferSize)
}
return cpy, nil
}

View File

@ -2,6 +2,7 @@ package wsutil
import (
"context"
"runtime/debug"
"time"
"github.com/pkg/errors"
@ -10,35 +11,57 @@ import (
"github.com/diamondburned/arikawa/internal/moreatomic"
)
type errBrokenConnection struct {
underneath error
}
// Error formats the broken connection error with the message "explicit
// connection break."
func (err errBrokenConnection) Error() string {
return "explicit connection break: " + err.underneath.Error()
}
// Unwrap returns the underlying error.
func (err errBrokenConnection) Unwrap() error {
return err.underneath
}
// ErrBrokenConnection marks the given error as a broken connection error. This
// error will cause the pacemaker loop to break and return the error. The error,
// when stringified, will say "explicit connection break."
func ErrBrokenConnection(err error) error {
return errBrokenConnection{underneath: err}
}
// IsBrokenConnection returns true if the error is a broken connection error.
func IsBrokenConnection(err error) bool {
var broken *errBrokenConnection
return errors.As(err, &broken)
}
// TODO API
type EventLoopHandler interface {
EventHandler
HeartbeatCtx(context.Context) error
}
// PacemakerLoop provides an event loop with a pacemaker.
// PacemakerLoop provides an event loop with a pacemaker. A zero-value instance
// is a valid instance only when RunAsync is called first.
type PacemakerLoop struct {
pacemaker *heart.Pacemaker // let's not copy this
pacedeath chan error
heart.Pacemaker
running moreatomic.Bool
stop chan struct{}
events <-chan Event
handler func(*OP) error
stack []byte
Extras ExtraHandlers
ErrorLog func(error)
}
func NewLoop(heartrate time.Duration, evs <-chan Event, evl EventLoopHandler) *PacemakerLoop {
return &PacemakerLoop{
pacemaker: heart.NewPacemaker(heartrate, evl.HeartbeatCtx),
events: evs,
handler: evl.HandleOP,
}
}
func (p *PacemakerLoop) errorLog(err error) {
if p.ErrorLog == nil {
WSDebug("Uncaught error:", err)
@ -50,28 +73,36 @@ func (p *PacemakerLoop) errorLog(err error) {
// Pace calls the pacemaker's Pace function.
func (p *PacemakerLoop) Pace(ctx context.Context) error {
return p.pacemaker.Pace(ctx)
return p.Pacemaker.PaceCtx(ctx)
}
// Echo calls the pacemaker's Echo function.
func (p *PacemakerLoop) Echo() {
p.pacemaker.Echo()
}
// Stop calls the pacemaker's Stop function.
// Stop stops the pacer loop. It does nothing if the loop is already stopped.
func (p *PacemakerLoop) Stop() {
p.pacemaker.Stop()
if p.Stopped() {
return
}
// Despite p.running and p.stop being thread-safe on their own, this entire
// block is actually not thread-safe.
p.Pacemaker.Stop()
close(p.stop)
}
func (p *PacemakerLoop) Stopped() bool {
return p == nil || !p.running.Get()
}
func (p *PacemakerLoop) RunAsync(exit func(error)) {
func (p *PacemakerLoop) RunAsync(
heartrate time.Duration, evs <-chan Event, evl EventLoopHandler, exit func(error)) {
WSDebug("Starting the pacemaker loop.")
// callers should explicitly handle waitgroups.
p.pacedeath = p.pacemaker.StartAsync(nil)
p.Pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx)
p.handler = evl.HandleOP
p.events = evs
p.stack = debug.Stack()
p.stop = make(chan struct{})
p.running.Set(true)
go func() {
@ -82,21 +113,27 @@ func (p *PacemakerLoop) RunAsync(exit func(error)) {
func (p *PacemakerLoop) startLoop() error {
defer WSDebug("Pacemaker loop has exited.")
defer p.running.Set(false)
defer p.Pacemaker.Stop()
for {
select {
case err := <-p.pacedeath:
WSDebug("Pacedeath returned with error:", err)
return errors.Wrap(err, "pacemaker died, reconnecting")
case <-p.stop:
WSDebug("Stop requested; exiting.")
return nil
case <-p.Pacemaker.Ticks:
if err := p.Pacemaker.Pace(); err != nil {
return errors.Wrap(err, "pace failed, reconnecting")
}
case ev, ok := <-p.events:
if !ok {
WSDebug("Events channel closed, stopping pacemaker.")
defer WSDebug("Pacemaker stopped automatically.")
// Events channel is closed. Kill the pacemaker manually and
// die.
p.pacemaker.Stop()
return <-p.pacedeath
return nil
}
if ev.Error != nil {
return errors.Wrap(ev.Error, "event returned error")
}
o, err := DecodeOP(ev)
@ -109,7 +146,11 @@ func (p *PacemakerLoop) startLoop() error {
// Handle the event
if err := p.handler(o); err != nil {
p.errorLog(errors.Wrap(err, "handler failed"))
if IsBrokenConnection(err) {
return errors.Wrap(err, "handler failed")
}
p.errorLog(err)
}
}
}

View File

@ -15,15 +15,12 @@ import (
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
WSTimeout = 5 * time.Minute
WSTimeout = 30 * time.Second
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSExtraReadTimeout is the duration to be added to Hello, as a read
// timeout for the websocket.
WSExtraReadTimeout = time.Second
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}
@ -82,9 +79,6 @@ func (ws *Websocket) Dial(ctx context.Context) error {
return errors.Wrap(err, "failed to dial")
}
// Reset the SendLimiter:
ws.SendLimiter = NewSendLimiter()
return nil
}

View File

@ -3,6 +3,7 @@
package voice
import (
"context"
"encoding/binary"
"io"
"log"
@ -18,73 +19,6 @@ import (
"github.com/diamondburned/arikawa/voice/voicegateway"
)
type testConfig struct {
BotToken string
VoiceChID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
t.Fatal("Missing $VOICE_ID")
}
id, err := discord.ParseSnowflake(sid)
if err != nil {
t.Fatal("Invalid $VOICE_ID:", err)
}
return testConfig{
BotToken: token,
VoiceChID: discord.ChannelID(id),
}
}
// file is only a few bytes lolmao
func nicoReadTo(t *testing.T, dst io.Writer) {
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("Failed to open nico.dca:", err)
}
t.Cleanup(func() {
f.Close()
})
var lenbuf [4]byte
for {
if _, err := io.ReadFull(f, lenbuf[:]); !catchRead(t, err) {
return
}
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(dst, f, framelen); !catchRead(t, err) {
return
}
}
}
func catchRead(t *testing.T, err error) bool {
t.Helper()
if err == io.EOF {
return false
}
if err != nil {
t.Fatal("Failed to read:", err)
}
return true
}
func TestIntegration(t *testing.T) {
config := mustConfig(t)
@ -150,12 +84,76 @@ func TestIntegration(t *testing.T) {
finish("sending the speaking command")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := vs.UseContext(ctx); err != nil {
t.Fatal("failed to set ctx into vs:", err)
}
// Copy the audio?
nicoReadTo(t, vs)
finish("copying the audio")
}
type testConfig struct {
BotToken string
VoiceChID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
t.Fatal("Missing $VOICE_ID")
}
id, err := discord.ParseSnowflake(sid)
if err != nil {
t.Fatal("Invalid $VOICE_ID:", err)
}
return testConfig{
BotToken: token,
VoiceChID: discord.ChannelID(id),
}
}
// file is only a few bytes lolmao
func nicoReadTo(t *testing.T, dst io.Writer) {
t.Helper()
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("Failed to open nico.dca:", err)
}
defer f.Close()
var lenbuf [4]byte
for {
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
if err == io.EOF {
break
}
t.Fatal("failed to read:", err)
}
// Read the integer
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
// Copy the frame.
if _, err := io.CopyN(dst, f, framelen); err != nil && err != io.EOF {
t.Fatal("failed to write:", err)
}
}
}
// simple shitty benchmark thing
func timer() func(finished string) {
var then = time.Now()

View File

@ -110,14 +110,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
}
}
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
func (s *Session) JoinChannel(
gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
}
func (s *Session) JoinChannelCtx(ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
func (s *Session) JoinChannelCtx(
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
// Acquire the mutex during join, locking during IO as well.
s.mut.Lock()
defer s.mut.Unlock()
@ -211,8 +215,7 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
return errors.Wrap(err, "failed to select protocol")
}
// Start the UDP loop.
go s.voiceUDP.Start(&d.SecretKey)
s.voiceUDP.UseSecret(d.SecretKey)
return nil
}
@ -237,6 +240,18 @@ func (s *Session) StopSpeaking() error {
return nil
}
// UseContext tells the UDP voice connection to write with the given mutex.
func (s *Session) UseContext(ctx context.Context) error {
s.mut.RLock()
defer s.mut.RUnlock()
if s.voiceUDP == nil {
return ErrCannotSend
}
return s.voiceUDP.UseContext(ctx)
}
// Write writes into the UDP voice connection WITHOUT a timeout.
func (s *Session) Write(b []byte) (int, error) {
return s.WriteCtx(context.Background(), b)

View File

@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"
"golang.org/x/crypto/nacl/secretbox"
"golang.org/x/time/rate"
)
// Dialer is the default dialer that this package uses for all its dialing.
@ -17,31 +18,29 @@ var Dialer = net.Dialer{
Timeout: 10 * time.Second,
}
// ErrClosed is returned if a Write was called on a closed connection.
var ErrClosed = errors.New("UDP connection closed")
type Connection struct {
GatewayIP string
GatewayPort uint16
ssrc uint32
mutex chan struct{} // for ctx
context context.Context
conn net.Conn
ssrc uint32
frequency rate.Limiter
packet [12]byte
secret [32]byte
sequence uint16
timestamp uint32
nonce [24]byte
conn net.Conn
close chan struct{}
closed chan struct{}
send chan []byte
reply chan error
}
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
// // Resolve the host.
// a, err := net.ResolveUDPAddr("udp", addr)
// if err != nil {
// return nil, errors.Wrap(err, "failed to resolve host")
// }
// Create a new UDP connection.
conn, err := Dialer.DialContext(ctx, "udp", addr)
if err != nil {
@ -78,20 +77,6 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti
ip := ipbody[:nullPos]
port := binary.LittleEndian.Uint16(ipBuffer[68:70])
return &Connection{
GatewayIP: string(ip),
GatewayPort: port,
ssrc: ssrc,
conn: conn,
send: make(chan []byte),
reply: make(chan error),
close: make(chan struct{}),
closed: make(chan struct{}),
}, nil
}
func (c *Connection) Start(secret *[32]byte) {
// https://discordapp.com/developers/docs/topics/voice-connections#encrypting-and-sending-voice
packet := [12]byte{
0: 0x80, // Version + Flags
@ -101,81 +86,118 @@ func (c *Connection) Start(secret *[32]byte) {
}
// Write SSRC to the header.
binary.BigEndian.PutUint32(packet[8:12], c.ssrc) // SSRC
binary.BigEndian.PutUint32(packet[8:12], ssrc) // SSRC
// 50 sends per second, 960 samples each at 48kHz
frequency := time.NewTicker(time.Millisecond * 20)
defer frequency.Stop()
return &Connection{
GatewayIP: string(ip),
GatewayPort: port,
// 50 sends per second, 960 samples each at 48kHz
frequency: *rate.NewLimiter(rate.Every(20*time.Millisecond), 1),
context: context.Background(),
mutex: make(chan struct{}, 1),
packet: packet,
ssrc: ssrc,
conn: conn,
}, nil
}
var b []byte
var ok bool
// UseSecret uses the given secret. This method is not thread-safe, so it should
// only be used right after initialization.
func (c *Connection) UseSecret(secret [32]byte) {
c.secret = secret
}
// Close these channels at the end so Write() doesn't block.
defer func() {
close(c.send)
close(c.closed)
}()
// UseContext lets the connection use the given context for its Write method.
// WriteCtx will override this context.
func (c *Connection) UseContext(ctx context.Context) error {
c.mutex <- struct{}{}
defer func() { <-c.mutex }()
for {
select {
case b, ok = <-c.send:
if !ok {
return
}
case <-c.close:
return
}
return c.useContext(ctx)
}
// Write a new sequence.
binary.BigEndian.PutUint16(packet[2:4], c.sequence)
c.sequence++
func (c *Connection) useContext(ctx context.Context) error {
if c.conn == nil {
return ErrClosed
}
binary.BigEndian.PutUint32(packet[4:8], c.timestamp)
c.timestamp += 960 // Samples
if c.context == ctx {
return nil
}
copy(c.nonce[:], packet[:])
c.context = ctx
toSend := secretbox.Seal(packet[:], b, &c.nonce, secret)
select {
case <-frequency.C:
case <-c.close:
// Prevent Write() from stalling before exiting.
c.reply <- nil
return
}
_, err := c.conn.Write(toSend)
c.reply <- err
if deadline, ok := c.context.Deadline(); ok {
return c.conn.SetWriteDeadline(deadline)
} else {
return c.conn.SetWriteDeadline(time.Time{})
}
}
func (c *Connection) Close() error {
close(c.close)
<-c.closed
return c.conn.Close()
c.mutex <- struct{}{}
err := c.conn.Close()
c.conn = nil
<-c.mutex
return err
}
// Write sends bytes into the voice UDP connection.
func (c *Connection) Write(b []byte) (int, error) {
return c.WriteCtx(context.Background(), b)
select {
case c.mutex <- struct{}{}:
defer func() { <-c.mutex }()
case <-c.context.Done():
return 0, c.context.Err()
}
if c.conn == nil {
return 0, ErrClosed
}
return c.write(b)
}
// WriteCtx sends bytes into the voice UDP connection with a timeout.
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
select {
case c.send <- b:
break
case c.mutex <- struct{}{}:
defer func() { <-c.mutex }()
case <-c.context.Done():
return 0, c.context.Err()
case <-ctx.Done():
return 0, ctx.Err()
}
select {
case err := <-c.reply:
return len(b), err
case <-ctx.Done():
return len(b), ctx.Err()
if err := c.useContext(ctx); err != nil {
return 0, errors.Wrap(err, "failed to use context")
}
return c.write(b)
}
// write is thread-unsafe.
func (c *Connection) write(b []byte) (int, error) {
// Write a new sequence.
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
c.sequence++
binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp)
c.timestamp += 960 // Samples
copy(c.nonce[:], c.packet[:])
if err := c.frequency.Wait(c.context); err != nil {
return 0, errors.Wrap(err, "failed to wait for frequency tick")
}
toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret)
n, err := c.conn.Write(toSend)
if err != nil {
return n, errors.Wrap(err, "failed to write to UDP connection")
}
// We're not really returning everything, since we're "sealing" the bytes.
return len(b), nil
}

View File

@ -5,9 +5,11 @@
package voice
import (
"context"
"log"
"strconv"
"sync"
"time"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
@ -170,9 +172,13 @@ func (v *Voice) Close() error {
}
for gID, s := range v.sessions {
if dErr := s.Disconnect(); dErr != nil {
log.Println("closing", gID)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
if dErr := s.DisconnectCtx(ctx); dErr != nil {
err.SessionErrors[gID] = dErr
}
cancel()
log.Println("closed", gID)
}
err.StateErr = v.State.Close()

View File

@ -51,12 +51,12 @@ type Gateway struct {
mutex sync.RWMutex
ready ReadyEvent
ws *wsutil.Websocket
WS *wsutil.Websocket
Timeout time.Duration
reconnect moreatomic.Bool
EventLoop *wsutil.PacemakerLoop
EventLoop wsutil.PacemakerLoop
// ErrorLog will be called when an error occurs (defaults to log.Println)
ErrorLog func(err error)
@ -96,14 +96,14 @@ func (c *Gateway) OpenCtx(ctx context.Context) error {
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
wsutil.WSDebug("Connecting to voice endpoint (endpoint=" + endpoint + ")")
c.ws = wsutil.New(endpoint)
c.WS = wsutil.New(endpoint)
// Create a new context with a timeout for the connection.
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
defer cancel()
// Connect to the Gateway Gateway.
if err := c.ws.Dial(ctx); err != nil {
if err := c.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "failed to connect to voice gateway")
}
@ -138,7 +138,7 @@ func (c *Gateway) __start(ctx context.Context) error {
// Make a new WaitGroup for use in background loops:
c.waitGroup = new(sync.WaitGroup)
ch := c.ws.Listen()
ch := c.WS.Listen()
// Wait for hello.
wsutil.WSDebug("Waiting for Hello..")
@ -181,13 +181,10 @@ func (c *Gateway) __start(ctx context.Context) error {
return errors.Wrap(err, "failed to wait for Ready or Resumed")
}
// Create an event loop executor.
c.EventLoop = wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, c)
// Start the event handler, which also handles the pacemaker death signal.
c.waitGroup.Add(1)
c.EventLoop.RunAsync(func(err error) {
c.EventLoop.RunAsync(hello.HeartbeatInterval.Duration(), ch, c, func(err error) {
c.waitGroup.Done() // mark so Close() can exit.
wsutil.WSDebug("Event loop stopped.")
@ -208,38 +205,38 @@ func (c *Gateway) __start(ctx context.Context) error {
}
// Close .
func (c *Gateway) Close() error {
// Check if the WS is already closed:
if c.waitGroup == nil && c.EventLoop.Stopped() {
wsutil.WSDebug("Gateway is already closed.")
func (c *Gateway) Close() (err error) {
wsutil.WSDebug("Trying to close.")
c.AfterClose(nil)
return nil
// Check if the WS is already closed:
if c.EventLoop.Stopped() {
wsutil.WSDebug("Gateway is already closed.")
return err
}
// Trigger the close callback on exit.
defer func() { c.AfterClose(err) }()
// If the pacemaker is running:
if !c.EventLoop.Stopped() {
wsutil.WSDebug("Stopping pacemaker...")
// Stop the pacemaker and the event handler
// Stop the pacemaker and the event handler.
c.EventLoop.Stop()
wsutil.WSDebug("Stopped pacemaker.")
}
wsutil.WSDebug("Closing the websocket...")
err = c.WS.Close()
wsutil.WSDebug("Waiting for WaitGroup to be done.")
// This should work, since Pacemaker should signal its loop to stop, which
// would also exit our event loop. Both would be 2.
c.waitGroup.Wait()
// Mark g.waitGroup as empty:
c.waitGroup = nil
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
err := c.ws.Close()
c.AfterClose(err)
return err
}
@ -311,11 +308,11 @@ func (c *Gateway) Send(code OPCode, v interface{}) error {
}
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
if c.ws == nil {
if c.WS == nil {
return errors.New("tried to send data to a connection without a Websocket")
}
if c.ws.Conn == nil {
if c.WS.Conn == nil {
return errors.New("tried to send data to a connection with a closed Websocket")
}
@ -338,5 +335,5 @@ func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error
}
// WS should already be thread-safe.
return c.ws.SendCtx(ctx, b)
return c.WS.SendCtx(ctx, b)
}