Compare commits
21 Commits
Author | SHA1 | Date |
---|---|---|
diamondburned | 16a408bf30 | |
diamondburned | 6c332ac145 | |
diamondburned | 91ee92e9d5 | |
diamondburned | 86795e42a6 | |
Maximilian von Lindern | 397d288927 | |
diamondburned | dec39c4c2d | |
mavolin | 6dabffb46c | |
diamondburned | 1bec57523d | |
diamondburned | 86dd05da9e | |
mavolin | 647efb8030 | |
diamondburned | 64ab8c4f30 | |
mavolin | 5acf9f3f22 | |
mavolin | 7d5cc89ff0 | |
diamondburned | 6b4e26e839 | |
diamondburned | fd818e181e | |
diamondburned | 87c648ae1d | |
diamondburned | 3312c66515 | |
diamondburned | de61fd912d | |
diamondburned | f0c73f4c99 | |
Maximilian von Lindern | a7e9439109 | |
diamondburned | af7f413cea |
|
@ -167,13 +167,14 @@ func (c *Client) DeleteChannel(channelID discord.ChannelID) error {
|
|||
return c.FastRequest("DELETE", EndpointChannels+channelID.String())
|
||||
}
|
||||
|
||||
// https://discord.com/developers/docs/resources/channel#edit-channel-permissions-json-params
|
||||
type EditChannelPermissionData struct {
|
||||
// Type is either "role" or "member".
|
||||
Type discord.OverwriteType `json:"type"`
|
||||
// Allow is a permission bit set for granted permissions.
|
||||
Allow discord.Permissions `json:"allow"`
|
||||
Allow discord.Permissions `json:"allow,string"`
|
||||
// Deny is a permission bit set for denied permissions.
|
||||
Deny discord.Permissions `json:"deny"`
|
||||
Deny discord.Permissions `json:"deny,string"`
|
||||
}
|
||||
|
||||
// EditChannelPermission edits the channel's permission overwrites for a user
|
||||
|
|
45
api/guild.go
45
api/guild.go
|
@ -9,6 +9,10 @@ import (
|
|||
"github.com/diamondburned/arikawa/utils/json/option"
|
||||
)
|
||||
|
||||
// maxGuildFetchLimit is the limit of max guilds per request, as imposed by
|
||||
// Discord.
|
||||
const maxGuildFetchLimit = 100
|
||||
|
||||
var EndpointGuilds = Endpoint + "guilds/"
|
||||
|
||||
// https://discordapp.com/developers/docs/resources/guild#create-guild-json-params
|
||||
|
@ -105,8 +109,7 @@ func (c *Client) GuildWithCount(id discord.GuildID) (*discord.Guild, error) {
|
|||
|
||||
// Guilds returns a list of partial guild objects the current user is a member
|
||||
// of. This method automatically paginates until it reaches the passed limit,
|
||||
// or, if the limit is set to 0, has fetched all guilds within the passed
|
||||
// range.
|
||||
// or, if the limit is set to 0, has fetched all guilds the user has joined.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 guilds per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
|
@ -125,8 +128,8 @@ func (c *Client) Guilds(limit uint) ([]discord.Guild, error) {
|
|||
|
||||
// GuildsBefore returns a list of partial guild objects the current user is a
|
||||
// member of. This method automatically paginates until it reaches the
|
||||
// passed limit, or, if the limit is set to 0, has fetched all guilds within
|
||||
// the passed range.
|
||||
// passed limit, or, if the limit is set to 0, has fetched all guilds with an
|
||||
// id smaller than before.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 guilds per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
|
@ -134,15 +137,16 @@ func (c *Client) Guilds(limit uint) ([]discord.Guild, error) {
|
|||
//
|
||||
// Requires the guilds OAuth2 scope.
|
||||
func (c *Client) GuildsBefore(before discord.GuildID, limit uint) ([]discord.Guild, error) {
|
||||
var guilds []discord.Guild
|
||||
guilds := make([]discord.Guild, 0, limit)
|
||||
|
||||
// this is the limit of max guilds per request,as imposed by Discord
|
||||
const hardLimit int = 100
|
||||
fetch := uint(maxGuildFetchLimit)
|
||||
|
||||
unlimited := limit == 0
|
||||
|
||||
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
|
||||
for limit > 0 || unlimited {
|
||||
if limit > 0 {
|
||||
// Only fetch as much as we need. Since limit gradually decreases,
|
||||
// we only need to fetch min(fetch, limit).
|
||||
if fetch > limit {
|
||||
fetch = limit
|
||||
}
|
||||
|
@ -155,20 +159,24 @@ func (c *Client) GuildsBefore(before discord.GuildID, limit uint) ([]discord.Gui
|
|||
}
|
||||
guilds = append(g, guilds...)
|
||||
|
||||
if len(g) < hardLimit {
|
||||
if len(g) < maxGuildFetchLimit {
|
||||
break
|
||||
}
|
||||
|
||||
before = g[0].ID
|
||||
}
|
||||
|
||||
if len(guilds) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return guilds, nil
|
||||
}
|
||||
|
||||
// GuildsAfter returns a list of partial guild objects the current user is a
|
||||
// member of. This method automatically paginates until it reaches the
|
||||
// passed limit, or, if the limit is set to 0, has fetched all guilds within
|
||||
// the passed range.
|
||||
// passed limit, or, if the limit is set to 0, has fetched all guilds with an
|
||||
// id higher than after.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 guilds per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
|
@ -176,14 +184,15 @@ func (c *Client) GuildsBefore(before discord.GuildID, limit uint) ([]discord.Gui
|
|||
//
|
||||
// Requires the guilds OAuth2 scope.
|
||||
func (c *Client) GuildsAfter(after discord.GuildID, limit uint) ([]discord.Guild, error) {
|
||||
var guilds []discord.Guild
|
||||
guilds := make([]discord.Guild, 0, limit)
|
||||
|
||||
// this is the limit of max guilds per request, as imposed by Discord
|
||||
const hardLimit int = 100
|
||||
fetch := uint(maxGuildFetchLimit)
|
||||
|
||||
unlimited := limit == 0
|
||||
|
||||
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
|
||||
for limit > 0 || unlimited {
|
||||
// Only fetch as much as we need. Since limit gradually decreases,
|
||||
// we only need to fetch min(fetch, limit).
|
||||
if limit > 0 {
|
||||
if fetch > limit {
|
||||
fetch = limit
|
||||
|
@ -197,13 +206,17 @@ func (c *Client) GuildsAfter(after discord.GuildID, limit uint) ([]discord.Guild
|
|||
}
|
||||
guilds = append(guilds, g...)
|
||||
|
||||
if len(g) < hardLimit {
|
||||
if len(g) < maxGuildFetchLimit {
|
||||
break
|
||||
}
|
||||
|
||||
after = g[len(g)-1].ID
|
||||
}
|
||||
|
||||
if len(guilds) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return guilds, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,9 @@ import (
|
|||
"github.com/diamondburned/arikawa/utils/json/option"
|
||||
)
|
||||
|
||||
// Member returns a guild member object for the specified user..
|
||||
const maxMemberFetchLimit = 1000
|
||||
|
||||
// Member returns a guild member object for the specified user.
|
||||
func (c *Client) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
|
||||
var m *discord.Member
|
||||
return m, c.RequestJSON(&m, "GET", EndpointGuilds+guildID.String()+"/members/"+userID.String())
|
||||
|
@ -14,11 +16,11 @@ func (c *Client) Member(guildID discord.GuildID, userID discord.UserID) (*discor
|
|||
|
||||
// Members returns a list of members of the guild with the passed id. This
|
||||
// method automatically paginates until it reaches the passed limit, or, if the
|
||||
// limit is set to 0, has fetched all members within the passed range.
|
||||
// limit is set to 0, has fetched all members in the guild.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 1000 members per request, at
|
||||
// maximum a total of limit/1000 rounded up requests will be made, although
|
||||
// they may be less, if no more members are available.
|
||||
// they may be less if no more members are available.
|
||||
//
|
||||
// When fetching the members, those with the smallest ID will be fetched first.
|
||||
func (c *Client) Members(guildID discord.GuildID, limit uint) ([]discord.Member, error) {
|
||||
|
@ -27,7 +29,7 @@ func (c *Client) Members(guildID discord.GuildID, limit uint) ([]discord.Member,
|
|||
|
||||
// MembersAfter returns a list of members of the guild with the passed id. This
|
||||
// method automatically paginates until it reaches the passed limit, or, if the
|
||||
// limit is set to 0, has fetched all members within the passed range.
|
||||
// limit is set to 0, has fetched all members with an id higher than after.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 1000 members per request, at
|
||||
// maximum a total of limit/1000 rounded up requests will be made, although
|
||||
|
@ -35,13 +37,15 @@ func (c *Client) Members(guildID discord.GuildID, limit uint) ([]discord.Member,
|
|||
func (c *Client) MembersAfter(
|
||||
guildID discord.GuildID, after discord.UserID, limit uint) ([]discord.Member, error) {
|
||||
|
||||
var mems []discord.Member
|
||||
mems := make([]discord.Member, 0, limit)
|
||||
|
||||
const hardLimit int = 1000
|
||||
fetch := uint(maxMemberFetchLimit)
|
||||
|
||||
unlimited := limit == 0
|
||||
|
||||
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
|
||||
for limit > 0 || unlimited {
|
||||
// Only fetch as much as we need. Since limit gradually decreases,
|
||||
// we only need to fetch min(fetch, limit).
|
||||
if limit > 0 {
|
||||
if fetch > limit {
|
||||
fetch = limit
|
||||
|
@ -56,13 +60,17 @@ func (c *Client) MembersAfter(
|
|||
mems = append(mems, m...)
|
||||
|
||||
// There aren't any to fetch, even if this is less than limit.
|
||||
if len(m) < hardLimit {
|
||||
if len(m) < maxMemberFetchLimit {
|
||||
break
|
||||
}
|
||||
|
||||
after = mems[len(mems)-1].User.ID
|
||||
}
|
||||
|
||||
if len(mems) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return mems, nil
|
||||
}
|
||||
|
||||
|
@ -248,9 +256,27 @@ func (c *Client) Prune(guildID discord.GuildID, data PruneData) (uint, error) {
|
|||
// Requires KICK_MEMBERS permission.
|
||||
// Fires a Guild Member Remove Gateway event.
|
||||
func (c *Client) Kick(guildID discord.GuildID, userID discord.UserID) error {
|
||||
return c.KickWithReason(guildID, userID, "")
|
||||
}
|
||||
|
||||
// KickWithReason removes a member from a guild.
|
||||
// The reason, if non-empty, will be displayed in the audit log of the guild.
|
||||
//
|
||||
// Requires KICK_MEMBERS permission.
|
||||
// Fires a Guild Member Remove Gateway event.
|
||||
func (c *Client) KickWithReason(
|
||||
guildID discord.GuildID, userID discord.UserID, reason string) error {
|
||||
|
||||
var data struct {
|
||||
Reason string `schema:"reason,omitempty"`
|
||||
}
|
||||
|
||||
data.Reason = reason
|
||||
|
||||
return c.FastRequest(
|
||||
"DELETE",
|
||||
EndpointGuilds+guildID.String()+"/members/"+userID.String(),
|
||||
httputil.WithSchema(c, data),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,18 +8,26 @@ import (
|
|||
"github.com/diamondburned/arikawa/utils/json/option"
|
||||
)
|
||||
|
||||
// Messages returns a list of messages sent in the channel with the passed ID.
|
||||
// This method automatically paginates until it reaches the passed limit, or,
|
||||
// if the limit is set to 0, has fetched all guilds within the passed ange.
|
||||
// the limit of max messages per request, as imposed by Discord
|
||||
const maxMessageFetchLimit = 100
|
||||
|
||||
// Messages returns a slice filled with the most recent messages sent in the
|
||||
// channel with the passed ID. The method automatically paginates until it
|
||||
// reaches the passed limit, or, if the limit is set to 0, has fetched all
|
||||
// messages in the channel.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 messages per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
// may be less, if no more messages are available.
|
||||
// As the underlying endpoint is capped at a maximum of 100 messages per
|
||||
// request, at maximum a total of limit/100 rounded up requests will be made,
|
||||
// although they may be less, if no more messages are available.
|
||||
//
|
||||
// When fetching the messages, those with the smallest ID will be fetched
|
||||
// When fetching the messages, those with the highest ID, will be fetched
|
||||
// first.
|
||||
// The returned slice will be sorted from latest to oldest.
|
||||
func (c *Client) Messages(channelID discord.ChannelID, limit uint) ([]discord.Message, error) {
|
||||
return c.MessagesAfter(channelID, 0, limit)
|
||||
// Since before is 0 it will be omitted by the http lib, which in turn
|
||||
// will lead discord to send us the most recent messages without having to
|
||||
// specify a Snowflake.
|
||||
return c.MessagesBefore(channelID, 0, limit)
|
||||
}
|
||||
|
||||
// MessagesAround returns messages around the ID, with a limit of 100.
|
||||
|
@ -29,85 +37,112 @@ func (c *Client) MessagesAround(
|
|||
return c.messagesRange(channelID, 0, 0, around, limit)
|
||||
}
|
||||
|
||||
// MessagesBefore returns a list messages sent in the channel with the passed
|
||||
// ID. This method automatically paginates until it reaches the passed limit,
|
||||
// or, if the limit is set to 0, has fetched all guilds within the passed
|
||||
// range.
|
||||
// MessagesBefore returns a slice filled with the messages sent in the channel
|
||||
// with the passed id. The method automatically paginates until it reaches the
|
||||
// passed limit, or, if the limit is set to 0, has fetched all messages in the
|
||||
// channel with an id smaller than before.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 messages per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
// may be less, if no more messages are available.
|
||||
//
|
||||
// The returned slice will be sorted from latest to oldest.
|
||||
func (c *Client) MessagesBefore(
|
||||
channelID discord.ChannelID, before discord.MessageID, limit uint) ([]discord.Message, error) {
|
||||
|
||||
var msgs []discord.Message
|
||||
msgs := make([]discord.Message, 0, limit)
|
||||
|
||||
// this is the limit of max messages per request, as imposed by Discord
|
||||
const hardLimit int = 100
|
||||
fetch := uint(maxMessageFetchLimit)
|
||||
|
||||
// Check if we are truly fetching unlimited messages to avoid confusion
|
||||
// later on, if the limit reaches 0.
|
||||
unlimited := limit == 0
|
||||
|
||||
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
|
||||
for limit > 0 || unlimited {
|
||||
if limit > 0 {
|
||||
// Only fetch as much as we need. Since limit gradually decreases,
|
||||
// we only need to fetch min(fetch, limit).
|
||||
if fetch > limit {
|
||||
fetch = limit
|
||||
}
|
||||
limit -= fetch
|
||||
limit -= maxMessageFetchLimit
|
||||
}
|
||||
|
||||
m, err := c.messagesRange(channelID, before, 0, 0, fetch)
|
||||
if err != nil {
|
||||
return msgs, err
|
||||
}
|
||||
msgs = append(m, msgs...)
|
||||
// Append the older messages into the list of newer messages.
|
||||
msgs = append(msgs, m...)
|
||||
|
||||
if len(m) < hardLimit {
|
||||
if len(m) < maxMessageFetchLimit {
|
||||
break
|
||||
}
|
||||
|
||||
before = m[0].ID
|
||||
before = m[len(m)-1].ID
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// MessagesAfter returns a list messages sent in the channel with the passed
|
||||
// ID. This method automatically paginates until it reaches the passed limit,
|
||||
// or, if the limit is set to 0, has fetched all guilds within the passed
|
||||
// range.
|
||||
// MessagesAfter returns a slice filled with the messages sent in the channel
|
||||
// with the passed ID. The method automatically paginates until it reaches the
|
||||
// passed limit, or, if the limit is set to 0, has fetched all messages in the
|
||||
// channel with an id higher than after.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 messages per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
// may be less, if no more messages are available.
|
||||
//
|
||||
// The returned slice will be sorted from latest to oldest.
|
||||
func (c *Client) MessagesAfter(
|
||||
channelID discord.ChannelID, after discord.MessageID, limit uint) ([]discord.Message, error) {
|
||||
|
||||
// 0 is uint's zero value and will lead to the after param getting omitted,
|
||||
// which in turn will lead to the most recent messages being returned.
|
||||
// Setting this to 1 will prevent that.
|
||||
if after == 0 {
|
||||
after = 1
|
||||
}
|
||||
|
||||
var msgs []discord.Message
|
||||
|
||||
// this is the limit of max messages per request, as imposed by Discord
|
||||
const hardLimit int = 100
|
||||
fetch := uint(maxMessageFetchLimit)
|
||||
|
||||
// Check if we are truly fetching unlimited messages to avoid confusion
|
||||
// later on, if the limit reaches 0.
|
||||
unlimited := limit == 0
|
||||
|
||||
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
|
||||
for limit > 0 || unlimited {
|
||||
if limit > 0 {
|
||||
// Only fetch as much as we need. Since limit gradually decreases,
|
||||
// we only need to fetch min(fetch, limit).
|
||||
if fetch > limit {
|
||||
fetch = limit
|
||||
}
|
||||
limit -= fetch
|
||||
limit -= maxMessageFetchLimit
|
||||
}
|
||||
|
||||
m, err := c.messagesRange(channelID, 0, after, 0, fetch)
|
||||
if err != nil {
|
||||
return msgs, err
|
||||
}
|
||||
msgs = append(msgs, m...)
|
||||
// Prepend the older messages into the newly-fetched messages list.
|
||||
msgs = append(m, msgs...)
|
||||
|
||||
if len(m) < hardLimit {
|
||||
if len(m) < maxMessageFetchLimit {
|
||||
break
|
||||
}
|
||||
|
||||
after = m[len(m)-1].ID
|
||||
after = m[0].ID
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return msgs, nil
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"github.com/diamondburned/arikawa/utils/httputil"
|
||||
)
|
||||
|
||||
const maxMessageReactionFetchLimit = 100
|
||||
|
||||
// React creates a reaction for the message.
|
||||
//
|
||||
// This endpoint requires the READ_MESSAGE_HISTORY permission to be present on
|
||||
|
@ -42,7 +44,8 @@ func (c *Client) Reactions(
|
|||
|
||||
// ReactionsBefore returns a list of users that reacted with the passed Emoji.
|
||||
// This method automatically paginates until it reaches the passed limit, or,
|
||||
// if the limit is set to 0, has fetched all users within the passed range.
|
||||
// if the limit is set to 0, has fetched all users with an id smaller than
|
||||
// before.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 users per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
|
@ -51,13 +54,15 @@ func (c *Client) ReactionsBefore(
|
|||
channelID discord.ChannelID, messageID discord.MessageID, before discord.UserID, emoji Emoji,
|
||||
limit uint) ([]discord.User, error) {
|
||||
|
||||
var users []discord.User
|
||||
users := make([]discord.User, 0, limit)
|
||||
|
||||
const hardLimit int = 100
|
||||
fetch := uint(maxMessageReactionFetchLimit)
|
||||
|
||||
unlimited := limit == 0
|
||||
|
||||
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
|
||||
for limit > 0 || unlimited {
|
||||
// Only fetch as much as we need. Since limit gradually decreases,
|
||||
// we only need to fetch min(fetch, limit).
|
||||
if limit > 0 {
|
||||
if fetch > limit {
|
||||
fetch = limit
|
||||
|
@ -71,19 +76,24 @@ func (c *Client) ReactionsBefore(
|
|||
}
|
||||
users = append(r, users...)
|
||||
|
||||
if len(r) < hardLimit {
|
||||
if len(r) < maxMessageReactionFetchLimit {
|
||||
break
|
||||
}
|
||||
|
||||
before = r[0].ID
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// ReactionsAfter returns a list of users that reacted with the passed Emoji.
|
||||
// This method automatically paginates until it reaches the passed limit, or,
|
||||
// if the limit is set to 0, has fetched all users within the passed range.
|
||||
// if the limit is set to 0, has fetched all users with an id higher than
|
||||
// after.
|
||||
//
|
||||
// As the underlying endpoint has a maximum of 100 users per request, at
|
||||
// maximum a total of limit/100 rounded up requests will be made, although they
|
||||
|
@ -92,13 +102,15 @@ func (c *Client) ReactionsAfter(
|
|||
channelID discord.ChannelID, messageID discord.MessageID, after discord.UserID, emoji Emoji,
|
||||
limit uint) ([]discord.User, error) {
|
||||
|
||||
var users []discord.User
|
||||
users := make([]discord.User, 0, limit)
|
||||
|
||||
const hardLimit int = 100
|
||||
fetch := uint(maxMessageReactionFetchLimit)
|
||||
|
||||
unlimited := limit == 0
|
||||
|
||||
for fetch := uint(hardLimit); limit > 0 || unlimited; fetch = uint(hardLimit) {
|
||||
for limit > 0 || unlimited {
|
||||
// Only fetch as much as we need. Since limit gradually decreases,
|
||||
// we only need to fetch min(fetch, limit).
|
||||
if limit > 0 {
|
||||
if fetch > limit {
|
||||
fetch = limit
|
||||
|
@ -112,13 +124,17 @@ func (c *Client) ReactionsAfter(
|
|||
}
|
||||
users = append(users, r...)
|
||||
|
||||
if len(r) < hardLimit {
|
||||
if len(r) < maxMessageReactionFetchLimit {
|
||||
break
|
||||
}
|
||||
|
||||
after = r[len(r)-1].ID
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ type CreateRoleData struct {
|
|||
// Permissions is the bitwise value of the enabled/disabled permissions.
|
||||
//
|
||||
// Default: @everyone permissions in guild
|
||||
Permissions discord.Permissions `json:"permissions,omitempty"`
|
||||
Permissions discord.Permissions `json:"permissions,omitempty,string"`
|
||||
// Color is the RGB color value of the role.
|
||||
//
|
||||
// Default: 0
|
||||
|
@ -98,7 +98,7 @@ type ModifyRoleData struct {
|
|||
// Name is the name of the role.
|
||||
Name option.NullableString `json:"name,omitempty"`
|
||||
// Permissions is the bitwise value of the enabled/disabled permissions.
|
||||
Permissions *discord.Permissions `json:"permissions,omitempty"`
|
||||
Permissions *discord.Permissions `json:"permissions,omitempty,string"`
|
||||
// Permissions is the bitwise value of the enabled/disabled permissions.
|
||||
Color option.NullableColor `json:"color,omitempty"`
|
||||
// Hoist specifies whether the role should be displayed separately in the
|
||||
|
|
|
@ -5,18 +5,45 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// WordOffset is the offset from the position cursor to print on the error.
|
||||
const WordOffset = 7
|
||||
|
||||
var escaper = strings.NewReplacer(
|
||||
"`", "\\`",
|
||||
"@", "\\@",
|
||||
"\\", "\\\\",
|
||||
)
|
||||
|
||||
type ErrParse struct {
|
||||
Position int
|
||||
ErrorStart,
|
||||
ErrorPart,
|
||||
ErrorEnd string
|
||||
Words string // joined
|
||||
}
|
||||
|
||||
func (e ErrParse) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"Unexpected quote or escape: %s__%s__%s",
|
||||
e.ErrorStart, e.ErrorPart, e.ErrorEnd,
|
||||
// Magic number 5.
|
||||
var a = max(0, e.Position-WordOffset)
|
||||
var b = min(len(e.Words), e.Position+WordOffset)
|
||||
var word = e.Words[a:b]
|
||||
var uidx = e.Position - a
|
||||
|
||||
errstr := strings.Builder{}
|
||||
errstr.WriteString("Unexpected quote or escape")
|
||||
|
||||
// Do a bound check.
|
||||
if uidx+1 > len(word) {
|
||||
// Invalid.
|
||||
errstr.WriteString(".")
|
||||
return errstr.String()
|
||||
}
|
||||
|
||||
// Write the pre-underline part.
|
||||
fmt.Fprintf(
|
||||
&errstr, ": %s__%s__",
|
||||
escaper.Replace(word[:uidx]),
|
||||
escaper.Replace(string(word[uidx:])),
|
||||
)
|
||||
|
||||
return errstr.String()
|
||||
}
|
||||
|
||||
// Parse parses the given text to a slice of words.
|
||||
|
@ -76,11 +103,6 @@ func Parse(line string) ([]string, error) {
|
|||
got = true
|
||||
}
|
||||
|
||||
// // If this is a backtick, then write it.
|
||||
// if r == '`' {
|
||||
// buf.WriteByte('`')
|
||||
// }
|
||||
|
||||
singleQuoted = !singleQuoted
|
||||
continue
|
||||
}
|
||||
|
@ -95,19 +117,9 @@ func Parse(line string) ([]string, error) {
|
|||
}
|
||||
|
||||
if escaped || singleQuoted || doubleQuoted {
|
||||
// the number of characters to highlight
|
||||
var (
|
||||
pos = cursor + 5
|
||||
start = string(runes[max(cursor-100, 0) : pos-1])
|
||||
end = string(runes[pos+1 : min(cursor+100, len(runes))])
|
||||
part = string(runes[max(pos-1, 0):min(len(runes), pos+2)])
|
||||
)
|
||||
|
||||
return args, &ErrParse{
|
||||
Position: cursor,
|
||||
ErrorStart: start,
|
||||
ErrorPart: part,
|
||||
ErrorEnd: end,
|
||||
Position: cursor + buf.Len(),
|
||||
Words: strings.Join(args, " "),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,6 +38,16 @@ func TestParse(t *testing.T) {
|
|||
[]string{"how", "about", "a", "go\npackage main\n", "go", "code?"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"this should not crash `",
|
||||
[]string{"this", "should", "not", "crash"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"this should not crash '",
|
||||
[]string{"this", "should", "not", "crash"},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package discord
|
||||
|
||||
import "github.com/diamondburned/arikawa/utils/json"
|
||||
|
||||
// https://discord.com/developers/docs/resources/channel#channel-object
|
||||
type Channel struct {
|
||||
// ID is the id of this channel.
|
||||
|
@ -55,7 +57,7 @@ type Channel struct {
|
|||
|
||||
// Mention returns a mention of the channel.
|
||||
func (ch Channel) Mention() string {
|
||||
return "<#" + ch.ID.String() + ">"
|
||||
return ch.ID.Mention()
|
||||
}
|
||||
|
||||
// IconURL returns the URL to the channel icon in the PNG format.
|
||||
|
@ -103,13 +105,37 @@ var (
|
|||
// https://discord.com/developers/docs/resources/channel#overwrite-object
|
||||
type Overwrite struct {
|
||||
// ID is the role or user id.
|
||||
ID Snowflake `json:"id,string"`
|
||||
ID Snowflake `json:"id"`
|
||||
// Type is either "role" or "member".
|
||||
Type OverwriteType `json:"type"`
|
||||
// Allow is a permission bit set for granted permissions.
|
||||
Allow Permissions `json:"allow"`
|
||||
Allow Permissions `json:"allow,string"`
|
||||
// Deny is a permission bit set for denied permissions.
|
||||
Deny Permissions `json:"deny"`
|
||||
Deny Permissions `json:"deny,string"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the passed json data into the Overwrite.
|
||||
// This is necessary because Discord has different names for fields when
|
||||
// sending than receiving.
|
||||
func (o *Overwrite) UnmarshalJSON(data []byte) (err error) {
|
||||
var recv struct {
|
||||
ID Snowflake `json:"id"`
|
||||
Type OverwriteType `json:"type"`
|
||||
Allow Permissions `json:"allow_new,string"`
|
||||
Deny Permissions `json:"deny_new,string"`
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, &recv)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
o.ID = recv.ID
|
||||
o.Type = recv.Type
|
||||
o.Allow = recv.Allow
|
||||
o.Deny = recv.Deny
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type OverwriteType string
|
||||
|
|
|
@ -23,7 +23,7 @@ type Guild struct {
|
|||
|
||||
// Permissions are the total permissions for the user in the guild
|
||||
// (excludes overrides).
|
||||
Permissions Permissions `json:"permissions,omitempty"`
|
||||
Permissions Permissions `json:"permissions_new,omitempty,string"`
|
||||
|
||||
// VoiceRegion is the voice region id for the guild.
|
||||
VoiceRegion string `json:"region"`
|
||||
|
@ -297,7 +297,7 @@ type Role struct {
|
|||
Position int `json:"position"`
|
||||
|
||||
// Permissions is the permission bit set.
|
||||
Permissions Permissions `json:"permissions"`
|
||||
Permissions Permissions `json:"permissions_new,string"`
|
||||
|
||||
// Manages specifies whether this role is managed by an integration.
|
||||
Managed bool `json:"managed"`
|
||||
|
@ -307,7 +307,7 @@ type Role struct {
|
|||
|
||||
// Mention returns the mention of the Role.
|
||||
func (r Role) Mention() string {
|
||||
return "<&" + r.ID.String() + ">"
|
||||
return r.ID.Mention()
|
||||
}
|
||||
|
||||
// https://discord.com/developers/docs/topics/gateway#presence-update
|
||||
|
@ -375,7 +375,7 @@ type Member struct {
|
|||
|
||||
// Mention returns the mention of the role.
|
||||
func (m Member) Mention() string {
|
||||
return "<@!" + m.User.ID.String() + ">"
|
||||
return m.User.Mention()
|
||||
}
|
||||
|
||||
// https://discord.com/developers/docs/resources/guild#ban-object
|
||||
|
|
|
@ -30,12 +30,12 @@ func ParseSnowflake(sf string) (Snowflake, error) {
|
|||
return NullSnowflake, nil
|
||||
}
|
||||
|
||||
i, err := strconv.ParseInt(sf, 10, 64)
|
||||
u, err := strconv.ParseUint(sf, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return Snowflake(i), nil
|
||||
return Snowflake(u), nil
|
||||
}
|
||||
|
||||
func (s *Snowflake) UnmarshalJSON(v []byte) error {
|
||||
|
@ -149,6 +149,7 @@ func (s ChannelID) Time() time.Time { return Snowflake(s).Time() }
|
|||
func (s ChannelID) Worker() uint8 { return Snowflake(s).Worker() }
|
||||
func (s ChannelID) PID() uint8 { return Snowflake(s).PID() }
|
||||
func (s ChannelID) Increment() uint16 { return Snowflake(s).Increment() }
|
||||
func (s ChannelID) Mention() string { return "<#" + s.String() + ">" }
|
||||
|
||||
type EmojiID Snowflake
|
||||
|
||||
|
@ -219,6 +220,7 @@ func (s RoleID) Time() time.Time { return Snowflake(s).Time() }
|
|||
func (s RoleID) Worker() uint8 { return Snowflake(s).Worker() }
|
||||
func (s RoleID) PID() uint8 { return Snowflake(s).PID() }
|
||||
func (s RoleID) Increment() uint16 { return Snowflake(s).Increment() }
|
||||
func (s RoleID) Mention() string { return "<@&" + s.String() + ">" }
|
||||
|
||||
type UserID Snowflake
|
||||
|
||||
|
@ -233,6 +235,7 @@ func (s UserID) Time() time.Time { return Snowflake(s).Time() }
|
|||
func (s UserID) Worker() uint8 { return Snowflake(s).Worker() }
|
||||
func (s UserID) PID() uint8 { return Snowflake(s).PID() }
|
||||
func (s UserID) Increment() uint16 { return Snowflake(s).Increment() }
|
||||
func (s UserID) Mention() string { return "<@" + s.String() + ">" }
|
||||
|
||||
type WebhookID Snowflake
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ type User struct {
|
|||
}
|
||||
|
||||
func (u User) Mention() string {
|
||||
return "<@" + u.ID.String() + ">"
|
||||
return u.ID.Mention()
|
||||
}
|
||||
|
||||
// AvatarURL returns the URL of the Avatar Image. It automatically detects a
|
||||
|
|
|
@ -81,7 +81,7 @@ type RequestGuildMembersData struct {
|
|||
GuildID []discord.GuildID `json:"guild_id"`
|
||||
UserIDs []discord.UserID `json:"user_ids,omitempty"`
|
||||
|
||||
Query string `json:"query,omitempty"`
|
||||
Query string `json:"query"`
|
||||
Limit uint `json:"limit"`
|
||||
Presences bool `json:"presences,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
|
@ -149,7 +149,7 @@ type GuildSubscribeData struct {
|
|||
GuildID discord.GuildID `json:"guild_id"`
|
||||
|
||||
// Channels is not documented. It's used to fetch the right members sidebar.
|
||||
Channels map[discord.ChannelID][][2]int `json:"channels"`
|
||||
Channels map[discord.ChannelID][][2]int `json:"channels,omitempty"`
|
||||
}
|
||||
|
||||
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {
|
||||
|
|
|
@ -133,7 +133,7 @@ type (
|
|||
Ops []GuildMemberListOp `json:"ops"`
|
||||
}
|
||||
GuildMemberListGroup struct {
|
||||
ID string `json:"id"` // either discord.RoleID or "online"
|
||||
ID string `json:"id"` // either discord.RoleID, "online" or "offline"
|
||||
Count uint64 `json:"count"`
|
||||
}
|
||||
GuildMemberListOp struct {
|
||||
|
|
|
@ -85,11 +85,14 @@ type Gateway struct {
|
|||
// Session.
|
||||
Events chan Event
|
||||
|
||||
// SessionID is used to store the session ID received after Ready. It is not
|
||||
// thread-safe.
|
||||
SessionID string
|
||||
|
||||
Identifier *Identifier
|
||||
Sequence *Sequence
|
||||
PacerLoop *wsutil.PacemakerLoop
|
||||
|
||||
PacerLoop wsutil.PacemakerLoop
|
||||
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
|
||||
|
@ -98,11 +101,6 @@ type Gateway struct {
|
|||
// reconnections or any type of connection interruptions.
|
||||
AfterClose func(err error) // noop by default
|
||||
|
||||
// Mutex to hold off calls when the WS is not available. Doesn't block if
|
||||
// Start() is not called or Close() is called. Also doesn't block for
|
||||
// Identify or Resume.
|
||||
// available sync.RWMutex
|
||||
|
||||
// Filled by methods, internal use
|
||||
waitGroup *sync.WaitGroup
|
||||
}
|
||||
|
@ -163,40 +161,38 @@ func (g *Gateway) AddIntent(i Intents) {
|
|||
}
|
||||
|
||||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
func (g *Gateway) Close() (err error) {
|
||||
wsutil.WSDebug("Trying to close.")
|
||||
|
||||
// Check if the WS is already closed:
|
||||
if g.waitGroup == nil && g.PacerLoop.Stopped() {
|
||||
if g.PacerLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
|
||||
g.AfterClose(nil)
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Trigger the close callback on exit.
|
||||
defer func() { g.AfterClose(err) }()
|
||||
|
||||
// If the pacemaker is running:
|
||||
if !g.PacerLoop.Stopped() {
|
||||
wsutil.WSDebug("Stopping pacemaker...")
|
||||
|
||||
// Stop the pacemaker and the event handler
|
||||
// Stop the pacemaker and the event handler.
|
||||
g.PacerLoop.Stop()
|
||||
|
||||
wsutil.WSDebug("Stopped pacemaker.")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Closing the websocket...")
|
||||
err = g.WS.Close()
|
||||
|
||||
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||
|
||||
// This should work, since Pacemaker should signal its loop to stop, which
|
||||
// would also exit our event loop. Both would be 2.
|
||||
g.waitGroup.Wait()
|
||||
|
||||
// Mark g.waitGroup as empty:
|
||||
g.waitGroup = nil
|
||||
|
||||
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
|
||||
err := g.WS.Close()
|
||||
g.AfterClose(err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -356,13 +352,11 @@ func (g *Gateway) start(ctx context.Context) error {
|
|||
return errors.Wrap(err, "first error")
|
||||
}
|
||||
|
||||
// Use the pacemaker loop.
|
||||
g.PacerLoop = wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, g)
|
||||
|
||||
// Start the event handler, which also handles the pacemaker death signal.
|
||||
g.waitGroup.Add(1)
|
||||
|
||||
g.PacerLoop.RunAsync(func(err error) {
|
||||
// Use the pacemaker loop.
|
||||
g.PacerLoop.RunAsync(hello.HeartbeatInterval.Duration(), ch, g, func(err error) {
|
||||
g.waitGroup.Done() // mark so Close() can exit.
|
||||
wsutil.WSDebug("Event loop stopped with error:", err)
|
||||
|
||||
|
|
|
@ -30,6 +30,11 @@ const (
|
|||
GuildSubscriptionsOP OPCode = 14
|
||||
)
|
||||
|
||||
// ErrReconnectRequest is returned by HandleOP if a ReconnectOP is given. This
|
||||
// is used mostly internally to signal the heartbeat loop to reconnect, if
|
||||
// needed. It is not a fatal error.
|
||||
var ErrReconnectRequest = errors.New("ReconnectOP received")
|
||||
|
||||
func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
||||
switch op.Code {
|
||||
case HeartbeatAckOP:
|
||||
|
@ -41,19 +46,17 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
defer cancel()
|
||||
|
||||
// Server requesting a heartbeat.
|
||||
return g.PacerLoop.Pace(ctx)
|
||||
if err := g.PacerLoop.Pace(ctx); err != nil {
|
||||
return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace"))
|
||||
}
|
||||
|
||||
case ReconnectOP:
|
||||
// Server requests to reconnect, die and retry.
|
||||
wsutil.WSDebug("ReconnectOP received.")
|
||||
|
||||
// We must reconnect in another goroutine, as running Reconnect
|
||||
// synchronously would prevent the main event loop from exiting.
|
||||
go g.Reconnect()
|
||||
|
||||
// Gracefully exit with a nil let the event handler take the signal from
|
||||
// the pacemaker.
|
||||
return nil
|
||||
// Exit with the ReconnectOP error to force the heartbeat event loop to
|
||||
// reconnect synchronously. Not really a fatal error.
|
||||
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
|
||||
|
||||
case InvalidSessionOP:
|
||||
// Discord expects us to sleep for no reason
|
||||
|
@ -65,13 +68,12 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
// Invalid session, try and Identify.
|
||||
if err := g.IdentifyCtx(ctx); err != nil {
|
||||
// Can't identify, reconnect.
|
||||
go g.Reconnect()
|
||||
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
case HelloOP:
|
||||
// What is this OP doing here???
|
||||
return nil
|
||||
|
||||
case DispatchOP:
|
||||
|
@ -102,7 +104,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
|||
g.SessionID = ev.SessionID
|
||||
}
|
||||
|
||||
// Throw the event into a channel, it's valid now.
|
||||
// Throw the event into a channel; it's valid now.
|
||||
g.Events <- ev
|
||||
return nil
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ type GuildFolder struct {
|
|||
|
||||
// GuildFolderID is possibly a snowflake. It can also be 0 (null) or a low
|
||||
// number of unknown significance.
|
||||
type GuildFolderID uint64
|
||||
type GuildFolderID int64
|
||||
|
||||
func (g *GuildFolderID) UnmarshalJSON(b []byte) error {
|
||||
var body = string(b)
|
||||
|
@ -130,7 +130,7 @@ func (g *GuildFolderID) UnmarshalJSON(b []byte) error {
|
|||
|
||||
body = strings.Trim(body, `"`)
|
||||
|
||||
u, err := strconv.ParseUint(body, 10, 64)
|
||||
u, err := strconv.ParseInt(body, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -144,5 +144,5 @@ func (g GuildFolderID) MarshalJSON() ([]byte, error) {
|
|||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return []byte(strconv.FormatUint(uint64(g), 10)), nil
|
||||
return []byte(strconv.FormatInt(int64(g), 10)), nil
|
||||
}
|
||||
|
|
1
go.sum
1
go.sum
|
@ -16,3 +16,4 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI=
|
||||
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s=
|
||||
|
|
|
@ -3,7 +3,6 @@ package heart
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -36,23 +35,30 @@ type Pacemaker struct {
|
|||
// Heartrate is the received duration between heartbeats.
|
||||
Heartrate time.Duration
|
||||
|
||||
ticker time.Ticker
|
||||
Ticks <-chan time.Time
|
||||
|
||||
// Time in nanoseconds, guarded by atomic read/writes.
|
||||
SentBeat AtomicTime
|
||||
EchoBeat AtomicTime
|
||||
|
||||
// Any callback that returns an error will stop the pacer.
|
||||
Pace func(context.Context) error
|
||||
|
||||
stop chan struct{}
|
||||
once sync.Once
|
||||
death chan error
|
||||
Pacer func(context.Context) error
|
||||
}
|
||||
|
||||
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) *Pacemaker {
|
||||
return &Pacemaker{
|
||||
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker {
|
||||
p := Pacemaker{
|
||||
Heartrate: heartrate,
|
||||
Pace: pacer,
|
||||
Pacer: pacer,
|
||||
ticker: *time.NewTicker(heartrate),
|
||||
}
|
||||
p.Ticks = p.ticker.C
|
||||
// Reset states to its old position.
|
||||
now := time.Now()
|
||||
p.EchoBeat.Set(now)
|
||||
p.SentBeat.Set(now)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pacemaker) Echo() {
|
||||
|
@ -62,14 +68,6 @@ func (p *Pacemaker) Echo() {
|
|||
|
||||
// Dead, if true, will have Pace return an ErrDead.
|
||||
func (p *Pacemaker) Dead() bool {
|
||||
/* Deprecated
|
||||
if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
|
||||
*/
|
||||
|
||||
var (
|
||||
echo = p.EchoBeat.Get()
|
||||
sent = p.SentBeat.Get()
|
||||
|
@ -84,75 +82,84 @@ func (p *Pacemaker) Dead() bool {
|
|||
|
||||
// Stop stops the pacemaker, or it does nothing if the pacemaker is not started.
|
||||
func (p *Pacemaker) Stop() {
|
||||
Debug("(*Pacemaker).stop is trying sync.Once.")
|
||||
|
||||
p.once.Do(func() {
|
||||
Debug("(*Pacemaker).stop closed the channel.")
|
||||
close(p.stop)
|
||||
})
|
||||
p.ticker.Stop()
|
||||
}
|
||||
|
||||
// pace sends a heartbeat with the appropriate timeout for the context.
|
||||
func (p *Pacemaker) pace() error {
|
||||
func (p *Pacemaker) Pace() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate)
|
||||
defer cancel()
|
||||
|
||||
return p.Pace(ctx)
|
||||
return p.PaceCtx(ctx)
|
||||
}
|
||||
|
||||
func (p *Pacemaker) start() error {
|
||||
// Reset states to its old position.
|
||||
p.EchoBeat.Set(time.Time{})
|
||||
p.SentBeat.Set(time.Time{})
|
||||
|
||||
// Create a new ticker.
|
||||
tick := time.NewTicker(p.Heartrate)
|
||||
defer tick.Stop()
|
||||
|
||||
// Echo at least once
|
||||
p.Echo()
|
||||
|
||||
for {
|
||||
if err := p.pace(); err != nil {
|
||||
return errors.Wrap(err, "failed to pace")
|
||||
}
|
||||
|
||||
// Paced, save:
|
||||
p.SentBeat.Set(time.Now())
|
||||
|
||||
if p.Dead() {
|
||||
return ErrDead
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.stop:
|
||||
return nil
|
||||
|
||||
case <-tick.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
|
||||
func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
||||
p.death = make(chan error)
|
||||
p.stop = make(chan struct{})
|
||||
p.once = sync.Once{}
|
||||
|
||||
if wg != nil {
|
||||
wg.Add(1)
|
||||
func (p *Pacemaker) PaceCtx(ctx context.Context) error {
|
||||
if err := p.Pacer(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.death <- p.start()
|
||||
// Debug.
|
||||
Debug("Pacemaker returned.")
|
||||
p.SentBeat.Set(time.Now())
|
||||
|
||||
// Mark the pacemaker loop as done.
|
||||
if wg != nil {
|
||||
wg.Done()
|
||||
}
|
||||
}()
|
||||
if p.Dead() {
|
||||
return ErrDead
|
||||
}
|
||||
|
||||
return p.death
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (p *Pacemaker) start() error {
|
||||
// // Reset states to its old position.
|
||||
// p.EchoBeat.Set(time.Time{})
|
||||
// p.SentBeat.Set(time.Time{})
|
||||
|
||||
// // Create a new ticker.
|
||||
// tick := time.NewTicker(p.Heartrate)
|
||||
// defer tick.Stop()
|
||||
|
||||
// // Echo at least once
|
||||
// p.Echo()
|
||||
|
||||
// for {
|
||||
// if err := p.pace(); err != nil {
|
||||
// return errors.Wrap(err, "failed to pace")
|
||||
// }
|
||||
|
||||
// // Paced, save:
|
||||
// p.SentBeat.Set(time.Now())
|
||||
|
||||
// if p.Dead() {
|
||||
// return ErrDead
|
||||
// }
|
||||
|
||||
// select {
|
||||
// case <-p.stop:
|
||||
// return nil
|
||||
|
||||
// case <-tick.C:
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
|
||||
// func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
|
||||
// p.death = make(chan error)
|
||||
// p.stop = make(chan struct{})
|
||||
// p.once = sync.Once{}
|
||||
|
||||
// if wg != nil {
|
||||
// wg.Add(1)
|
||||
// }
|
||||
|
||||
// go func() {
|
||||
// p.death <- p.start()
|
||||
// // Debug.
|
||||
// Debug("Pacemaker returned.")
|
||||
|
||||
// // Mark the pacemaker loop as done.
|
||||
// if wg != nil {
|
||||
// wg.Done()
|
||||
// }
|
||||
// }()
|
||||
|
||||
// return p.death
|
||||
// }
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/api"
|
||||
|
@ -39,6 +41,7 @@ type Session struct {
|
|||
Ticket string
|
||||
|
||||
hstop chan struct{}
|
||||
wstop sync.Once
|
||||
}
|
||||
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
|
||||
|
@ -103,9 +106,9 @@ func NewWithGateway(gw *gateway.Gateway) *Session {
|
|||
|
||||
func (s *Session) Open() error {
|
||||
// Start the handler beforehand so no events are missed.
|
||||
stop := make(chan struct{})
|
||||
s.hstop = stop
|
||||
go s.startHandler(stop)
|
||||
s.hstop = make(chan struct{})
|
||||
s.wstop = sync.Once{}
|
||||
go s.startHandler()
|
||||
|
||||
// Set the AfterClose's handler.
|
||||
s.Gateway.AfterClose = func(err error) {
|
||||
|
@ -121,10 +124,10 @@ func (s *Session) Open() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) startHandler(stop <-chan struct{}) {
|
||||
func (s *Session) startHandler() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
case <-s.hstop:
|
||||
return
|
||||
case ev := <-s.Gateway.Events:
|
||||
s.Call(ev)
|
||||
|
@ -134,14 +137,7 @@ func (s *Session) startHandler(stop <-chan struct{}) {
|
|||
|
||||
func (s *Session) Close() error {
|
||||
// Stop the event handler
|
||||
s.close()
|
||||
|
||||
s.wstop.Do(func() { s.hstop <- struct{}{} })
|
||||
// Close the websocket
|
||||
return s.Gateway.Close()
|
||||
}
|
||||
|
||||
func (s *Session) close() {
|
||||
if s.hstop != nil {
|
||||
close(s.hstop)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -636,6 +636,10 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.R
|
|||
}
|
||||
}
|
||||
|
||||
if role == nil {
|
||||
return nil, ErrStoreNotFound
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -50,6 +50,13 @@ func (s *State) hookSession() {
|
|||
func (s *State) onEvent(iface interface{}) {
|
||||
switch ev := iface.(type) {
|
||||
case *gateway.ReadyEvent:
|
||||
// Reset the store before proceeding.
|
||||
if resetter, ok := s.Store.(StoreResetter); ok {
|
||||
if err := resetter.Reset(); err != nil {
|
||||
s.stateErr(err, "Failed to reset state on READY")
|
||||
}
|
||||
}
|
||||
|
||||
// Set Ready to the state
|
||||
s.Ready = *ev
|
||||
|
||||
|
|
|
@ -92,6 +92,12 @@ type StoreModifier interface {
|
|||
VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error
|
||||
}
|
||||
|
||||
// StoreResetter is used by the state to reset the store on every Ready event.
|
||||
type StoreResetter interface {
|
||||
// Reset resets the store to a new valid instance.
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// ErrStoreNotFound is an error that a store can use to return when something
|
||||
// isn't in the storage. There is no strict restrictions on what uses this (the
|
||||
// default one does, though), so be advised.
|
||||
|
|
|
@ -9,12 +9,17 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const CopyBufferSize = 2048
|
||||
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
|
||||
// size is 4KB.
|
||||
var CopyBufferSize = 4096
|
||||
|
||||
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
|
||||
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
|
||||
var MaxCapUntilReset = CopyBufferSize * 4
|
||||
|
||||
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
||||
var CloseDeadline = time.Second
|
||||
|
@ -45,72 +50,54 @@ type Connection interface {
|
|||
// Conn is the default Websocket connection. It compresses all payloads using
|
||||
// zlib.
|
||||
type Conn struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
Conn *websocket.Conn
|
||||
json.Driver
|
||||
|
||||
dialer *websocket.Dialer
|
||||
events chan Event
|
||||
|
||||
// write channels
|
||||
writes chan []byte
|
||||
errors chan error
|
||||
|
||||
buf bytes.Buffer
|
||||
zlib io.ReadCloser // (compress/zlib).reader
|
||||
|
||||
// nil until Dial().
|
||||
closeOnce *sync.Once
|
||||
|
||||
// zlib *zlib.Inflator // zlib.NewReader
|
||||
// buf []byte // io.Copy buffer
|
||||
}
|
||||
|
||||
var _ Connection = (*Conn)(nil)
|
||||
|
||||
// NewConn creates a new default websocket connection with a default dialer.
|
||||
func NewConn() *Conn {
|
||||
return NewConnWithDriver(json.Default)
|
||||
return NewConnWithDialer(&websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: WSTimeout,
|
||||
ReadBufferSize: CopyBufferSize,
|
||||
WriteBufferSize: CopyBufferSize,
|
||||
EnableCompression: true,
|
||||
})
|
||||
}
|
||||
|
||||
func NewConnWithDriver(driver json.Driver) *Conn {
|
||||
return &Conn{
|
||||
Driver: driver,
|
||||
dialer: &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: WSTimeout,
|
||||
EnableCompression: true,
|
||||
},
|
||||
// zlib: zlib.NewInflator(),
|
||||
// buf: make([]byte, CopyBufferSize),
|
||||
}
|
||||
// NewConn creates a new default websocket connection with a custom dialer.
|
||||
func NewConnWithDialer(dialer *websocket.Dialer) *Conn {
|
||||
return &Conn{dialer: dialer}
|
||||
}
|
||||
|
||||
func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||
var err error
|
||||
|
||||
// Enable compression:
|
||||
headers := http.Header{}
|
||||
headers.Set("Accept-Encoding", "zlib")
|
||||
headers := http.Header{
|
||||
"Accept-Encoding": {"zlib"},
|
||||
}
|
||||
|
||||
// BUG: https://github.com/golang/go/issues/31514
|
||||
// // Enable stream compression:
|
||||
// addr = InjectValues(addr, url.Values{
|
||||
// "compress": {"zlib-stream"},
|
||||
// })
|
||||
// BUG which prevents stream compression.
|
||||
// See https://github.com/golang/go/issues/31514.
|
||||
|
||||
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
|
||||
conn, _, err := c.dialer.DialContext(ctx, addr, headers)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to dial WS")
|
||||
}
|
||||
|
||||
// Set up the closer.
|
||||
c.closeOnce = &sync.Once{}
|
||||
events := make(chan Event, WSBuffer)
|
||||
go startReadLoop(conn, events)
|
||||
|
||||
c.events = make(chan Event)
|
||||
go c.readLoop()
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.writes = make(chan []byte)
|
||||
c.errors = make(chan error)
|
||||
go c.writeLoop()
|
||||
c.Conn = conn
|
||||
c.events = events
|
||||
|
||||
return err
|
||||
}
|
||||
|
@ -119,18 +106,88 @@ func (c *Conn) Listen() <-chan Event {
|
|||
return c.events
|
||||
}
|
||||
|
||||
func (c *Conn) readLoop() {
|
||||
// Acquire the read lock throughout the span of the loop. This would still
|
||||
// allow Send to acquire another RLock, but wouldn't allow Close to
|
||||
// prematurely exit, as Close acquires a write lock.
|
||||
// c.mut.RLock()
|
||||
// defer c.mut.RUnlock()
|
||||
// resetDeadline is used to reset the write deadline after using the context's.
|
||||
var resetDeadline = time.Time{}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
d, ok := ctx.Deadline()
|
||||
if ok {
|
||||
c.Conn.SetWriteDeadline(d)
|
||||
defer c.Conn.SetWriteDeadline(resetDeadline)
|
||||
}
|
||||
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
// Use a sync.Once to guarantee that other Close() calls block until the
|
||||
// main call is done. It also prevents future calls.
|
||||
WSDebug("Conn: Acquiring write lock...")
|
||||
|
||||
// Acquire the write lock forever.
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
WSDebug("Conn: Write lock acquired; closing.")
|
||||
|
||||
// Close the WS.
|
||||
err := c.closeWS()
|
||||
|
||||
WSDebug("Conn: Websocket closed; error:", err)
|
||||
WSDebug("Conn: Flusing events...")
|
||||
|
||||
// Flush all events before closing the channel. This will return as soon as
|
||||
// c.events is closed, or after closed.
|
||||
for range c.events {
|
||||
}
|
||||
|
||||
WSDebug("Flushed events.")
|
||||
|
||||
// Mark c.Conn as empty.
|
||||
c.Conn = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) closeWS() error {
|
||||
// We can't close with a write control here, since it will invalidate the
|
||||
// old session, breaking resumes.
|
||||
|
||||
// // Quick deadline:
|
||||
// deadline := time.Now().Add(CloseDeadline)
|
||||
|
||||
// // Make a closure message:
|
||||
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
|
||||
|
||||
// // Send a close message before closing the connection. We're not error
|
||||
// // checking this because it's not important.
|
||||
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)
|
||||
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// loopState is a thread-unsafe disposable state container for the read loop.
|
||||
// It's made to completely separate the read loop of any synchronization that
|
||||
// doesn't involve the websocket connection itself.
|
||||
type loopState struct {
|
||||
conn *websocket.Conn
|
||||
zlib io.ReadCloser
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
|
||||
// Clean up the events channel in the end.
|
||||
defer close(c.events)
|
||||
defer close(eventCh)
|
||||
|
||||
// Allocate the read loop its own private resources.
|
||||
state := loopState{conn: conn}
|
||||
state.buf.Grow(CopyBufferSize)
|
||||
|
||||
for {
|
||||
b, err := c.handle()
|
||||
b, err := state.handle()
|
||||
if err != nil {
|
||||
// Is the error an EOF?
|
||||
if errors.Is(err, io.EOF) {
|
||||
|
@ -144,7 +201,7 @@ func (c *Conn) readLoop() {
|
|||
}
|
||||
|
||||
// Unusual error; log and exit:
|
||||
c.events <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -153,34 +210,13 @@ func (c *Conn) readLoop() {
|
|||
continue
|
||||
}
|
||||
|
||||
c.events <- Event{b, nil}
|
||||
eventCh <- Event{b, nil}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) writeLoop() {
|
||||
// Closing c.writes would break the loop immediately.
|
||||
for b := range c.writes {
|
||||
c.errors <- c.Conn.WriteMessage(websocket.TextMessage, b)
|
||||
}
|
||||
|
||||
// Quick deadline:
|
||||
deadline := time.Now().Add(CloseDeadline)
|
||||
|
||||
// Make a closure message:
|
||||
msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")
|
||||
|
||||
// Send a close message before closing the connection. We're not error
|
||||
// checking this because it's not important.
|
||||
c.Conn.WriteControl(websocket.TextMessage, msg, deadline)
|
||||
|
||||
// Safe to close now.
|
||||
c.errors <- c.Conn.Close()
|
||||
close(c.errors)
|
||||
}
|
||||
|
||||
func (c *Conn) handle() ([]byte, error) {
|
||||
func (state *loopState) handle() ([]byte, error) {
|
||||
// skip message type
|
||||
t, r, err := c.Conn.NextReader()
|
||||
t, r, err := state.conn.NextReader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -188,112 +224,43 @@ func (c *Conn) handle() ([]byte, error) {
|
|||
if t == websocket.BinaryMessage {
|
||||
// Probably a zlib payload
|
||||
|
||||
if c.zlib == nil {
|
||||
if state.zlib == nil {
|
||||
z, err := zlib.NewReader(r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create a zlib reader")
|
||||
}
|
||||
c.zlib = z
|
||||
state.zlib = z
|
||||
} else {
|
||||
if err := c.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
||||
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to reset zlib reader")
|
||||
}
|
||||
}
|
||||
|
||||
defer c.zlib.Close()
|
||||
r = c.zlib
|
||||
defer state.zlib.Close()
|
||||
r = state.zlib
|
||||
}
|
||||
|
||||
return readAll(&c.buf, r)
|
||||
|
||||
// if t is a text message, then handle it normally.
|
||||
// if t == websocket.TextMessage {
|
||||
// return readAll(&c.buf, r)
|
||||
// }
|
||||
|
||||
// // Write to the zlib writer.
|
||||
// c.zlib.Write(r)
|
||||
// // if _, err := io.CopyBuffer(c.zlib, r, c.buf); err != nil {
|
||||
// // return nil, errors.Wrap(err, "Failed to write to zlib")
|
||||
// // }
|
||||
|
||||
// if !c.zlib.CanFlush() {
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
// // Flush and get the uncompressed payload.
|
||||
// b, err := c.zlib.Flush()
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "Failed to flush zlib")
|
||||
// }
|
||||
|
||||
// return nil, errors.New("Unexpected binary message.")
|
||||
}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
// If websocket is already closed.
|
||||
if c.writes == nil {
|
||||
return ErrWebsocketClosed
|
||||
}
|
||||
|
||||
// Send the bytes.
|
||||
select {
|
||||
case c.writes <- b:
|
||||
// continue
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Receive the error.
|
||||
select {
|
||||
case err := <-c.errors:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Close() (err error) {
|
||||
// Use a sync.Once to guarantee that other Close() calls block until the
|
||||
// main call is done. It also prevents future calls.
|
||||
c.closeOnce.Do(func() {
|
||||
// Close c.writes. This should trigger the websocket to close itself.
|
||||
close(c.writes)
|
||||
// Mark c.writes as empty.
|
||||
c.writes = nil
|
||||
|
||||
// Wait for the write loop to exit by flusing the errors channel.
|
||||
err = <-c.errors // get close error
|
||||
for range c.errors { // then flush
|
||||
}
|
||||
|
||||
// Flush all events before closing the channel. This will return as soon as
|
||||
// c.events is closed, or after closed.
|
||||
for range c.events {
|
||||
}
|
||||
|
||||
// Mark c.events as empty.
|
||||
c.events = nil
|
||||
|
||||
// Mark c.Conn as empty.
|
||||
c.Conn = nil
|
||||
})
|
||||
|
||||
return err
|
||||
return state.readAll(r)
|
||||
}
|
||||
|
||||
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
||||
// buffer.
|
||||
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
|
||||
defer buf.Reset()
|
||||
if _, err := buf.ReadFrom(r); err != nil {
|
||||
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
|
||||
defer state.buf.Reset()
|
||||
|
||||
if _, err := state.buf.ReadFrom(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy the bytes so we could empty the buffer for reuse.
|
||||
p := buf.Bytes()
|
||||
cpy := make([]byte, len(p))
|
||||
copy(cpy, p)
|
||||
cpy := make([]byte, state.buf.Len())
|
||||
copy(cpy, state.buf.Bytes())
|
||||
|
||||
// If the buffer's capacity is over the limit, then re-allocate a new one.
|
||||
if state.buf.Cap() > MaxCapUntilReset {
|
||||
state.buf = bytes.Buffer{}
|
||||
state.buf.Grow(CopyBufferSize)
|
||||
}
|
||||
|
||||
return cpy, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package wsutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -10,35 +11,57 @@ import (
|
|||
"github.com/diamondburned/arikawa/internal/moreatomic"
|
||||
)
|
||||
|
||||
type errBrokenConnection struct {
|
||||
underneath error
|
||||
}
|
||||
|
||||
// Error formats the broken connection error with the message "explicit
|
||||
// connection break."
|
||||
func (err errBrokenConnection) Error() string {
|
||||
return "explicit connection break: " + err.underneath.Error()
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (err errBrokenConnection) Unwrap() error {
|
||||
return err.underneath
|
||||
}
|
||||
|
||||
// ErrBrokenConnection marks the given error as a broken connection error. This
|
||||
// error will cause the pacemaker loop to break and return the error. The error,
|
||||
// when stringified, will say "explicit connection break."
|
||||
func ErrBrokenConnection(err error) error {
|
||||
return errBrokenConnection{underneath: err}
|
||||
}
|
||||
|
||||
// IsBrokenConnection returns true if the error is a broken connection error.
|
||||
func IsBrokenConnection(err error) bool {
|
||||
var broken *errBrokenConnection
|
||||
return errors.As(err, &broken)
|
||||
}
|
||||
|
||||
// TODO API
|
||||
type EventLoopHandler interface {
|
||||
EventHandler
|
||||
HeartbeatCtx(context.Context) error
|
||||
}
|
||||
|
||||
// PacemakerLoop provides an event loop with a pacemaker.
|
||||
// PacemakerLoop provides an event loop with a pacemaker. A zero-value instance
|
||||
// is a valid instance only when RunAsync is called first.
|
||||
type PacemakerLoop struct {
|
||||
pacemaker *heart.Pacemaker // let's not copy this
|
||||
pacedeath chan error
|
||||
|
||||
heart.Pacemaker
|
||||
running moreatomic.Bool
|
||||
|
||||
stop chan struct{}
|
||||
events <-chan Event
|
||||
handler func(*OP) error
|
||||
|
||||
stack []byte
|
||||
|
||||
Extras ExtraHandlers
|
||||
|
||||
ErrorLog func(error)
|
||||
}
|
||||
|
||||
func NewLoop(heartrate time.Duration, evs <-chan Event, evl EventLoopHandler) *PacemakerLoop {
|
||||
return &PacemakerLoop{
|
||||
pacemaker: heart.NewPacemaker(heartrate, evl.HeartbeatCtx),
|
||||
events: evs,
|
||||
handler: evl.HandleOP,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) errorLog(err error) {
|
||||
if p.ErrorLog == nil {
|
||||
WSDebug("Uncaught error:", err)
|
||||
|
@ -50,28 +73,36 @@ func (p *PacemakerLoop) errorLog(err error) {
|
|||
|
||||
// Pace calls the pacemaker's Pace function.
|
||||
func (p *PacemakerLoop) Pace(ctx context.Context) error {
|
||||
return p.pacemaker.Pace(ctx)
|
||||
return p.Pacemaker.PaceCtx(ctx)
|
||||
}
|
||||
|
||||
// Echo calls the pacemaker's Echo function.
|
||||
func (p *PacemakerLoop) Echo() {
|
||||
p.pacemaker.Echo()
|
||||
}
|
||||
|
||||
// Stop calls the pacemaker's Stop function.
|
||||
// Stop stops the pacer loop. It does nothing if the loop is already stopped.
|
||||
func (p *PacemakerLoop) Stop() {
|
||||
p.pacemaker.Stop()
|
||||
if p.Stopped() {
|
||||
return
|
||||
}
|
||||
|
||||
// Despite p.running and p.stop being thread-safe on their own, this entire
|
||||
// block is actually not thread-safe.
|
||||
p.Pacemaker.Stop()
|
||||
close(p.stop)
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) Stopped() bool {
|
||||
return p == nil || !p.running.Get()
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) RunAsync(exit func(error)) {
|
||||
func (p *PacemakerLoop) RunAsync(
|
||||
heartrate time.Duration, evs <-chan Event, evl EventLoopHandler, exit func(error)) {
|
||||
|
||||
WSDebug("Starting the pacemaker loop.")
|
||||
|
||||
// callers should explicitly handle waitgroups.
|
||||
p.pacedeath = p.pacemaker.StartAsync(nil)
|
||||
p.Pacemaker = heart.NewPacemaker(heartrate, evl.HeartbeatCtx)
|
||||
p.handler = evl.HandleOP
|
||||
p.events = evs
|
||||
p.stack = debug.Stack()
|
||||
p.stop = make(chan struct{})
|
||||
|
||||
p.running.Set(true)
|
||||
|
||||
go func() {
|
||||
|
@ -82,21 +113,27 @@ func (p *PacemakerLoop) RunAsync(exit func(error)) {
|
|||
func (p *PacemakerLoop) startLoop() error {
|
||||
defer WSDebug("Pacemaker loop has exited.")
|
||||
defer p.running.Set(false)
|
||||
defer p.Pacemaker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-p.pacedeath:
|
||||
WSDebug("Pacedeath returned with error:", err)
|
||||
return errors.Wrap(err, "pacemaker died, reconnecting")
|
||||
case <-p.stop:
|
||||
WSDebug("Stop requested; exiting.")
|
||||
return nil
|
||||
|
||||
case <-p.Pacemaker.Ticks:
|
||||
if err := p.Pacemaker.Pace(); err != nil {
|
||||
return errors.Wrap(err, "pace failed, reconnecting")
|
||||
}
|
||||
|
||||
case ev, ok := <-p.events:
|
||||
if !ok {
|
||||
WSDebug("Events channel closed, stopping pacemaker.")
|
||||
defer WSDebug("Pacemaker stopped automatically.")
|
||||
// Events channel is closed. Kill the pacemaker manually and
|
||||
// die.
|
||||
p.pacemaker.Stop()
|
||||
return <-p.pacedeath
|
||||
return nil
|
||||
}
|
||||
|
||||
if ev.Error != nil {
|
||||
return errors.Wrap(ev.Error, "event returned error")
|
||||
}
|
||||
|
||||
o, err := DecodeOP(ev)
|
||||
|
@ -109,7 +146,11 @@ func (p *PacemakerLoop) startLoop() error {
|
|||
|
||||
// Handle the event
|
||||
if err := p.handler(o); err != nil {
|
||||
p.errorLog(errors.Wrap(err, "handler failed"))
|
||||
if IsBrokenConnection(err) {
|
||||
return errors.Wrap(err, "handler failed")
|
||||
}
|
||||
|
||||
p.errorLog(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,15 +15,12 @@ import (
|
|||
var (
|
||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||
// before Gateway cancels and fails.
|
||||
WSTimeout = 5 * time.Minute
|
||||
WSTimeout = 30 * time.Second
|
||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
||||
// timeout for the websocket.
|
||||
WSExtraReadTimeout = time.Second
|
||||
// WSDebug is used for extra debug logging. This is expected to behave
|
||||
// similarly to log.Println().
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
|
@ -82,9 +79,6 @@ func (ws *Websocket) Dial(ctx context.Context) error {
|
|||
return errors.Wrap(err, "failed to dial")
|
||||
}
|
||||
|
||||
// Reset the SendLimiter:
|
||||
ws.SendLimiter = NewSendLimiter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
|
@ -18,73 +19,6 @@ import (
|
|||
"github.com/diamondburned/arikawa/voice/voicegateway"
|
||||
)
|
||||
|
||||
type testConfig struct {
|
||||
BotToken string
|
||||
VoiceChID discord.ChannelID
|
||||
}
|
||||
|
||||
func mustConfig(t *testing.T) testConfig {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
t.Fatal("Missing $BOT_TOKEN")
|
||||
}
|
||||
|
||||
var sid = os.Getenv("VOICE_ID")
|
||||
if sid == "" {
|
||||
t.Fatal("Missing $VOICE_ID")
|
||||
}
|
||||
|
||||
id, err := discord.ParseSnowflake(sid)
|
||||
if err != nil {
|
||||
t.Fatal("Invalid $VOICE_ID:", err)
|
||||
}
|
||||
|
||||
return testConfig{
|
||||
BotToken: token,
|
||||
VoiceChID: discord.ChannelID(id),
|
||||
}
|
||||
}
|
||||
|
||||
// file is only a few bytes lolmao
|
||||
func nicoReadTo(t *testing.T, dst io.Writer) {
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open nico.dca:", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
f.Close()
|
||||
})
|
||||
|
||||
var lenbuf [4]byte
|
||||
|
||||
for {
|
||||
if _, err := io.ReadFull(f, lenbuf[:]); !catchRead(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Read the integer
|
||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||
|
||||
// Copy the frame.
|
||||
if _, err := io.CopyN(dst, f, framelen); !catchRead(t, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func catchRead(t *testing.T, err error) bool {
|
||||
t.Helper()
|
||||
|
||||
if err == io.EOF {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("Failed to read:", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
config := mustConfig(t)
|
||||
|
||||
|
@ -150,12 +84,76 @@ func TestIntegration(t *testing.T) {
|
|||
|
||||
finish("sending the speaking command")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := vs.UseContext(ctx); err != nil {
|
||||
t.Fatal("failed to set ctx into vs:", err)
|
||||
}
|
||||
|
||||
// Copy the audio?
|
||||
nicoReadTo(t, vs)
|
||||
|
||||
finish("copying the audio")
|
||||
}
|
||||
|
||||
type testConfig struct {
|
||||
BotToken string
|
||||
VoiceChID discord.ChannelID
|
||||
}
|
||||
|
||||
func mustConfig(t *testing.T) testConfig {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
t.Fatal("Missing $BOT_TOKEN")
|
||||
}
|
||||
|
||||
var sid = os.Getenv("VOICE_ID")
|
||||
if sid == "" {
|
||||
t.Fatal("Missing $VOICE_ID")
|
||||
}
|
||||
|
||||
id, err := discord.ParseSnowflake(sid)
|
||||
if err != nil {
|
||||
t.Fatal("Invalid $VOICE_ID:", err)
|
||||
}
|
||||
|
||||
return testConfig{
|
||||
BotToken: token,
|
||||
VoiceChID: discord.ChannelID(id),
|
||||
}
|
||||
}
|
||||
|
||||
// file is only a few bytes lolmao
|
||||
func nicoReadTo(t *testing.T, dst io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open nico.dca:", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var lenbuf [4]byte
|
||||
|
||||
for {
|
||||
if _, err := io.ReadFull(f, lenbuf[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
t.Fatal("failed to read:", err)
|
||||
}
|
||||
|
||||
// Read the integer
|
||||
framelen := int64(binary.LittleEndian.Uint32(lenbuf[:]))
|
||||
|
||||
// Copy the frame.
|
||||
if _, err := io.CopyN(dst, f, framelen); err != nil && err != io.EOF {
|
||||
t.Fatal("failed to write:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// simple shitty benchmark thing
|
||||
func timer() func(finished string) {
|
||||
var then = time.Now()
|
||||
|
|
|
@ -110,14 +110,18 @@ func (s *Session) UpdateState(ev *gateway.VoiceStateUpdateEvent) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
func (s *Session) JoinChannel(
|
||||
gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.JoinChannelCtx(ctx, gID, cID, muted, deafened)
|
||||
}
|
||||
|
||||
func (s *Session) JoinChannelCtx(ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
func (s *Session) JoinChannelCtx(
|
||||
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) error {
|
||||
|
||||
// Acquire the mutex during join, locking during IO as well.
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
@ -211,8 +215,7 @@ func (s *Session) reconnectCtx(ctx context.Context) (err error) {
|
|||
return errors.Wrap(err, "failed to select protocol")
|
||||
}
|
||||
|
||||
// Start the UDP loop.
|
||||
go s.voiceUDP.Start(&d.SecretKey)
|
||||
s.voiceUDP.UseSecret(d.SecretKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -237,6 +240,18 @@ func (s *Session) StopSpeaking() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UseContext tells the UDP voice connection to write with the given mutex.
|
||||
func (s *Session) UseContext(ctx context.Context) error {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
if s.voiceUDP == nil {
|
||||
return ErrCannotSend
|
||||
}
|
||||
|
||||
return s.voiceUDP.UseContext(ctx)
|
||||
}
|
||||
|
||||
// Write writes into the UDP voice connection WITHOUT a timeout.
|
||||
func (s *Session) Write(b []byte) (int, error) {
|
||||
return s.WriteCtx(context.Background(), b)
|
||||
|
|
180
voice/udp/udp.go
180
voice/udp/udp.go
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/nacl/secretbox"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Dialer is the default dialer that this package uses for all its dialing.
|
||||
|
@ -17,31 +18,29 @@ var Dialer = net.Dialer{
|
|||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// ErrClosed is returned if a Write was called on a closed connection.
|
||||
var ErrClosed = errors.New("UDP connection closed")
|
||||
|
||||
type Connection struct {
|
||||
GatewayIP string
|
||||
GatewayPort uint16
|
||||
|
||||
ssrc uint32
|
||||
mutex chan struct{} // for ctx
|
||||
|
||||
context context.Context
|
||||
conn net.Conn
|
||||
ssrc uint32
|
||||
|
||||
frequency rate.Limiter
|
||||
packet [12]byte
|
||||
secret [32]byte
|
||||
|
||||
sequence uint16
|
||||
timestamp uint32
|
||||
nonce [24]byte
|
||||
|
||||
conn net.Conn
|
||||
close chan struct{}
|
||||
closed chan struct{}
|
||||
|
||||
send chan []byte
|
||||
reply chan error
|
||||
}
|
||||
|
||||
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
|
||||
// // Resolve the host.
|
||||
// a, err := net.ResolveUDPAddr("udp", addr)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "failed to resolve host")
|
||||
// }
|
||||
|
||||
// Create a new UDP connection.
|
||||
conn, err := Dialer.DialContext(ctx, "udp", addr)
|
||||
if err != nil {
|
||||
|
@ -78,20 +77,6 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti
|
|||
ip := ipbody[:nullPos]
|
||||
port := binary.LittleEndian.Uint16(ipBuffer[68:70])
|
||||
|
||||
return &Connection{
|
||||
GatewayIP: string(ip),
|
||||
GatewayPort: port,
|
||||
|
||||
ssrc: ssrc,
|
||||
conn: conn,
|
||||
send: make(chan []byte),
|
||||
reply: make(chan error),
|
||||
close: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Start(secret *[32]byte) {
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#encrypting-and-sending-voice
|
||||
packet := [12]byte{
|
||||
0: 0x80, // Version + Flags
|
||||
|
@ -101,81 +86,118 @@ func (c *Connection) Start(secret *[32]byte) {
|
|||
}
|
||||
|
||||
// Write SSRC to the header.
|
||||
binary.BigEndian.PutUint32(packet[8:12], c.ssrc) // SSRC
|
||||
binary.BigEndian.PutUint32(packet[8:12], ssrc) // SSRC
|
||||
|
||||
// 50 sends per second, 960 samples each at 48kHz
|
||||
frequency := time.NewTicker(time.Millisecond * 20)
|
||||
defer frequency.Stop()
|
||||
return &Connection{
|
||||
GatewayIP: string(ip),
|
||||
GatewayPort: port,
|
||||
// 50 sends per second, 960 samples each at 48kHz
|
||||
frequency: *rate.NewLimiter(rate.Every(20*time.Millisecond), 1),
|
||||
context: context.Background(),
|
||||
mutex: make(chan struct{}, 1),
|
||||
packet: packet,
|
||||
ssrc: ssrc,
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var b []byte
|
||||
var ok bool
|
||||
// UseSecret uses the given secret. This method is not thread-safe, so it should
|
||||
// only be used right after initialization.
|
||||
func (c *Connection) UseSecret(secret [32]byte) {
|
||||
c.secret = secret
|
||||
}
|
||||
|
||||
// Close these channels at the end so Write() doesn't block.
|
||||
defer func() {
|
||||
close(c.send)
|
||||
close(c.closed)
|
||||
}()
|
||||
// UseContext lets the connection use the given context for its Write method.
|
||||
// WriteCtx will override this context.
|
||||
func (c *Connection) UseContext(ctx context.Context) error {
|
||||
c.mutex <- struct{}{}
|
||||
defer func() { <-c.mutex }()
|
||||
|
||||
for {
|
||||
select {
|
||||
case b, ok = <-c.send:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-c.close:
|
||||
return
|
||||
}
|
||||
return c.useContext(ctx)
|
||||
}
|
||||
|
||||
// Write a new sequence.
|
||||
binary.BigEndian.PutUint16(packet[2:4], c.sequence)
|
||||
c.sequence++
|
||||
func (c *Connection) useContext(ctx context.Context) error {
|
||||
if c.conn == nil {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(packet[4:8], c.timestamp)
|
||||
c.timestamp += 960 // Samples
|
||||
if c.context == ctx {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy(c.nonce[:], packet[:])
|
||||
c.context = ctx
|
||||
|
||||
toSend := secretbox.Seal(packet[:], b, &c.nonce, secret)
|
||||
|
||||
select {
|
||||
case <-frequency.C:
|
||||
case <-c.close:
|
||||
// Prevent Write() from stalling before exiting.
|
||||
c.reply <- nil
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
_, err := c.conn.Write(toSend)
|
||||
c.reply <- err
|
||||
if deadline, ok := c.context.Deadline(); ok {
|
||||
return c.conn.SetWriteDeadline(deadline)
|
||||
} else {
|
||||
return c.conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
close(c.close)
|
||||
<-c.closed
|
||||
|
||||
return c.conn.Close()
|
||||
c.mutex <- struct{}{}
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
<-c.mutex
|
||||
return err
|
||||
}
|
||||
|
||||
// Write sends bytes into the voice UDP connection.
|
||||
func (c *Connection) Write(b []byte) (int, error) {
|
||||
return c.WriteCtx(context.Background(), b)
|
||||
select {
|
||||
case c.mutex <- struct{}{}:
|
||||
defer func() { <-c.mutex }()
|
||||
case <-c.context.Done():
|
||||
return 0, c.context.Err()
|
||||
}
|
||||
|
||||
if c.conn == nil {
|
||||
return 0, ErrClosed
|
||||
}
|
||||
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
// WriteCtx sends bytes into the voice UDP connection with a timeout.
|
||||
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
select {
|
||||
case c.send <- b:
|
||||
break
|
||||
case c.mutex <- struct{}{}:
|
||||
defer func() { <-c.mutex }()
|
||||
case <-c.context.Done():
|
||||
return 0, c.context.Err()
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-c.reply:
|
||||
return len(b), err
|
||||
case <-ctx.Done():
|
||||
return len(b), ctx.Err()
|
||||
if err := c.useContext(ctx); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to use context")
|
||||
}
|
||||
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
// write is thread-unsafe.
|
||||
func (c *Connection) write(b []byte) (int, error) {
|
||||
// Write a new sequence.
|
||||
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
|
||||
c.sequence++
|
||||
|
||||
binary.BigEndian.PutUint32(c.packet[4:8], c.timestamp)
|
||||
c.timestamp += 960 // Samples
|
||||
|
||||
copy(c.nonce[:], c.packet[:])
|
||||
|
||||
if err := c.frequency.Wait(c.context); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to wait for frequency tick")
|
||||
}
|
||||
|
||||
toSend := secretbox.Seal(c.packet[:], b, &c.nonce, &c.secret)
|
||||
|
||||
n, err := c.conn.Write(toSend)
|
||||
if err != nil {
|
||||
return n, errors.Wrap(err, "failed to write to UDP connection")
|
||||
}
|
||||
|
||||
// We're not really returning everything, since we're "sealing" the bytes.
|
||||
return len(b), nil
|
||||
}
|
||||
|
|
|
@ -5,9 +5,11 @@
|
|||
package voice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
|
@ -170,9 +172,13 @@ func (v *Voice) Close() error {
|
|||
}
|
||||
|
||||
for gID, s := range v.sessions {
|
||||
if dErr := s.Disconnect(); dErr != nil {
|
||||
log.Println("closing", gID)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
if dErr := s.DisconnectCtx(ctx); dErr != nil {
|
||||
err.SessionErrors[gID] = dErr
|
||||
}
|
||||
cancel()
|
||||
log.Println("closed", gID)
|
||||
}
|
||||
|
||||
err.StateErr = v.State.Close()
|
||||
|
|
|
@ -51,12 +51,12 @@ type Gateway struct {
|
|||
mutex sync.RWMutex
|
||||
ready ReadyEvent
|
||||
|
||||
ws *wsutil.Websocket
|
||||
WS *wsutil.Websocket
|
||||
|
||||
Timeout time.Duration
|
||||
reconnect moreatomic.Bool
|
||||
|
||||
EventLoop *wsutil.PacemakerLoop
|
||||
EventLoop wsutil.PacemakerLoop
|
||||
|
||||
// ErrorLog will be called when an error occurs (defaults to log.Println)
|
||||
ErrorLog func(err error)
|
||||
|
@ -96,14 +96,14 @@ func (c *Gateway) OpenCtx(ctx context.Context) error {
|
|||
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
|
||||
|
||||
wsutil.WSDebug("Connecting to voice endpoint (endpoint=" + endpoint + ")")
|
||||
c.ws = wsutil.New(endpoint)
|
||||
c.WS = wsutil.New(endpoint)
|
||||
|
||||
// Create a new context with a timeout for the connection.
|
||||
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Connect to the Gateway Gateway.
|
||||
if err := c.ws.Dial(ctx); err != nil {
|
||||
if err := c.WS.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to connect to voice gateway")
|
||||
}
|
||||
|
||||
|
@ -138,7 +138,7 @@ func (c *Gateway) __start(ctx context.Context) error {
|
|||
// Make a new WaitGroup for use in background loops:
|
||||
c.waitGroup = new(sync.WaitGroup)
|
||||
|
||||
ch := c.ws.Listen()
|
||||
ch := c.WS.Listen()
|
||||
|
||||
// Wait for hello.
|
||||
wsutil.WSDebug("Waiting for Hello..")
|
||||
|
@ -181,13 +181,10 @@ func (c *Gateway) __start(ctx context.Context) error {
|
|||
return errors.Wrap(err, "failed to wait for Ready or Resumed")
|
||||
}
|
||||
|
||||
// Create an event loop executor.
|
||||
c.EventLoop = wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, c)
|
||||
|
||||
// Start the event handler, which also handles the pacemaker death signal.
|
||||
c.waitGroup.Add(1)
|
||||
|
||||
c.EventLoop.RunAsync(func(err error) {
|
||||
c.EventLoop.RunAsync(hello.HeartbeatInterval.Duration(), ch, c, func(err error) {
|
||||
c.waitGroup.Done() // mark so Close() can exit.
|
||||
wsutil.WSDebug("Event loop stopped.")
|
||||
|
||||
|
@ -208,38 +205,38 @@ func (c *Gateway) __start(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Close .
|
||||
func (c *Gateway) Close() error {
|
||||
// Check if the WS is already closed:
|
||||
if c.waitGroup == nil && c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
func (c *Gateway) Close() (err error) {
|
||||
wsutil.WSDebug("Trying to close.")
|
||||
|
||||
c.AfterClose(nil)
|
||||
return nil
|
||||
// Check if the WS is already closed:
|
||||
if c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Gateway is already closed.")
|
||||
return err
|
||||
}
|
||||
|
||||
// Trigger the close callback on exit.
|
||||
defer func() { c.AfterClose(err) }()
|
||||
|
||||
// If the pacemaker is running:
|
||||
if !c.EventLoop.Stopped() {
|
||||
wsutil.WSDebug("Stopping pacemaker...")
|
||||
|
||||
// Stop the pacemaker and the event handler
|
||||
// Stop the pacemaker and the event handler.
|
||||
c.EventLoop.Stop()
|
||||
|
||||
wsutil.WSDebug("Stopped pacemaker.")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Closing the websocket...")
|
||||
err = c.WS.Close()
|
||||
|
||||
wsutil.WSDebug("Waiting for WaitGroup to be done.")
|
||||
|
||||
// This should work, since Pacemaker should signal its loop to stop, which
|
||||
// would also exit our event loop. Both would be 2.
|
||||
c.waitGroup.Wait()
|
||||
|
||||
// Mark g.waitGroup as empty:
|
||||
c.waitGroup = nil
|
||||
|
||||
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
|
||||
|
||||
err := c.ws.Close()
|
||||
c.AfterClose(err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -311,11 +308,11 @@ func (c *Gateway) Send(code OPCode, v interface{}) error {
|
|||
}
|
||||
|
||||
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
|
||||
if c.ws == nil {
|
||||
if c.WS == nil {
|
||||
return errors.New("tried to send data to a connection without a Websocket")
|
||||
}
|
||||
|
||||
if c.ws.Conn == nil {
|
||||
if c.WS.Conn == nil {
|
||||
return errors.New("tried to send data to a connection with a closed Websocket")
|
||||
}
|
||||
|
||||
|
@ -338,5 +335,5 @@ func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error
|
|||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return c.ws.SendCtx(ctx, b)
|
||||
return c.WS.SendCtx(ctx, b)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue