1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-10-31 20:14:21 +00:00

gateway: Refactor for a better concurrent API

This commit refactors the whole package gateway as well as utils/ws
(formerly utils/wsutil) and voice/voicegateway. The new refactor
utilizes a design pattern involving a concurrent loop and an arriving
event channel.

An additional change was made to the way gateway events are typed.
Before, pretty much any type will satisfy a gateway event type, since
the actual type was just interface{}. The new refactor defines a
concrete interface that events can implement:

    type Event interface {
        Op() OpCode
        EventType() EventType
    }

Using this interface, the user can easily add custom gateway events
independently of the library without relying on string maps. This adds a
lot of type safety into the library and makes type-switching on Event
types much more reasonable.

Gateway error callbacks are also almost entirely removed in favor of
custom gateway events. A catch-all can easily be added like this:

    s.AddHandler(func(err error) {
        log.Println("gateway error:, err")
    })
This commit is contained in:
diamondburned 2021-09-28 13:19:04 -07:00
parent 5c88317130
commit 17b9c73ce3
No known key found for this signature in database
GPG key ID: D78C4471CE776659
67 changed files with 4391 additions and 3687 deletions

View file

@ -12,15 +12,25 @@ environment:
GO111MODULE: "on"
CGO_ENABLED: "1"
# Integration test variables.
SHARD_COUNT: "3"
SHARD_COUNT: "2"
tested: "./api,./gateway,./bot,./discord"
cov_file: "/tmp/cov_results"
dismock: "github.com/mavolin/dismock/v2/pkg/dismock"
dismock_v: "259685b84e4b6ab364b0fd858aac2aa2dfa42502"
tasks:
- generate: |-
cd arikawa
go generate ./...
if [[ "$(git status --porcelain)" ]]; then
echo "Repository differ after regeneration."
exit 1
fi
- build: cd arikawa && go build ./...
- unit: cd arikawa && go test -tags unitonly -race ./...
- integration: |-
sh -c '
test -f ~/.env || {

View file

@ -23,11 +23,7 @@ func main() {
log.Fatalln("No $BOT_TOKEN given.")
}
s, err := state.New("Bot " + token)
if err != nil {
log.Fatalln("Session failed:", err)
return
}
s := state.New("Bot " + token)
app, err := s.CurrentApplication()
if err != nil {

View file

@ -22,11 +22,7 @@ func main() {
log.Fatalln("no $BOT_TOKEN given")
}
s, err := session.New("Bot " + token)
if err != nil {
log.Fatalln("session failed:", err)
return
}
s := session.New("Bot " + token)
app, err := s.CurrentApplication()
if err != nil {

View file

@ -22,11 +22,7 @@ func main() {
log.Fatalln("No $BOT_TOKEN given.")
}
s, err := state.New("Bot " + token)
if err != nil {
log.Fatalln("Session failed:", err)
return
}
s := state.New("Bot " + token)
app, err := s.CurrentApplication()
if err != nil {

View file

@ -8,7 +8,7 @@ import (
"os"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/session/shard"
"github.com/diamondburned/arikawa/v3/state"
)

View file

@ -19,11 +19,7 @@ func main() {
log.Fatalln("No $BOT_TOKEN given.")
}
s, err := session.New("Bot " + token)
if err != nil {
log.Fatalln("Session failed:", err)
}
s := session.New("Bot " + token)
s.AddHandler(func(c *gateway.MessageCreateEvent) {
log.Println(c.Author.Username, "sent", c.Content)
})

View file

@ -19,11 +19,7 @@ func main() {
log.Fatalln("No $BOT_TOKEN given.")
}
s, err := state.New("Bot " + token)
if err != nil {
log.Fatalln("Session failed:", err)
}
s := state.New("Bot " + token)
// Make a pre-handler
s.PreHandler = handler.New()
s.PreHandler.AddSyncHandler(func(c *gateway.MessageDeleteEvent) {

View file

@ -47,6 +47,46 @@ This example demonstrates the PreHandler feature of the state library.
PreHandler calls all handlers that are registered (separately from the session),
calling them before the state is updated.
### Bare Minimum Print Example
The least amount of code recommended to have a bot that logs all messages to
console.
```go
package main
import (
"context"
"log"
"os"
"os/signal"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/state"
)
func main() {
s := state.New("Bot " + os.Getenv("DISCORD_TOKEN"))
s.AddIntents(gateway.IntentGuilds | gateway.IntentGuildMessages)
s.AddHandler(func(m *gateway.MessageCreateEvent) {
log.Printf("%s: %s", m.Author.Username, m.Content)
})
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
if err := s.Open(ctx); err != nil {
log.Println("cannot open:", err)
}
<-ctx.Done() // block until Ctrl+C
if err := s.Close(); err != nil {
log.Println("cannot close:", err)
}
}
```
### Bare Minimum Bot
The least amount of code for a basic ping-pong bot. It's similar to Serenity's

View file

@ -1,6 +1,8 @@
package api
import (
"context"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/utils/httputil"
)
@ -30,7 +32,8 @@ func (c *Client) BotURL() (*BotData, error) {
}
// GatewayURL asks Discord for a Websocket URL to the Gateway.
func GatewayURL() (string, error) {
func GatewayURL(ctx context.Context) (string, error) {
var g BotData
return g.URL, httputil.NewClient().RequestJSON(&g, "GET", EndpointGateway)
err := httputil.NewClient().WithContext(ctx).RequestJSON(&g, "GET", EndpointGateway)
return g.URL, err
}

View file

@ -15,7 +15,7 @@ func DurationSinceEpoch(t time.Time) time.Duration {
return time.Duration(t.UnixNano()) - Epoch
}
//go:generate go run ../utils/gensnowflake -o snowflake_types.go AppID AttachmentID AuditLogEntryID ChannelID CommandID EmojiID GuildID IntegrationID InteractionID MessageID RoleID StageID StickerID StickerPackID TeamID UserID WebhookID
//go:generate go run ../utils/cmd/gensnowflake -o snowflake_types.go AppID AttachmentID AuditLogEntryID ChannelID CommandID EmojiID GuildID IntegrationID InteractionID MessageID RoleID StageID StickerID StickerPackID TeamID UserID WebhookID
// Mention generates the mention syntax for this channel ID.
func (s ChannelID) Mention() string { return "<#" + s.String() + ">" }

View file

@ -1,167 +0,0 @@
package gateway
import (
"context"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/discord"
)
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
// Identify structure is at identify.go
// Identify sends off the Identify command with the Gateway's IdentifyData.
func (g *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.IdentifyCtx(ctx)
}
// IdentifyCtx sends off the Identify command with the Gateway's IdentifyData
// with the given context for time out.
func (g *Gateway) IdentifyCtx(ctx context.Context) error {
if err := g.Identifier.Wait(ctx); err != nil {
return errors.Wrap(err, "can't wait for identify()")
}
return g.SendCtx(ctx, IdentifyOP, g.Identifier)
}
type ResumeData struct {
Token string `json:"token"`
SessionID string `json:"session_id"`
Sequence int64 `json:"seq"`
}
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) Resume() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.ResumeCtx(ctx)
}
// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) ResumeCtx(ctx context.Context) error {
var (
ses = g.SessionID()
seq = g.Sequence.Get()
)
if ses == "" || seq == 0 {
return ErrMissingForResume
}
return g.SendCtx(ctx, ResumeOP, ResumeData{
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,
})
}
// HeartbeatData is the last sequence number to be sent.
type HeartbeatData int
func (g *Gateway) Heartbeat() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.HeartbeatCtx(ctx)
}
func (g *Gateway) HeartbeatCtx(ctx context.Context) error {
return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get())
}
type RequestGuildMembersData struct {
// GuildIDs contains the ids of the guilds to request data from. Multiple
// guilds can only be requested when using user accounts.
GuildIDs []discord.GuildID `json:"guild_id"`
UserIDs []discord.UserID `json:"user_ids,omitempty"`
Query string `json:"query"`
Limit uint `json:"limit"`
Presences bool `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.RequestGuildMembersCtx(ctx, data)
}
func (g *Gateway) RequestGuildMembersCtx(
ctx context.Context, data RequestGuildMembersData) error {
return g.SendCtx(ctx, RequestGuildMembersOP, data)
}
type UpdateVoiceStateData struct {
GuildID discord.GuildID `json:"guild_id"`
ChannelID discord.ChannelID `json:"channel_id"` // nullable
SelfMute bool `json:"self_mute"`
SelfDeaf bool `json:"self_deaf"`
}
func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateVoiceStateCtx(ctx, data)
}
func (g *Gateway) UpdateVoiceStateCtx(ctx context.Context, data UpdateVoiceStateData) error {
return g.SendCtx(ctx, VoiceStateUpdateOP, data)
}
// UpdateStatusData is sent by this client to indicate a presence or status
// update.
type UpdateStatusData struct {
Since discord.UnixMsTimestamp `json:"since"` // 0 if not idle
// Activities can be null or an empty slice.
Activities []discord.Activity `json:"activities"`
Status discord.Status `json:"status"`
AFK bool `json:"afk"`
}
func (g *Gateway) UpdateStatus(data UpdateStatusData) error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateStatusCtx(ctx, data)
}
func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error {
return g.SendCtx(ctx, StatusUpdateOP, data)
}
// Undocumented
type GuildSubscribeData struct {
Typing bool `json:"typing"`
Threads bool `json:"threads"`
Activities bool `json:"activities"`
GuildID discord.GuildID `json:"guild_id"`
// Channels is not documented. It's used to fetch the right members sidebar.
Channels map[discord.ChannelID][][2]int `json:"channels,omitempty"`
}
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.GuildSubscribeCtx(ctx, data)
}
func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error {
return g.SendCtx(ctx, GuildSubscriptionsOP, data)
}

460
gateway/event_methods.go Normal file
View file

@ -0,0 +1,460 @@
// Code generated by genevent. DO NOT EDIT.
package gateway
import "github.com/diamondburned/arikawa/v3/utils/ws"
func init() {
OpUnmarshalers.Add(
func() ws.Event { return new(HeartbeatCommand) },
func() ws.Event { return new(HeartbeatAckEvent) },
func() ws.Event { return new(ReconnectEvent) },
func() ws.Event { return new(HelloEvent) },
func() ws.Event { return new(ResumeCommand) },
func() ws.Event { return new(InvalidSessionEvent) },
func() ws.Event { return new(RequestGuildMembersCommand) },
func() ws.Event { return new(UpdateVoiceStateCommand) },
func() ws.Event { return new(UpdatePresenceCommand) },
func() ws.Event { return new(GuildSubscribeCommand) },
func() ws.Event { return new(ResumedEvent) },
func() ws.Event { return new(ChannelCreateEvent) },
func() ws.Event { return new(ChannelUpdateEvent) },
func() ws.Event { return new(ChannelDeleteEvent) },
func() ws.Event { return new(ChannelPinsUpdateEvent) },
func() ws.Event { return new(ChannelUnreadUpdateEvent) },
func() ws.Event { return new(ThreadCreateEvent) },
func() ws.Event { return new(ThreadUpdateEvent) },
func() ws.Event { return new(ThreadDeleteEvent) },
func() ws.Event { return new(ThreadListSyncEvent) },
func() ws.Event { return new(ThreadMemberUpdateEvent) },
func() ws.Event { return new(ThreadMembersUpdateEvent) },
func() ws.Event { return new(GuildCreateEvent) },
func() ws.Event { return new(GuildUpdateEvent) },
func() ws.Event { return new(GuildDeleteEvent) },
func() ws.Event { return new(GuildBanAddEvent) },
func() ws.Event { return new(GuildBanRemoveEvent) },
func() ws.Event { return new(GuildEmojisUpdateEvent) },
func() ws.Event { return new(GuildIntegrationsUpdateEvent) },
func() ws.Event { return new(GuildMemberAddEvent) },
func() ws.Event { return new(GuildMemberRemoveEvent) },
func() ws.Event { return new(GuildMemberUpdateEvent) },
func() ws.Event { return new(GuildMembersChunkEvent) },
func() ws.Event { return new(GuildRoleCreateEvent) },
func() ws.Event { return new(GuildRoleUpdateEvent) },
func() ws.Event { return new(GuildRoleDeleteEvent) },
func() ws.Event { return new(InviteCreateEvent) },
func() ws.Event { return new(InviteDeleteEvent) },
func() ws.Event { return new(MessageCreateEvent) },
func() ws.Event { return new(MessageUpdateEvent) },
func() ws.Event { return new(MessageDeleteEvent) },
func() ws.Event { return new(MessageDeleteBulkEvent) },
func() ws.Event { return new(MessageReactionAddEvent) },
func() ws.Event { return new(MessageReactionRemoveEvent) },
func() ws.Event { return new(MessageReactionRemoveAllEvent) },
func() ws.Event { return new(MessageReactionRemoveEmojiEvent) },
func() ws.Event { return new(MessageAckEvent) },
func() ws.Event { return new(PresenceUpdateEvent) },
func() ws.Event { return new(PresencesReplaceEvent) },
func() ws.Event { return new(SessionsReplaceEvent) },
func() ws.Event { return new(TypingStartEvent) },
func() ws.Event { return new(UserUpdateEvent) },
func() ws.Event { return new(VoiceStateUpdateEvent) },
func() ws.Event { return new(VoiceServerUpdateEvent) },
func() ws.Event { return new(WebhooksUpdateEvent) },
func() ws.Event { return new(InteractionCreateEvent) },
func() ws.Event { return new(UserGuildSettingsUpdateEvent) },
func() ws.Event { return new(UserSettingsUpdateEvent) },
func() ws.Event { return new(UserNoteUpdateEvent) },
func() ws.Event { return new(RelationshipAddEvent) },
func() ws.Event { return new(RelationshipRemoveEvent) },
func() ws.Event { return new(ReadyEvent) },
func() ws.Event { return new(ReadySupplementalEvent) },
func() ws.Event { return new(IdentifyCommand) },
)
}
// Op implements Event. It always returns Op 1.
func (*HeartbeatCommand) Op() ws.OpCode { return 1 }
// EventType implements Event.
func (*HeartbeatCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 11.
func (*HeartbeatAckEvent) Op() ws.OpCode { return 11 }
// EventType implements Event.
func (*HeartbeatAckEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 7.
func (*ReconnectEvent) Op() ws.OpCode { return 7 }
// EventType implements Event.
func (*ReconnectEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 10.
func (*HelloEvent) Op() ws.OpCode { return 10 }
// EventType implements Event.
func (*HelloEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 6.
func (*ResumeCommand) Op() ws.OpCode { return 6 }
// EventType implements Event.
func (*ResumeCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 9.
func (*InvalidSessionEvent) Op() ws.OpCode { return 9 }
// EventType implements Event.
func (*InvalidSessionEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 8.
func (*RequestGuildMembersCommand) Op() ws.OpCode { return 8 }
// EventType implements Event.
func (*RequestGuildMembersCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 4.
func (*UpdateVoiceStateCommand) Op() ws.OpCode { return 4 }
// EventType implements Event.
func (*UpdateVoiceStateCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 3.
func (*UpdatePresenceCommand) Op() ws.OpCode { return 3 }
// EventType implements Event.
func (*UpdatePresenceCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 14.
func (*GuildSubscribeCommand) Op() ws.OpCode { return 14 }
// EventType implements Event.
func (*GuildSubscribeCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns 0.
func (*ResumedEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ResumedEvent) EventType() ws.EventType { return "RESUMED" }
// Op implements Event. It always returns 0.
func (*ChannelCreateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ChannelCreateEvent) EventType() ws.EventType { return "CHANNEL_CREATE" }
// Op implements Event. It always returns 0.
func (*ChannelUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ChannelUpdateEvent) EventType() ws.EventType { return "CHANNEL_UPDATE" }
// Op implements Event. It always returns 0.
func (*ChannelDeleteEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ChannelDeleteEvent) EventType() ws.EventType { return "CHANNEL_DELETE" }
// Op implements Event. It always returns 0.
func (*ChannelPinsUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ChannelPinsUpdateEvent) EventType() ws.EventType { return "CHANNEL_PINS_UPDATE" }
// Op implements Event. It always returns 0.
func (*ChannelUnreadUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ChannelUnreadUpdateEvent) EventType() ws.EventType { return "CHANNEL_UNREAD_UPDATE" }
// Op implements Event. It always returns 0.
func (*ThreadCreateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ThreadCreateEvent) EventType() ws.EventType { return "THREAD_CREATE" }
// Op implements Event. It always returns 0.
func (*ThreadUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ThreadUpdateEvent) EventType() ws.EventType { return "THREAD_UPDATE" }
// Op implements Event. It always returns 0.
func (*ThreadDeleteEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ThreadDeleteEvent) EventType() ws.EventType { return "THREAD_DELETE" }
// Op implements Event. It always returns 0.
func (*ThreadListSyncEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ThreadListSyncEvent) EventType() ws.EventType { return "THREAD_LIST_SYNC" }
// Op implements Event. It always returns 0.
func (*ThreadMemberUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ThreadMemberUpdateEvent) EventType() ws.EventType { return "THREAD_MEMBER_UPDATE" }
// Op implements Event. It always returns 0.
func (*ThreadMembersUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ThreadMembersUpdateEvent) EventType() ws.EventType { return "THREAD_MEMBERS_UPDATE" }
// Op implements Event. It always returns 0.
func (*GuildCreateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildCreateEvent) EventType() ws.EventType { return "GUILD_CREATE" }
// Op implements Event. It always returns 0.
func (*GuildUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildUpdateEvent) EventType() ws.EventType { return "GUILD_UPDATE" }
// Op implements Event. It always returns 0.
func (*GuildDeleteEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildDeleteEvent) EventType() ws.EventType { return "GUILD_DELETE" }
// Op implements Event. It always returns 0.
func (*GuildBanAddEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildBanAddEvent) EventType() ws.EventType { return "GUILD_BAN_ADD" }
// Op implements Event. It always returns 0.
func (*GuildBanRemoveEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildBanRemoveEvent) EventType() ws.EventType { return "GUILD_BAN_REMOVE" }
// Op implements Event. It always returns 0.
func (*GuildEmojisUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildEmojisUpdateEvent) EventType() ws.EventType { return "GUILD_EMOJIS_UPDATE" }
// Op implements Event. It always returns 0.
func (*GuildIntegrationsUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildIntegrationsUpdateEvent) EventType() ws.EventType { return "GUILD_INTEGRATIONS_UPDATE" }
// Op implements Event. It always returns 0.
func (*GuildMemberAddEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildMemberAddEvent) EventType() ws.EventType { return "GUILD_MEMBER_ADD" }
// Op implements Event. It always returns 0.
func (*GuildMemberRemoveEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildMemberRemoveEvent) EventType() ws.EventType { return "GUILD_MEMBER_REMOVE" }
// Op implements Event. It always returns 0.
func (*GuildMemberUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildMemberUpdateEvent) EventType() ws.EventType { return "GUILD_MEMBER_UPDATE" }
// Op implements Event. It always returns 0.
func (*GuildMembersChunkEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildMembersChunkEvent) EventType() ws.EventType { return "GUILD_MEMBERS_CHUNK" }
// Op implements Event. It always returns 0.
func (*GuildRoleCreateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildRoleCreateEvent) EventType() ws.EventType { return "GUILD_ROLE_CREATE" }
// Op implements Event. It always returns 0.
func (*GuildRoleUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildRoleUpdateEvent) EventType() ws.EventType { return "GUILD_ROLE_UPDATE" }
// Op implements Event. It always returns 0.
func (*GuildRoleDeleteEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*GuildRoleDeleteEvent) EventType() ws.EventType { return "GUILD_ROLE_DELETE" }
// Op implements Event. It always returns 0.
func (*InviteCreateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*InviteCreateEvent) EventType() ws.EventType { return "INVITE_CREATE" }
// Op implements Event. It always returns 0.
func (*InviteDeleteEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*InviteDeleteEvent) EventType() ws.EventType { return "INVITE_DELETE" }
// Op implements Event. It always returns 0.
func (*MessageCreateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageCreateEvent) EventType() ws.EventType { return "MESSAGE_CREATE" }
// Op implements Event. It always returns 0.
func (*MessageUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageUpdateEvent) EventType() ws.EventType { return "MESSAGE_UPDATE" }
// Op implements Event. It always returns 0.
func (*MessageDeleteEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageDeleteEvent) EventType() ws.EventType { return "MESSAGE_DELETE" }
// Op implements Event. It always returns 0.
func (*MessageDeleteBulkEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageDeleteBulkEvent) EventType() ws.EventType { return "MESSAGE_DELETE_BULK" }
// Op implements Event. It always returns 0.
func (*MessageReactionAddEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageReactionAddEvent) EventType() ws.EventType { return "MESSAGE_REACTION_ADD" }
// Op implements Event. It always returns 0.
func (*MessageReactionRemoveEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageReactionRemoveEvent) EventType() ws.EventType { return "MESSAGE_REACTION_REMOVE" }
// Op implements Event. It always returns 0.
func (*MessageReactionRemoveAllEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageReactionRemoveAllEvent) EventType() ws.EventType { return "MESSAGE_REACTION_REMOVE_ALL" }
// Op implements Event. It always returns 0.
func (*MessageReactionRemoveEmojiEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageReactionRemoveEmojiEvent) EventType() ws.EventType {
return "MESSAGE_REACTION_REMOVE_EMOJI"
}
// Op implements Event. It always returns 0.
func (*MessageAckEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*MessageAckEvent) EventType() ws.EventType { return "MESSAGE_ACK" }
// Op implements Event. It always returns 0.
func (*PresenceUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*PresenceUpdateEvent) EventType() ws.EventType { return "PRESENCE_UPDATE" }
// Op implements Event. It always returns 0.
func (*PresencesReplaceEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*PresencesReplaceEvent) EventType() ws.EventType { return "PRESENCES_REPLACE" }
// Op implements Event. It always returns 0.
func (*SessionsReplaceEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*SessionsReplaceEvent) EventType() ws.EventType { return "SESSIONS_REPLACE" }
// Op implements Event. It always returns 0.
func (*TypingStartEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*TypingStartEvent) EventType() ws.EventType { return "TYPING_START" }
// Op implements Event. It always returns 0.
func (*UserUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*UserUpdateEvent) EventType() ws.EventType { return "USER_UPDATE" }
// Op implements Event. It always returns 0.
func (*VoiceStateUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*VoiceStateUpdateEvent) EventType() ws.EventType { return "VOICE_STATE_UPDATE" }
// Op implements Event. It always returns 0.
func (*VoiceServerUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*VoiceServerUpdateEvent) EventType() ws.EventType { return "VOICE_SERVER_UPDATE" }
// Op implements Event. It always returns 0.
func (*WebhooksUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*WebhooksUpdateEvent) EventType() ws.EventType { return "WEBHOOKS_UPDATE" }
// Op implements Event. It always returns 0.
func (*InteractionCreateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*InteractionCreateEvent) EventType() ws.EventType { return "INTERACTION_CREATE" }
// Op implements Event. It always returns 0.
func (*UserGuildSettingsUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*UserGuildSettingsUpdateEvent) EventType() ws.EventType { return "USER_GUILD_SETTINGS_UPDATE" }
// Op implements Event. It always returns 0.
func (*UserSettingsUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*UserSettingsUpdateEvent) EventType() ws.EventType { return "USER_SETTINGS_UPDATE" }
// Op implements Event. It always returns 0.
func (*UserNoteUpdateEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*UserNoteUpdateEvent) EventType() ws.EventType { return "USER_NOTE_UPDATE" }
// Op implements Event. It always returns 0.
func (*RelationshipAddEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*RelationshipAddEvent) EventType() ws.EventType { return "RELATIONSHIP_ADD" }
// Op implements Event. It always returns 0.
func (*RelationshipRemoveEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*RelationshipRemoveEvent) EventType() ws.EventType { return "RELATIONSHIP_REMOVE" }
// Op implements Event. It always returns 0.
func (*ReadyEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ReadyEvent) EventType() ws.EventType { return "READY" }
// Op implements Event. It always returns 0.
func (*ReadySupplementalEvent) Op() ws.OpCode { return dispatchOp }
// EventType implements Event.
func (*ReadySupplementalEvent) EventType() ws.EventType { return "READY_SUPPLEMENTAL" }
// Op implements Event. It always returns Op 2.
func (*IdentifyCommand) Op() ws.OpCode { return 2 }
// EventType implements Event.
func (*IdentifyCommand) EventType() ws.EventType { return "" }

File diff suppressed because it is too large Load diff

View file

@ -1,76 +0,0 @@
package gateway
// Event is any event struct. They have an "Event" suffixed to them.
type Event = interface{}
// EventCreator maps an event type string to a constructor.
var EventCreator = map[string]func() Event{
"HELLO": func() Event { return new(HelloEvent) },
"READY": func() Event { return new(ReadyEvent) },
"READY_SUPPLEMENTAL": func() Event { return new(ReadySupplementalEvent) },
"RESUMED": func() Event { return new(ResumedEvent) },
"INVALID_SESSION": func() Event { return new(InvalidSessionEvent) },
"CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) },
"CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) },
"CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) },
"CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) },
"CHANNEL_UNREAD_UPDATE": func() Event { return new(ChannelUnreadUpdateEvent) },
"GUILD_CREATE": func() Event { return new(GuildCreateEvent) },
"GUILD_UPDATE": func() Event { return new(GuildUpdateEvent) },
"GUILD_DELETE": func() Event { return new(GuildDeleteEvent) },
"GUILD_BAN_ADD": func() Event { return new(GuildBanAddEvent) },
"GUILD_BAN_REMOVE": func() Event { return new(GuildBanRemoveEvent) },
"GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) },
"GUILD_INTEGRATIONS_UPDATE": func() Event { return new(GuildIntegrationsUpdateEvent) },
"GUILD_MEMBER_ADD": func() Event { return new(GuildMemberAddEvent) },
"GUILD_MEMBER_REMOVE": func() Event { return new(GuildMemberRemoveEvent) },
"GUILD_MEMBER_UPDATE": func() Event { return new(GuildMemberUpdateEvent) },
"GUILD_MEMBERS_CHUNK": func() Event { return new(GuildMembersChunkEvent) },
"GUILD_MEMBER_LIST_UPDATE": func() Event { return new(GuildMemberListUpdate) },
"GUILD_ROLE_CREATE": func() Event { return new(GuildRoleCreateEvent) },
"GUILD_ROLE_UPDATE": func() Event { return new(GuildRoleUpdateEvent) },
"GUILD_ROLE_DELETE": func() Event { return new(GuildRoleDeleteEvent) },
"INVITE_CREATE": func() Event { return new(InviteCreateEvent) },
"INVITE_DELETE": func() Event { return new(InviteDeleteEvent) },
"MESSAGE_CREATE": func() Event { return new(MessageCreateEvent) },
"MESSAGE_UPDATE": func() Event { return new(MessageUpdateEvent) },
"MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) },
"MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) },
"MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) },
"MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) },
"MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) },
"MESSAGE_REACTION_REMOVE_EMOJI": func() Event { return new(MessageReactionRemoveEmojiEvent) },
"MESSAGE_ACK": func() Event { return new(MessageAckEvent) },
"PRESENCE_UPDATE": func() Event { return new(PresenceUpdateEvent) },
"PRESENCES_REPLACE": func() Event { return new(PresencesReplaceEvent) },
"SESSIONS_REPLACE": func() Event { return new(SessionsReplaceEvent) },
"TYPING_START": func() Event { return new(TypingStartEvent) },
"VOICE_STATE_UPDATE": func() Event { return new(VoiceStateUpdateEvent) },
"VOICE_SERVER_UPDATE": func() Event { return new(VoiceServerUpdateEvent) },
"WEBHOOKS_UPDATE": func() Event { return new(WebhooksUpdateEvent) },
"INTERACTION_CREATE": func() Event { return new(InteractionCreateEvent) },
"USER_UPDATE": func() Event { return new(UserUpdateEvent) },
"USER_SETTINGS_UPDATE": func() Event { return new(UserSettingsUpdateEvent) },
"USER_GUILD_SETTINGS_UPDATE": func() Event { return new(UserGuildSettingsUpdateEvent) },
"USER_NOTE_UPDATE": func() Event { return new(UserNoteUpdateEvent) },
"RELATIONSHIP_ADD": func() Event { return new(RelationshipAddEvent) },
"RELATIONSHIP_REMOVE": func() Event { return new(RelationshipRemoveEvent) },
}

View file

@ -9,19 +9,15 @@ package gateway
import (
"context"
"math/rand"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/diamondburned/arikawa/v3/internal/lazytime"
"github.com/diamondburned/arikawa/v3/utils/ws"
"github.com/pkg/errors"
)
var (
@ -29,26 +25,24 @@ var (
Encoding = "json"
)
var (
ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
ErrWSMaxTries = errors.New(
"could not connect to the Discord gateway before reaching the timeout")
ErrClosed = errors.New("the gateway is closed and cannot reconnect")
)
// CodeInvalidSequence is the code returned by Discord to signal that the given
// sequence number is invalid.
const CodeInvalidSequence = 4007
// see
// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes
const errCodeShardingRequired = 4011
// CodeShardingRequired is the code returned by Discord to signal that the bot
// must reshard before proceeding. For more information, see
// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes.
const CodeShardingRequired = 4011
// URL asks Discord for a Websocket URL to the Gateway.
func URL() (string, error) {
return api.GatewayURL()
func URL(ctx context.Context) (string, error) {
return api.GatewayURL(ctx)
}
// BotURL fetches the Gateway URL along with extra metadata. The token
// passed in will NOT be prefixed with Bot.
func BotURL(token string) (*api.BotData, error) {
return api.NewClient(token).BotURL()
func BotURL(ctx context.Context, token string) (*api.BotData, error) {
return api.NewClient(token).WithContext(ctx).BotURL()
}
// AddGatewayParams appends into the given URL string the gateway URL
@ -62,469 +56,371 @@ func AddGatewayParams(baseURL string) string {
return baseURL + "?" + param.Encode()
}
type Gateway struct {
WS *wsutil.Websocket
// WSTimeout is a timeout for an arbitrary action. An example of this is the
// timeout for Start and the timeout for sending each Gateway command
// independently.
WSTimeout time.Duration
// ReconnectAttempts are the amount of attempts made to Reconnect, before
// aborting. If this set to 0, unlimited attempts will be made.
ReconnectAttempts uint
// All events sent over are pointers to Event structs (structs suffixed with
// "Event"). This shouldn't be accessed if the Gateway is created with a
// Session.
Events chan Event
sessionMu sync.RWMutex
sessionID string
Identifier *Identifier
Sequence *moreatomic.Int64
PacerLoop wsutil.PacemakerLoop
ErrorLog func(err error) // default to log.Println
// FatalErrorCallback is called, if the Gateway exits fatally. At the point
// of calling, the gateway will be already closed.
//
// Currently this will only be called, if the ReconnectTimeout was changed
// to a definite timeout, and connection could not be established during
// that time.
// err will be ErrWSMaxTries in that case.
//
// Defaults to noop.
FatalErrorCallback func(err error)
// AfterClose is called after each close or pause. It is used mainly for
// reconnections or any type of connection interruptions.
//
// Constructors will use a no-op function by default.
AfterClose func(err error)
onShardingRequired func()
waitGroup sync.WaitGroup
closed chan struct{}
// State contains the gateway state. It is a piece of data that can be shared
// across gateways during construction to be used for resuming a connection or
// starting a new one with the previous data.
//
// The data structure itself is not thread-safe, so they may only be pulled from
// the gateway after it's done and set before it's done.
type State struct {
Identifier Identifier
SessionID string
Sequence int64
}
// NewGatewayWithIntents creates a new Gateway with the given intents and the
// default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
g, err := NewGateway(token)
// Gateway describes an instance that handles the Discord gateway. It is
// basically an abstracted concurrent event loop that the user could signal to
// start connecting to the Discord gateway server.
type Gateway struct {
gateway *ws.Gateway
state State
// non-mutex-guarded states
// TODO: make lastBeat part of ws.Gateway so it can keep track of whether or
// not the websocket is dead.
beatMutex sync.Mutex
sentBeat time.Time
echoBeat time.Time
retryTimer lazytime.Timer
}
// NewWithIntents creates a new Gateway with the given intents and the default
// stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
func NewWithIntents(ctx context.Context, token string, intents ...Intents) (*Gateway, error) {
var allIntents Intents
for _, intent := range intents {
allIntents |= intent
}
g, err := New(ctx, token)
if err != nil {
return nil, err
}
for _, intent := range intents {
g.AddIntents(intent)
}
g.AddIntents(allIntents)
return g, nil
}
// NewGateway creates a new Gateway to the default Discord server.
func NewGateway(token string) (*Gateway, error) {
return NewIdentifiedGateway(DefaultIdentifier(token))
// New creates a new Gateway to the default Discord server.
func New(ctx context.Context, token string) (*Gateway, error) {
return NewWithIdentifier(ctx, DefaultIdentifier(token))
}
// NewIdentifiedGateway creates a new Gateway with the given gateway identifier
// and the default everything. Sharded bots should prefer this function for the
// shared identifier.
func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
var gatewayURL string
var botData *api.BotData
var err error
if strings.HasPrefix(id.Token, "Bot ") {
botData, err = BotURL(id.Token)
if err != nil {
return nil, errors.Wrap(err, "failed to get bot data")
}
gatewayURL = botData.URL
} else {
gatewayURL, err = URL()
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway endpoint")
}
// NewWithIdentifier creates a new Gateway with the given gateway identifier and
// the default everything. Sharded bots should prefer this function for the
// shared identifier. The given Identifier will be modified.
func NewWithIdentifier(ctx context.Context, id Identifier) (*Gateway, error) {
gatewayURL, err := id.QueryGateway(ctx)
if err != nil {
return nil, err
}
gatewayURL = AddGatewayParams(gatewayURL)
gateway := NewCustomIdentifiedGateway(gatewayURL, id)
// Use the supplied connect rate limit, if any.
if botData != nil && botData.StartLimit != nil {
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
limiter := gateway.Identifier.IdentifyGlobalLimit
// Update the burst to be the current given time and reset it back to
// the default when the given time is reached.
limiter.SetBurst(botData.StartLimit.Remaining)
limiter.SetBurstAt(resetAt, botData.StartLimit.Total)
// Update the maximum number of identify requests allowed per 5s.
gateway.Identifier.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
}
gateway := NewCustomWithIdentifier(gatewayURL, id, nil)
return gateway, nil
}
// NewCustomGateway creates a new Gateway with a custom gateway URL and a new
// NewCustom creates a new Gateway with a custom gateway URL and a new
// Identifier. Most bots connecting to the official server should not use these
// custom functions.
func NewCustomGateway(gatewayURL, token string) *Gateway {
return NewCustomIdentifiedGateway(gatewayURL, DefaultIdentifier(token))
func NewCustom(gatewayURL, token string) *Gateway {
return NewCustomWithIdentifier(gatewayURL, DefaultIdentifier(token), nil)
}
// NewCustomIdentifiedGateway creates a new Gateway with a custom gateway URL
// and a pre-existing Identifier. Refer to NewCustomGateway.
func NewCustomIdentifiedGateway(gatewayURL string, id *Identifier) *Gateway {
// DefaultGatewayOpts contains the default options to be used for connecting to
// the gateway.
var DefaultGatewayOpts = ws.GatewayOpts{
ReconnectDelay: func(try int) time.Duration {
// minimum 4 seconds
return time.Duration(4+(2*try)) * time.Second
},
// FatalCloseCodes contains the default gateway close codes that will cause
// the gateway to exit. In other words, it's a list of unrecoverable close
// codes.
FatalCloseCodes: []int{
4003, // not authenticated
4004, // authentication failed
4010, // invalid shard sent
4011, // sharding required
4012, // invalid API version
4013, // invalid intents
4014, // disallowed intents
},
DialTimeout: 0,
ReconnectAttempt: 0,
AlwaysCloseGracefully: true,
}
// NewCustomWithIdentifier creates a new Gateway with a custom gateway URL and a
// pre-existing Identifier. If opts is nil, then DefaultGatewayOpts is used.
func NewCustomWithIdentifier(gatewayURL string, id Identifier, opts *ws.GatewayOpts) *Gateway {
return NewFromState(gatewayURL, State{Identifier: id}, opts)
}
// NewFromState creates a new gateway from the given state and optionally
// gateway options. If opts is nil, then DefaultGatewayOpts is used.
func NewFromState(gatewayURL string, state State, opts *ws.GatewayOpts) *Gateway {
if opts == nil {
opts = &DefaultGatewayOpts
}
gw := ws.NewGateway(ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), gatewayURL), opts)
return &Gateway{
WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
WSTimeout: wsutil.WSTimeout,
Events: make(chan Event, wsutil.WSBuffer),
Identifier: id,
Sequence: moreatomic.NewInt64(0),
ErrorLog: wsutil.WSError,
AfterClose: func(error) {},
PacerLoop: wsutil.PacemakerLoop{ErrorLog: wsutil.WSError},
gateway: gw,
state: state,
}
}
// AddIntents adds a Gateway Intent before connecting to the Gateway. As such,
// this function will only work before Open() is called.
// State returns a copy of the gateway's internal state. It panics if the
// gateway is currently running.
func (g *Gateway) State() State {
g.gateway.AssertIsNotRunning()
return g.state
}
// SetState sets the gateway's state.
func (g *Gateway) SetState(state State) {
g.gateway.AssertIsNotRunning()
g.state = state
}
// AddIntents adds a Gateway Intent before connecting to the Gateway. This
// function will only work before Connect() is called. Calling it once Connect()
// is called will result in a panic.
func (g *Gateway) AddIntents(i Intents) {
if g.Identifier.Intents == nil {
g.Identifier.Intents = option.NewUint(uint(i))
} else {
*g.Identifier.Intents |= uint(i)
}
g.gateway.AssertIsNotRunning()
g.state.Identifier.AddIntents(i)
}
// HasIntents reports if the Gateway has the passed Intents.
// SentBeat returns the last time that the heart was beaten. If the gateway has
// never connected, then a zero-value time is returned.
func (g *Gateway) SentBeat() time.Time {
g.beatMutex.Lock()
defer g.beatMutex.Unlock()
return g.sentBeat
}
// EchoBeat returns the last time that the heartbeat was acknowledged. It is
// similar to SentBeat.
func (g *Gateway) EchoBeat() time.Time {
g.beatMutex.Lock()
defer g.beatMutex.Unlock()
return g.echoBeat
}
// Latency is a convenient function around SentBeat and EchoBeat. It subtracts
// the EchoBeat with the SentBeat.
func (g *Gateway) Latency() time.Duration {
g.beatMutex.Lock()
defer g.beatMutex.Unlock()
return g.echoBeat.Sub(g.sentBeat)
}
// LastError returns the last error that the gateway has received. It only
// returns a valid error if the gateway's event loop as exited. If the event
// loop hasn't been started AND stopped, the function will panic.
func (g *Gateway) LastError() error {
return g.gateway.LastError()
}
// Send is a function to send an Op payload to the Gateway.
func (g *Gateway) Send(ctx context.Context, data ws.Event) error {
return g.gateway.Send(ctx, data)
}
// Connect starts the background goroutine that tries its best to maintain a
// stable connection to the Discord gateway. To the user, the gateway should
// appear to be working seamlessly.
//
// If no intents are set, i.e. if using a user account HasIntents will always
// return true.
func (g *Gateway) HasIntents(intents Intents) bool {
if g.Identifier.Intents == nil {
return true
}
return Intents(*g.Identifier.Intents).Has(intents)
}
// Close closes the underlying Websocket connection, invalidating the session
// ID.
// Behaviors
//
// It will send a closing frame before ending the connection, closing it
// gracefully. This will cause the bot to appear as offline instantly.
func (g *Gateway) Close() error {
return g.close(true)
}
// Pause pauses the Gateway connection, by ending the connection without
// sending a closing frame. This allows the connection to be resumed at a later
// point, by calling Reconnect or ReconnectCtx.
func (g *Gateway) Pause() error {
return g.close(false)
}
func (g *Gateway) close(graceful bool) (err error) {
wsutil.WSDebug("Trying to close. Pacemaker check skipped.")
wsutil.WSDebug("Closing the Websocket...")
if graceful {
err = g.WS.CloseGracefully()
} else {
err = g.WS.Close()
}
if errors.Is(err, wsutil.ErrWebsocketClosed) {
wsutil.WSDebug("Websocket already closed.")
return nil
}
// Explicitly signal the pacemaker loop to stop. We should do this in case
// the Start function exited before it could bind the event channel into the
// loop.
g.PacerLoop.Stop()
wsutil.WSDebug("Websocket closed; error:", err)
wsutil.WSDebug("Waiting for the Pacemaker loop to exit.")
g.waitGroup.Wait()
wsutil.WSDebug("Pacemaker loop exited.")
g.AfterClose(err)
wsutil.WSDebug("AfterClose callback finished.")
if graceful {
// If a Reconnect is in progress, signal to cancel.
close(g.closed)
// Delete our session id, as we just invalidated it.
g.sessionMu.Lock()
g.sessionID = ""
g.sessionMu.Unlock()
}
return err
}
// SessionID returns the session ID received after Ready. This function is
// concurrently safe.
func (g *Gateway) SessionID() string {
g.sessionMu.RLock()
defer g.sessionMu.RUnlock()
return g.sessionID
}
// UseSessionID overrides the internal session ID for the one the user provides.
func (g *Gateway) UseSessionID(sessionID string) {
g.sessionMu.Lock()
defer g.sessionMu.Unlock()
g.sessionID = sessionID
}
// OnShardingRequired sets the function to be called if Discord closes with
// error code 4011 aka Sharding Required. When called, the Gateway will already
// be closed, and can (after increasing the number of shards) be reopened using
// Open. Reconnect or ReconnectCtx, however, will not be available as the
// session is invalidated.
// There are several behaviors that the gateway will overload onto the channel.
//
// The gateway will completely halt what it's doing in the background when this
// callback is called.
func (g *Gateway) OnShardingRequired(fn func()) {
g.onShardingRequired = fn
// Once the gateway has exited, fatally or not, the event channel returned by
// Connect will be closed. The user should therefore know whether or not the
// gateway has exited by spinning on the channel until it is closed.
//
// If Connect is called twice, the second call will return the same exact
// channel that the first call has made without starting any new goroutines,
// except if the gateway is already closed, then a new gateway will spin up with
// the existing gateway state.
//
// If the gateway stumbles upon any background errors, it will do its best to
// recover from it, but errors will be notified to the user using the
// BackgroundErrorEvent event. The user can type-assert the Op's data field,
// like so:
//
// switch data := ev.Data.(type) {
// case *gateway.BackgroundErrorEvent:
// log.Println("gateway error:", data.Error)
// }
//
// Closing
//
// As outlined in the first paragraph, closing the gateway would involve
// cancelling the context that's given to gateway. If AlwaysCloseGracefully is
// true (which it is by default), then the gateway is closed gracefully, and the
// session ID is invalidated.
//
// To wait until the gateway has completely successfully exited, the user can
// keep spinning on the event loop:
//
// for op := range ch {
// select op.Data.(type) {
// case *gateway.ReadyEvent:
// // Close the gateway on READY.
// cancel()
// }
// }
//
// // Gateway is now completely closed.
//
// To capture the final close errors, the user can use the Error method once the
// event channel is closed, like so:
//
// var err error
//
// for op := range ch {
// switch data := op.Data.(type) {
// case *gateway.ReadyEvent:
// cancel()
// }
// }
//
// // Gateway is now completely closed.
// if gateway.LastError() != nil {
// return gateway.LastError()
// }
//
func (g *Gateway) Connect(ctx context.Context) <-chan ws.Op {
return g.gateway.Connect(ctx, &gatewayImpl{Gateway: g})
}
// Reconnect tries to reconnect to the Gateway until the ReconnectAttempts are
// reached.
func (g *Gateway) Reconnect() {
g.ReconnectCtx(context.Background())
type gatewayImpl struct {
*Gateway
lastSentBeat time.Time
}
// ReconnectCtx attempts to Reconnect until context expires.
// If the context expires FatalErrorCallback will be called with ErrWSMaxTries,
// and the last error returned by Open will be returned.
func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
wsutil.WSDebug("Reconnecting...")
func (g *gatewayImpl) invalidate() {
g.state.SessionID = ""
g.state.Sequence = 0
}
// Guarantee the gateway is already closed. Ignore its error, as we're
// redialing anyway.
g.Pause()
// sendIdentify sends off the Identify command with the Gateway's IdentifyData
// with the given context for timeout.
func (g *gatewayImpl) sendIdentify(ctx context.Context) error {
if err := g.state.Identifier.Wait(ctx); err != nil {
return errors.Wrap(err, "can't wait for identify()")
}
for try := uint(1); g.ReconnectAttempts == 0 || g.ReconnectAttempts >= try; try++ {
select {
case <-g.closed:
g.ErrorLog(ErrClosed)
return ErrClosed
case <-ctx.Done():
wsutil.WSDebug("Unable to Reconnect after", try, "attempts, aborting")
g.FatalErrorCallback(ErrWSMaxTries)
return err
default:
return g.gateway.Send(ctx, &g.state.Identifier.IdentifyCommand)
}
func (g *gatewayImpl) sendResume(ctx context.Context) error {
return g.gateway.Send(ctx, &ResumeCommand{
Token: g.state.Identifier.Token,
SessionID: g.state.SessionID,
Sequence: g.state.Sequence,
})
}
func (g *gatewayImpl) OnOp(ctx context.Context, op ws.Op) bool {
if op.Code == dispatchOp {
g.state.Sequence = op.Sequence
}
switch data := op.Data.(type) {
case *ws.CloseEvent:
if data.Code == CodeInvalidSequence {
// Invalid sequence.
g.invalidate()
}
wsutil.WSDebug("Trying to dial, attempt", try)
g.gateway.QueueReconnect()
// if we encounter an error, make sure we return it, and not nil
if oerr := g.Open(ctx); oerr != nil {
err = oerr
g.ErrorLog(oerr)
case *HelloEvent:
g.gateway.ResetHeartbeat(data.HeartbeatInterval.Duration())
wait := time.Duration(4+2*try) * time.Second
if wait > 60*time.Second {
wait = 60 * time.Second
// Send Discord either the Identify packet (if it's a fresh
// connection), or a Resume packet (if it's a dead connection).
if g.state.SessionID == "" || g.state.Sequence == 0 {
// SessionID is empty, so this is a completely new session.
if err := g.sendIdentify(ctx); err != nil {
g.gateway.SendErrorWrap(err, "failed to send identify")
g.gateway.QueueReconnect()
}
time.Sleep(wait)
continue
}
wsutil.WSDebug("Started after attempt:", try)
return nil
}
wsutil.WSDebug("Unable to Reconnect after", g.ReconnectAttempts, "attempts, aborting")
return err
}
// Open connects to the Websocket and authenticates it. You should usually use
// this function over Start(). The given context provides cancellation and
// timeout.
func (g *Gateway) Open(ctx context.Context) error {
// Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "failed to Reconnect")
}
wsutil.WSDebug("Trying to start...")
// Try to resume the connection
if err := g.StartCtx(ctx); err != nil {
return err
}
// Started successfully, return
return nil
}
// Start calls StartCtx with a background context. You wouldn't usually use this
// function, but Open() instead.
func (g *Gateway) Start() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.StartCtx(ctx)
}
// StartCtx authenticates with the websocket, or resume from a dead Websocket
// connection. You wouldn't usually use this function, but OpenCtx() instead.
func (g *Gateway) StartCtx(ctx context.Context) error {
g.closed = make(chan struct{})
if err := g.start(ctx); err != nil {
wsutil.WSDebug("Start failed:", err)
// Close can be called with the mutex still acquired here, as the
// pacemaker hasn't started yet.
if err := g.Close(); err != nil {
wsutil.WSDebug("Failed to close after start fail:", err)
}
return err
}
return nil
}
func (g *Gateway) start(ctx context.Context) error {
// This is where we'll get our events
ch := g.WS.Listen()
// Create a new Hello event and wait for it.
var hello HelloEvent
// Wait for an OP 10 Hello.
select {
case e, ok := <-ch:
if !ok {
return errors.New("unexpected ws close while waiting for Hello")
}
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
}
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
}
wsutil.WSDebug("Hello received; duration:", hello.HeartbeatInterval)
// Start the event handler, which also handles the pacemaker death signal.
g.waitGroup.Add(1)
// Use the pacemaker loop.
g.PacerLoop.StartBeating(hello.HeartbeatInterval.Duration(), g, func(err error) {
g.waitGroup.Done() // mark so Close() can exit.
wsutil.WSDebug("Event loop stopped with error:", err)
if err != nil && g.onShardingRequired != nil {
// If Discord signals us sharding is required, do not attempt to
// Reconnect, unless we don't know what to do. Instead invalidate
// our session ID, as we cannot resume, call OnShardingRequired, and
// exit.
var cerr *websocket.CloseError
if errors.As(err, &cerr) && cerr.Code == errCodeShardingRequired {
g.ErrorLog(cerr)
g.UseSessionID("")
g.onShardingRequired()
return
} else {
if err := g.sendResume(ctx); err != nil {
g.gateway.SendErrorWrap(err, "failed to send resume")
g.gateway.QueueReconnect()
}
}
// Bail if there is no error or if the error is an explicit close, as
// there might be an ongoing reconnection.
if err == nil || errors.Is(err, wsutil.ErrWebsocketClosed) {
return
case *InvalidSessionEvent:
// Wipe the session state.
g.invalidate()
if !*data {
g.gateway.QueueReconnect()
break
}
// Only attempt to Reconnect if we have a session ID at all. We may not
// have one if we haven't even connected successfully once.
if g.SessionID() != "" {
g.ErrorLog(err)
g.Reconnect()
// Discord expects us to wait before reconnecting.
g.retryTimer.Reset(time.Duration(rand.Intn(5)+1) * time.Second)
if err := g.retryTimer.Wait(ctx); err != nil {
g.gateway.SendErrorWrap(err, "failed to wait before identifying")
g.gateway.QueueReconnect()
break
}
})
// Send Discord either the Identify packet (if it's a fresh connection), or
// a Resume packet (if it's a dead connection).
if g.SessionID() == "" {
// SessionID is empty, so this is a completely new session.
if err := g.IdentifyCtx(ctx); err != nil {
return errors.Wrap(err, "failed to identify")
}
} else {
if err := g.ResumeCtx(ctx); err != nil {
return errors.Wrap(err, "failed to resume")
// If we fail to identify, then the gateway cannot continue with
// a bad identification, since it's likely a user error.
if err := g.sendIdentify(ctx); err != nil {
g.gateway.SendErrorWrap(err, "failed to identify")
g.gateway.QueueReconnect()
break
}
case *HeartbeatCommand:
g.SendHeartbeat(ctx)
case *HeartbeatAckEvent:
now := time.Now()
g.beatMutex.Lock()
g.sentBeat = g.lastSentBeat
g.echoBeat = now
g.beatMutex.Unlock()
case *ReconnectEvent:
g.gateway.QueueReconnect()
case *ReadyEvent:
g.state.SessionID = data.SessionID
}
// Expect either READY or RESUMED before continuing.
wsutil.WSDebug("Waiting for either READY or RESUMED.")
return true
}
// WaitForEvent should until the bot becomes ready or resumes (if a
// previous ready event has already been called).
err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
switch op.EventName {
case "READY":
wsutil.WSDebug("Found READY event.")
return true
case "RESUMED":
wsutil.WSDebug("Found RESUMED event.")
return true
}
return false
})
// SendHeartbeat sends a heartbeat with the gateway's current sequence.
func (g *gatewayImpl) SendHeartbeat(ctx context.Context) {
g.lastSentBeat = time.Now()
if err != nil {
return errors.Wrap(err, "first error")
sequence := HeartbeatCommand(g.state.Sequence)
if err := g.gateway.Send(ctx, &sequence); err != nil {
g.gateway.SendErrorWrap(err, "heartbeat error")
g.gateway.QueueReconnect()
}
}
// Bind the event channel to the pacemaker loop.
g.PacerLoop.SetEventChannel(ch)
wsutil.WSDebug("Started successfully.")
// Close closes the state.
func (g *gatewayImpl) Close() error {
g.retryTimer.Stop()
g.invalidate()
return nil
}
// SendCtx is a low-level function to send an OP payload to the Gateway. Most
// users shouldn't touch this, unless they know what they're doing.
func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
var op = wsutil.OP{
Code: code,
}
if v != nil {
b, err := json.Marshal(v)
if err != nil {
return errors.Wrap(err, "failed to encode v")
}
op.Data = b
}
b, err := json.Marshal(op)
if err != nil {
return errors.Wrap(err, "failed to encode payload")
}
// WS should already be thread-safe.
return g.WS.SendCtx(ctx, b)
}

View file

@ -0,0 +1,31 @@
package gateway_test
import (
"context"
"log"
"os"
"os/signal"
"github.com/diamondburned/arikawa/v3/gateway"
)
func Example() {
token := os.Getenv("BOT_TOKEN")
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
g, err := gateway.NewWithIntents(ctx, token, gateway.IntentGuilds)
if err != nil {
log.Fatalln("failed to initialize gateway:", err)
}
for op := range g.Connect(ctx) {
switch data := op.Data.(type) {
case *gateway.ReadyEvent:
log.Println("logged in as", data.User.Username)
case *gateway.MessageCreateEvent:
log.Println("got message", data.Content)
}
}
}

197
gateway/gateway_test.go Normal file
View file

@ -0,0 +1,197 @@
package gateway
import (
"context"
"log"
"strings"
"sync"
"testing"
"time"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/utils/ws"
)
var doLogOnce sync.Once
func doLog() {
doLogOnce.Do(func() {
if testing.Verbose() {
ws.WSDebug = func(v ...interface{}) {
log.Println(append([]interface{}{"Debug:"}, v...)...)
}
}
})
}
func TestURL(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
t.Cleanup(cancel)
u, err := URL(ctx)
if err != nil {
t.Fatal("failed to get gateway URL:", err)
}
if u == "" {
t.Fatal("gateway URL is empty")
}
if !strings.HasPrefix(u, "wss://") {
t.Fatal("gatewayURL is invalid:", u)
}
}
func TestInvalidToken(t *testing.T) {
doLog()
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
t.Cleanup(cancel)
g, err := New(ctx, "bad token")
if err != nil {
t.Fatal("failed to make a Gateway:", err)
}
assertIsClose := func(err error) {
if err == nil {
t.Fatal("unexpected nil error")
}
// 4004 Authentication Failed.
if !strings.Contains(err.Error(), "4004") {
t.Fatal("unexpected error:", err)
}
}
for op := range g.Connect(ctx) {
if op.Data == nil {
// This shouldn't happen; the loop should've broken out.
t.Fatal("nil event received")
}
switch data := op.Data.(type) {
case *ws.CloseEvent:
assertIsClose(data)
case *ws.BackgroundErrorEvent:
t.Error("gateway error:", data)
case *HelloEvent:
t.Log("got Hello")
case *InvalidSessionEvent:
t.Log("got InvalidSession")
default:
t.Errorf("got unexpected event %#v", data)
}
}
assertIsClose(g.LastError())
}
func TestIntegration(t *testing.T) {
doLog()
config := testenv.Must(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
t.Cleanup(cancel)
// NewGateway should call Start for us.
g, err := NewWithIntents(ctx, "Bot "+config.BotToken, IntentGuilds)
if err != nil {
t.Fatal("failed to make a Gateway:", err)
}
gatewayOpenAndSpin(t, ctx, g)
cancel()
}
func TestReuseGateway(t *testing.T) {
doLog()
config := testenv.Must(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
t.Cleanup(cancel)
// NewGateway should call Start for us.
g, err := NewWithIntents(ctx, "Bot "+config.BotToken, IntentGuilds)
if err != nil {
t.Fatal("failed to make a Gateway:", err)
}
// Reuse this 3 times.
for i := 0; i < 3; i++ {
cctx, cancel := context.WithCancel(ctx)
gatewayOpenAndSpin(t, cctx, g)
cancel()
}
}
func gatewayOpenAndSpin(t *testing.T, ctx context.Context, g *Gateway) {
ch := g.Connect(ctx)
var reconnected bool
reconnect := func() {
if !reconnected {
reconnected = true
g.gateway.QueueReconnect()
}
}
for op := range ch {
if op.Data == nil {
// This shouldn't happen; the loop should've broken out.
t.Fatal("nil event received")
}
switch data := op.Data.(type) {
case *ReadyEvent:
t.Log("got Ready")
if g.state.SessionID != data.SessionID {
t.Fatal("missing SessionID")
}
log.Println("Bot's username is", data.User.Username)
reconnect()
case *ResumedEvent:
t.Log("got Resumed, test done")
return
case *HelloEvent:
t.Log("got Hello")
case *ws.BackgroundErrorEvent:
t.Error("gateway error:", data)
default:
t.Logf("got event %T", data)
}
}
}
func wait(t *testing.T, evCh chan interface{}) interface{} {
select {
case ev := <-evCh:
return ev
case <-time.After(20 * time.Second):
t.Fatal("timed out waiting for event")
return nil
}
}
func gotimeout(t *testing.T, fn func(context.Context)) {
t.Helper()
// Try and reconnect for 20 seconds maximum.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
var done = make(chan struct{})
go func() {
fn(ctx)
done <- struct{}{}
}()
select {
case <-ctx.Done():
t.Fatal("timed out waiting for function.")
case <-done:
return
}
}

View file

@ -3,8 +3,10 @@ package gateway
import (
"context"
"runtime"
"strings"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/pkg/errors"
@ -13,27 +15,27 @@ import (
// DefaultPresence is used as the default presence when initializing a new
// Gateway.
var DefaultPresence *UpdateStatusData
var DefaultPresence *UpdatePresenceCommand
// Identifier is a wrapper around IdentifyData to add in appropriate rate
// Identifier is a wrapper around IdentifyCommand to add in appropriate rate
// limiters.
type Identifier struct {
IdentifyData
IdentifyCommand
IdentifyShortLimit *rate.Limiter `json:"-"` // optional
IdentifyGlobalLimit *rate.Limiter `json:"-"` // optional
}
// DefaultIdentifier creates a new default Identifier
func DefaultIdentifier(token string) *Identifier {
return NewIdentifier(DefaultIdentifyData(token))
func DefaultIdentifier(token string) Identifier {
return NewIdentifier(DefaultIdentifyCommand(token))
}
// NewIdentifier creates a new identifier with the given IdentifyData and
// NewIdentifier creates a new identifier with the given IdentifyCommand and
// default rate limiters.
func NewIdentifier(data IdentifyData) *Identifier {
return &Identifier{
IdentifyData: data,
func NewIdentifier(data IdentifyCommand) Identifier {
return Identifier{
IdentifyCommand: data,
IdentifyShortLimit: rate.NewLimiter(rate.Every(5*time.Second), 1),
IdentifyGlobalLimit: rate.NewLimiter(rate.Every(24*time.Hour), 1000),
}
@ -41,15 +43,15 @@ func NewIdentifier(data IdentifyData) *Identifier {
// Wait waits for the rate limiters to pass. If a limiter is nil, then it will
// not be used to wait. This is useful
func (i *Identifier) Wait(ctx context.Context) error {
if i.IdentifyShortLimit != nil {
if err := i.IdentifyShortLimit.Wait(ctx); err != nil {
func (id *Identifier) Wait(ctx context.Context) error {
if id.IdentifyShortLimit != nil {
if err := id.IdentifyShortLimit.Wait(ctx); err != nil {
return errors.Wrap(err, "can't wait for short limit")
}
}
if i.IdentifyGlobalLimit != nil {
if err := i.IdentifyGlobalLimit.Wait(ctx); err != nil {
if id.IdentifyGlobalLimit != nil {
if err := id.IdentifyGlobalLimit.Wait(ctx); err != nil {
return errors.Wrap(err, "can't wait for global limit")
}
}
@ -57,6 +59,41 @@ func (i *Identifier) Wait(ctx context.Context) error {
return nil
}
// QueryGateway queries the gateway for the URL and updates the Identifier with
// the appropriate information.
func (id *Identifier) QueryGateway(ctx context.Context) (gatewayURL string, err error) {
var botData *api.BotData
if strings.HasPrefix(id.Token, "Bot ") {
botData, err = BotURL(ctx, id.Token)
if err != nil {
return "", errors.Wrap(err, "failed to get bot data")
}
gatewayURL = botData.URL
} else {
gatewayURL, err = URL(ctx)
if err != nil {
return "", errors.Wrap(err, "failed to get gateway endpoint")
}
}
// Use the supplied connect rate limit, if any.
if botData != nil && botData.StartLimit != nil {
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
limiter := id.IdentifyGlobalLimit
// Update the burst to be the current given time and reset it back to
// the default when the given time is reached.
limiter.SetBurst(botData.StartLimit.Remaining)
limiter.SetBurstAt(resetAt, botData.StartLimit.Total)
// Update the maximum number of identify requests allowed per 5s.
id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
}
return
}
// DefaultIdentity is used as the default identity when initializing a new
// Gateway.
var DefaultIdentity = IdentifyProperties{
@ -65,9 +102,9 @@ var DefaultIdentity = IdentifyProperties{
Device: "Arikawa",
}
// IdentifyData is the struct for a data that's sent over in an Identify
// command.
type IdentifyData struct {
// IdentifyCommand is a command for Op 2. It is the struct for a data that's
// sent over in an Identify command.
type IdentifyCommand struct {
Token string `json:"token"`
Properties IdentifyProperties `json:"properties"`
@ -76,7 +113,7 @@ type IdentifyData struct {
Shard *Shard `json:"shard,omitempty"` // [ shard_id, num_shards ]
Presence *UpdateStatusData `json:"presence,omitempty"`
Presence *UpdatePresenceCommand `json:"presence,omitempty"`
// ClientState is the client state for a user's accuont. Bot accounts should
// NOT touch this field.
@ -97,9 +134,9 @@ type IdentifyData struct {
Intents option.Uint `json:"intents"`
}
// DefaultIdentifyData creates a default IdentifyData with the given token.
func DefaultIdentifyData(token string) IdentifyData {
return IdentifyData{
// DefaultIdentifyCommand creates a default IdentifyCommand with the given token.
func DefaultIdentifyCommand(token string) IdentifyCommand {
return IdentifyCommand{
Token: token,
Properties: DefaultIdentity,
Presence: DefaultPresence,
@ -110,14 +147,35 @@ func DefaultIdentifyData(token string) IdentifyData {
}
// SetShard is a helper function to set the shard configuration inside
// IdentifyData.
func (i *IdentifyData) SetShard(id, num int) {
// IdentifyCommand.
func (i *IdentifyCommand) SetShard(id, num int) {
if i.Shard == nil {
i.Shard = new(Shard)
}
i.Shard[0], i.Shard[1] = id, num
}
// AddIntents adds gateway intents into the identify data.
func (i *IdentifyCommand) AddIntents(intents Intents) {
if i.Intents == nil {
i.Intents = option.NewUint(uint(intents))
} else {
*i.Intents |= uint(intents)
}
}
// HasIntents reports if the Gateway has the passed Intents.
//
// If no intents are set, e.g. if using a user account, HasIntents will always
// return true.
func (i *IdentifyCommand) HasIntents(intents Intents) bool {
if i.Intents == nil {
return true
}
return Intents(*i.Intents).Has(intents)
}
type IdentifyProperties struct {
// Required
OS string `json:"os"` // GOOS

View file

@ -1,160 +0,0 @@
package gateway
import (
"context"
"log"
"strings"
"testing"
"time"
"github.com/diamondburned/arikawa/v3/internal/heart"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
)
func init() {
wsutil.WSDebug = func(v ...interface{}) {
log.Println(append([]interface{}{"Debug:"}, v...)...)
}
heart.Debug = func(v ...interface{}) {
log.Println(append([]interface{}{"Heart:"}, v...)...)
}
}
func TestURL(t *testing.T) {
u, err := URL()
if err != nil {
t.Fatal("failed to get gateway URL:", err)
}
if u == "" {
t.Fatal("gateway URL is empty")
}
if !strings.HasPrefix(u, "wss://") {
t.Fatal("gatewayURL is invalid:", u)
}
}
func TestInvalidToken(t *testing.T) {
g, err := NewGateway("bad token")
if err != nil {
t.Fatal("failed to make a Gateway:", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err = g.Open(ctx); err == nil {
t.Fatal("unexpected success while opening with a bad token.")
}
// 4004 Authentication Failed.
if !strings.Contains(err.Error(), "4004") {
t.Fatal("unexpected error:", err)
}
}
func TestIntegration(t *testing.T) {
config := testenv.Must(t)
var gateway *Gateway
// NewGateway should call Start for us.
g, err := NewGateway("Bot " + config.BotToken)
if err != nil {
t.Fatal("failed to make a Gateway:", err)
}
g.AddIntents(IntentGuilds)
g.AfterClose = func(err error) {
t.Log("closed.")
}
g.ErrorLog = func(err error) {
t.Log("gateway error:", err)
}
gateway = g
gotimeout(t, func(ctx context.Context) {
if err := g.Open(ctx); err != nil {
t.Fatal("failed to authenticate with Discord:", err)
}
})
ev := wait(t, gateway.Events)
ready, ok := ev.(*ReadyEvent)
if !ok {
t.Fatal("event received is not of type Ready:", ev)
}
if gateway.SessionID() == "" {
t.Fatal("session ID is empty")
}
log.Println("Bot's username is", ready.User.Username)
// Send a faster heartbeat every second for testing.
g.PacerLoop.SetPace(time.Second)
// Sleep past the rate limiter before reconnecting:
time.Sleep(5 * time.Second)
gotimeout(t, func(ctx context.Context) {
g.ErrorLog = func(err error) {
t.Error("unexpected error while reconnecting:", err)
}
if err := gateway.ReconnectCtx(ctx); err != nil {
t.Error("failed to reconnect Gateway:", err)
}
})
g.ErrorLog = func(err error) { t.Log("warning:", err) }
// Wait for the desired event:
gotimeout(t, func(context.Context) {
for ev := range gateway.Events {
switch ev.(type) {
// Accept only a Resumed event.
case *ResumedEvent:
return // exit
case *ReadyEvent:
t.Fatal("Ready event received instead of Resumed.")
}
}
})
if err := g.Close(); err != nil {
t.Fatal("failed to close Gateway:", err)
}
}
func wait(t *testing.T, evCh chan interface{}) interface{} {
select {
case ev := <-evCh:
return ev
case <-time.After(20 * time.Second):
t.Fatal("timed out waiting for event")
return nil
}
}
func gotimeout(t *testing.T, fn func(context.Context)) {
t.Helper()
// Try and reconnect for 20 seconds maximum.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
var done = make(chan struct{})
go func() {
fn(ctx)
done <- struct{}{}
}()
select {
case <-ctx.Done():
t.Fatal("timed out waiting for function.")
case <-done:
return
}
}

View file

@ -1,6 +1,9 @@
package gateway
import "github.com/diamondburned/arikawa/v3/discord"
import (
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/utils/ws"
)
// Intents for the new Discord API feature, documented at
// https://discord.com/developers/docs/topics/gateway#gateway-intents.
@ -44,7 +47,7 @@ func (i Intents) IsPrivileged() (presences, member bool) {
}
// EventIntents maps event types to intents.
var EventIntents = map[string]Intents{
var EventIntents = map[ws.EventType]Intents{
"GUILD_CREATE": IntentGuilds,
"GUILD_UPDATE": IntentGuilds,
"GUILD_DELETE": IntentGuilds,

View file

@ -1,118 +0,0 @@
package gateway
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/pkg/errors"
)
type OPCode = wsutil.OPCode
const (
DispatchOP OPCode = 0 // recv
HeartbeatOP OPCode = 1 // send/recv
IdentifyOP OPCode = 2 // send...
StatusUpdateOP OPCode = 3 //
VoiceStateUpdateOP OPCode = 4 //
VoiceServerPingOP OPCode = 5 //
ResumeOP OPCode = 6 //
ReconnectOP OPCode = 7 // recv
RequestGuildMembersOP OPCode = 8 // send
InvalidSessionOP OPCode = 9 // recv...
HelloOP OPCode = 10
HeartbeatAckOP OPCode = 11
CallConnectOP OPCode = 13
GuildSubscriptionsOP OPCode = 14
)
// ErrReconnectRequest is returned by HandleOP if a ReconnectOP is given. This
// is used mostly internally to signal the heartbeat loop to reconnect, if
// needed. It is not a fatal error.
var ErrReconnectRequest = errors.New("ReconnectOP received")
func (g *Gateway) HandleOP(op *wsutil.OP) error {
switch op.Code {
case HeartbeatAckOP:
// Heartbeat from the server?
g.PacerLoop.Echo()
case HeartbeatOP:
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Server requesting a heartbeat.
if err := g.PacerLoop.Pace(ctx); err != nil {
return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace"))
}
case ReconnectOP:
// Server requests to Reconnect, die and retry.
wsutil.WSDebug("ReconnectOP received.")
// Exit with the ReconnectOP error to force the heartbeat event loop to
// Reconnect synchronously. Not really a fatal error.
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
case InvalidSessionOP:
// Discord expects us to sleep for no reason
time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Invalid session, try and Identify.
if err := g.IdentifyCtx(ctx); err != nil {
// Can't identify, Reconnect.
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
}
return nil
case HelloOP:
return nil
case DispatchOP:
// Set the sequence
if op.Sequence > 0 {
g.Sequence.Set(op.Sequence)
}
// Check if we know the event
fn, ok := EventCreator[op.EventName]
if !ok {
return &wsutil.UnknownEventError{
Name: op.EventName,
Data: op.Data,
}
}
// Make a new pointer to the event
var ev = fn()
// Try and parse the event
if err := json.Unmarshal(op.Data, ev); err != nil {
return errors.Wrap(err, "failed to parse event "+op.EventName)
}
// If the event is a ready, we'll want its sessionID
if ev, ok := ev.(*ReadyEvent); ok {
g.sessionMu.Lock()
g.sessionID = ev.SessionID
g.sessionMu.Unlock()
}
// Throw the event into a channel; it's valid now.
g.Events <- ev
return nil
default:
return fmt.Errorf("unknown OP code %d (event %s)", op.Code, op.EventName)
}
return nil
}

View file

@ -1,56 +0,0 @@
// +build perseverance
package gateway
import (
"testing"
"time"
"github.com/diamondburned/arikawa/v3/internal/testenv"
)
func TestPerseverance(t *testing.T) {
t.Parallel()
config := testenv.Must(t)
g, err := NewGateway("Bot " + config.BotToken)
if err != nil {
t.Fatal("failed to make the gateway:", err)
}
g.AddIntents(IntentGuilds)
if err := g.Open(); err != nil {
t.Fatal("failed to open the gateway:", err)
}
timeout := make(chan struct{}, 1)
// Automatically close the gateway after set duration.
time.AfterFunc(testenv.PerseveranceTime, func() {
t.Log("Perserverence test finshed. Closing gateway.")
timeout <- struct{}{}
if err := g.Close(); err != nil {
t.Error("failed to close gateway:", err)
}
})
// Spin on events.
for ev := range g.Events {
t.Logf("Received event %T.", ev)
}
// Exit gracefully if we have not.
select {
case <-timeout:
return
default:
}
if err := g.Close(); err != nil {
t.Fatal("failed to clean up gateway after fail:", err)
}
t.Fatal("Test failed before timeout.")
}

View file

@ -1,267 +0,0 @@
package gateway
import (
"strconv"
"strings"
"github.com/diamondburned/arikawa/v3/discord"
)
type (
// ReadyEvent is the struct for a READY event.
ReadyEvent struct {
Version int `json:"version"`
User discord.User `json:"user"`
SessionID string `json:"session_id"`
PrivateChannels []discord.Channel `json:"private_channels"`
Guilds []GuildCreateEvent `json:"guilds"`
Shard *Shard `json:"shard,omitempty"`
// Undocumented fields
UserSettings *UserSettings `json:"user_settings,omitempty"`
ReadStates []ReadState `json:"read_state,omitempty"`
UserGuildSettings []UserGuildSetting `json:"user_guild_settings,omitempty"`
Relationships []discord.Relationship `json:"relationships,omitempty"`
Presences []discord.Presence `json:"presences,omitempty"`
FriendSuggestionCount int `json:"friend_suggestion_count,omitempty"`
GeoOrderedRTCRegions []string `json:"geo_ordered_rtc_regions,omitempty"`
}
// ReadState is a single ReadState entry. It is undocumented.
ReadState struct {
ChannelID discord.ChannelID `json:"id"`
LastMessageID discord.MessageID `json:"last_message_id"`
LastPinTimestamp discord.Timestamp `json:"last_pin_timestamp"`
MentionCount int `json:"mention_count"`
}
// UserSettings is the struct for (almost) all user settings. It is
// undocumented.
UserSettings struct {
ShowCurrentGame bool `json:"show_current_game"`
DefaultGuildsRestricted bool `json:"default_guilds_restricted"`
InlineAttachmentMedia bool `json:"inline_attachment_media"`
InlineEmbedMedia bool `json:"inline_embed_media"`
GIFAutoPlay bool `json:"gif_auto_play"`
RenderEmbeds bool `json:"render_embeds"`
RenderReactions bool `json:"render_reactions"`
AnimateEmoji bool `json:"animate_emoji"`
AnimateStickers int `json:"animate_stickers"`
EnableTTSCommand bool `json:"enable_tts_command"`
MessageDisplayCompact bool `json:"message_display_compact"`
ConvertEmoticons bool `json:"convert_emoticons"`
ExplicitContentFilter uint8 `json:"explicit_content_filter"` // ???
DisableGamesTab bool `json:"disable_games_tab"`
DeveloperMode bool `json:"developer_mode"`
DetectPlatformAccounts bool `json:"detect_platform_accounts"`
StreamNotification bool `json:"stream_notification_enabled"`
AccessibilityDetection bool `json:"allow_accessibility_detection"`
ContactSync bool `json:"contact_sync_enabled"`
NativePhoneIntegration bool `json:"native_phone_integration_enabled"`
TimezoneOffset int `json:"timezone_offset"`
Locale string `json:"locale"`
Theme string `json:"theme"`
GuildPositions []discord.GuildID `json:"guild_positions"`
GuildFolders []GuildFolder `json:"guild_folders"`
RestrictedGuilds []discord.GuildID `json:"restricted_guilds"`
FriendSourceFlags FriendSourceFlags `json:"friend_source_flags"`
Status discord.Status `json:"status"`
CustomStatus *CustomUserStatus `json:"custom_status"`
}
// CustomUserStatus is the custom user status that allows setting an emoji
// and a piece of text on each user.
CustomUserStatus struct {
Text string `json:"text"`
ExpiresAt discord.Timestamp `json:"expires_at,omitempty"`
EmojiID discord.EmojiID `json:"emoji_id,string"`
EmojiName string `json:"emoji_name"`
}
// UserGuildSetting stores the settings for a single guild. It is
// undocumented.
UserGuildSetting struct {
GuildID discord.GuildID `json:"guild_id"`
SuppressRoles bool `json:"suppress_roles"`
SuppressEveryone bool `json:"suppress_everyone"`
Muted bool `json:"muted"`
MuteConfig *UserMuteConfig `json:"mute_config"`
MobilePush bool `json:"mobile_push"`
Notifications UserNotification `json:"message_notifications"`
ChannelOverrides []UserChannelOverride `json:"channel_overrides"`
}
// A UserChannelOverride struct describes a channel settings override for a
// users guild settings.
UserChannelOverride struct {
Muted bool `json:"muted"`
MuteConfig *UserMuteConfig `json:"mute_config"`
Notifications UserNotification `json:"message_notifications"`
ChannelID discord.ChannelID `json:"channel_id"`
}
// UserMuteConfig seems to describe the mute settings. It belongs to the
// UserGuildSettingEntry and UserChannelOverride structs and is
// undocumented.
UserMuteConfig struct {
SelectedTimeWindow int `json:"selected_time_window"`
EndTime discord.Timestamp `json:"end_time"`
}
// GuildFolder holds a single folder that you see in the left guild panel.
GuildFolder struct {
Name string `json:"name"`
ID GuildFolderID `json:"id"`
GuildIDs []discord.GuildID `json:"guild_ids"`
Color discord.Color `json:"color"`
}
// FriendSourceFlags describes sources that friend requests could be sent
// from. It belongs to the UserSettings struct and is undocumented.
FriendSourceFlags struct {
All bool `json:"all,omitempty"`
MutualGuilds bool `json:"mutual_guilds,omitempty"`
MutualFriends bool `json:"mutual_friends,omitempty"`
}
)
// UserNotification is the notification setting for a channel or guild.
type UserNotification uint8
const (
AllNotifications UserNotification = iota
OnlyMentions
NoNotifications
GuildDefaults
)
// GuildFolderID is possibly a snowflake. It can also be 0 (null) or a low
// number of unknown significance.
type GuildFolderID int64
func (g *GuildFolderID) UnmarshalJSON(b []byte) error {
var body = string(b)
if body == "null" {
return nil
}
body = strings.Trim(body, `"`)
u, err := strconv.ParseInt(body, 10, 64)
if err != nil {
return err
}
*g = GuildFolderID(u)
return nil
}
func (g GuildFolderID) MarshalJSON() ([]byte, error) {
if g == 0 {
return []byte("null"), nil
}
return []byte(strconv.FormatInt(int64(g), 10)), nil
}
// ReadySupplemental event structs. For now, this event is never used, and its
// usage have yet been discovered.
type (
// ReadySupplementalEvent is the struct for a READY_SUPPLEMENTAL event,
// which is an undocumented event.
ReadySupplementalEvent struct {
Guilds []GuildCreateEvent `json:"guilds"` // only have ID and VoiceStates
MergedMembers [][]SupplementalMember `json:"merged_members"`
MergedPresences MergedPresences `json:"merged_presences"`
}
// SupplementalMember is the struct for a member in the MergedMembers field
// of ReadySupplementalEvent. It has slight differences to discord.Member.
SupplementalMember struct {
UserID discord.UserID `json:"user_id"`
Nick string `json:"nick,omitempty"`
RoleIDs []discord.RoleID `json:"roles"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
IsPending bool `json:"pending,omitempty"`
HoistedRole discord.RoleID `json:"hoisted_role"`
Mute bool `json:"mute"`
Deaf bool `json:"deaf"`
// Joined specifies when the user joined the guild.
Joined discord.Timestamp `json:"joined_at"`
// BoostedSince specifies when the user started boosting the guild.
BoostedSince discord.Timestamp `json:"premium_since,omitempty"`
}
// MergedPresences is the struct for presences of guilds' members and
// friends. It is undocumented.
MergedPresences struct {
Guilds [][]SupplementalPresence `json:"guilds"`
Friends []SupplementalPresence `json:"friends"`
}
// SupplementalPresence is a single presence for either a guild member or
// friend. It is used in MergedPresences and is undocumented.
SupplementalPresence struct {
UserID discord.UserID `json:"user_id"`
// Status is either "idle", "dnd", "online", or "offline".
Status discord.Status `json:"status"`
// Activities are the user's current activities.
Activities []discord.Activity `json:"activities"`
// ClientStaus is the user's platform-dependent status.
ClientStatus discord.ClientStatus `json:"client_status"`
// LastModified is only present in Friends.
LastModified discord.UnixMsTimestamp `json:"last_modified,omitempty"`
}
)
// ConvertSupplementalMembers converts a SupplementalMember to a regular Member.
func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member {
members := make([]discord.Member, len(sms))
for i, sm := range sms {
members[i] = discord.Member{
User: discord.User{ID: sm.UserID},
Nick: sm.Nick,
RoleIDs: sm.RoleIDs,
Joined: sm.Joined,
BoostedSince: sm.BoostedSince,
Deaf: sm.Deaf,
Mute: sm.Mute,
IsPending: sm.IsPending,
}
}
return members
}
// ConvertSupplementalPresences converts a SupplementalPresence to a regular
// Presence with an empty GuildID.
func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence {
presences := make([]discord.Presence, len(sps))
for i, sp := range sps {
presences[i] = discord.Presence{
User: discord.User{ID: sp.UserID},
Status: sp.Status,
Activities: sp.Activities,
ClientStatus: sp.ClientStatus,
}
}
return presences
}

View file

@ -1,60 +0,0 @@
// Package handleloop provides clean abstractions to handle listening to
// channels and passing them onto event handlers.
package handleloop
import "github.com/diamondburned/arikawa/v3/utils/handler"
// Loop provides a reusable event looper abstraction. It is thread-safe to use
// concurrently.
type Loop struct {
dst *handler.Handler
run chan struct{}
stop chan struct{}
}
func NewLoop(dst *handler.Handler) *Loop {
return &Loop{
dst: dst,
run: make(chan struct{}, 1), // intentional 1 buffer
stop: make(chan struct{}), // intentional unbuffer
}
}
// Start starts a new event loop. It will try to stop existing loops before.
func (l *Loop) Start(src <-chan interface{}) {
// Ensure we're stopped.
l.Stop()
// Mark that we're running.
l.run <- struct{}{}
go func() {
for {
select {
case event := <-src:
l.dst.Call(event)
case <-l.stop:
l.stop <- struct{}{}
return
}
}
}()
}
// Stop tries to stop the Loop. If the Loop is not running, then it does
// nothing; thus, it can be called multiple times.
func (l *Loop) Stop() {
// Ensure that we are running before stopping.
select {
case <-l.run:
// running
default:
return
}
// send a close request
l.stop <- struct{}{}
// wait for a reply
<-l.stop
}

View file

@ -1,135 +0,0 @@
// Package heart implements a general purpose pacemaker.
package heart
import (
"context"
"sync/atomic"
"time"
"github.com/pkg/errors"
)
// Debug is the default logger that Pacemaker uses.
var Debug = func(v ...interface{}) {}
var ErrDead = errors.New("no heartbeat replied")
// AtomicTime is a thread-safe UnixNano timestamp guarded by atomic.
type AtomicTime struct {
unixnano int64
}
func (t *AtomicTime) Get() int64 {
return atomic.LoadInt64(&t.unixnano)
}
func (t *AtomicTime) Set(time time.Time) {
atomic.StoreInt64(&t.unixnano, time.UnixNano())
}
func (t *AtomicTime) Time() time.Time {
return time.Unix(0, t.Get())
}
// AtomicDuration is a thread-safe Duration guarded by atomic.
type AtomicDuration struct {
duration int64
}
func (d *AtomicDuration) Get() time.Duration {
return time.Duration(atomic.LoadInt64(&d.duration))
}
func (d *AtomicDuration) Set(dura time.Duration) {
atomic.StoreInt64(&d.duration, int64(dura))
}
// Pacemaker is the internal pacemaker state. All fields are not thread-safe
// unless they're atomic.
type Pacemaker struct {
// Heartrate is the received duration between heartbeats.
Heartrate AtomicDuration
ticker time.Ticker
Ticks <-chan time.Time
// Time in nanoseconds, guarded by atomic read/writes.
SentBeat AtomicTime
EchoBeat AtomicTime
// Any callback that returns an error will stop the pacer.
Pacer func(context.Context) error
}
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker {
p := Pacemaker{
Heartrate: AtomicDuration{int64(heartrate)},
Pacer: pacer,
ticker: *time.NewTicker(heartrate),
}
p.Ticks = p.ticker.C
// Reset states to its old position.
now := time.Now()
p.EchoBeat.Set(now)
p.SentBeat.Set(now)
return p
}
func (p *Pacemaker) Echo() {
// Swap our received heartbeats
p.EchoBeat.Set(time.Now())
}
// Dead, if true, will have Pace return an ErrDead.
func (p *Pacemaker) Dead() bool {
var (
echo = p.EchoBeat.Get()
sent = p.SentBeat.Get()
)
if echo == 0 || sent == 0 {
return false
}
return sent-echo > int64(p.Heartrate.Get())*2
}
// SetHeartRate sets the ticker's heart rate.
func (p *Pacemaker) SetPace(heartrate time.Duration) {
p.Heartrate.Set(heartrate)
// To uncomment when 1.16 releases and we drop support for 1.14.
// p.ticker.Reset(heartrate)
p.ticker.Stop()
p.ticker = *time.NewTicker(heartrate)
p.Ticks = p.ticker.C
}
// Stop stops the pacemaker, or it does nothing if the pacemaker is not started.
func (p *Pacemaker) StopTicker() {
p.ticker.Stop()
}
// pace sends a heartbeat with the appropriate timeout for the context.
func (p *Pacemaker) Pace() error {
ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate.Get())
defer cancel()
return p.PaceCtx(ctx)
}
func (p *Pacemaker) PaceCtx(ctx context.Context) error {
if err := p.Pacer(ctx); err != nil {
return err
}
p.SentBeat.Set(time.Now())
if p.Dead() {
return ErrDead
}
return nil
}

View file

@ -0,0 +1,30 @@
package lazytime
import "time"
type Ticker struct {
C <-chan time.Time
ticker *time.Ticker
}
// Reset resets the ticker. If this is the first time calling, then a new timer
// is created.
func (t *Ticker) Reset(d time.Duration) {
if t.ticker == nil {
t.ticker = time.NewTicker(d)
t.C = t.ticker.C
} else {
t.ticker.Reset(d)
}
}
// Stop stops the ticker. If the ticker has never been used, then it does
// nothing.
func (t *Ticker) Stop() {
if t.ticker == nil {
return
}
t.ticker.Stop()
}

View file

@ -0,0 +1,50 @@
package lazytime
import (
"context"
"time"
)
type Timer struct {
C <-chan time.Time
timer *time.Timer
}
// Reset resets the timer by draining it and resetting the internal channel. If
// this is the first time calling, then a new timer is created.
func (t *Timer) Reset(d time.Duration) {
if t.timer == nil {
t.timer = time.NewTimer(d)
t.C = t.timer.C
return
}
t.Stop()
t.timer.Reset(d)
}
// Stop stops the timer and drains it. If the timer has never been used, then it
// does nothing.
func (t *Timer) Stop() {
if t.timer == nil {
return
}
if !t.timer.Stop() {
select {
case <-t.timer.C:
default:
}
}
}
// Wait blocks until the timer fires or until the context expires.
func (t *Timer) Wait(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return nil
}
}

View file

@ -17,3 +17,17 @@ func (b *Bool) Set(val bool) {
}
atomic.StoreUint32(&b.val, x)
}
func (b *Bool) SetTrue() {
atomic.StoreUint32(&b.val, 1)
}
func (b *Bool) SetFalse() {
atomic.StoreUint32(&b.val, 0)
}
// Acquire sets bool to true if it's false and returns true, otherwise returns
// false.
func (b *Bool) Acquire() bool {
return atomic.CompareAndSwapUint32(&b.val, 0, 1)
}

View file

@ -5,89 +5,64 @@ package session
import (
"context"
"sync"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/handleloop"
"github.com/diamondburned/arikawa/v3/utils/handler"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/ws/ophandler"
)
// ErrMFA is returned if the account requires a 2FA code to log in.
var ErrMFA = errors.New("account has 2FA enabled")
// Closed is an event that's sent to Session's command handler. This works by
// using (*Gateway).AfterClose. If the user sets this callback, no Closed events
// would be sent.
//
// Usage
//
// ses.AddHandler(func(*session.Closed) {})
//
type Closed struct {
Error error
}
// NewShardFunc creates a shard constructor for a session.
func NewShardFunc(f func(m *shard.Manager, s *Session)) shard.NewShardFunc {
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
s := NewCustomShard(m, id)
if f != nil {
f(m, s)
}
return s, nil
}
}
// NewCustomShard creates a new session from the given shard manager and other
// parameters.
func NewCustomShard(m *shard.Manager, id *gateway.Identifier) *Session {
return NewCustomSession(
shard.NewGatewayShard(m, id),
api.NewClient(id.Token),
handler.New(),
)
}
// Session manages both the API and Gateway. As such, Session inherits all of
// API's methods, as well has the Handler used for Gateway.
type Session struct {
*api.Client
*gateway.Gateway
// Command handler with inherited methods.
*handler.Handler
// internal state to not be copied around.
looper *handleloop.Loop
state *sessionState
}
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
g, err := gateway.NewGatewayWithIntents(token, intents...)
if err != nil {
return nil, errors.Wrap(err, "failed to connect to Gateway")
type sessionState struct {
sync.Mutex
id gateway.Identifier
gateway *gateway.Gateway
ctx context.Context
cancel context.CancelFunc
doneCh <-chan struct{}
}
// NewWithIntents is similar to New but adds the given intents in during
// construction.
func NewWithIntents(token string, intents ...gateway.Intents) *Session {
var allIntent gateway.Intents
for _, intent := range intents {
allIntent |= intent
}
return NewWithGateway(g), nil
id := gateway.DefaultIdentifier(token)
id.Intents = option.NewUint(uint(allIntent))
return NewWithIdentifier(id)
}
// New creates a new session from a given token. Most bots should be using
// NewWithIntents instead.
func New(token string) (*Session, error) {
// Create a gateway
g, err := gateway.NewGateway(token)
if err != nil {
return nil, errors.Wrap(err, "failed to connect to Gateway")
}
return NewWithGateway(g), nil
func New(token string) *Session {
return NewWithIdentifier(gateway.DefaultIdentifier(token))
}
// Login tries to log in as a normal user account; MFA is optional.
func Login(email, password, mfa string) (*Session, error) {
func Login(ctx context.Context, email, password, mfa string) (*Session, error) {
// Make a scratch HTTP client without a token
client := api.NewClient("")
client := api.NewClient("").WithContext(ctx)
// Try to login without TOTP
l, err := client.Login(email, password)
@ -97,7 +72,7 @@ func Login(email, password, mfa string) (*Session, error) {
if l.Token != "" && !l.MFA {
// We got the token, return with a new Session.
return New(l.Token)
return New(l.Token), nil
}
// Discord requests MFA, so we need the MFA token.
@ -111,40 +86,117 @@ func Login(email, password, mfa string) (*Session, error) {
return nil, errors.Wrap(err, "failed to login with 2FA")
}
return New(l.Token)
return New(l.Token), nil
}
// NewWithGateway creates a new Session with the given Gateway.
func NewWithGateway(gw *gateway.Gateway) *Session {
return NewCustomSession(gw, api.NewClient(gw.Identifier.Token), handler.New())
// NewWithIdentifier creates a bare Session with the given identifier.
func NewWithIdentifier(id gateway.Identifier) *Session {
return NewCustom(id, api.NewClient(id.Token), handler.New())
}
// NewCustomSession constructs a bare Session from the given parameters.
func NewCustomSession(gw *gateway.Gateway, cl *api.Client, h *handler.Handler) *Session {
// NewWithGateway constructs a bare Session from the given UNOPENED gateway.
func NewWithGateway(g *gateway.Gateway, h *handler.Handler) *Session {
state := g.State()
return &Session{
Client: api.NewClient(state.Identifier.Token),
Handler: h,
state: &sessionState{
gateway: g,
id: state.Identifier,
},
}
}
// NewCustom constructs a bare Session from the given parameters.
func NewCustom(id gateway.Identifier, cl *api.Client, h *handler.Handler) *Session {
return &Session{
Gateway: gw,
Client: cl,
Handler: h,
looper: handleloop.NewLoop(h),
state: &sessionState{id: id},
}
}
// AddIntents adds the given intents into the gateway. Calling it after Open has
// already been called will result in a panic.
func (s *Session) AddIntents(intents gateway.Intents) {
s.state.Lock()
s.state.id.AddIntents(intents)
if s.state.gateway != nil {
s.state.gateway.AddIntents(intents)
}
s.state.Unlock()
}
// HasIntents reports if the Gateway has the passed Intents.
//
// If no intents are set, e.g. if using a user account, HasIntents will always
// return true.
func (s *Session) HasIntents(intents gateway.Intents) bool {
return s.state.id.HasIntents(intents)
}
// Gateway returns the current session's gateway. If Open has never been called
// or Session was never constructed with a gateway, then nil is returned.
func (s *Session) Gateway() *gateway.Gateway {
s.state.Lock()
g := s.state.gateway
s.state.Unlock()
return g
}
// Open opens the Discord gateway and its handler, then waits until either the
// Ready or Resumed event gets through.
func (s *Session) Open(ctx context.Context) error {
// Start the handler beforehand so no events are missed.
s.looper.Start(s.Gateway.Events)
evCh := make(chan interface{})
// Set the AfterClose's handler.
s.Gateway.AfterClose = func(err error) {
s.Handler.Call(&Closed{
Error: err,
})
s.state.Lock()
defer s.state.Unlock()
if s.state.cancel != nil {
if err := s.close(ctx); err != nil {
return err
}
}
if err := s.Gateway.Open(ctx); err != nil {
return errors.Wrap(err, "failed to start gateway")
if s.state.gateway == nil {
g, err := gateway.NewWithIdentifier(ctx, s.state.id)
if err != nil {
return err
}
s.state.gateway = g
}
return nil
ctx, cancel := context.WithCancel(context.Background())
s.state.ctx = ctx
s.state.cancel = cancel
// TODO: change this to AddSyncHandler.
rm := s.AddHandler(evCh)
defer rm()
opCh := s.state.gateway.Connect(s.state.ctx)
s.state.doneCh = ophandler.Loop(opCh, s.Handler)
for {
select {
case <-ctx.Done():
s.close(ctx)
return ctx.Err()
case <-s.state.doneCh:
// Event loop died.
return s.state.gateway.LastError()
case ev := <-evCh:
switch ev.(type) {
case *gateway.ReadyEvent, *gateway.ResumedEvent:
return nil
}
}
}
}
// WithContext returns a shallow copy of Session with the context replaced in
@ -160,21 +212,34 @@ func (s *Session) WithContext(ctx context.Context) *Session {
}
// Close closes the underlying Websocket connection, invalidating the session
// ID.
//
// It will send a closing frame before ending the connection, closing it
// gracefully. This will cause the bot to appear as offline instantly.
// ID. It will send a closing frame before ending the connection, closing it
// gracefully. This will cause the bot to appear as offline instantly. To
// prevent this behavior, change Gateway.AlwaysCloseGracefully.
func (s *Session) Close() error {
// Stop the event handler
s.looper.Stop()
return s.Gateway.Close()
s.state.Lock()
defer s.state.Unlock()
return s.close(context.Background())
}
// Pause pauses the Gateway connection, by ending the connection without
// sending a closing frame. This allows the connection to be resumed at a later
// point.
func (s *Session) Pause() error {
// Stop the event handler
s.looper.Stop()
return s.Gateway.Pause()
func (s *Session) close(ctx context.Context) error {
if s.state.cancel == nil {
return errors.New("Session is already closed")
}
s.state.cancel()
s.state.cancel = nil
s.state.ctx = nil
// Wait until we've successfully disconnected.
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "cannot wait for gateway exit")
case <-s.state.doneCh:
// ok
}
s.state.doneCh = nil
return s.state.gateway.LastError()
}

50
session/session_test.go Normal file
View file

@ -0,0 +1,50 @@
package session
import (
"context"
"testing"
"time"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/internal/testenv"
)
func TestSession(t *testing.T) {
attempts := 1
timeout := 15 * time.Second
if !testing.Short() {
attempts = 5
timeout = time.Minute // 5s-10s each reconnection
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
t.Cleanup(cancel)
env := testenv.Must(t)
readyCh := make(chan *gateway.ReadyEvent, 1)
s := NewWithIntents(env.BotToken, gateway.IntentGuilds)
s.AddHandler(readyCh)
for i := 0; i < attempts; i++ {
if err := s.Open(ctx); err != nil {
t.Fatal("failed to open:", err)
}
if ready, ok := <-readyCh; !ok {
t.Fatal("ready not received")
} else {
now := time.Now()
t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username)
}
if err := s.Close(); err != nil {
t.Fatal("failed to close:", err)
}
// Hold for an additional one second.
time.Sleep(time.Second)
}
}

View file

@ -61,7 +61,7 @@ type rescalingState struct {
func NewManager(token string, fn NewShardFunc) (*Manager, error) {
id := gateway.DefaultIdentifier(token)
url, err := updateIdentifier(context.Background(), id)
url, err := updateIdentifier(context.Background(), &id)
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway info")
}
@ -76,20 +76,20 @@ func NewManager(token string, fn NewShardFunc) (*Manager, error) {
//
// This function should rarely be used, since the shard information will be
// queried from Discord if it's required to shard anyway.
func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager, error) {
func NewIdentifiedManager(idData gateway.IdentifyCommand, fn NewShardFunc) (*Manager, error) {
// Ensure id.Shard is never nil.
if data.Shard == nil {
data.Shard = gateway.DefaultShard
if idData.Shard == nil {
idData.Shard = gateway.DefaultShard
}
id := gateway.NewIdentifier(data)
id := gateway.NewIdentifier(idData)
url, err := updateIdentifier(context.Background(), id)
url, err := updateIdentifier(context.Background(), &id)
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway info")
}
id.Shard = data.Shard
id.Shard = idData.Shard
return NewIdentifiedManagerWithURL(url, id, fn)
}
@ -97,7 +97,7 @@ func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager,
// NewIdentifiedManagerWithURL creates a new Manager with the given Identifier
// and gateway URL. It behaves similarly to NewIdentifiedManager.
func NewIdentifiedManagerWithURL(
url string, id *gateway.Identifier, fn NewShardFunc) (*Manager, error) {
url string, id gateway.Identifier, fn NewShardFunc) (*Manager, error) {
m := Manager{
gatewayURL: gateway.AddGatewayParams(url),
@ -108,12 +108,12 @@ func NewIdentifiedManagerWithURL(
var err error
for i := range m.shards {
data := id.IdentifyData
data := id.IdentifyCommand
data.Shard = &gateway.Shard{i, len(m.shards)}
m.shards[i] = ShardState{
ID: gateway.Identifier{
IdentifyData: data,
IdentifyCommand: data,
IdentifyShortLimit: id.IdentifyShortLimit,
IdentifyGlobalLimit: id.IdentifyGlobalLimit,
},
@ -263,10 +263,10 @@ func (m *Manager) rescale() {
func (m *Manager) tryRescale(ctx context.Context) bool {
m.mutex.Lock()
data := m.shards[0].ID.IdentifyData
data := m.shards[0].ID.IdentifyCommand
newID := gateway.NewIdentifier(data)
url, err := updateIdentifier(ctx, newID)
url, err := updateIdentifier(ctx, &newID)
if err != nil {
m.mutex.Unlock()
return false
@ -282,12 +282,12 @@ func (m *Manager) tryRescale(ctx context.Context) bool {
newShards := make([]ShardState, numShards)
for i := 0; i < numShards; i++ {
data := newID.IdentifyData
data := newID.IdentifyCommand
data.Shard = &gateway.Shard{i, len(m.shards)}
newShards[i] = ShardState{
ID: gateway.Identifier{
IdentifyData: data,
IdentifyCommand: data,
IdentifyShortLimit: newID.IdentifyShortLimit,
IdentifyGlobalLimit: newID.IdentifyGlobalLimit,
},

View file

@ -3,7 +3,10 @@ package shard
import (
"context"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/utils/handler"
"github.com/pkg/errors"
)
@ -23,17 +26,14 @@ type Shard interface {
// methods without deadlocking.
type NewShardFunc func(m *Manager, id *gateway.Identifier) (Shard, error)
// NewGatewayShardFunc wraps around NewGatewayShard to be compatible with
// NewShardFunc.
var NewGatewayShardFunc NewShardFunc = func(m *Manager, id *gateway.Identifier) (Shard, error) {
return NewGatewayShard(m, id), nil
}
// NewGatewayShard creates a new gateway that's plugged into the shard manager.
func NewGatewayShard(m *Manager, id *gateway.Identifier) *gateway.Gateway {
gw := gateway.NewCustomIdentifiedGateway(m.GatewayURL(), id)
gw.OnShardingRequired(m.Rescale)
return gw
// NewSessionShard creates a shard constructor for a session.
// Accessing any shard and adding a handler will add a handler for all shards.
func NewSessionShard(f func(m *Manager, s *session.Session)) NewShardFunc {
return func(m *Manager, id *gateway.Identifier) (Shard, error) {
s := session.NewCustom(*id, api.NewClient(id.Token), handler.New())
f(m, s)
return s, nil
}
}
// ShardState wraps around the Gateway interface to provide additional state.

View file

@ -1,4 +1,4 @@
package session
package shard
import (
"context"
@ -6,29 +6,28 @@ import (
"time"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/session"
)
func TestSharding(t *testing.T) {
env := testenv.Must(t)
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken)
data.Shard = &gateway.Shard{0, env.ShardCount}
readyCh := make(chan *gateway.ReadyEvent)
m, err := shard.NewIdentifiedManager(data, NewShardFunc(
func(m *shard.Manager, s *Session) {
m, err := NewIdentifiedManager(data, NewSessionShard(
func(m *Manager, s *session.Session) {
now := time.Now().Format(time.StampMilli)
t.Log(now, "initializing shard")
s.Gateway.ErrorLog = func(err error) {
t.Error("gateway error:", err)
}
s.AddIntents(gateway.IntentGuilds)
s.AddHandler(readyCh)
s.AddHandler(func(err error) {
t.Error("unexpected error:", err)
})
},
))
if err != nil {

View file

@ -6,10 +6,11 @@ import (
"context"
"sync"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/session/shard"
"github.com/diamondburned/arikawa/v3/state/store"
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
"github.com/diamondburned/arikawa/v3/utils/handler"
@ -27,7 +28,8 @@ var (
// The user should initialize handlers and intents in the opts function.
func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc {
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
state := NewFromSession(session.NewCustomShard(m, id), defaultstore.New())
sessn := session.NewCustom(*id, api.NewClient(id.Token), handler.New())
state := NewFromSession(sessn, defaultstore.New())
opts(m, state)
return state, nil
}
@ -110,29 +112,21 @@ type State struct {
}
// New creates a new state.
func New(token string) (*State, error) {
func New(token string) *State {
return NewWithStore(token, defaultstore.New())
}
// NewWithIntents creates a new state with the given gateway intents. For more
// information, refer to gateway.Intents.
func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
s, err := session.NewWithIntents(token, intents...)
if err != nil {
return nil, err
}
return NewFromSession(s, defaultstore.New()), nil
func NewWithIntents(token string, intents ...gateway.Intents) *State {
s := session.NewWithIntents(token, intents...)
return NewFromSession(s, defaultstore.New())
}
// NewWithStore creates a new state with the given store cabinet.
func NewWithStore(token string, cabinet *store.Cabinet) (*State, error) {
s, err := session.New(token)
if err != nil {
return nil, err
}
return NewFromSession(s, cabinet), nil
func NewWithStore(token string, cabinet *store.Cabinet) *State {
s := session.New(token)
return NewFromSession(s, cabinet)
}
// NewFromSession creates a new State from the passed Session and Cabinet.
@ -239,11 +233,11 @@ func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (dis
merr = store.ErrNotFound
)
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
g, gerr = s.Cabinet.Guild(guildID)
}
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
if s.HasIntents(gateway.IntentGuildMembers) {
m, merr = s.Cabinet.Member(guildID, userID)
}
@ -300,11 +294,11 @@ func (s *State) Permissions(
merr = store.ErrNotFound
)
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
g, gerr = s.Cabinet.Guild(ch.GuildID)
}
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
if s.HasIntents(gateway.IntentGuildMembers) {
m, merr = s.Cabinet.Member(ch.GuildID, userID)
}
@ -372,7 +366,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) {
}
func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err error) {
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
cs, err = s.Cabinet.Channels(guildID)
if err == nil {
return
@ -384,7 +378,7 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err
return
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
for i := range cs {
if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil {
return
@ -436,7 +430,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) {
func (s *State) Emoji(
guildID discord.GuildID, emojiID discord.EmojiID) (e *discord.Emoji, err error) {
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
if s.HasIntents(gateway.IntentGuildEmojis) {
e, err = s.Cabinet.Emoji(guildID, emojiID)
if err == nil {
return
@ -464,7 +458,7 @@ func (s *State) Emoji(
}
func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) {
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
if s.HasIntents(gateway.IntentGuildEmojis) {
es, err = s.Cabinet.Emojis(guildID)
if err == nil {
return
@ -476,7 +470,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error)
return
}
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
if s.HasIntents(gateway.IntentGuildEmojis) {
err = s.Cabinet.EmojiSet(guildID, es, false)
}
@ -486,7 +480,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error)
////
func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
c, err := s.Cabinet.Guild(id)
if err == nil {
return c, nil
@ -498,7 +492,7 @@ func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
// Guilds will only fill a maximum of 100 guilds from the API.
func (s *State) Guilds() (gs []discord.Guild, err error) {
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
gs, err = s.Cabinet.Guilds()
if err == nil {
return
@ -510,7 +504,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) {
return
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
for i := range gs {
if err = s.Cabinet.GuildSet(&gs[i], false); err != nil {
return
@ -524,7 +518,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) {
////
func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
if s.HasIntents(gateway.IntentGuildMembers) {
m, err := s.Cabinet.Member(guildID, userID)
if err == nil {
return m, nil
@ -535,7 +529,7 @@ func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord
}
func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error) {
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
if s.HasIntents(gateway.IntentGuildMembers) {
ms, err = s.Cabinet.Members(guildID)
if err == nil {
return
@ -547,7 +541,7 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error
return
}
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
if s.HasIntents(gateway.IntentGuildMembers) {
for i := range ms {
if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil {
return
@ -580,7 +574,7 @@ func (s *State) Message(
wg.Add(1)
go func() {
c, cerr = s.Session.Channel(channelID)
if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
if cerr == nil && s.HasIntents(gateway.IntentGuilds) {
cerr = s.Cabinet.ChannelSet(c, false)
}
@ -706,13 +700,13 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes
// Presence checks the state for user presences. If no guildID is given, it
// will look for the presence in all cached guilds.
func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*discord.Presence, error) {
if !s.Gateway.HasIntents(gateway.IntentGuildPresences) {
if !s.HasIntents(gateway.IntentGuildPresences) {
return nil, store.ErrNotFound
}
// If there's no guild ID, look in all guilds
if !gID.IsValid() {
if !s.Gateway.HasIntents(gateway.IntentGuilds) {
if !s.HasIntents(gateway.IntentGuilds) {
return nil, store.ErrNotFound
}
@ -736,7 +730,7 @@ func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*discord.Pres
////
func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *discord.Role, err error) {
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
target, err = s.Cabinet.Role(guildID, roleID)
if err == nil {
return
@ -754,7 +748,7 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di
target = &r
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
if err = s.RoleSet(guildID, &rs[i], false); err != nil {
return
}
@ -779,7 +773,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
return nil, err
}
if s.Gateway.HasIntents(gateway.IntentGuilds) {
if s.HasIntents(gateway.IntentGuilds) {
for i := range rs {
if err := s.RoleSet(guildID, &rs[i], false); err != nil {
return rs, err
@ -792,7 +786,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
g, err = s.Session.Guild(id)
if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
if err == nil && s.HasIntents(gateway.IntentGuilds) {
err = s.Cabinet.GuildSet(g, false)
}
@ -801,7 +795,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) {
m, err = s.Session.Member(gID, uID)
if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) {
if err == nil && s.HasIntents(gateway.IntentGuildMembers) {
err = s.Cabinet.MemberSet(gID, m, false)
}
@ -811,14 +805,12 @@ func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord
// tracksMessage reports whether the state would track the passed message and
// messages from the same channel.
func (s *State) tracksMessage(m *discord.Message) bool {
return s.Gateway.Identifier.Intents == nil ||
(m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) ||
(!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages))
return (m.GuildID.IsValid() && s.HasIntents(gateway.IntentGuildMessages)) ||
(!m.GuildID.IsValid() && s.HasIntents(gateway.IntentDirectMessages))
}
// tracksChannel reports whether the state would track the passed channel.
func (s *State) tracksChannel(c *discord.Channel) bool {
return s.Gateway.Identifier.Intents == nil ||
(c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) ||
return (c.GuildID.IsValid() && s.HasIntents(gateway.IntentGuilds)) ||
!c.GuildID.IsValid()
}

View file

@ -152,7 +152,7 @@ func (s *State) onEvent(iface interface{}) {
}
// Update available fields from ev into m
ev.Update(m)
ev.UpdateMember(m)
if err := s.Cabinet.MemberSet(ev.GuildID, m, true); err != nil {
s.stateErr(err, "failed to update a member in state")

View file

@ -0,0 +1,32 @@
package state_test
import (
"context"
"log"
"os"
"os/signal"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/state"
)
func Example() {
s := state.New("Bot " + os.Getenv("DISCORD_TOKEN"))
s.AddIntents(gateway.IntentGuilds | gateway.IntentGuildMessages)
s.AddHandler(func(m *gateway.MessageCreateEvent) {
log.Printf("%s: %s", m.Author.Username, m.Content)
})
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
if err := s.Open(ctx); err != nil {
log.Println("cannot open:", err)
}
<-ctx.Done() // block until Ctrl+C
if err := s.Close(); err != nil {
log.Println("cannot close:", err)
}
}

View file

@ -5,16 +5,24 @@ import (
"testing"
"time"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/session/shard"
)
func TestSharding(t *testing.T) {
env := testenv.Must(t)
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken)
data.Shard = &gateway.Shard{0, env.ShardCount}
data.Presence = &gateway.UpdatePresenceCommand{
Status: discord.DoNotDisturbStatus,
Activities: []discord.Activity{{
Name: "Testing shards...",
Type: discord.CustomActivity,
}},
}
readyCh := make(chan *gateway.ReadyEvent)
@ -23,19 +31,18 @@ func TestSharding(t *testing.T) {
now := time.Now().Format(time.StampMilli)
t.Log(now, "initializing shard")
s.Gateway.ErrorLog = func(err error) {
t.Error("gateway error:", err)
}
s.AddIntents(gateway.IntentGuilds)
s.AddHandler(readyCh)
s.AddSyncHandler(readyCh)
s.AddSyncHandler(func(err error) {
t.Log("background error:", err)
})
},
))
if err != nil {
t.Fatal("failed to make shard manager:", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
go func() {

View file

@ -2,22 +2,29 @@ package bot
import (
"reflect"
"sync"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/utils/ws"
)
// eventIntents maps event pointer types to intents.
var eventIntents = map[reflect.Type]gateway.Intents{}
var (
// eventIntents maps event pointer types to intents.
eventIntents map[reflect.Type]gateway.Intents
eventIntentsOnce sync.Once
)
func init() {
for event, intent := range gateway.EventIntents {
fn, ok := gateway.EventCreator[event]
if !ok {
continue
}
eventIntents[reflect.TypeOf(fn())] = intent
}
func ensureEventIntents() {
eventIntentsOnce.Do(func() {
eventIntents = map[reflect.Type]gateway.Intents{}
gateway.OpUnmarshalers.Each(func(_ ws.OpCode, t ws.EventType, f ws.OpFunc) bool {
intent, ok := gateway.EventIntents[t]
if ok {
eventIntents[reflect.TypeOf(f())] = intent
}
return false
})
})
}
type command struct {
@ -44,6 +51,8 @@ func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, er
// intents returns the command's intents from the event.
func (c *command) intents() gateway.Intents {
ensureEventIntents()
intents, ok := eventIntents[c.event]
if !ok {
return 0

View file

@ -14,13 +14,13 @@ import (
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/utils/bot/extras/shellwords"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/session/shard"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/state/store"
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
"github.com/diamondburned/arikawa/v3/utils/bot/extras/shellwords"
"github.com/diamondburned/arikawa/v3/utils/handler"
)
// Prefixer checks a message if it starts with the desired prefix. By default,
@ -55,22 +55,15 @@ func NewShardFunc(fn func(*state.State) (*Context, error)) shard.NewShardFunc {
panic("bot.NewShardFunc missing fn")
}
var once sync.Once
var cab *store.Cabinet
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
state := state.NewFromSession(session.NewCustomShard(m, id), nil)
sessn := session.NewCustom(*id, api.NewClient(id.Token), handler.New())
state := state.NewFromSession(sessn, defaultstore.New())
bot, err := fn(state)
if err != nil {
return nil, errors.Wrap(err, "failed to create bot instance")
}
if state.Cabinet == nil {
once.Do(func() { cab = defaultstore.New() })
state.Cabinet = cab
}
return bot, nil
}
}
@ -206,10 +199,6 @@ func Start(
// fail api request if they (will) take up more than 5 minutes
ctx.Client.Client.Timeout = 5 * time.Minute
ctx.Gateway.ErrorLog = func(err error) {
ctx.ErrorLogger(err)
}
if opts != nil {
if err := opts(ctx); err != nil {
return nil, err
@ -317,7 +306,7 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
// AddIntents adds the given Gateway Intent into the Gateway. This is a
// convenient function that calls Gateway's AddIntent.
func (ctx *Context) AddIntents(i gateway.Intents) {
ctx.Gateway.AddIntents(i)
ctx.Session.AddIntents(i)
}
// Subcommands returns the slice of subcommands. To add subcommands, use

View file

@ -6,8 +6,8 @@ import (
"time"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/gateway/shard"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/session/shard"
"github.com/diamondburned/arikawa/v3/state"
)
@ -24,7 +24,7 @@ func (bot *shardedBot) OnReady(r *gateway.ReadyEvent) {
func TestSharding(t *testing.T) {
env := testenv.Must(t)
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken)
data.Shard = &gateway.Shard{0, env.ShardCount}
readyCh := make(chan *gateway.ReadyEvent)

View file

@ -4,27 +4,18 @@ import (
"errors"
"testing"
"github.com/diamondburned/arikawa/v3/utils/bot"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/state/store"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/bot"
)
func TestAdminOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Session: &session.Session{
Gateway: &gateway.Gateway{
Identifier: &gateway.Identifier{
IdentifyData: gateway.IdentifyData{
Intents: option.NewUint(uint(gateway.IntentGuilds | gateway.IntentGuildMembers)),
},
},
},
},
Session: session.New(""),
Cabinet: mockCabinet(),
},
}
@ -62,15 +53,7 @@ func TestAdminOnly(t *testing.T) {
func TestGuildOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Session: &session.Session{
Gateway: &gateway.Gateway{
Identifier: &gateway.Identifier{
IdentifyData: gateway.IdentifyData{
Intents: option.NewUint(uint(gateway.IntentGuilds)),
},
},
},
},
Session: session.New(""),
Cabinet: mockCabinet(),
},
}

View file

@ -0,0 +1,172 @@
package main
import (
"bytes"
"flag"
"go/format"
"log"
"os"
"regexp"
"strconv"
"strings"
"text/template"
"unicode"
_ "embed"
"github.com/pkg/errors"
)
var (
pkg = "gateway"
out = "-"
)
type registry struct {
PackageName string
EventTypes []EventType
}
type EventType struct {
StructName string
EventName string
IsDispatch bool
OpCode int
}
func (t *EventType) MethodRecv() string {
if len(t.StructName) == 0 {
return "e"
}
return string(unicode.ToLower([]rune(t.StructName)[0]))
}
//go:embed template.tmpl
var packageTmpl string
var tmpl = template.Must(template.New("").Parse(packageTmpl))
const eventStructRegex = "(?m)" +
`^// ([A-Za-z]+(?:Event|Command)) is (a dispatch event|an event|a command)` +
`(?:` +
` for ([A-Z_]+)` + "|" +
` for Op (\d+)` +
`)?` +
`\.(?:.|\n)*?\ntype ([A-Za-z]+(?:Event|Command)) .*`
func main() {
flag.StringVar(&pkg, "p", pkg, "the package name to use")
flag.StringVar(&out, "o", out, "output file, - for stdout")
flag.Parse()
log.Println(eventStructRegex)
r := registry{
PackageName: pkg,
}
files, err := os.ReadDir(".")
if err != nil {
log.Fatalln("failed to read current directory:", err)
}
for _, file := range files {
if file.IsDir() || !strings.HasSuffix(file.Name(), ".go") {
continue
}
if err := r.CrawlFile(file.Name()); err != nil {
log.Fatalln("failed to crawl file:", err)
}
}
buf := bytes.Buffer{}
if err := tmpl.Execute(&buf, &r); err != nil {
log.Fatalln("failed to execute template:", err)
}
b, err := format.Source(buf.Bytes())
if err != nil {
log.Fatalln("failed to fmt:", err)
}
output := os.Stdout
if out != "-" {
f, err := os.Create(out)
if err != nil {
log.Fatalln("failed to create output:", err)
}
defer f.Close()
output = f
}
if _, err := output.Write(b); err != nil {
log.Fatalln("failed to write rendered:", err)
}
}
var reEventStruct = regexp.MustCompile(eventStructRegex)
func (r *registry) CrawlFile(name string) error {
f, err := os.ReadFile(name)
if err != nil {
return errors.Wrap(err, "failed to read file")
}
for _, match := range reEventStruct.FindAllSubmatch(f, -1) {
// Validity check.
if string(match[1]) != string(match[5]) {
continue
}
if strings.HasSuffix(string(match[1]), "Command") && string(match[2]) != "a command" {
log.Println(string(match[1]), "has invalid comment %q", string(match[2]))
continue
}
t := EventType{
StructName: string(match[1]),
EventName: string(match[3]),
IsDispatch: string(match[2]) == "a dispatch event",
OpCode: -1,
}
if op := string(match[4]); op != "" && !t.IsDispatch {
i, err := strconv.Atoi(op)
if err != nil {
log.Printf("error at struct %s: error parsing Op %v", t.StructName, err)
}
t.OpCode = i
}
if t.IsDispatch && t.EventName == "" {
t.EventName = guessEventName(t.StructName)
}
r.EventTypes = append(r.EventTypes, t)
}
return nil
}
func guessEventName(structName string) string {
name := strings.TrimSuffix(structName, "Event")
var newName strings.Builder
newName.Grow(len(name) * 2)
for i, r := range name {
if unicode.IsLower(r) {
newName.WriteRune(unicode.ToUpper(r))
continue
}
if i > 0 {
newName.WriteByte('_')
}
newName.WriteRune(r)
}
return newName.String()
}

View file

@ -0,0 +1,27 @@
// Code generated by genevent. DO NOT EDIT.
package {{ .PackageName }}
import "github.com/diamondburned/arikawa/v3/utils/ws"
func init() {
OpUnmarshalers.Add(
{{ range .EventTypes -}}
func() ws.Event { return new({{ .StructName }}) },
{{ end -}}
)
}
{{ range .EventTypes }}
{{ if .IsDispatch }}
// Op implements Event. It always returns 0.
func (*{{ .StructName }}) Op() ws.OpCode { return dispatchOp }
{{ else if (gt .OpCode -1) }}
// Op implements Event. It always returns Op {{ .OpCode }}.
func (*{{ .StructName }}) Op() ws.OpCode { return {{ .OpCode }} }
{{ end }}
// EventType implements Event.
func (*{{ .StructName }}) EventType() ws.EventType { return "{{ .EventName }}" }
{{ end }}

View file

@ -239,6 +239,7 @@ type handler struct {
chanclose reflect.Value // IsValid() if chan
isIface bool
isSync bool
isOnce bool
}
// newHandler reflects either a channel or a function into a handler. A function

101
utils/ws/codec.go Normal file
View file

@ -0,0 +1,101 @@
package ws
import (
"io"
"net/http"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/pkg/errors"
)
// Codec holds the codec states for Websocket implementations to share with the
// manager. It is used internally in the Websocket and the Connection
// implementation.
type Codec struct {
Unmarshalers OpUnmarshalers
Headers http.Header
}
// NewCodec creates a new default Codec instance.
func NewCodec(unmarshalers OpUnmarshalers) Codec {
return Codec{
Unmarshalers: unmarshalers,
Headers: http.Header{
"Accept-Encoding": {"zlib"},
},
}
}
type codecOp struct {
Op
Data json.Raw `json:"d,omitempty"`
}
const maxSharedBufferSize = 1 << 15 // 32KB
// DecodeBuffer boxes a byte slice to provide a shared and thread-unsafe buffer.
// It is used internally and should only be handled around as an opaque thing.
type DecodeBuffer struct {
buf []byte
}
// NewDecodeBuffer creates a new preallocated DecodeBuffer.
func NewDecodeBuffer(cap int) DecodeBuffer {
if cap > maxSharedBufferSize {
cap = maxSharedBufferSize
}
return DecodeBuffer{
buf: make([]byte, 0, cap),
}
}
// DecodeFrom reads the given reader and decodes it into an Op.
//
// buf is optional.
func (c Codec) DecodeFrom(r io.Reader, buf *DecodeBuffer) Op {
var op codecOp
op.Data = json.Raw(buf.buf)
if err := json.DecodeStream(r, &op); err != nil {
return newErrOp(err, "cannot read JSON stream")
}
// buf isn't grown from here out. Set it back right now. If Data hasn't been
// grown, then this will just set buf back to what it was.
if cap(op.Data) < maxSharedBufferSize {
buf.buf = op.Data[:0]
}
fn := c.Unmarshalers.Lookup(op.Code, op.Type)
if fn == nil {
err := UnknownEventError{
Op: op.Code,
Type: op.Type,
}
return newErrOp(err, "")
}
op.Op.Data = fn()
if err := op.Data.UnmarshalTo(op.Op.Data); err != nil {
return newErrOp(err, "cannot unmarshal JSON data from gateway")
}
return op.Op
}
func newErrOp(err error, wrap string) Op {
if wrap != "" {
err = errors.Wrap(err, wrap)
}
ev := &BackgroundErrorEvent{
Err: err,
}
return Op{
Code: ev.Op(),
Type: ev.EventType(),
Data: ev,
}
}

278
utils/ws/conn.go Normal file
View file

@ -0,0 +1,278 @@
package ws
import (
"compress/zlib"
"context"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
const rwBufferSize = 1 << 15 // 32KB
// ErrWebsocketClosed is returned if the websocket is already closed.
var ErrWebsocketClosed = errors.New("websocket is closed")
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself, including
// modifying the connection URL. The implementation doesn't have to be safe for
// concurrent use.
type Connection interface {
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) (<-chan Op, error)
// Send allows the caller to send bytes.
Send(context.Context, []byte) error
// Close should close the websocket connection. The underlying connection
// may be reused, but this Connection instance will be reused with Dial. The
// Connection must still be reusable even if Close returns an error. If
// gracefully is true, then the implementation must send a close frame
// prior.
Close(gracefully bool) error
}
// Conn is the default Websocket connection. It tries to compresses all payloads
// using zlib.
type Conn struct {
dialer websocket.Dialer
codec Codec
// conn is used for synchronizing the conn instance itself. Any use of conn
// must copy conn out.
conn *connMutex
// mut is used for synchronizing the conn field.
mut sync.Mutex
// CloseTimeout is the timeout for graceful closing. It's defaulted to 5s.
CloseTimeout time.Duration
}
type connMutex struct {
wrmut chan struct{}
*websocket.Conn
}
var _ Connection = (*Conn)(nil)
// NewConn creates a new default websocket connection with a default dialer.
func NewConn(codec Codec) *Conn {
return NewConnWithDialer(codec, websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: rwBufferSize,
WriteBufferSize: rwBufferSize,
EnableCompression: true,
})
}
// NewConnWithDialer creates a new default websocket connection with a custom
// dialer.
func NewConnWithDialer(codec Codec, dialer websocket.Dialer) *Conn {
return &Conn{
dialer: dialer,
codec: codec,
CloseTimeout: 5 * time.Second,
}
}
// Dial starts a new connection and returns the listening channel for it. If the
// websocket is already dialed, then the connection is closed first.
func (c *Conn) Dial(ctx context.Context, addr string) (<-chan Op, error) {
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
c.mut.Lock()
defer c.mut.Unlock()
// Ensure that the connection is already closed.
if c.conn != nil {
c.conn.close(c.CloseTimeout, false)
}
conn, _, err := c.dialer.DialContext(ctx, addr, c.codec.Headers)
if err != nil {
return nil, errors.Wrap(err, "failed to dial WS")
}
events := make(chan Op, 1)
go readLoop(conn, c.codec, events)
c.conn = &connMutex{
wrmut: make(chan struct{}, 1),
Conn: conn,
}
return events, err
}
// Close implements Connection.
func (c *Conn) Close(gracefully bool) error {
c.mut.Lock()
defer c.mut.Unlock()
return c.conn.close(c.CloseTimeout, gracefully)
}
func (c *connMutex) close(timeout time.Duration, gracefully bool) error {
if c == nil || c.Conn == nil {
WSDebug("Conn: Close is called on already closed connection")
return ErrWebsocketClosed
}
WSDebug("Conn: Close is called; shutting down the Websocket connection.")
if gracefully {
// Have a deadline before closing.
deadline := time.Now().Add(timeout)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
select {
case c.wrmut <- struct{}{}:
// Lock acquired. We can now safely set the deadline and write.
c.SetWriteDeadline(deadline)
WSDebug("Conn: Graceful closing requested, sending close frame.")
if err := c.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
); err != nil {
WSError(err)
}
// Release the lock.
<-c.wrmut
case <-ctx.Done():
// We couldn't acquire the lock. Resort to just closing the
// connection directly.
}
}
// Close the WS.
err := c.Conn.Close()
if err != nil {
WSDebug("Conn: Websocket closed; error:", err)
} else {
WSDebug("Conn: Websocket closed successfully")
}
c.Conn = nil
return err
}
// resetDeadline is used to reset the write deadline after using the context's.
var resetDeadline = time.Time{}
// Send implements Connection.
func (c *Conn) Send(ctx context.Context, b []byte) error {
c.mut.Lock()
conn := c.conn
c.mut.Unlock()
select {
case conn.wrmut <- struct{}{}:
defer func() { <-conn.wrmut }()
if ctx != context.Background() {
d, ok := ctx.Deadline()
if ok {
conn.SetWriteDeadline(d)
defer conn.SetWriteDeadline(resetDeadline)
}
}
return conn.WriteMessage(websocket.TextMessage, b)
case <-ctx.Done():
return ctx.Err()
}
}
// loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself.
type loopState struct {
conn *websocket.Conn
codec Codec
zlib io.ReadCloser
buf DecodeBuffer
}
func readLoop(conn *websocket.Conn, codec Codec, opCh chan<- Op) {
// Clean up the events channel in the end.
defer close(opCh)
// Allocate the read loop its own private resources.
state := loopState{
conn: conn,
codec: codec,
buf: NewDecodeBuffer(1 << 14), // 16KB
}
for {
b, err := state.handle()
if err != nil {
WSDebug("Conn: fatal Conn error:", err)
closeEv := &CloseEvent{
Err: err,
Code: -1,
}
var closeErr *websocket.CloseError
if errors.As(err, &closeErr) {
closeEv.Code = closeErr.Code
closeEv.Err = fmt.Errorf("%d %s", closeErr.Code, closeErr.Text)
}
opCh <- Op{
Code: closeEv.Op(),
Type: closeEv.EventType(),
Data: closeEv,
}
return
}
opCh <- b
}
}
func (state *loopState) handle() (Op, error) {
// skip message type
t, r, err := state.conn.NextReader()
if err != nil {
return Op{}, err
}
if t == websocket.BinaryMessage {
// Probably a zlib payload.
if state.zlib == nil {
z, err := zlib.NewReader(r)
if err != nil {
return Op{}, errors.Wrap(err, "failed to create a zlib reader")
}
state.zlib = z
} else {
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
return Op{}, errors.Wrap(err, "failed to reset zlib reader")
}
}
defer state.zlib.Close()
r = state.zlib
}
return state.codec.DecodeFrom(r, &state.buf), nil
}

364
utils/ws/gateway.go Normal file
View file

@ -0,0 +1,364 @@
package ws
import (
"context"
"fmt"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/internal/lazytime"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/pkg/errors"
)
// ConnectionError is given to the user if the gateway fails to connect to the
// gateway for any reason, including during an initial connection or a
// reconnection. To check for this error, use the errors.As function.
type ConnectionError struct {
Err error
}
// Unwrap unwraps the ConnectionError.
func (err ConnectionError) Unwrap() error { return err.Err }
// Error formats the error.
func (err ConnectionError) Error() string {
return fmt.Sprintf("error reconnecting: %s", err.Err)
}
// BackgroundErrorEvent describes an error that the gateway event loop might
// stumble upon while it's running. See Gateway's documentation for possible
// usages.
type BackgroundErrorEvent struct {
Err error
}
var _ Event = (*BackgroundErrorEvent)(nil)
// Unwrap returns err.Err.
func (err *BackgroundErrorEvent) Unwrap() error { return err.Err }
// Error formats the BackgroundErrorEvent.
func (err *BackgroundErrorEvent) Error() string {
return "background gateway error: " + err.Err.Error()
}
// Op implements Op. It returns -1.
func (err *BackgroundErrorEvent) Op() OpCode { return -1 }
// EventType implements Op. It returns an opaque unique string.
func (err *BackgroundErrorEvent) EventType() EventType {
return "__ws.BackgroundErrorEvent"
}
// GatewayOpts describes the gateway event loop options.
type GatewayOpts struct {
// ReconnectDelay determines the duration to idle after each failed retry.
// This can be used to implement exponential backoff. The default is already
// sane, so this field rarely needs to be changed.
ReconnectDelay func(try int) time.Duration
// FatalCloseCodes is a list of close codes that will cause the gateway to
// exit out if it stumbles on one of these. It is a copy of FatalCloseCodes
// (the global variable) by default.
FatalCloseCodes []int
// DialTimeout is the timeout to wait for each websocket dial before failing
// it and retrying. Default is 0.
DialTimeout time.Duration
// ReconnectAttempt is the maximum number of attempts made to Reconnect
// before aborting the whole gateway. If this set to 0, unlimited attempts
// will be made. Default is 0.
ReconnectAttempt int
// AlwaysCloseGracefully, if true, will always make the Gateway close
// gracefully once the context given to Open is cancelled. It governs the
// Close behavior. The default is true.
AlwaysCloseGracefully bool
}
// DefaultGatewayOpts is the default event loop options.
var DefaultGatewayOpts = GatewayOpts{
ReconnectDelay: func(try int) time.Duration {
// minimum 4 seconds
return time.Duration(4+(2*try)) * time.Second
},
DialTimeout: 0,
ReconnectAttempt: 0,
AlwaysCloseGracefully: true,
}
// Gateway describes an instance that handles the Discord gateway. It is
// basically an abstracted concurrent event loop that the user could signal to
// start connecting to the Discord gateway server.
type Gateway struct {
ws *Websocket
reconnect chan struct{}
heart lazytime.Ticker
srcOp <-chan Op // from WS
outer outerState
lastError error
opts GatewayOpts
}
// outerState holds gateway state that the caller may change concurrently. As
// such, it holds a mutex to allow that. The main purpose of this
// synchronization is to allow the caller to use the gateway while the event
// loop is still running without having the event loop muddle in without locking
// properly. For example, opCh is given to the event loop as a copy; the event
// loop must never access the outerState directly.
type outerState struct {
sync.Mutex
ch chan Op
started bool
}
// Handler describes a gateway handler. It describes the core that governs the
// behavior of the gateway event loop.
type Handler interface {
// OnOp is called by the gateway event loop on every new Op. If the returned
// boolean is false, then the loop fatally exits.
OnOp(context.Context, Op) (canContinue bool)
// SendHeartbeat is called by the gateway event loop everytime a heartbeat
// needs to be sent over.
SendHeartbeat(context.Context)
// Close closes the handler.
Close() error
}
// NewGateway creates a new Gateway with a custom gateway URL and a pre-existing
// Identifier. If opts is nil, then DefaultOpts is used.
func NewGateway(ws *Websocket, opts *GatewayOpts) *Gateway {
if opts == nil {
opts = &DefaultGatewayOpts
}
return &Gateway{
ws: ws,
opts: *opts,
}
}
// Send is a function to send an Op payload to the Gateway.
func (g *Gateway) Send(ctx context.Context, data Event) error {
op := Op{
Code: data.Op(),
Type: data.EventType(),
Data: data,
}
b, err := json.Marshal(op)
if err != nil {
return errors.Wrap(err, "failed to encode payload")
}
// WS should already be thread-safe.
return g.ws.Send(ctx, b)
}
// HasStarted returns true if the gateway event loop is currently spinning.
func (g *Gateway) HasStarted() bool {
g.outer.Lock()
defer g.outer.Unlock()
return g.outer.started
}
// AssertIsNotRunning asserts that the gateway is currently not running. If the
// gateway is running, the method will panic. Since a gateway cannot be started
// back up, this method can be used to detect whether or not the caller in a
// single goroutine can read the state safely.
func (g *Gateway) AssertIsNotRunning() {
g.outer.Lock()
defer g.outer.Unlock()
if !g.outer.started {
return
}
// Hack to ensure that the event channel is closed.
select {
case _, ok := <-g.outer.ch:
if !ok {
return
}
// The panic behavior is a must, because if this branch is hit, then
// we've actually stolen an event from the channel unexpectedly, putting
// the event loop under a weird state.
//
// An alternative solution to this bug would be to mutex-guard the error
// field, but the purpose of this method isn't to be called before the
// gateway has been stopped.
panic("ws: Error called while Gateway is still running")
default:
panic("ws: Error called while Gateway is still running")
}
}
// Connect starts the background goroutine that tries its best to maintain a
// stable connection to the Websocket gateway. To the user, the gateway should
// appear to be working seamlessly.
//
// For more documentation, refer to (*gateway.Gateway).Connect.
func (g *Gateway) Connect(ctx context.Context, h Handler) <-chan Op {
g.outer.Lock()
defer g.outer.Unlock()
if !g.outer.started {
g.outer.started = true
g.outer.ch = make(chan Op, 1)
go g.spin(ctx, h)
}
return g.outer.ch
}
// LastError returns the last error that the gateway has received.
func (g *Gateway) LastError() error {
g.AssertIsNotRunning()
return g.lastError
}
// finalize closes the gateway permanently.
func (g *Gateway) finalize(h Handler) {
var err error
if g.opts.AlwaysCloseGracefully {
err = g.ws.CloseGracefully()
} else {
err = g.ws.Close()
}
if err != nil {
g.SendErrorWrap(err, "failed to finalize websocket")
}
if err := h.Close(); err != nil {
g.SendError(err)
}
g.outer.Lock()
close(g.outer.ch)
g.outer.started = false
g.outer.Unlock()
}
// QueueReconnect queues a reconnection in the gateway loop. This method should
// only be called in the event loop ONCE; calling more than once will deadlock
// the loop.
func (g *Gateway) QueueReconnect() {
select {
case g.reconnect <- struct{}{}:
default:
}
g.heart.Stop()
}
// ResetHeartbeat resets the heartbeat to be the given duration.
func (g *Gateway) ResetHeartbeat(d time.Duration) {
g.heart.Reset(d)
}
// SendError sends the given error wrapped in a BackgroundErrorEvent into the
// event channel.
func (g *Gateway) SendError(err error) {
event := &BackgroundErrorEvent{err}
g.outer.ch <- Op{
Code: event.Op(),
Type: event.EventType(),
Data: event,
}
g.lastError = err
}
// SendErrorWrap is a convenient function over SendError.
func (g *Gateway) SendErrorWrap(err error, message string) {
g.SendError(errors.Wrap(err, message))
}
func (g *Gateway) spin(ctx context.Context, h Handler) {
// Always close the event channel once we exit.
defer g.finalize(h)
var retryTimer lazytime.Timer
defer retryTimer.Stop()
g.reconnect = make(chan struct{}, 1)
g.reconnect <- struct{}{}
for {
select {
case <-ctx.Done():
return
case op, ok := <-g.srcOp:
if !ok {
// Skip zero-value Ops that may happen on gateway closure.
continue
}
switch data := op.Data.(type) {
case *CloseEvent:
for _, code := range g.opts.FatalCloseCodes {
if code == data.Code {
// Don't wrap the error, but instead, just pipe it as-is
// through the channel.
g.outer.ch <- op
g.lastError = data
return
}
}
}
ok = h.OnOp(ctx, op)
g.outer.ch <- op
if !ok {
return
}
case <-g.heart.C:
h.SendHeartbeat(ctx)
case <-g.reconnect:
// Close the previous connection if it's not already. Ignore the
// already closed error.
if err := g.ws.Close(); err != nil && !errors.Is(err, ErrWebsocketClosed) {
g.SendErrorWrap(err, "error closing before reconnecting")
}
// Invalidate our srcOp.
g.srcOp = nil
// Keep track of the last error for notifying.
var err error
for try := 0; g.opts.ReconnectAttempt == 0 || try < g.opts.ReconnectAttempt; try++ {
g.srcOp, err = g.ws.Dial(ctx)
if err == nil {
break
}
// Signal an error before retrying.
g.SendError(ConnectionError{err})
retryTimer.Reset(g.opts.ReconnectDelay(try))
if err := retryTimer.Wait(ctx); err != nil {
g.SendError(ConnectionError{ctx.Err()})
return
}
}
// Ensure that we've reconnected successfully. Exit otherwise.
if g.srcOp == nil {
err = errors.Wrap(err, "failed to reconnect after max attempts")
g.SendError(ConnectionError{err})
return
}
}
}
}

212
utils/ws/op.go Normal file
View file

@ -0,0 +1,212 @@
package ws
import (
"context"
"fmt"
"sync"
"github.com/pkg/errors"
)
// OpCode is the type for websocket Op codes. Op codes less than 0 are
// internal Op codes and should usually be ignored.
type OpCode int
// CloseEvent is an event that is given from wsutil when the websocket is
// closed.
type CloseEvent struct {
// Err is the underlying error.
Err error
// Code is the websocket close code, if any.
Code int
}
// Unwrap returns err.Err.
func (e *CloseEvent) Unwrap() error { return e.Err }
// Error formats the CloseEvent. A CloseEvent is also an error.
func (e *CloseEvent) Error() string {
return fmt.Sprintf("websocket closed, reason: %s", e.Err)
}
// Op implements Event. It returns -1.
func (e *CloseEvent) Op() OpCode { return -1 }
// EventType implements Event. It returns an emty string.
func (e *CloseEvent) EventType() EventType { return "__ws.CloseEvent" }
// EventType is a type for event types, which is the "t" field in the payload.
type EventType string
// Event describes an Event data that comes from a gateway Operation.
type Event interface {
Op() OpCode
EventType() EventType
}
// OpFunc is a constructor function for an Operation.
type OpFunc func() Event
// OpUnmarshalers contains a map of event constructor function.
type OpUnmarshalers struct {
r map[opFuncID]OpFunc
}
type opFuncID struct {
Op OpCode `json:"op"`
T EventType `json:"t"`
}
// NewOpUnmarshalers creates a nwe OpUnmarshalers instance from the given
// constructor functions.
func NewOpUnmarshalers(funcs ...OpFunc) OpUnmarshalers {
m := OpUnmarshalers{r: make(map[opFuncID]OpFunc)}
m.Add(funcs...)
return m
}
// Each iterates over the marshaler map.
func (m OpUnmarshalers) Each(f func(OpCode, EventType, OpFunc) (done bool)) {
for id, fn := range m.r {
if f(id.Op, id.T, fn) {
return
}
}
}
// Add adds the given functions into the unmarshaler registry.
func (m OpUnmarshalers) Add(funcs ...OpFunc) {
for _, fn := range funcs {
ev := fn()
id := opFuncID{
Op: ev.Op(),
T: ev.EventType(),
}
m.r[id] = fn
}
}
// Lookup searches the OpMarshalers map for the given constructor function.
func (m OpUnmarshalers) Lookup(op OpCode, t EventType) OpFunc {
return m.r[opFuncID{op, t}]
}
// Op is a gateway Operation.
type Op struct {
Code OpCode `json:"op"`
Data Event `json:"d,omitempty"`
// Type is only for gateway dispatch events.
Type EventType `json:"t,omitempty"`
// Sequence is only for gateway dispatch events (Op 0).
Sequence int64 `json:"s,omitempty"`
}
// UnknownEventError is required by HandleOp if an event is encountered that is
// not known. Internally, unknown events are logged and ignored. It is not a
// fatal error.
type UnknownEventError struct {
Op OpCode
Type EventType
}
// Error formats the unknown event error to with the event name and payload
func (err UnknownEventError) Error() string {
return fmt.Sprintf("unknown op %d, event %s", err.Op, err.Type)
}
// IsBrokenConnection returns true if the error is a broken connection error.
func IsUnknownEvent(err error) bool {
var uevent *UnknownEventError
return errors.As(err, &uevent)
}
// ReadOps reads maximum n Ops and accumulate them into a slice.
func ReadOps(ctx context.Context, ch <-chan Op, n int) ([]Op, error) {
ops := make([]Op, 0, n)
for {
select {
case <-ctx.Done():
return ops, ctx.Err()
case op := <-ch:
ops = append(ops, op)
if len(ops) == n {
return ops, nil
}
}
}
}
// ReadOp reads a single Op.
func ReadOp(ctx context.Context, ch <-chan Op) (Op, error) {
select {
case <-ctx.Done():
return Op{}, ctx.Err()
case op := <-ch:
return op, nil
}
}
// Broadcaster is primarily used for debugging.
type Broadcaster struct {
src <-chan Op
dst map[chan<- Op]struct{}
mut sync.Mutex
void bool
}
// NewBroadcaster creates a new broadcaster.
func NewBroadcaster(src <-chan Op) *Broadcaster {
return &Broadcaster{
src: src,
dst: make(map[chan<- Op]struct{}),
}
}
// Start starts the broadcasting loop.
func (b *Broadcaster) Start() {
b.mut.Lock()
if b.void {
panic("Start called on voided Broadcaster")
}
b.mut.Unlock()
go func() {
for op := range b.src {
b.mut.Lock()
for ch := range b.dst {
ch <- op
}
b.mut.Unlock()
}
b.mut.Lock()
b.void = true
for ch := range b.dst {
close(ch)
}
b.mut.Unlock()
}()
}
// Subscribe subscribes the given channel
func (b *Broadcaster) Subscribe(ch chan<- Op) {
b.mut.Lock()
if b.void {
panic("Subscribe called on voided Broadcaster")
}
b.dst[ch] = struct{}{}
b.mut.Unlock()
}
// NewSubscribed creates a newly subscribed Op channel.
func (b *Broadcaster) NewSubscribed() <-chan Op {
ch := make(chan Op, 1)
b.Subscribe(ch)
return ch
}

View file

@ -0,0 +1,35 @@
// Package ophandler provides an Op channel reader that redistributes the events
// into handlers.
package ophandler
import (
"context"
"github.com/diamondburned/arikawa/v3/utils/handler"
"github.com/diamondburned/arikawa/v3/utils/ws"
)
// Loop starts a background goroutine that starts reading from src and
// distributes received events into the given handler. It's stopped once src is
// closed. The returned channel will be closed once src is closed.
func Loop(src <-chan ws.Op, dst *handler.Handler) <-chan struct{} {
done := make(chan struct{})
go func() {
for op := range src {
dst.Call(op.Data)
}
close(done)
}()
return done
}
// WaitForDone waits for the done channel returned by Loop until the channel is
// closed or the context expires.
func WaitForDone(ctx context.Context, done <-chan struct{}) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View file

@ -1,4 +1,4 @@
package wsutil
package ws
import (
"time"

118
utils/ws/ws.go Normal file
View file

@ -0,0 +1,118 @@
// Package wsutil provides abstractions around the Websocket, including rate
// limits.
package ws
import (
"context"
"log"
"sync"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
var (
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}
)
// Websocket is a wrapper around a websocket Conn with thread safety and rate
// limiting for sending and throttling.
type Websocket struct {
mutex sync.Mutex
conn Connection
addr string
// If you ever need access to these fields from outside the package, please
// open an issue. It might be worth it to refactor these out for distributed
// sharding.
sendLimiter *rate.Limiter
dialLimiter *rate.Limiter
}
// NewWebsocket creates a default Websocket with the given address.
func NewWebsocket(c Codec, addr string) *Websocket {
return NewCustomWebsocket(NewConn(c), addr)
}
// NewCustomWebsocket creates a new undialed Websocket.
func NewCustomWebsocket(conn Connection, addr string) *Websocket {
return &Websocket{
conn: conn,
addr: addr,
sendLimiter: NewSendLimiter(),
dialLimiter: NewDialLimiter(),
}
}
// Dial waits until the rate limiter allows then dials the websocket.
func (ws *Websocket) Dial(ctx context.Context) (<-chan Op, error) {
if err := ws.dialLimiter.Wait(ctx); err != nil {
// Expired, fatal error
return nil, errors.Wrap(err, "failed to wait for dial rate limiter")
}
ws.mutex.Lock()
defer ws.mutex.Unlock()
// Reset the send limiter.
// TODO: see if each limit only applies to one connection or not.
ws.sendLimiter = NewSendLimiter()
return ws.conn.Dial(ctx, ws.addr)
}
// Send sends b over the Websocket with a deadline. It closes the internal
// Websocket if the Send method errors out.
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
WSDebug("Acquiring the websoccket mutex for sending.")
ws.mutex.Lock()
WSDebug("Mutex lock acquired.")
sendLimiter := ws.sendLimiter
conn := ws.conn
ws.mutex.Unlock()
WSDebug("Waiting for the send rate limiter...")
if err := sendLimiter.Wait(ctx); err != nil {
WSDebug("Send rate limiter timed out.")
return errors.Wrap(err, "SendLimiter failed")
}
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
return conn.Send(ctx, b)
}
// Close closes the websocket connection. It assumes that the Websocket is
// closed even when it returns an error. If the Websocket was already closed
// before, ErrWebsocketClosed will be returned.
func (ws *Websocket) Close() error {
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired")
return ws.conn.Close(false)
}
// CloseGracefully is similar to Close, but a proper close frame is sent to
// Discord, invalidating the internal session ID and voiding resumes.
func (ws *Websocket) CloseGracefully() error {
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired")
return ws.conn.Close(true)
}

View file

@ -1,270 +0,0 @@
package wsutil
import (
"bytes"
"compress/zlib"
"context"
"io"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
// size is 4KB.
var CopyBufferSize = 4096
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
var MaxCapUntilReset = CopyBufferSize * 4
// CloseDeadline controls the deadline to wait for sending the Close frame.
var CloseDeadline = time.Second
// ErrWebsocketClosed is returned if the websocket is already closed.
var ErrWebsocketClosed = errors.New("websocket is closed")
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself, including
// modifying the connection URL. The implementation doesn't have to be safe for
// concurrent use.
type Connection interface {
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) error
// Listen returns an event channel that sends over events constantly. It can
// return nil if there isn't an ongoing connection.
Listen() <-chan Event
// Send allows the caller to send bytes. It does not need to clean itself
// up on errors, as the Websocket wrapper will do that.
//
// If the data is nil, it should send a close frame
Send(context.Context, []byte) error
// Close should close the websocket connection. The underlying connection
// may be reused, but this Connection instance will be reused with Dial. The
// Connection must still be reusable even if Close returns an error.
Close() error
// CloseGracefully sends a close frame and then closes the websocket
// connection.
CloseGracefully() error
}
// Conn is the default Websocket connection. It tries to compresses all payloads
// using zlib.
type Conn struct {
Dialer websocket.Dialer
Header http.Header
Conn *websocket.Conn
events chan Event
}
var _ Connection = (*Conn)(nil)
// NewConn creates a new default websocket connection with a default dialer.
func NewConn() *Conn {
return NewConnWithDialer(websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
ReadBufferSize: CopyBufferSize,
WriteBufferSize: CopyBufferSize,
EnableCompression: true,
})
}
// NewConnWithDialer creates a new default websocket connection with a custom
// dialer.
func NewConnWithDialer(dialer websocket.Dialer) *Conn {
return &Conn{
Dialer: dialer,
Header: http.Header{
"Accept-Encoding": {"zlib"},
},
}
}
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
// BUG which prevents stream compression.
// See https://github.com/golang/go/issues/31514.
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, c.Header)
if err != nil {
return errors.Wrap(err, "failed to dial WS")
}
// Reset the deadline.
c.Conn.SetWriteDeadline(resetDeadline)
c.events = make(chan Event, WSBuffer)
go startReadLoop(c.Conn, c.events)
return err
}
// Listen returns an event channel if there is a connection associated with it.
// It returns nil if there is none.
func (c *Conn) Listen() <-chan Event {
return c.events
}
// resetDeadline is used to reset the write deadline after using the context's.
var resetDeadline = time.Time{}
func (c *Conn) Send(ctx context.Context, b []byte) error {
d, ok := ctx.Deadline()
if ok {
c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline)
}
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
return err
}
return nil
}
func (c *Conn) Close() error {
WSDebug("Conn: Close is called; shutting down the Websocket connection.")
// Have a deadline before closing.
var deadline = time.Now().Add(5 * time.Second)
c.Conn.SetWriteDeadline(deadline)
// Close the WS.
err := c.Conn.Close()
c.Conn.SetWriteDeadline(resetDeadline)
WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flushing events...")
// Flush all events before closing the channel. This will return as soon as
// c.events is closed, or after closed.
for range c.events {
}
WSDebug("Flushed events.")
return err
}
func (c *Conn) CloseGracefully() error {
WSDebug("Conn: CloseGracefully is called; sending close frame.")
c.Conn.SetWriteDeadline(time.Now().Add(CloseDeadline))
err := c.Conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
WSError(err)
}
WSDebug("Conn: Close frame sent; error:", err)
return c.Close()
}
// loopState is a thread-unsafe disposable state container for the read loop.
// It's made to completely separate the read loop of any synchronization that
// doesn't involve the websocket connection itself.
type loopState struct {
conn *websocket.Conn
zlib io.ReadCloser
buf bytes.Buffer
}
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
// Clean up the events channel in the end.
defer close(eventCh)
// Allocate the read loop its own private resources.
state := loopState{conn: conn}
state.buf.Grow(CopyBufferSize)
for {
b, err := state.handle()
if err != nil {
WSDebug("Conn: Read error:", err)
// Is the error an EOF?
if errors.Is(err, io.EOF) {
// Yes it is, exit.
return
}
// Is the error an intentional close call? Go 1.16 exposes
// ErrClosing, but we have to do this for now.
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return
}
// Unusual error; log and exit:
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
return
}
// If the payload length is 0, skip it.
if len(b) == 0 {
continue
}
eventCh <- Event{b, nil}
}
}
func (state *loopState) handle() ([]byte, error) {
// skip message type
t, r, err := state.conn.NextReader()
if err != nil {
return nil, err
}
if t == websocket.BinaryMessage {
// Probably a zlib payload.
if state.zlib == nil {
z, err := zlib.NewReader(r)
if err != nil {
return nil, errors.Wrap(err, "failed to create a zlib reader")
}
state.zlib = z
} else {
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
return nil, errors.Wrap(err, "failed to reset zlib reader")
}
}
defer state.zlib.Close()
r = state.zlib
}
return state.readAll(r)
}
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
// buffer.
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
defer state.buf.Reset()
if _, err := state.buf.ReadFrom(r); err != nil {
return nil, err
}
// Copy the bytes so we could empty the buffer for reuse.
cpy := make([]byte, state.buf.Len())
copy(cpy, state.buf.Bytes())
// If the buffer's capacity is over the limit, then re-allocate a new one.
if state.buf.Cap() > MaxCapUntilReset {
state.buf = bytes.Buffer{}
state.buf.Grow(CopyBufferSize)
}
return cpy, nil
}

View file

@ -1,151 +0,0 @@
package wsutil
import (
"context"
"time"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/internal/heart"
)
type brokenConnectionError struct {
underneath error
}
// Error formats the broken connection error with the message "explicit
// connection break."
func (err brokenConnectionError) Error() string {
return "explicit connection break: " + err.underneath.Error()
}
// Unwrap returns the underlying error.
func (err brokenConnectionError) Unwrap() error {
return err.underneath
}
// ErrBrokenConnection marks the given error as a broken connection error. This
// error will cause the pacemaker loop to break and return the error. The error,
// when stringified, will say "explicit connection break."
func ErrBrokenConnection(err error) error {
return brokenConnectionError{underneath: err}
}
// IsBrokenConnection returns true if the error is a broken connection error.
func IsBrokenConnection(err error) bool {
var broken *brokenConnectionError
return errors.As(err, &broken)
}
// TODO API
type EventLoopHandler interface {
EventHandler
HeartbeatCtx(context.Context) error
}
// PacemakerLoop provides an event loop with a pacemaker. A zero-value instance
// is a valid instance only when RunAsync is called first.
type PacemakerLoop struct {
heart.Pacemaker
Extras ExtraHandlers
ErrorLog func(error)
events <-chan Event
control chan func()
handler func(*OP) error
}
func (p *PacemakerLoop) errorLog(err error) {
if p.ErrorLog == nil {
WSDebug("Uncaught error:", err)
return
}
p.ErrorLog(err)
}
// Pace calls the pacemaker's Pace function.
func (p *PacemakerLoop) Pace(ctx context.Context) error {
return p.Pacemaker.PaceCtx(ctx)
}
// StartBeating asynchronously starts the pacemaker loop.
func (p *PacemakerLoop) StartBeating(pace time.Duration, evl EventLoopHandler, exit func(error)) {
WSDebug("Starting the pacemaker loop.")
p.Pacemaker = heart.NewPacemaker(pace, evl.HeartbeatCtx)
p.control = make(chan func())
p.handler = evl.HandleOP
p.events = nil // block forever
go func() { exit(p.startLoop()) }()
}
// Stop signals the pacemaker to stop. It does not wait for the pacer to stop.
// The pacer will call the given callback with a nil error.
func (p *PacemakerLoop) Stop() {
close(p.control)
}
// SetEventChannel sets the event channel inside the event loop. There is no
// guarantee that the channel is set when the function returns. This function is
// concurrently safe.
func (p *PacemakerLoop) SetEventChannel(evCh <-chan Event) {
p.control <- func() { p.events = evCh }
}
// SetPace (re)sets the pace duration. As with SetEventChannel, there is no
// guarantee that the pacer is reset when the function returns. This function is
// concurrently safe.
func (p *PacemakerLoop) SetPace(pace time.Duration) {
p.control <- func() { p.Pacemaker.SetPace(pace) }
}
func (p *PacemakerLoop) startLoop() error {
defer WSDebug("Pacemaker loop has exited.")
defer p.Pacemaker.StopTicker()
for {
select {
case <-p.Pacemaker.Ticks:
if err := p.Pacemaker.Pace(); err != nil {
return errors.Wrap(err, "pace failed, reconnecting")
}
case fn, ok := <-p.control:
if !ok { // Intentional stop at p.Close().
WSDebug("Pacemaker intentionally stopped using p.control.")
return nil
}
fn()
case ev, ok := <-p.events:
if !ok {
WSDebug("Events channel closed, stopping pacemaker.")
return nil
}
if ev.Error != nil {
return errors.Wrap(ev.Error, "event returned error")
}
o, err := DecodeOP(ev)
if err != nil {
return errors.Wrap(err, "failed to decode OP")
}
// Check the events before handling.
p.Extras.Check(o)
// Handle the event
if err := p.handler(o); err != nil {
if IsBrokenConnection(err) {
return errors.Wrap(err, "handler failed")
}
p.errorLog(err)
}
}
}
}

View file

@ -1,202 +0,0 @@
package wsutil
import (
"context"
"fmt"
"sync"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/utils/json"
)
var ErrEmptyPayload = errors.New("empty payload")
// OPCode is a generic type for websocket OP codes.
type OPCode uint8
type OP struct {
Code OPCode `json:"op"`
Data json.Raw `json:"d,omitempty"`
// Only for Gateway Dispatch (op 0)
Sequence int64 `json:"s,omitempty"`
EventName string `json:"t,omitempty"`
}
func (op *OP) UnmarshalData(v interface{}) error {
return json.Unmarshal(op.Data, v)
}
func DecodeOP(ev Event) (*OP, error) {
if ev.Error != nil {
return nil, ev.Error
}
if len(ev.Data) == 0 {
return nil, ErrEmptyPayload
}
var op *OP
if err := json.Unmarshal(ev.Data, &op); err != nil {
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
}
return op, nil
}
func AssertEvent(ev Event, code OPCode, v interface{}) (*OP, error) {
op, err := DecodeOP(ev)
if err != nil {
return nil, err
}
if op.Code != code {
return op, fmt.Errorf(
"unexpected OP Code: %d, expected %d (%s)",
op.Code, code, op.Data,
)
}
if err := json.Unmarshal(op.Data, v); err != nil {
return op, errors.Wrap(err, "failed to decode data")
}
return op, nil
}
// UnknownEventError is required by HandleOP if an event is encountered that is
// not known. Internally, unknown events are logged and ignored. It is not a
// fatal error.
type UnknownEventError struct {
Name string
Data json.Raw
}
// Error formats the unknown event error to with the event name and payload
func (err UnknownEventError) Error() string {
return fmt.Sprintf("unknown event %s: %s", err.Name, string(err.Data))
}
// IsBrokenConnection returns true if the error is a broken connection error.
func IsUnknownEvent(err error) bool {
var uevent *UnknownEventError
return errors.As(err, &uevent)
}
type EventHandler interface {
HandleOP(op *OP) error
}
func HandleEvent(h EventHandler, ev Event) error {
o, err := DecodeOP(ev)
if err != nil {
return err
}
return h.HandleOP(o)
}
// WaitForEvent blocks until fn() returns true. All incoming events are handled
// regardless.
func WaitForEvent(ctx context.Context, h EventHandler, ch <-chan Event, fn func(*OP) bool) error {
for {
select {
case e, ok := <-ch:
if !ok {
return errors.New("event not found and event channel is closed")
}
o, err := DecodeOP(e)
if err != nil {
return err
}
// Handle the *OP first, in case it's an Invalid Session. This should
// also prevent a race condition with things that need Ready after
// Open().
if err := h.HandleOP(o); err != nil {
// Explicitly ignore events we don't know.
if IsUnknownEvent(err) {
WSError(err)
continue
}
return err
}
// Are these events what we're looking for? If we've found the event,
// return.
if fn(o) {
return nil
}
case <-ctx.Done():
return ctx.Err()
}
}
}
type ExtraHandlers struct {
mutex sync.Mutex
handlers map[uint32]*ExtraHandler
serial uint32
}
type ExtraHandler struct {
Check func(*OP) bool
send chan *OP
closed moreatomic.Bool
}
func (ex *ExtraHandlers) Add(check func(*OP) bool) (<-chan *OP, func()) {
handler := &ExtraHandler{
Check: check,
send: make(chan *OP),
}
ex.mutex.Lock()
defer ex.mutex.Unlock()
if ex.handlers == nil {
ex.handlers = make(map[uint32]*ExtraHandler, 1)
}
i := ex.serial
ex.serial++
ex.handlers[i] = handler
return handler.send, func() {
// Check the atomic bool before acquiring the mutex. Might help a bit in
// performance.
if handler.closed.Get() {
return
}
ex.mutex.Lock()
defer ex.mutex.Unlock()
delete(ex.handlers, i)
}
}
// Check runs and sends OP data. It is not thread-safe.
func (ex *ExtraHandlers) Check(op *OP) {
ex.mutex.Lock()
defer ex.mutex.Unlock()
for i, handler := range ex.handlers {
if handler.Check(op) {
// Attempt to send.
handler.send <- op
// Mark the handler as closed.
handler.closed.Set(true)
// Delete the handler.
delete(ex.handlers, i)
}
}
}

View file

@ -1,200 +0,0 @@
// Package wsutil provides abstractions around the Websocket, including rate
// limits.
package wsutil
import (
"context"
"log"
"sync"
"time"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
WSTimeout = 30 * time.Second
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}
)
type Event struct {
Data []byte
// Error is non-nil if Data is nil.
Error error
}
// Websocket is a wrapper around a websocket Conn with thread safety and rate
// limiting for sending and throttling.
type Websocket struct {
mutex sync.Mutex
conn Connection
addr string
closed bool
sendLimiter *rate.Limiter
dialLimiter *rate.Limiter
// Timeout is the default timeout used if a context with no deadline is
// given to Dial.
//
// It must not be changed after the Websocket is used once.
Timeout time.Duration
}
// New creates a default Websocket with the given address.
func New(addr string) *Websocket {
return NewCustom(NewConn(), addr)
}
// NewCustom creates a new undialed Websocket.
func NewCustom(conn Connection, addr string) *Websocket {
return &Websocket{
conn: conn,
addr: addr,
closed: true,
sendLimiter: NewSendLimiter(),
dialLimiter: NewDialLimiter(),
Timeout: WSTimeout,
}
}
// Dial waits until the rate limiter allows then dials the websocket.
//
// If the passed context has no deadline, Dial will wrap it in a
// context.WithTimeout using ws.Timeout as timeout.
func (ws *Websocket) Dial(ctx context.Context) error {
if _, ok := ctx.Deadline(); !ok && ws.Timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, ws.Timeout)
defer cancel()
}
if err := ws.dialLimiter.Wait(ctx); err != nil {
// Expired, fatal error
return errors.Wrap(err, "failed to wait")
}
ws.mutex.Lock()
defer ws.mutex.Unlock()
if !ws.closed {
WSDebug("Old connection not yet closed while dialog; closing it.")
ws.conn.Close()
}
if err := ws.conn.Dial(ctx, ws.addr); err != nil {
return errors.Wrap(err, "failed to dial")
}
ws.closed = false
// Reset the send limiter.
ws.sendLimiter = NewSendLimiter()
return nil
}
// Listen returns the inner event channel or nil if the Websocket connection is
// not alive.
func (ws *Websocket) Listen() <-chan Event {
ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return nil
}
return ws.conn.Listen()
}
// Send sends b over the Websocket without a timeout.
func (ws *Websocket) Send(b []byte) error {
return ws.SendCtx(context.Background(), b)
}
// SendCtx sends b over the Websocket with a deadline. It closes the internal
// Websocket if the Send method errors out.
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
WSDebug("Waiting for the send rate limiter...")
if err := ws.sendLimiter.Wait(ctx); err != nil {
WSDebug("Send rate limiter timed out.")
return errors.Wrap(err, "SendLimiter failed")
}
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Mutex lock acquired.")
if ws.closed {
return ErrWebsocketClosed
}
if err := ws.conn.Send(ctx, b); err != nil {
// We need to clean up ourselves if things are erroring out.
WSDebug("Conn: Error while sending; closing the connection. Error:", err)
ws.close(false)
return err
}
return nil
}
// Close closes the websocket connection. It assumes that the Websocket is
// closed even when it returns an error. If the Websocket was already closed
// before, ErrWebsocketClosed will be returned.
func (ws *Websocket) Close() error {
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired")
return ws.close(false)
}
func (ws *Websocket) CloseGracefully() error {
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired")
return ws.close(true)
}
// close closes the Websocket without acquiring the mutex. Refer to Close for
// more information.
func (ws *Websocket) close(graceful bool) error {
if ws.closed {
WSDebug("Conn: Websocket is already closed.")
return ErrWebsocketClosed
}
ws.closed = true
if graceful {
WSDebug("Conn: Closing gracefully")
return ws.conn.CloseGracefully()
}
WSDebug("Conn: Closing")
return ws.conn.Close()
}

View file

@ -2,20 +2,22 @@ package voice
import (
"context"
"net"
"sync"
"time"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/utils/handler"
"github.com/diamondburned/arikawa/v3/utils/ws/ophandler"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/internal/handleloop"
"github.com/diamondburned/arikawa/v3/internal/lazytime"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/diamondburned/arikawa/v3/utils/ws"
"github.com/diamondburned/arikawa/v3/voice/udp"
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
)
@ -32,24 +34,65 @@ var ErrCannotSend = errors.New("cannot send audio to closed channel")
// WSTimeout is the duration to wait for a gateway operation including Session
// to complete before erroring out. This only applies to functions that don't
// take in a context already.
var WSTimeout = 10 * time.Second
const WSTimeout = 25 * time.Second
// ReconnectError is emitted into Session.Handler everytime the voice gateway
// fails to be reconnected. It implements the error interface.
type ReconnectError struct {
Err error
}
// Error implements error.
func (e ReconnectError) Error() string {
return "voice reconnect error: " + e.Err.Error()
}
// Unwrap returns e.Err.
func (e ReconnectError) Unwrap() error { return e.Err }
// MainSession abstracts both session.Session and state.State.
type MainSession interface {
// AddHandler describes the method in handler.Handler.
AddHandler(handler interface{}) (rm func())
// Gateway returns the session's main Discord gateway.
Gateway() *gateway.Gateway
// Me returns the current user.
Me() (*discord.User, error)
// Channel queries for the channel with the given ID.
Channel(discord.ChannelID) (*discord.Channel, error)
}
var (
_ MainSession = (*session.Session)(nil)
_ MainSession = (*state.State)(nil)
)
// UDPDialer is the UDP dialer function type. It's the function signature for
// udp.DialConnection.
type UDPDialer = func(ctx context.Context, addr string, ssrc uint32) (*udp.Connection, error)
// Session is a single voice session that wraps around the voice gateway and UDP
// connection.
type Session struct {
*handler.Handler
ErrorLog func(err error)
session MainSession
session *session.Session
looper *handleloop.Loop
detach func()
mut sync.RWMutex
// connected is a non-nil blocking channel after Join is called and is
// closed once Leave is called.
disconnected chan struct{}
mut sync.RWMutex
state voicegateway.State // guarded except UserID
// TODO: expose getters mutex-guarded.
gateway *voicegateway.Gateway
detachReconnect []func()
voiceUDP *udp.Connection
// end of mutex
gateway *voicegateway.Gateway
gwCancel context.CancelFunc
gwDone <-chan struct{}
// DialUDP is the custom function for dialing up a UDP connection.
DialUDP UDPDialer
WSTimeout time.Duration // global WSTimeout
WSMaxRetry int // 2
@ -59,76 +102,102 @@ type Session struct {
// joining determines the behavior of incoming event callbacks (Update).
// If this is true, incoming events will just send into Updated channels. If
// false, events will trigger a reconnection.
joining moreatomic.Bool
connected bool
joining moreatomic.Bool
// disconnectClosed is true if connected is already closed. It is only used
// to keep track of closing connected.
disconnectClosed bool
}
// NewSession creates a new voice session for the current user.
func NewSession(state *state.State) (*Session, error) {
func NewSession(state MainSession) (*Session, error) {
u, err := state.Me()
if err != nil {
return nil, errors.Wrap(err, "failed to get me")
}
return NewSessionCustom(state.Session, u.ID), nil
return NewSessionCustom(state, u.ID), nil
}
// NewSessionCustom creates a new voice session from the given session and user
// ID.
func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
handler := handler.New()
hlooper := handleloop.NewLoop(handler)
func NewSessionCustom(ses MainSession, userID discord.UserID) *Session {
closed := make(chan struct{})
close(closed)
session := &Session{
Handler: handler,
looper: hlooper,
Handler: handler.New(),
session: ses,
state: voicegateway.State{
UserID: userID,
},
ErrorLog: func(err error) {},
DialUDP: udp.DialConnection,
WSTimeout: WSTimeout,
WSMaxRetry: 2,
WSRetryDelay: 2 * time.Second,
WSWaitDuration: 5 * time.Second,
// Set this pair of value in so we never have to nil-check the channel.
// We can just assume that it's either closed or connected.
disconnected: closed,
disconnectClosed: true,
}
return session
}
// updateServer is specifically used to monitor for reconnects.
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
func (s *Session) acquireUpdate(f func()) bool {
if s.joining.Get() {
return
return false
}
s.mut.Lock()
defer s.mut.Unlock()
// Ignore if we haven't connected yet or we're still joining.
if !s.connected || s.state.GuildID != ev.GuildID {
return
select {
case <-s.disconnected:
return false
default:
// ok
}
// Reconnect.
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
if err := s.reconnectCtx(ctx); err != nil {
s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update"))
}
f()
return true
}
// JoinChannel joins a voice channel with the default WS timeout. See
// JoinChannelCtx for more information.
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
// updateServer is specifically used to monitor for reconnects.
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
s.acquireUpdate(func() {
if s.state.GuildID != ev.GuildID {
return
}
return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
s.state.Endpoint = ev.Endpoint
s.state.Token = ev.Token
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
s.reconnectCtx(ctx)
})
}
// updateState is specifically used after connecting to monitor when the bot is
// forced across channels.
func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
s.acquireUpdate(func() {
if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID {
return
}
s.state.ChannelID = ev.ChannelID
s.state.SessionID = ev.SessionID
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
s.reconnectCtx(ctx)
})
}
type waitEventChs struct {
@ -136,11 +205,17 @@ type waitEventChs struct {
stateUpdate chan *gateway.VoiceStateUpdateEvent
}
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
// directly, but rather Voice's. This method shouldn't ever be called
// concurrently.
func (s *Session) JoinChannelCtx(
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
// JoinChannel joins the given voice channel with the default timeout.
func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error {
var ch *discord.Channel
if chID.IsValid() {
var err error
ch, err = s.session.Channel(chID)
if err != nil {
return errors.Wrap(err, "invalid channel ID")
}
}
s.mut.Lock()
defer s.mut.Unlock()
@ -155,14 +230,20 @@ func (s *Session) JoinChannelCtx(
defer s.joining.Set(false)
// Set the state.
s.state.ChannelID = cID
s.state.GuildID = gID
s.detach = s.session.AddHandler(s.updateServer)
if ch != nil {
s.state.ChannelID = ch.ID
s.state.GuildID = ch.GuildID
} else {
s.state.GuildID = 0
// Ensure that if `cID` is zero that it passes null to the update event.
s.state.ChannelID = discord.NullChannelID
}
// Ensure that if `cID` is zero that it passes null to the update event.
channelID := discord.NullChannelID
if cID.IsValid() {
channelID = cID
if s.detachReconnect == nil {
s.detachReconnect = []func(){
s.session.AddHandler(s.updateServer),
s.session.AddHandler(s.updateState),
}
}
chs := waitEventChs{
@ -185,15 +266,16 @@ func (s *Session) JoinChannelCtx(
// Ensure gateway and voiceUDP are already closed.
s.ensureClosed()
data := gateway.UpdateVoiceStateData{
GuildID: gID,
ChannelID: channelID,
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
// Send a Voice State Update event to the gateway.
data := &gateway.UpdateVoiceStateCommand{
GuildID: s.state.GuildID,
ChannelID: s.state.ChannelID,
SelfMute: mute,
SelfDeaf: deaf,
}
var err error
var timer *time.Timer
// Retry 3 times maximum.
@ -230,17 +312,17 @@ func (s *Session) JoinChannelCtx(
// Mark the session as connected and move on. This allows one of the
// connected handlers to reconnect on its own.
s.connected = true
s.disconnected = make(chan struct{})
return s.reconnectCtx(ctx)
}
func (s *Session) askDiscord(
ctx context.Context, data gateway.UpdateVoiceStateData, chs waitEventChs) error {
ctx context.Context, data *gateway.UpdateVoiceStateCommand, chs waitEventChs) error {
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
// Send a Voice State Update event to the gateway.
if err := s.session.Gateway.UpdateVoiceStateCtx(ctx, data); err != nil {
if err := s.session.Gateway().Send(ctx, data); err != nil {
return errors.Wrap(err, "failed to send Voice State Update event")
}
@ -290,118 +372,183 @@ func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error {
// reconnect uses the current state to reconnect to a new gateway and UDP
// connection.
func (s *Session) reconnectCtx(ctx context.Context) (err error) {
wsutil.WSDebug("Sending stop handle.")
s.looper.Stop()
func (s *Session) reconnectCtx(ctx context.Context) error {
ws.WSDebug("Sending stop handle.")
wsutil.WSDebug("Start gateway.")
s.ensureClosed()
ws.WSDebug("Start gateway.")
s.gateway = voicegateway.New(s.state)
// Open the voice gateway. The function will block until Ready is received.
if err := s.gateway.OpenCtx(ctx); err != nil {
return errors.Wrap(err, "failed to open voice gateway")
gwctx, gwcancel := context.WithCancel(context.Background())
s.gwCancel = gwcancel
gwch := s.gateway.Connect(gwctx)
if err := s.spinGateway(ctx, gwch); err != nil {
// Early cancel the gateway.
gwcancel()
// Nil this so future reconnects don't use the invalid gwDone.
s.gwCancel = nil
// Emit the error. It's fine to do this here since this is the only
// place that can error out.
s.Handler.Call(&ReconnectError{err})
return err
}
// Start the handler dispatching
s.looper.Start(s.gateway.Events)
// Get the Ready event.
voiceReady := s.gateway.Ready()
// Prepare the UDP voice connection.
s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC)
if err != nil {
return errors.Wrap(err, "failed to open voice UDP connection")
}
// Get the session description from the voice gateway.
d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{
Protocol: "udp",
Data: voicegateway.SelectProtocolData{
Address: s.voiceUDP.GatewayIP,
Port: s.voiceUDP.GatewayPort,
Mode: Protocol,
},
})
if err != nil {
return errors.Wrap(err, "failed to select protocol")
}
s.voiceUDP.UseSecret(d.SecretKey)
// Start dispatching.
s.gwDone = ophandler.Loop(gwch, s.Handler)
return nil
}
func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
var err error
for {
select {
case <-ctx.Done():
return ctx.Err()
case ev, ok := <-gwch:
if !ok {
return s.gateway.LastError()
}
switch data := ev.Data.(type) {
case *ws.CloseEvent:
return errors.Wrap(err, "voice gateway error")
case *voicegateway.ReadyEvent:
// Prepare the UDP voice connection.
s.voiceUDP, err = s.DialUDP(ctx, data.Addr(), data.SSRC)
if err != nil {
return errors.Wrap(err, "failed to open voice UDP connection")
}
if err := s.gateway.Send(ctx, &voicegateway.SelectProtocolCommand{
Protocol: "udp",
Data: voicegateway.SelectProtocolData{
Address: s.voiceUDP.GatewayIP,
Port: s.voiceUDP.GatewayPort,
Mode: Protocol,
},
}); err != nil {
return errors.Wrap(err, "failed to send SelectProtocolCommand")
}
case *voicegateway.SessionDescriptionEvent:
// We're done.
s.voiceUDP.UseSecret(data.SecretKey)
return nil
}
// Dispatch this event to the handler.
s.Handler.Call(ev.Data)
}
}
}
// Speaking tells Discord we're speaking. This method should not be called
// concurrently.
func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
s.mut.RLock()
gateway := s.gateway
s.mut.RUnlock()
return gateway.Speaking(flag)
}
// UseContext tells the UDP voice connection to write with the given context.
func (s *Session) UseContext(ctx context.Context) error {
func (s *Session) Speaking(ctx context.Context, flag voicegateway.SpeakingFlag) error {
s.mut.Lock()
defer s.mut.Unlock()
gateway := s.gateway
s.mut.Unlock()
if s.voiceUDP == nil {
return ErrCannotSend
return gateway.Speaking(ctx, flag)
}
func (s *Session) useUDP(f func(c *udp.Connection) error) (err error) {
const maxAttempts = 5
const retryDelay = 250 * time.Millisecond // adds up to about 1.25s
var lazyWait lazytime.Timer
// Hack: loop until we no longer get an error closed or until the connection
// is dead. This is a workaround for when the session is trying to reconnect
// itself in the background, which would drop the UDP connection.
for i := 0; i < maxAttempts; i++ {
s.mut.RLock()
voiceUDP := s.voiceUDP
disconnected := s.disconnected
s.mut.RUnlock()
select {
case <-disconnected:
return net.ErrClosed
default:
if voiceUDP == nil {
// Session is still connected, but our voice UDP connection is
// nil, so we're probably in the process of reconnecting
// already.
goto retry
}
}
if err = f(voiceUDP); err != nil && errors.Is(err, net.ErrClosed) {
// Session is still connected, but our UDP connection is somehow
// closed, so we're probably waiting for the server to ask us to
// reconnect with a new session.
goto retry
}
// Unknown error or none at all; exit.
return err
retry:
// Wait a slight bit. We can probably make the caller wait a couple
// milliseconds without a wait.
lazyWait.Reset(retryDelay)
select {
case <-lazyWait.C:
continue
case <-disconnected:
return net.ErrClosed
}
}
return s.voiceUDP.UseContext(ctx)
return
}
// VoiceUDPConn gets a voice UDP connection. The caller could use this method to
// circumvent the rapid mutex-read-lock acquire inside Write.
func (s *Session) VoiceUDPConn() *udp.Connection {
s.mut.RLock()
defer s.mut.RUnlock()
return s.voiceUDP
}
// Write writes into the UDP voice connection WITHOUT a timeout. Refer to
// WriteCtx for more information.
// Write writes into the UDP voice connection. This method is thread safe as far
// as calling other methods of Session goes; HOWEVER it is not thread safe to
// call Write itself concurrently.
func (s *Session) Write(b []byte) (int, error) {
return s.WriteCtx(context.Background(), b)
var n int
err := s.useUDP(func(c *udp.Connection) (err error) {
n, err = c.Write(b)
return
})
return n, err
}
// WriteCtx writes into the UDP voice connection with a context for timeout.
// This method is thread safe as far as calling other methods of Session goes;
// HOWEVER it is not thread safe to call Write itself concurrently.
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
voiceUDP := s.VoiceUDPConn()
if voiceUDP == nil {
return 0, ErrCannotSend
}
return voiceUDP.WriteCtx(ctx, b)
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
// thread safe, and must be used very carefully. The backing buffer is always
// reused.
func (s *Session) ReadPacket() (*udp.Packet, error) {
var p *udp.Packet
err := s.useUDP(func(c *udp.Connection) (err error) {
p, err = c.ReadPacket()
return
})
return p, err
}
// Leave disconnects the current voice session from the currently connected
// channel.
func (s *Session) Leave() error {
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
defer cancel()
return s.LeaveCtx(ctx)
}
// LeaveCtx disconencts with a context. Refer to Leave for more information.
func (s *Session) LeaveCtx(ctx context.Context) error {
func (s *Session) Leave(ctx context.Context) error {
s.mut.Lock()
defer s.mut.Unlock()
s.connected = false
s.ensureClosed()
// Unbind the handlers.
if s.detach != nil {
s.detach()
s.detach = nil
if s.detachReconnect != nil {
for _, detach := range s.detachReconnect {
detach()
}
s.detachReconnect = nil
}
// If we're already closed.
@ -409,46 +556,59 @@ func (s *Session) LeaveCtx(ctx context.Context) error {
return nil
}
s.looper.Stop()
// Notify Discord that we're leaving. This will send a
// VoiceStateUpdateEvent, in which our handler will promptly remove the
// session from the map.
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
// Notify Discord that we're leaving.
err := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{
GuildID: s.state.GuildID,
ChannelID: discord.ChannelID(discord.NullSnowflake),
SelfMute: true,
SelfDeaf: true,
})
s.ensureClosed()
// wrap returns nil if err is nil
return errors.Wrap(err, "failed to update voice state")
// Wait for the gateway to exit first before we tell the user of the gateway
// send error.
if err := s.cancelGateway(ctx); err != nil {
return err
}
if err != nil {
return errors.Wrap(err, "failed to update voice state")
}
return nil
}
func (s *Session) cancelGateway(ctx context.Context) error {
if s.gwCancel != nil {
s.gwCancel()
s.gwCancel = nil
// Wait for the previous gateway to finish closing up, but make sure to
// bail if the context expires.
if err := ophandler.WaitForDone(ctx, s.gwDone); err != nil {
return errors.Wrap(err, "cannot wait for gateway to close")
}
}
return nil
}
// close ensures everything is closed. It does not acquire the mutex.
func (s *Session) ensureClosed() {
s.looper.Stop()
// Disconnect the UDP connection.
if s.voiceUDP != nil {
s.voiceUDP.Close()
s.voiceUDP = nil
}
// Disconnect the voice gateway, ignoring the error.
if s.gateway != nil {
if err := s.gateway.Close(); err != nil {
wsutil.WSDebug("Uncaught voice gateway close error:", err)
}
s.gateway = nil
if !s.disconnectClosed {
close(s.disconnected)
s.disconnectClosed = true
}
if s.gwCancel != nil {
s.gwCancel()
// Don't actually clear this field, because we still want the caller to
// be able to wait for the gateway to completely exit using
// cancelGateway.
}
}
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
// thread safe, and must be used very carefully. The backing buffer is always
// reused.
func (s *Session) ReadPacket() (*udp.Packet, error) {
return s.VoiceUDPConn().ReadPacket()
}

View file

@ -34,33 +34,25 @@ func TestNoop(t *testing.T) {
}
func ExampleSession() {
s, err := state.New("Bot " + token)
if err != nil {
log.Fatalln("failed to make state:", err)
}
s := state.New("Bot " + token)
// This is required for bots.
voice.AddIntents(s.Gateway)
voice.AddIntents(s)
if err := s.Open(context.TODO()); err != nil {
log.Fatalln("failed to open gateway:", err)
}
defer s.Close()
c, err := s.Channel(channelID)
if err != nil {
log.Fatalln("failed to get channel:", err)
}
v, err := voice.NewSession(s)
if err != nil {
log.Fatalln("failed to create voice session:", err)
}
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
if err := v.JoinChannel(context.TODO(), channelID, false, false); err != nil {
log.Fatalln("failed to join voice channel:", err)
}
defer v.Leave()
defer v.Leave(context.TODO())
// Start writing Opus frames.
for {

View file

@ -15,7 +15,7 @@ import (
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/internal/testenv"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/diamondburned/arikawa/v3/utils/ws"
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
"github.com/pkg/errors"
)
@ -23,17 +23,14 @@ import (
func TestIntegration(t *testing.T) {
config := testenv.Must(t)
wsutil.WSDebug = func(v ...interface{}) {
ws.WSDebug = func(v ...interface{}) {
_, file, line, _ := runtime.Caller(1)
caller := file + ":" + strconv.Itoa(line)
log.Println(append([]interface{}{caller}, v...)...)
}
s, err := state.New("Bot " + config.BotToken)
if err != nil {
t.Fatal("failed to create a new state:", err)
}
AddIntents(s.Gateway)
s := state.New("Bot " + config.BotToken)
AddIntents(s)
func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
@ -70,7 +67,6 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
if err != nil {
t.Fatal("failed to create a new voice session:", err)
}
v.ErrorLog = func(err error) { t.Error(err) }
// Grab a timer to benchmark things.
finish := timer()
@ -80,7 +76,10 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
finish("receiving voice speaking event")
})
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
t.Cleanup(cancel)
if err := v.JoinChannel(ctx, c.ID, false, false); err != nil {
t.Fatal("failed to join voice:", err)
}
@ -88,28 +87,19 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
log.Println("Leaving the voice channel concurrently.")
raceMe(t, "failed to leave voice channel", func() (interface{}, error) {
return nil, v.Leave()
return nil, v.Leave(ctx)
})
})
finish("joining the voice channel")
// Create a context and only cancel it AFTER we're done sending silence
// frames.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
// Trigger speaking.
if err := v.Speaking(voicegateway.Microphone); err != nil {
if err := v.Speaking(ctx, voicegateway.Microphone); err != nil {
t.Fatal("failed to start speaking:", err)
}
finish("sending the speaking command")
if err := v.UseContext(ctx); err != nil {
t.Fatal("failed to set ctx into vs:", err)
}
f, err := os.Open("testdata/nico.dca")
if err != nil {
t.Fatal("failed to open nico.dca:", err)

View file

@ -8,6 +8,7 @@ import (
"net"
"time"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/pkg/errors"
"golang.org/x/crypto/nacl/secretbox"
)
@ -39,13 +40,13 @@ type Connection struct {
GatewayIP string
GatewayPort uint16
context context.Context
conn net.Conn
ssrc uint32
conn net.Conn
ssrc uint32
// frequency rate.Limiter
frequency *time.Ticker
timeIncr uint32
stopFreq chan struct{}
packet [12]byte
secret [32]byte
@ -59,9 +60,12 @@ type Connection struct {
recvBuf []byte // len 1400
recvOpus []byte // len 1400
recvPacket *Packet // uses recvOpus' backing array
closed moreatomic.Bool
}
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
// DialConnection dials a UDP connection.
func DialConnection(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
// Create a new UDP connection.
conn, err := Dialer.DialContext(ctx, "udp", addr)
if err != nil {
@ -114,7 +118,7 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti
GatewayPort: port,
frequency: time.NewTicker(20 * time.Millisecond),
timeIncr: 960,
context: context.Background(),
stopFreq: make(chan struct{}),
packet: packet,
ssrc: ssrc,
conn: conn,
@ -159,49 +163,29 @@ func (c *Connection) UseSecret(secret [32]byte) {
c.secret = secret
}
// UseContext lets the connection use the given context for its Write method.
// WriteCtx will override this context.
func (c *Connection) UseContext(ctx context.Context) error {
return c.useContext(ctx)
// SetWriteDeadline sets the UDP connection's write deadline.
func (c *Connection) SetWriteDeadline(deadline time.Time) {
c.conn.SetWriteDeadline(deadline)
}
func (c *Connection) useContext(ctx context.Context) error {
if c.context == ctx {
return nil
}
c.context = ctx
if deadline, ok := c.context.Deadline(); ok {
return c.conn.SetWriteDeadline(deadline)
} else {
return c.conn.SetWriteDeadline(time.Time{})
}
// SetReadDeadline sets the UDP connection's read deadline.
func (c *Connection) SetReadDeadline(deadline time.Time) {
c.conn.SetReadDeadline(deadline)
}
func (c *Connection) Close() error {
c.frequency.Stop()
if c.closed.Acquire() {
// Be sure to only run this ONCE.
c.frequency.Stop()
close(c.stopFreq)
}
return c.conn.Close()
}
// Write sends bytes into the voice UDP connection using the preset context.
// Write sends a packet of audio into the voice UDP connection using the preset
// context.
func (c *Connection) Write(b []byte) (int, error) {
return c.write(b)
}
// WriteCtx sends bytes into the voice UDP connection with a timeout using the
// given context. It ignores the context inside the connection, but will restore
// the deadline after this call is done.
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
oldCtx := c.context
c.useContext(ctx)
defer c.useContext(oldCtx)
return c.write(b)
}
func (c *Connection) write(b []byte) (int, error) {
// Write a new sequence.
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
c.sequence++
@ -215,9 +199,9 @@ func (c *Connection) write(b []byte) (int, error) {
select {
case <-c.frequency.C:
case <-c.context.Done():
return 0, c.context.Err()
// ok
case <-c.stopFreq:
return 0, net.ErrClosed
}
n, err := c.conn.Write(toSend)

View file

@ -8,7 +8,7 @@ import "github.com/diamondburned/arikawa/v3/gateway"
// AddIntents adds the needed voice intents into gw. Bots should always call
// this before Open if voice is required.
func AddIntents(gw *gateway.Gateway) {
func AddIntents(gw interface{ AddIntents(gateway.Intents) }) {
gw.AddIntents(gateway.IntentGuilds)
gw.AddIntents(gateway.IntentGuildVoiceStates)
}

View file

@ -1,167 +0,0 @@
package voicegateway
import (
"context"
"time"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/pkg/errors"
)
var (
// ErrMissingForIdentify is an error when we are missing information to identify.
ErrMissingForIdentify = errors.New("missing GuildID, UserID, SessionID, or Token for identify")
// ErrMissingForResume is an error when we are missing information to resume.
ErrMissingForResume = errors.New("missing GuildID, SessionID, or Token for resuming")
)
// OPCode 0
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload
type IdentifyData struct {
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
UserID discord.UserID `json:"user_id"`
SessionID string `json:"session_id"`
Token string `json:"token"`
}
// Identify sends an Identify operation (opcode 0) to the Gateway Gateway.
func (c *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.IdentifyCtx(ctx)
}
// IdentifyCtx sends an Identify operation (opcode 0) to the Gateway Gateway.
func (c *Gateway) IdentifyCtx(ctx context.Context) error {
guildID := c.state.GuildID
userID := c.state.UserID
sessionID := c.state.SessionID
token := c.state.Token
if guildID == 0 || userID == 0 || sessionID == "" || token == "" {
return ErrMissingForIdentify
}
return c.SendCtx(ctx, IdentifyOP, IdentifyData{
GuildID: guildID,
UserID: userID,
SessionID: sessionID,
Token: token,
})
}
// OPCode 1
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload
type SelectProtocol struct {
Protocol string `json:"protocol"`
Data SelectProtocolData `json:"data"`
}
type SelectProtocolData struct {
Address string `json:"address"`
Port uint16 `json:"port"`
Mode string `json:"mode"`
}
// SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
func (c *Gateway) SelectProtocol(data SelectProtocol) error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SelectProtocolCtx(ctx, data)
}
// SelectProtocolCtx sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
func (c *Gateway) SelectProtocolCtx(ctx context.Context, data SelectProtocol) error {
return c.SendCtx(ctx, SelectProtocolOP, data)
}
// OPCode 3
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
// type Heartbeat uint64
// Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
func (c *Gateway) Heartbeat() error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.HeartbeatCtx(ctx)
}
// HeartbeatCtx sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
return c.SendCtx(ctx, HeartbeatOP, time.Now().UnixNano())
}
// https://discord.com/developers/docs/topics/voice-connections#speaking
type SpeakingFlag uint64
const (
NotSpeaking SpeakingFlag = 0
Microphone SpeakingFlag = 1 << iota
Soundshare
Priority
)
// OPCode 5
// https://discord.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload
type SpeakingData struct {
Speaking SpeakingFlag `json:"speaking"`
Delay int `json:"delay"`
SSRC uint32 `json:"ssrc"`
UserID discord.UserID `json:"user_id,omitempty"`
}
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (c *Gateway) Speaking(flag SpeakingFlag) error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SpeakingCtx(ctx, flag)
}
// SpeakingCtx sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (c *Gateway) SpeakingCtx(ctx context.Context, flag SpeakingFlag) error {
// How do we allow a user to stop speaking?
// Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation
return c.SendCtx(ctx, SpeakingOP, SpeakingData{
Speaking: flag,
Delay: 0,
SSRC: c.ready.SSRC,
})
}
// OPCode 7
// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload
type ResumeData struct {
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
SessionID string `json:"session_id"`
Token string `json:"token"`
}
// Resume sends a Resume operation (opcode 7) to the Gateway Gateway.
func (c *Gateway) Resume() error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.ResumeCtx(ctx)
}
// ResumeCtx sends a Resume operation (opcode 7) to the Gateway Gateway.
func (c *Gateway) ResumeCtx(ctx context.Context) error {
guildID := c.state.GuildID
sessionID := c.state.SessionID
token := c.state.Token
if !guildID.IsValid() || sessionID == "" || token == "" {
return ErrMissingForResume
}
return c.SendCtx(ctx, ResumeOP, ResumeData{
GuildID: guildID,
SessionID: sessionID,
Token: token,
})
}

View file

@ -0,0 +1,94 @@
// Code generated by genevent. DO NOT EDIT.
package voicegateway
import "github.com/diamondburned/arikawa/v3/utils/ws"
func init() {
OpUnmarshalers.Add(
func() ws.Event { return new(IdentifyCommand) },
func() ws.Event { return new(SelectProtocolCommand) },
func() ws.Event { return new(ReadyEvent) },
func() ws.Event { return new(HeartbeatCommand) },
func() ws.Event { return new(SessionDescriptionEvent) },
func() ws.Event { return new(SpeakingEvent) },
func() ws.Event { return new(HeartbeatAckEvent) },
func() ws.Event { return new(ResumeCommand) },
func() ws.Event { return new(HelloEvent) },
func() ws.Event { return new(ResumedEvent) },
func() ws.Event { return new(ClientConnectEvent) },
func() ws.Event { return new(ClientDisconnectEvent) },
)
}
// Op implements Event. It always returns Op 0.
func (*IdentifyCommand) Op() ws.OpCode { return 0 }
// EventType implements Event.
func (*IdentifyCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 1.
func (*SelectProtocolCommand) Op() ws.OpCode { return 1 }
// EventType implements Event.
func (*SelectProtocolCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 2.
func (*ReadyEvent) Op() ws.OpCode { return 2 }
// EventType implements Event.
func (*ReadyEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 3.
func (*HeartbeatCommand) Op() ws.OpCode { return 3 }
// EventType implements Event.
func (*HeartbeatCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 4.
func (*SessionDescriptionEvent) Op() ws.OpCode { return 4 }
// EventType implements Event.
func (*SessionDescriptionEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 5.
func (*SpeakingEvent) Op() ws.OpCode { return 5 }
// EventType implements Event.
func (*SpeakingEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 6.
func (*HeartbeatAckEvent) Op() ws.OpCode { return 6 }
// EventType implements Event.
func (*HeartbeatAckEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 7.
func (*ResumeCommand) Op() ws.OpCode { return 7 }
// EventType implements Event.
func (*ResumeCommand) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 8.
func (*HelloEvent) Op() ws.OpCode { return 8 }
// EventType implements Event.
func (*HelloEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 9.
func (*ResumedEvent) Op() ws.OpCode { return 9 }
// EventType implements Event.
func (*ResumedEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 12.
func (*ClientConnectEvent) Op() ws.OpCode { return 12 }
// EventType implements Event.
func (*ClientConnectEvent) EventType() ws.EventType { return "" }
// Op implements Event. It always returns Op 13.
func (*ClientDisconnectEvent) Op() ws.OpCode { return 13 }
// EventType implements Event.
func (*ClientDisconnectEvent) EventType() ws.EventType { return "" }

View file

@ -4,9 +4,41 @@ import (
"strconv"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/utils/ws"
)
// OPCode 2
//go:generate go run ../../utils/cmd/genevent -p voicegateway -o event_methods.go
// OpUnmarshalers contains the Op unmarshalers for the voice gateway events.
var OpUnmarshalers = ws.NewOpUnmarshalers()
// IdentifyCommand is a command for Op 0.
//
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload
type IdentifyCommand struct {
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
UserID discord.UserID `json:"user_id"`
SessionID string `json:"session_id"`
Token string `json:"token"`
}
// SelectProtocolCommand is a command for Op 1.
//
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload
type SelectProtocolCommand struct {
Protocol string `json:"protocol"`
Data SelectProtocolData `json:"data"`
}
// SelectProtocolData is the data inside a SelectProtocolCommand.
type SelectProtocolData struct {
Address string `json:"address"`
Port uint16 `json:"port"`
Mode string `json:"mode"`
}
// ReadyEvent is an event for Op 2.
//
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-ready-payload
type ReadyEvent struct {
SSRC uint32 `json:"ssrc"`
@ -23,45 +55,79 @@ type ReadyEvent struct {
// HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
}
// Addr formats the URL inside Ready to be of format "host:port".
func (r ReadyEvent) Addr() string {
return r.IP + ":" + strconv.Itoa(r.Port)
}
// OPCode 4
// HeartbeatCommand is a command for Op 3.
//
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
type HeartbeatCommand uint64
// SessionDescriptionEvent is an event for Op 4.
//
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-session-description-payload
type SessionDescriptionEvent struct {
Mode string `json:"mode"`
SecretKey [32]byte `json:"secret_key"`
}
// OPCode 5
type SpeakingEvent SpeakingData
// https://discord.com/developers/docs/topics/voice-connections#speaking
type SpeakingFlag uint64
// OPCode 6
const (
NotSpeaking SpeakingFlag = 0
Microphone SpeakingFlag = 1 << iota
Soundshare
Priority
)
// SpeakingEvent is an event for Op 5. It is also a command.
//
// https://discord.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload
type SpeakingEvent struct {
Speaking SpeakingFlag `json:"speaking"`
Delay int `json:"delay"`
SSRC uint32 `json:"ssrc"`
UserID discord.UserID `json:"user_id,omitempty"`
}
// HeartbeatAckEvent is an event for Op 6.
//
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-ack-payload
type HeartbeatACKEvent uint64
type HeartbeatAckEvent uint64
// OPCode 8
// ResumeCommand is a command for Op 7.
//
// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload
type ResumeCommand struct {
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
SessionID string `json:"session_id"`
Token string `json:"token"`
}
// HelloEvent is an event for Op 8.
//
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-hello-payload-since-v3
type HelloEvent struct {
HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
}
// OPCode 9
// ResumedEvent is an event for Op 9.
// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resumed-payload
type ResumedEvent struct{}
// OPCode 12
// (undocumented)
// ClientConnectEvent is an event for Op 12. It is undocumented.
type ClientConnectEvent struct {
UserID discord.UserID `json:"user_id"`
AudioSSRC uint32 `json:"audio_ssrc"`
VideoSSRC uint32 `json:"video_ssrc"`
}
// OPCode 13
// Undocumented, existence mentioned in below issue
// https://github.com/discord/discord-api-docs/issues/510
// ClientDisconnectEvent is an event for Op 13. It is undocumented, but its
// existence is mentioned in this issue:
// https://github.com/discord/discord-api-docs/issues/510.
type ClientDisconnectEvent struct {
UserID discord.UserID `json:"user_id"`
}

View file

@ -18,9 +18,7 @@ import (
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/diamondburned/arikawa/v3/utils/ws"
)
const (
@ -37,9 +35,9 @@ type Event = interface{}
// State contains state information of a voice gateway.
type State struct {
GuildID discord.GuildID
UserID discord.UserID // constant
GuildID discord.GuildID // constant
ChannelID discord.ChannelID
UserID discord.UserID
SessionID string
Token string
@ -48,282 +46,164 @@ type State struct {
// Gateway represents a Discord Gateway Gateway connection.
type Gateway struct {
state State // constant
gateway *ws.Gateway
state State // constant
mutex sync.RWMutex
ready ReadyEvent
WS *wsutil.Websocket
Timeout time.Duration
reconnect moreatomic.Bool
EventLoop wsutil.PacemakerLoop
Events chan Event
// ErrorLog will be called when an error occurs (defaults to log.Println)
ErrorLog func(err error)
// AfterClose is called after each close. Error can be non-nil, as this is
// called even when the Gateway is gracefully closed. It's used mainly for
// reconnections or any type of connection interruptions. (defaults to noop)
AfterClose func(err error)
// Filled by methods, internal use
waitGroup *sync.WaitGroup
ready *ReadyEvent
}
// DefaultGatewayOpts contains the default options to be used for connecting to
// the gateway.
var DefaultGatewayOpts = ws.GatewayOpts{
ReconnectDelay: func(try int) time.Duration {
// minimum 4 seconds
return time.Duration(4+(2*try)) * time.Second
},
// FatalCloseCodes contains the default gateway close codes that will cause
// the gateway to exit. In other words, it's a list of unrecoverable close
// codes.
FatalCloseCodes: []int{
4003, // not authenticated
4004, // authentication failed
4006, // session invalid
4009, // session timed out
4011, // server not found
4012, // unknown protocol
4014, // disconnected
4016, // unknown encryption mode
},
DialTimeout: 0,
ReconnectAttempt: 0,
AlwaysCloseGracefully: true,
}
// New creates a new voice gateway.
func New(state State) *Gateway {
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
var endpoint = "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version
gw := ws.NewGateway(
ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), endpoint),
&DefaultGatewayOpts,
)
return &Gateway{
state: state,
WS: wsutil.New(endpoint),
Timeout: wsutil.WSTimeout,
Events: make(chan Event, wsutil.WSBuffer),
ErrorLog: wsutil.WSError,
AfterClose: func(error) {},
gateway: gw,
state: state,
}
}
// TODO: get rid of
func (c *Gateway) Ready() ReadyEvent {
c.mutex.RLock()
defer c.mutex.RUnlock()
// Ready returns the ready event.
func (g *Gateway) Ready() *ReadyEvent {
g.mutex.RLock()
defer g.mutex.RUnlock()
return c.ready
return g.ready
}
// OpenCtx shouldn't be used, but JoinServer instead.
func (c *Gateway) OpenCtx(ctx context.Context) error {
if c.state.Endpoint == "" {
return errors.New("missing endpoint in state")
}
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
wsutil.WSDebug("VoiceGateway: Connecting to voice endpoint (endpoint=" + endpoint + ")")
// Create a new context with a timeout for the connection.
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
defer cancel()
// Connect to the Gateway Gateway.
if err := c.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "failed to connect to voice gateway")
}
wsutil.WSDebug("VoiceGateway: Trying to start...")
// Try to start or resume the connection.
if err := c.start(ctx); err != nil {
return err
}
return nil
// LastError returns the last error that the gateway has received. It only
// returns a valid error if the gateway's event loop as exited. If the event
// loop hasn't been started AND stopped, the function will panic.
func (g *Gateway) LastError() error {
return g.gateway.LastError()
}
// Start .
func (c *Gateway) start(ctx context.Context) error {
if err := c.__start(ctx); err != nil {
wsutil.WSDebug("VoiceGateway: Start failed: ", err)
// Close can be called with the mutex still acquired here, as the
// pacemaker hasn't started yet.
if err := c.Close(); err != nil {
wsutil.WSDebug("VoiceGateway: Failed to close after start fail: ", err)
}
return err
}
return nil
// Send is a function to send an Op payload to the Gateway.
func (g *Gateway) Send(ctx context.Context, data ws.Event) error {
return g.gateway.Send(ctx, data)
}
// this function blocks until READY.
func (c *Gateway) __start(ctx context.Context) error {
// Make a new WaitGroup for use in background loops:
c.waitGroup = new(sync.WaitGroup)
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
func (g *Gateway) Speaking(ctx context.Context, flag SpeakingFlag) error {
g.mutex.RLock()
ssrc := g.ready.SSRC
g.mutex.RUnlock()
ch := c.WS.Listen()
return g.gateway.Send(ctx, &SpeakingEvent{
Speaking: flag,
Delay: 0,
SSRC: ssrc,
})
}
// Wait for hello.
wsutil.WSDebug("VoiceGateway: Waiting for Hello..")
func (g *Gateway) Connect(ctx context.Context) <-chan ws.Op {
return g.gateway.Connect(ctx, (*gatewayImpl)(g))
}
var hello *HelloEvent
// Wait for the Hello event; return if it times out.
select {
case e, ok := <-ch:
if !ok {
return errors.New("unexpected ws close while waiting for Hello")
}
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
}
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
var (
// ErrMissingForIdentify is an error when we are missing information to
// identify.
ErrMissingForIdentify = errors.New("missing GuildID, UserID, SessionID, or Token for identify")
// ErrMissingForResume is an error when we are missing information to
// resume.
ErrMissingForResume = errors.New("missing GuildID, SessionID, or Token for resuming")
)
type gatewayImpl Gateway
func (g *gatewayImpl) sendIdentify(ctx context.Context) error {
id := IdentifyCommand{
GuildID: g.state.GuildID,
UserID: g.state.UserID,
SessionID: g.state.SessionID,
Token: g.state.Token,
}
if !id.GuildID.IsValid() || id == (IdentifyCommand{}) {
return ErrMissingForIdentify
}
wsutil.WSDebug("VoiceGateway: Received Hello")
return g.gateway.Send(ctx, &id)
}
// Start the event handler, which also handles the pacemaker death signal.
c.waitGroup.Add(1)
func (g *gatewayImpl) sendResume(ctx context.Context) error {
if !g.state.GuildID.IsValid() || g.state.SessionID == "" || g.state.Token == "" {
return ErrMissingForResume
}
c.EventLoop.StartBeating(hello.HeartbeatInterval.Duration(), c, func(err error) {
c.waitGroup.Done() // mark so Close() can exit.
wsutil.WSDebug("VoiceGateway: Event loop stopped.")
return g.gateway.Send(ctx, &ResumeCommand{
GuildID: g.state.GuildID,
SessionID: g.state.SessionID,
Token: g.state.Token,
})
}
if err != nil {
c.ErrorLog(err)
func (g *gatewayImpl) OnOp(ctx context.Context, op ws.Op) bool {
switch data := op.Data.(type) {
case *HelloEvent:
g.gateway.ResetHeartbeat(data.HeartbeatInterval.Duration())
if err := c.Reconnect(); err != nil {
c.ErrorLog(errors.Wrap(err, "failed to reconnect voice"))
// Send Discord either the Identify packet (if it's a fresh
// connection), or a Resume packet (if it's a dead connection).
if g.ready == nil {
// SessionID is empty, so this is a completely new session.
if err := g.sendIdentify(ctx); err != nil {
g.gateway.SendErrorWrap(err, "failed to send identify")
g.gateway.QueueReconnect()
}
} else {
if err := g.sendResume(ctx); err != nil {
g.gateway.SendErrorWrap(err, "failed to send resume")
g.gateway.QueueReconnect()
}
// Reconnect should spawn another eventLoop in its Start function.
}
})
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
// Turns out Hello is sent right away on connection start.
if !c.reconnect.Get() {
if err := c.IdentifyCtx(ctx); err != nil {
return errors.Wrap(err, "failed to identify")
}
} else {
if err := c.ResumeCtx(ctx); err != nil {
return errors.Wrap(err, "failed to resume")
}
}
// This bool is because we should only try and Resume once.
c.reconnect.Set(false)
// Wait for either Ready or Resumed.
err := wsutil.WaitForEvent(ctx, c, ch, func(op *wsutil.OP) bool {
return op.Code == ReadyOP || op.Code == ResumedOP
})
if err != nil {
return errors.Wrap(err, "failed to wait for Ready or Resumed")
case *ReadyEvent:
g.mutex.Lock()
g.ready = data
g.mutex.Unlock()
}
// Bind the event channel away.
c.EventLoop.SetEventChannel(ch)
return true
}
wsutil.WSDebug("VoiceGateway: Started successfully.")
func (g *gatewayImpl) SendHeartbeat(ctx context.Context) {
heartbeat := HeartbeatCommand(time.Now().UnixNano())
if err := g.gateway.Send(ctx, &heartbeat); err != nil {
g.gateway.SendErrorWrap(err, "heartbeat error")
g.gateway.QueueReconnect()
}
}
func (g *gatewayImpl) Close() error {
return nil
}
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
wsutil.WSDebug("VoiceGateway: Trying to close. Pacemaker check skipped.")
wsutil.WSDebug("VoiceGateway: Closing the Websocket...")
err := g.WS.Close()
if errors.Is(err, wsutil.ErrWebsocketClosed) {
wsutil.WSDebug("VoiceGateway: Websocket already closed.")
return nil
}
wsutil.WSDebug("VoiceGateway: Websocket closed; error:", err)
wsutil.WSDebug("VoiceGateway: Waiting for the Pacemaker loop to exit.")
g.waitGroup.Wait()
wsutil.WSDebug("VoiceGateway: Pacemaker loop exited.")
g.AfterClose(err)
wsutil.WSDebug("VoiceGateway: AfterClose callback finished.")
return err
}
func (c *Gateway) Reconnect() error {
return c.ReconnectCtx(context.Background())
}
func (c *Gateway) ReconnectCtx(ctx context.Context) error {
wsutil.WSDebug("VoiceGateway: Reconnecting...")
// TODO: implement a reconnect loop
// Guarantee the gateway is already closed. Ignore its error, as we're
// redialing anyway.
c.Close()
c.reconnect.Set(true)
// Condition: err == ErrInvalidSession:
// If the connection is rate limited (documented behavior):
// https://discord.com/developers/docs/topics/gateway#rate-limiting
if err := c.OpenCtx(ctx); err != nil {
return errors.Wrap(err, "failed to reopen gateway")
}
wsutil.WSDebug("VoiceGateway: Reconnected successfully.")
return nil
}
func (c *Gateway) SessionDescriptionCtx(
ctx context.Context, sp SelectProtocol) (*SessionDescriptionEvent, error) {
// Add the handler first.
ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool {
return op.Code == SessionDescriptionOP
})
defer cancel()
if err := c.SelectProtocolCtx(ctx, sp); err != nil {
return nil, err
}
var sesdesc *SessionDescriptionEvent
// Wait for SessionDescriptionOP packet.
select {
case e, ok := <-ch:
if !ok {
return nil, errors.New("unexpected close waiting for session description")
}
if err := e.UnmarshalData(&sesdesc); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal session description")
}
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "failed to wait for session description")
}
return sesdesc, nil
}
// Send sends a payload to the Gateway with the default timeout.
func (c *Gateway) Send(code OPCode, v interface{}) error {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
return c.SendCtx(ctx, code, v)
}
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
var op = wsutil.OP{
Code: code,
}
if v != nil {
b, err := json.Marshal(v)
if err != nil {
return errors.Wrap(err, "failed to encode v")
}
op.Data = b
}
b, err := json.Marshal(op)
if err != nil {
return errors.Wrap(err, "failed to encode payload")
}
// WS should already be thread-safe.
return c.WS.SendCtx(ctx, b)
}

View file

@ -1,99 +0,0 @@
package voicegateway
import (
"fmt"
"sync"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/pkg/errors"
)
// OPCode represents a Discord Gateway Gateway operation code.
type OPCode = wsutil.OPCode
const (
IdentifyOP OPCode = 0 // send
SelectProtocolOP OPCode = 1 // send
ReadyOP OPCode = 2 // receive
HeartbeatOP OPCode = 3 // send
SessionDescriptionOP OPCode = 4 // receive
SpeakingOP OPCode = 5 // send/receive
HeartbeatAckOP OPCode = 6 // receive
ResumeOP OPCode = 7 // send
HelloOP OPCode = 8 // receive
ResumedOP OPCode = 9 // receive
ClientConnectOP OPCode = 12 // receive
ClientDisconnectOP OPCode = 13 // receive
)
func (c *Gateway) HandleOP(op *wsutil.OP) error {
wsutil.WSDebug("Handle OP", op.Code)
switch op.Code {
// Gives information required to make a UDP connection
case ReadyOP:
if err := unmarshalMutex(op.Data, &c.ready, &c.mutex); err != nil {
return errors.Wrap(err, "failed to parse READY event")
}
c.Events <- &c.ready
// Gives information about the encryption mode and secret key for sending voice packets
case SessionDescriptionOP:
// ?
// Already handled by Session.
// Someone started or stopped speaking.
case SpeakingOP:
ev := new(SpeakingEvent)
if err := json.Unmarshal(op.Data, ev); err != nil {
return errors.Wrap(err, "failed to parse Speaking event")
}
c.Events <- ev
// Heartbeat response from the server
case HeartbeatAckOP:
c.EventLoop.Echo()
// Hello server, we hear you! :)
case HelloOP:
// ?
// Already handled on initial connection.
// Server is saying the connection was resumed, no data here.
case ResumedOP:
wsutil.WSDebug("Gateway connection has been resumed.")
case ClientConnectOP:
ev := new(ClientConnectEvent)
if err := json.Unmarshal(op.Data, ev); err != nil {
return errors.Wrap(err, "failed to parse Speaking event")
}
c.Events <- ev
case ClientDisconnectOP:
ev := new(ClientDisconnectEvent)
if err := json.Unmarshal(op.Data, ev); err != nil {
return errors.Wrap(err, "failed to parse Speaking event")
}
c.Events <- ev
default:
return fmt.Errorf("unknown OP code %d", op.Code)
}
return nil
}
func unmarshalMutex(d []byte, v interface{}, m *sync.RWMutex) error {
m.Lock()
err := json.Unmarshal(d, v)
m.Unlock()
return err
}