From 17b9c73ce3cb434d21309e29188673687f41ff11 Mon Sep 17 00:00:00 2001 From: diamondburned Date: Tue, 28 Sep 2021 13:19:04 -0700 Subject: [PATCH] gateway: Refactor for a better concurrent API This commit refactors the whole package gateway as well as utils/ws (formerly utils/wsutil) and voice/voicegateway. The new refactor utilizes a design pattern involving a concurrent loop and an arriving event channel. An additional change was made to the way gateway events are typed. Before, pretty much any type will satisfy a gateway event type, since the actual type was just interface{}. The new refactor defines a concrete interface that events can implement: type Event interface { Op() OpCode EventType() EventType } Using this interface, the user can easily add custom gateway events independently of the library without relying on string maps. This adds a lot of type safety into the library and makes type-switching on Event types much more reasonable. Gateway error callbacks are also almost entirely removed in favor of custom gateway events. A catch-all can easily be added like this: s.AddHandler(func(err error) { log.Println("gateway error:, err") }) --- .build.yml | 12 +- 0-examples/autocomplete/main.go | 6 +- 0-examples/buttons/main.go | 6 +- 0-examples/commands/main.go | 6 +- 0-examples/sharded/main.go | 2 +- 0-examples/simple/main.go | 6 +- 0-examples/undeleter/main.go | 6 +- README.md | 40 + api/bot.go | 7 +- discord/snowflake.go | 2 +- gateway/commands.go | 167 --- gateway/event_methods.go | 460 +++++++ gateway/events.go | 1150 ++++++++++++----- gateway/events_map.go | 76 -- gateway/gateway.go | 752 +++++------ gateway/gateway_example_test.go | 31 + gateway/gateway_test.go | 197 +++ gateway/identify.go | 104 +- gateway/integration_test.go | 160 --- gateway/intents.go | 7 +- gateway/op.go | 118 -- gateway/perseverance_test.go | 56 - gateway/ready.go | 267 ---- internal/handleloop/handleloop.go | 60 - internal/heart/heart.go | 135 -- internal/lazytime/ticker.go | 30 + internal/lazytime/timer.go | 50 + internal/moreatomic/bool.go | 14 + session/session.go | 239 ++-- session/session_test.go | 50 + {gateway => session}/shard/manager.go | 28 +- {gateway => session}/shard/shard.go | 22 +- .../shard_test.go} | 17 +- state/state.go | 82 +- state/state_events.go | 2 +- state/state_example_test.go | 32 + state/state_shard_test.go | 23 +- utils/bot/command.go | 31 +- utils/bot/ctx.go | 23 +- utils/bot/ctx_shard_test.go | 4 +- .../extras/middlewares/middlewares_test.go | 23 +- utils/cmd/genevent/genevent.go | 172 +++ utils/cmd/genevent/template.tmpl | 27 + utils/{ => cmd}/gensnowflake/main.go | 0 utils/{ => cmd}/gensnowflake/template.tmpl | 0 utils/handler/handler.go | 1 + utils/ws/codec.go | 101 ++ utils/ws/conn.go | 278 ++++ utils/ws/gateway.go | 364 ++++++ utils/ws/op.go | 212 +++ utils/ws/ophandler/ophandler.go | 35 + utils/{wsutil => ws}/throttler.go | 2 +- utils/ws/ws.go | 118 ++ utils/wsutil/conn.go | 270 ---- utils/wsutil/heart.go | 151 --- utils/wsutil/op.go | 202 --- utils/wsutil/ws.go | 200 --- voice/session.go | 498 ++++--- voice/session_example_test.go | 16 +- voice/session_test.go | 30 +- voice/udp/udp.go | 68 +- voice/voice.go | 2 +- voice/voicegateway/commands.go | 167 --- voice/voicegateway/event_methods.go | 94 ++ voice/voicegateway/events.go | 92 +- voice/voicegateway/gateway.go | 376 ++---- voice/voicegateway/op.go | 99 -- 67 files changed, 4391 insertions(+), 3687 deletions(-) delete mode 100644 gateway/commands.go create mode 100644 gateway/event_methods.go delete mode 100644 gateway/events_map.go create mode 100644 gateway/gateway_example_test.go create mode 100644 gateway/gateway_test.go delete mode 100644 gateway/integration_test.go delete mode 100644 gateway/op.go delete mode 100644 gateway/perseverance_test.go delete mode 100644 gateway/ready.go delete mode 100644 internal/handleloop/handleloop.go delete mode 100644 internal/heart/heart.go create mode 100644 internal/lazytime/ticker.go create mode 100644 internal/lazytime/timer.go create mode 100644 session/session_test.go rename {gateway => session}/shard/manager.go (92%) rename {gateway => session}/shard/shard.go (79%) rename session/{session_shard_test.go => shard/shard_test.go} (78%) create mode 100644 state/state_example_test.go create mode 100644 utils/cmd/genevent/genevent.go create mode 100644 utils/cmd/genevent/template.tmpl rename utils/{ => cmd}/gensnowflake/main.go (100%) rename utils/{ => cmd}/gensnowflake/template.tmpl (100%) create mode 100644 utils/ws/codec.go create mode 100644 utils/ws/conn.go create mode 100644 utils/ws/gateway.go create mode 100644 utils/ws/op.go create mode 100644 utils/ws/ophandler/ophandler.go rename utils/{wsutil => ws}/throttler.go (96%) create mode 100644 utils/ws/ws.go delete mode 100644 utils/wsutil/conn.go delete mode 100644 utils/wsutil/heart.go delete mode 100644 utils/wsutil/op.go delete mode 100644 utils/wsutil/ws.go delete mode 100644 voice/voicegateway/commands.go create mode 100644 voice/voicegateway/event_methods.go delete mode 100644 voice/voicegateway/op.go diff --git a/.build.yml b/.build.yml index bbe3be6..26cf584 100644 --- a/.build.yml +++ b/.build.yml @@ -12,15 +12,25 @@ environment: GO111MODULE: "on" CGO_ENABLED: "1" # Integration test variables. - SHARD_COUNT: "3" + SHARD_COUNT: "2" tested: "./api,./gateway,./bot,./discord" cov_file: "/tmp/cov_results" dismock: "github.com/mavolin/dismock/v2/pkg/dismock" dismock_v: "259685b84e4b6ab364b0fd858aac2aa2dfa42502" tasks: + - generate: |- + cd arikawa + go generate ./... + + if [[ "$(git status --porcelain)" ]]; then + echo "Repository differ after regeneration." + exit 1 + fi + - build: cd arikawa && go build ./... - unit: cd arikawa && go test -tags unitonly -race ./... + - integration: |- sh -c ' test -f ~/.env || { diff --git a/0-examples/autocomplete/main.go b/0-examples/autocomplete/main.go index 584c835..bd01538 100644 --- a/0-examples/autocomplete/main.go +++ b/0-examples/autocomplete/main.go @@ -23,11 +23,7 @@ func main() { log.Fatalln("No $BOT_TOKEN given.") } - s, err := state.New("Bot " + token) - if err != nil { - log.Fatalln("Session failed:", err) - return - } + s := state.New("Bot " + token) app, err := s.CurrentApplication() if err != nil { diff --git a/0-examples/buttons/main.go b/0-examples/buttons/main.go index 7847044..4af95fb 100644 --- a/0-examples/buttons/main.go +++ b/0-examples/buttons/main.go @@ -22,11 +22,7 @@ func main() { log.Fatalln("no $BOT_TOKEN given") } - s, err := session.New("Bot " + token) - if err != nil { - log.Fatalln("session failed:", err) - return - } + s := session.New("Bot " + token) app, err := s.CurrentApplication() if err != nil { diff --git a/0-examples/commands/main.go b/0-examples/commands/main.go index e9b59d7..0b29791 100644 --- a/0-examples/commands/main.go +++ b/0-examples/commands/main.go @@ -22,11 +22,7 @@ func main() { log.Fatalln("No $BOT_TOKEN given.") } - s, err := state.New("Bot " + token) - if err != nil { - log.Fatalln("Session failed:", err) - return - } + s := state.New("Bot " + token) app, err := s.CurrentApplication() if err != nil { diff --git a/0-examples/sharded/main.go b/0-examples/sharded/main.go index bb32155..e9b66eb 100644 --- a/0-examples/sharded/main.go +++ b/0-examples/sharded/main.go @@ -8,7 +8,7 @@ import ( "os" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/gateway/shard" + "github.com/diamondburned/arikawa/v3/session/shard" "github.com/diamondburned/arikawa/v3/state" ) diff --git a/0-examples/simple/main.go b/0-examples/simple/main.go index 32890d3..d288c50 100644 --- a/0-examples/simple/main.go +++ b/0-examples/simple/main.go @@ -19,11 +19,7 @@ func main() { log.Fatalln("No $BOT_TOKEN given.") } - s, err := session.New("Bot " + token) - if err != nil { - log.Fatalln("Session failed:", err) - } - + s := session.New("Bot " + token) s.AddHandler(func(c *gateway.MessageCreateEvent) { log.Println(c.Author.Username, "sent", c.Content) }) diff --git a/0-examples/undeleter/main.go b/0-examples/undeleter/main.go index b758d9f..c1ae06c 100644 --- a/0-examples/undeleter/main.go +++ b/0-examples/undeleter/main.go @@ -19,11 +19,7 @@ func main() { log.Fatalln("No $BOT_TOKEN given.") } - s, err := state.New("Bot " + token) - if err != nil { - log.Fatalln("Session failed:", err) - } - + s := state.New("Bot " + token) // Make a pre-handler s.PreHandler = handler.New() s.PreHandler.AddSyncHandler(func(c *gateway.MessageDeleteEvent) { diff --git a/README.md b/README.md index 9aed137..acd7ee4 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,46 @@ This example demonstrates the PreHandler feature of the state library. PreHandler calls all handlers that are registered (separately from the session), calling them before the state is updated. +### Bare Minimum Print Example + +The least amount of code recommended to have a bot that logs all messages to +console. + +```go +package main + +import ( + "context" + "log" + "os" + "os/signal" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/state" +) + +func main() { + s := state.New("Bot " + os.Getenv("DISCORD_TOKEN")) + s.AddIntents(gateway.IntentGuilds | gateway.IntentGuildMessages) + s.AddHandler(func(m *gateway.MessageCreateEvent) { + log.Printf("%s: %s", m.Author.Username, m.Content) + }) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err := s.Open(ctx); err != nil { + log.Println("cannot open:", err) + } + + <-ctx.Done() // block until Ctrl+C + + if err := s.Close(); err != nil { + log.Println("cannot close:", err) + } +} +``` + ### Bare Minimum Bot The least amount of code for a basic ping-pong bot. It's similar to Serenity's diff --git a/api/bot.go b/api/bot.go index bf7eeda..532370b 100644 --- a/api/bot.go +++ b/api/bot.go @@ -1,6 +1,8 @@ package api import ( + "context" + "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/utils/httputil" ) @@ -30,7 +32,8 @@ func (c *Client) BotURL() (*BotData, error) { } // GatewayURL asks Discord for a Websocket URL to the Gateway. -func GatewayURL() (string, error) { +func GatewayURL(ctx context.Context) (string, error) { var g BotData - return g.URL, httputil.NewClient().RequestJSON(&g, "GET", EndpointGateway) + err := httputil.NewClient().WithContext(ctx).RequestJSON(&g, "GET", EndpointGateway) + return g.URL, err } diff --git a/discord/snowflake.go b/discord/snowflake.go index aacee7a..5ef04f5 100644 --- a/discord/snowflake.go +++ b/discord/snowflake.go @@ -15,7 +15,7 @@ func DurationSinceEpoch(t time.Time) time.Duration { return time.Duration(t.UnixNano()) - Epoch } -//go:generate go run ../utils/gensnowflake -o snowflake_types.go AppID AttachmentID AuditLogEntryID ChannelID CommandID EmojiID GuildID IntegrationID InteractionID MessageID RoleID StageID StickerID StickerPackID TeamID UserID WebhookID +//go:generate go run ../utils/cmd/gensnowflake -o snowflake_types.go AppID AttachmentID AuditLogEntryID ChannelID CommandID EmojiID GuildID IntegrationID InteractionID MessageID RoleID StageID StickerID StickerPackID TeamID UserID WebhookID // Mention generates the mention syntax for this channel ID. func (s ChannelID) Mention() string { return "<#" + s.String() + ">" } diff --git a/gateway/commands.go b/gateway/commands.go deleted file mode 100644 index ee2225a..0000000 --- a/gateway/commands.go +++ /dev/null @@ -1,167 +0,0 @@ -package gateway - -import ( - "context" - - "github.com/pkg/errors" - - "github.com/diamondburned/arikawa/v3/discord" -) - -// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent - -// Identify structure is at identify.go - -// Identify sends off the Identify command with the Gateway's IdentifyData. -func (g *Gateway) Identify() error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.IdentifyCtx(ctx) -} - -// IdentifyCtx sends off the Identify command with the Gateway's IdentifyData -// with the given context for time out. -func (g *Gateway) IdentifyCtx(ctx context.Context) error { - if err := g.Identifier.Wait(ctx); err != nil { - return errors.Wrap(err, "can't wait for identify()") - } - - return g.SendCtx(ctx, IdentifyOP, g.Identifier) -} - -type ResumeData struct { - Token string `json:"token"` - SessionID string `json:"session_id"` - Sequence int64 `json:"seq"` -} - -// Resume sends to the Websocket a Resume OP, but it doesn't actually resume -// from a dead connection. Start() resumes from a dead connection. -func (g *Gateway) Resume() error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.ResumeCtx(ctx) -} - -// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume -// from a dead connection. Start() resumes from a dead connection. -func (g *Gateway) ResumeCtx(ctx context.Context) error { - var ( - ses = g.SessionID() - seq = g.Sequence.Get() - ) - - if ses == "" || seq == 0 { - return ErrMissingForResume - } - - return g.SendCtx(ctx, ResumeOP, ResumeData{ - Token: g.Identifier.Token, - SessionID: ses, - Sequence: seq, - }) -} - -// HeartbeatData is the last sequence number to be sent. -type HeartbeatData int - -func (g *Gateway) Heartbeat() error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.HeartbeatCtx(ctx) -} - -func (g *Gateway) HeartbeatCtx(ctx context.Context) error { - return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get()) -} - -type RequestGuildMembersData struct { - // GuildIDs contains the ids of the guilds to request data from. Multiple - // guilds can only be requested when using user accounts. - GuildIDs []discord.GuildID `json:"guild_id"` - UserIDs []discord.UserID `json:"user_ids,omitempty"` - - Query string `json:"query"` - Limit uint `json:"limit"` - Presences bool `json:"presences,omitempty"` - Nonce string `json:"nonce,omitempty"` -} - -func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.RequestGuildMembersCtx(ctx, data) -} - -func (g *Gateway) RequestGuildMembersCtx( - ctx context.Context, data RequestGuildMembersData) error { - - return g.SendCtx(ctx, RequestGuildMembersOP, data) -} - -type UpdateVoiceStateData struct { - GuildID discord.GuildID `json:"guild_id"` - ChannelID discord.ChannelID `json:"channel_id"` // nullable - SelfMute bool `json:"self_mute"` - SelfDeaf bool `json:"self_deaf"` -} - -func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.UpdateVoiceStateCtx(ctx, data) -} - -func (g *Gateway) UpdateVoiceStateCtx(ctx context.Context, data UpdateVoiceStateData) error { - return g.SendCtx(ctx, VoiceStateUpdateOP, data) -} - -// UpdateStatusData is sent by this client to indicate a presence or status -// update. -type UpdateStatusData struct { - Since discord.UnixMsTimestamp `json:"since"` // 0 if not idle - - // Activities can be null or an empty slice. - Activities []discord.Activity `json:"activities"` - - Status discord.Status `json:"status"` - AFK bool `json:"afk"` -} - -func (g *Gateway) UpdateStatus(data UpdateStatusData) error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.UpdateStatusCtx(ctx, data) -} - -func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error { - return g.SendCtx(ctx, StatusUpdateOP, data) -} - -// Undocumented -type GuildSubscribeData struct { - Typing bool `json:"typing"` - Threads bool `json:"threads"` - Activities bool `json:"activities"` - GuildID discord.GuildID `json:"guild_id"` - - // Channels is not documented. It's used to fetch the right members sidebar. - Channels map[discord.ChannelID][][2]int `json:"channels,omitempty"` -} - -func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.GuildSubscribeCtx(ctx, data) -} - -func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error { - return g.SendCtx(ctx, GuildSubscriptionsOP, data) -} diff --git a/gateway/event_methods.go b/gateway/event_methods.go new file mode 100644 index 0000000..e8e98fd --- /dev/null +++ b/gateway/event_methods.go @@ -0,0 +1,460 @@ +// Code generated by genevent. DO NOT EDIT. + +package gateway + +import "github.com/diamondburned/arikawa/v3/utils/ws" + +func init() { + OpUnmarshalers.Add( + func() ws.Event { return new(HeartbeatCommand) }, + func() ws.Event { return new(HeartbeatAckEvent) }, + func() ws.Event { return new(ReconnectEvent) }, + func() ws.Event { return new(HelloEvent) }, + func() ws.Event { return new(ResumeCommand) }, + func() ws.Event { return new(InvalidSessionEvent) }, + func() ws.Event { return new(RequestGuildMembersCommand) }, + func() ws.Event { return new(UpdateVoiceStateCommand) }, + func() ws.Event { return new(UpdatePresenceCommand) }, + func() ws.Event { return new(GuildSubscribeCommand) }, + func() ws.Event { return new(ResumedEvent) }, + func() ws.Event { return new(ChannelCreateEvent) }, + func() ws.Event { return new(ChannelUpdateEvent) }, + func() ws.Event { return new(ChannelDeleteEvent) }, + func() ws.Event { return new(ChannelPinsUpdateEvent) }, + func() ws.Event { return new(ChannelUnreadUpdateEvent) }, + func() ws.Event { return new(ThreadCreateEvent) }, + func() ws.Event { return new(ThreadUpdateEvent) }, + func() ws.Event { return new(ThreadDeleteEvent) }, + func() ws.Event { return new(ThreadListSyncEvent) }, + func() ws.Event { return new(ThreadMemberUpdateEvent) }, + func() ws.Event { return new(ThreadMembersUpdateEvent) }, + func() ws.Event { return new(GuildCreateEvent) }, + func() ws.Event { return new(GuildUpdateEvent) }, + func() ws.Event { return new(GuildDeleteEvent) }, + func() ws.Event { return new(GuildBanAddEvent) }, + func() ws.Event { return new(GuildBanRemoveEvent) }, + func() ws.Event { return new(GuildEmojisUpdateEvent) }, + func() ws.Event { return new(GuildIntegrationsUpdateEvent) }, + func() ws.Event { return new(GuildMemberAddEvent) }, + func() ws.Event { return new(GuildMemberRemoveEvent) }, + func() ws.Event { return new(GuildMemberUpdateEvent) }, + func() ws.Event { return new(GuildMembersChunkEvent) }, + func() ws.Event { return new(GuildRoleCreateEvent) }, + func() ws.Event { return new(GuildRoleUpdateEvent) }, + func() ws.Event { return new(GuildRoleDeleteEvent) }, + func() ws.Event { return new(InviteCreateEvent) }, + func() ws.Event { return new(InviteDeleteEvent) }, + func() ws.Event { return new(MessageCreateEvent) }, + func() ws.Event { return new(MessageUpdateEvent) }, + func() ws.Event { return new(MessageDeleteEvent) }, + func() ws.Event { return new(MessageDeleteBulkEvent) }, + func() ws.Event { return new(MessageReactionAddEvent) }, + func() ws.Event { return new(MessageReactionRemoveEvent) }, + func() ws.Event { return new(MessageReactionRemoveAllEvent) }, + func() ws.Event { return new(MessageReactionRemoveEmojiEvent) }, + func() ws.Event { return new(MessageAckEvent) }, + func() ws.Event { return new(PresenceUpdateEvent) }, + func() ws.Event { return new(PresencesReplaceEvent) }, + func() ws.Event { return new(SessionsReplaceEvent) }, + func() ws.Event { return new(TypingStartEvent) }, + func() ws.Event { return new(UserUpdateEvent) }, + func() ws.Event { return new(VoiceStateUpdateEvent) }, + func() ws.Event { return new(VoiceServerUpdateEvent) }, + func() ws.Event { return new(WebhooksUpdateEvent) }, + func() ws.Event { return new(InteractionCreateEvent) }, + func() ws.Event { return new(UserGuildSettingsUpdateEvent) }, + func() ws.Event { return new(UserSettingsUpdateEvent) }, + func() ws.Event { return new(UserNoteUpdateEvent) }, + func() ws.Event { return new(RelationshipAddEvent) }, + func() ws.Event { return new(RelationshipRemoveEvent) }, + func() ws.Event { return new(ReadyEvent) }, + func() ws.Event { return new(ReadySupplementalEvent) }, + func() ws.Event { return new(IdentifyCommand) }, + ) +} + +// Op implements Event. It always returns Op 1. +func (*HeartbeatCommand) Op() ws.OpCode { return 1 } + +// EventType implements Event. +func (*HeartbeatCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 11. +func (*HeartbeatAckEvent) Op() ws.OpCode { return 11 } + +// EventType implements Event. +func (*HeartbeatAckEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 7. +func (*ReconnectEvent) Op() ws.OpCode { return 7 } + +// EventType implements Event. +func (*ReconnectEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 10. +func (*HelloEvent) Op() ws.OpCode { return 10 } + +// EventType implements Event. +func (*HelloEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 6. +func (*ResumeCommand) Op() ws.OpCode { return 6 } + +// EventType implements Event. +func (*ResumeCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 9. +func (*InvalidSessionEvent) Op() ws.OpCode { return 9 } + +// EventType implements Event. +func (*InvalidSessionEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 8. +func (*RequestGuildMembersCommand) Op() ws.OpCode { return 8 } + +// EventType implements Event. +func (*RequestGuildMembersCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 4. +func (*UpdateVoiceStateCommand) Op() ws.OpCode { return 4 } + +// EventType implements Event. +func (*UpdateVoiceStateCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 3. +func (*UpdatePresenceCommand) Op() ws.OpCode { return 3 } + +// EventType implements Event. +func (*UpdatePresenceCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 14. +func (*GuildSubscribeCommand) Op() ws.OpCode { return 14 } + +// EventType implements Event. +func (*GuildSubscribeCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns 0. +func (*ResumedEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ResumedEvent) EventType() ws.EventType { return "RESUMED" } + +// Op implements Event. It always returns 0. +func (*ChannelCreateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ChannelCreateEvent) EventType() ws.EventType { return "CHANNEL_CREATE" } + +// Op implements Event. It always returns 0. +func (*ChannelUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ChannelUpdateEvent) EventType() ws.EventType { return "CHANNEL_UPDATE" } + +// Op implements Event. It always returns 0. +func (*ChannelDeleteEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ChannelDeleteEvent) EventType() ws.EventType { return "CHANNEL_DELETE" } + +// Op implements Event. It always returns 0. +func (*ChannelPinsUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ChannelPinsUpdateEvent) EventType() ws.EventType { return "CHANNEL_PINS_UPDATE" } + +// Op implements Event. It always returns 0. +func (*ChannelUnreadUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ChannelUnreadUpdateEvent) EventType() ws.EventType { return "CHANNEL_UNREAD_UPDATE" } + +// Op implements Event. It always returns 0. +func (*ThreadCreateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ThreadCreateEvent) EventType() ws.EventType { return "THREAD_CREATE" } + +// Op implements Event. It always returns 0. +func (*ThreadUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ThreadUpdateEvent) EventType() ws.EventType { return "THREAD_UPDATE" } + +// Op implements Event. It always returns 0. +func (*ThreadDeleteEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ThreadDeleteEvent) EventType() ws.EventType { return "THREAD_DELETE" } + +// Op implements Event. It always returns 0. +func (*ThreadListSyncEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ThreadListSyncEvent) EventType() ws.EventType { return "THREAD_LIST_SYNC" } + +// Op implements Event. It always returns 0. +func (*ThreadMemberUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ThreadMemberUpdateEvent) EventType() ws.EventType { return "THREAD_MEMBER_UPDATE" } + +// Op implements Event. It always returns 0. +func (*ThreadMembersUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ThreadMembersUpdateEvent) EventType() ws.EventType { return "THREAD_MEMBERS_UPDATE" } + +// Op implements Event. It always returns 0. +func (*GuildCreateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildCreateEvent) EventType() ws.EventType { return "GUILD_CREATE" } + +// Op implements Event. It always returns 0. +func (*GuildUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildUpdateEvent) EventType() ws.EventType { return "GUILD_UPDATE" } + +// Op implements Event. It always returns 0. +func (*GuildDeleteEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildDeleteEvent) EventType() ws.EventType { return "GUILD_DELETE" } + +// Op implements Event. It always returns 0. +func (*GuildBanAddEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildBanAddEvent) EventType() ws.EventType { return "GUILD_BAN_ADD" } + +// Op implements Event. It always returns 0. +func (*GuildBanRemoveEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildBanRemoveEvent) EventType() ws.EventType { return "GUILD_BAN_REMOVE" } + +// Op implements Event. It always returns 0. +func (*GuildEmojisUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildEmojisUpdateEvent) EventType() ws.EventType { return "GUILD_EMOJIS_UPDATE" } + +// Op implements Event. It always returns 0. +func (*GuildIntegrationsUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildIntegrationsUpdateEvent) EventType() ws.EventType { return "GUILD_INTEGRATIONS_UPDATE" } + +// Op implements Event. It always returns 0. +func (*GuildMemberAddEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildMemberAddEvent) EventType() ws.EventType { return "GUILD_MEMBER_ADD" } + +// Op implements Event. It always returns 0. +func (*GuildMemberRemoveEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildMemberRemoveEvent) EventType() ws.EventType { return "GUILD_MEMBER_REMOVE" } + +// Op implements Event. It always returns 0. +func (*GuildMemberUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildMemberUpdateEvent) EventType() ws.EventType { return "GUILD_MEMBER_UPDATE" } + +// Op implements Event. It always returns 0. +func (*GuildMembersChunkEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildMembersChunkEvent) EventType() ws.EventType { return "GUILD_MEMBERS_CHUNK" } + +// Op implements Event. It always returns 0. +func (*GuildRoleCreateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildRoleCreateEvent) EventType() ws.EventType { return "GUILD_ROLE_CREATE" } + +// Op implements Event. It always returns 0. +func (*GuildRoleUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildRoleUpdateEvent) EventType() ws.EventType { return "GUILD_ROLE_UPDATE" } + +// Op implements Event. It always returns 0. +func (*GuildRoleDeleteEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*GuildRoleDeleteEvent) EventType() ws.EventType { return "GUILD_ROLE_DELETE" } + +// Op implements Event. It always returns 0. +func (*InviteCreateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*InviteCreateEvent) EventType() ws.EventType { return "INVITE_CREATE" } + +// Op implements Event. It always returns 0. +func (*InviteDeleteEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*InviteDeleteEvent) EventType() ws.EventType { return "INVITE_DELETE" } + +// Op implements Event. It always returns 0. +func (*MessageCreateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageCreateEvent) EventType() ws.EventType { return "MESSAGE_CREATE" } + +// Op implements Event. It always returns 0. +func (*MessageUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageUpdateEvent) EventType() ws.EventType { return "MESSAGE_UPDATE" } + +// Op implements Event. It always returns 0. +func (*MessageDeleteEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageDeleteEvent) EventType() ws.EventType { return "MESSAGE_DELETE" } + +// Op implements Event. It always returns 0. +func (*MessageDeleteBulkEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageDeleteBulkEvent) EventType() ws.EventType { return "MESSAGE_DELETE_BULK" } + +// Op implements Event. It always returns 0. +func (*MessageReactionAddEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageReactionAddEvent) EventType() ws.EventType { return "MESSAGE_REACTION_ADD" } + +// Op implements Event. It always returns 0. +func (*MessageReactionRemoveEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageReactionRemoveEvent) EventType() ws.EventType { return "MESSAGE_REACTION_REMOVE" } + +// Op implements Event. It always returns 0. +func (*MessageReactionRemoveAllEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageReactionRemoveAllEvent) EventType() ws.EventType { return "MESSAGE_REACTION_REMOVE_ALL" } + +// Op implements Event. It always returns 0. +func (*MessageReactionRemoveEmojiEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageReactionRemoveEmojiEvent) EventType() ws.EventType { + return "MESSAGE_REACTION_REMOVE_EMOJI" +} + +// Op implements Event. It always returns 0. +func (*MessageAckEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*MessageAckEvent) EventType() ws.EventType { return "MESSAGE_ACK" } + +// Op implements Event. It always returns 0. +func (*PresenceUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*PresenceUpdateEvent) EventType() ws.EventType { return "PRESENCE_UPDATE" } + +// Op implements Event. It always returns 0. +func (*PresencesReplaceEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*PresencesReplaceEvent) EventType() ws.EventType { return "PRESENCES_REPLACE" } + +// Op implements Event. It always returns 0. +func (*SessionsReplaceEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*SessionsReplaceEvent) EventType() ws.EventType { return "SESSIONS_REPLACE" } + +// Op implements Event. It always returns 0. +func (*TypingStartEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*TypingStartEvent) EventType() ws.EventType { return "TYPING_START" } + +// Op implements Event. It always returns 0. +func (*UserUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*UserUpdateEvent) EventType() ws.EventType { return "USER_UPDATE" } + +// Op implements Event. It always returns 0. +func (*VoiceStateUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*VoiceStateUpdateEvent) EventType() ws.EventType { return "VOICE_STATE_UPDATE" } + +// Op implements Event. It always returns 0. +func (*VoiceServerUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*VoiceServerUpdateEvent) EventType() ws.EventType { return "VOICE_SERVER_UPDATE" } + +// Op implements Event. It always returns 0. +func (*WebhooksUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*WebhooksUpdateEvent) EventType() ws.EventType { return "WEBHOOKS_UPDATE" } + +// Op implements Event. It always returns 0. +func (*InteractionCreateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*InteractionCreateEvent) EventType() ws.EventType { return "INTERACTION_CREATE" } + +// Op implements Event. It always returns 0. +func (*UserGuildSettingsUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*UserGuildSettingsUpdateEvent) EventType() ws.EventType { return "USER_GUILD_SETTINGS_UPDATE" } + +// Op implements Event. It always returns 0. +func (*UserSettingsUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*UserSettingsUpdateEvent) EventType() ws.EventType { return "USER_SETTINGS_UPDATE" } + +// Op implements Event. It always returns 0. +func (*UserNoteUpdateEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*UserNoteUpdateEvent) EventType() ws.EventType { return "USER_NOTE_UPDATE" } + +// Op implements Event. It always returns 0. +func (*RelationshipAddEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*RelationshipAddEvent) EventType() ws.EventType { return "RELATIONSHIP_ADD" } + +// Op implements Event. It always returns 0. +func (*RelationshipRemoveEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*RelationshipRemoveEvent) EventType() ws.EventType { return "RELATIONSHIP_REMOVE" } + +// Op implements Event. It always returns 0. +func (*ReadyEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ReadyEvent) EventType() ws.EventType { return "READY" } + +// Op implements Event. It always returns 0. +func (*ReadySupplementalEvent) Op() ws.OpCode { return dispatchOp } + +// EventType implements Event. +func (*ReadySupplementalEvent) EventType() ws.EventType { return "READY_SUPPLEMENTAL" } + +// Op implements Event. It always returns Op 2. +func (*IdentifyCommand) Op() ws.OpCode { return 2 } + +// EventType implements Event. +func (*IdentifyCommand) EventType() ws.EventType { return "" } diff --git a/gateway/events.go b/gateway/events.go index f6a7825..8e3a9f3 100644 --- a/gateway/events.go +++ b/gateway/events.go @@ -1,424 +1,884 @@ package gateway import ( + "strconv" + "strings" + "github.com/diamondburned/arikawa/v3/discord" + "github.com/diamondburned/arikawa/v3/utils/ws" ) -// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent +//go:generate go run ../utils/cmd/genevent -o event_methods.go +// Rule: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent. +// Ready is too big, so it's moved to ready.go. + +const ( + dispatchOp ws.OpCode = 0 // recv + heartbeatOp ws.OpCode = 1 // send/recv + identifyOp ws.OpCode = 2 // send... + statusUpdateOp ws.OpCode = 3 // + voiceStateUpdateOp ws.OpCode = 4 // + voiceServerPingOp ws.OpCode = 5 // + resumeOp ws.OpCode = 6 // + reconnectOp ws.OpCode = 7 // recv + requestGuildMembersOp ws.OpCode = 8 // send + invalidSessionOp ws.OpCode = 9 // recv... + helloOp ws.OpCode = 10 + heartbeatAckOp ws.OpCode = 11 + callConnectOp ws.OpCode = 13 + guildSubscriptionsOp ws.OpCode = 14 +) + +// OpUnmarshalers contains the Op unmarshalers for this gateway. +var OpUnmarshalers = ws.NewOpUnmarshalers() + +// HeartbeatCommand is a command for Op 1. It is the last sequence number to be +// sent. +type HeartbeatCommand int + +// HeartbeatAckEvent is an event for Op 11. +type HeartbeatAckEvent struct{} + +// ReconnectEvent is an event for Op 7. +type ReconnectEvent struct{} + +// HelloEvent is an event for Op 10. +// // https://discord.com/developers/docs/topics/gateway#connecting-and-resuming -type ( - HelloEvent struct { - HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"` - } +type HelloEvent struct { + HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"` +} - // Ready is too big, so it's moved to ready.go +// ResumeCommand is a command for Op 6. It describes the Resume send command. +// This is not to be confused with ResumedEvent, which is an event that Discord +// sends us. +type ResumeCommand struct { + Token string `json:"token"` + SessionID string `json:"session_id"` + Sequence int64 `json:"seq"` +} - ResumedEvent struct{} +// InvalidSessionEvent is an event for Op 9. It indicates if the event is +// resumable. +// +// https://discord.com/developers/docs/topics/gateway#connecting-and-resuming +type InvalidSessionEvent bool - // InvalidSessionEvent indicates if the event is resumable. - InvalidSessionEvent bool -) +// RequestGuildMembersCommand is a command for Op 8. +type RequestGuildMembersCommand struct { + // GuildIDs contains the ids of the guilds to request data from. Multiple + // guilds can only be requested when using user accounts. + GuildIDs []discord.GuildID `json:"guild_id"` + UserIDs []discord.UserID `json:"user_ids,omitempty"` + Query string `json:"query"` + Limit uint `json:"limit"` + Presences bool `json:"presences,omitempty"` + Nonce string `json:"nonce,omitempty"` +} + +// UpdateVoiceStateCommand is a command for Op 4. +type UpdateVoiceStateCommand struct { + GuildID discord.GuildID `json:"guild_id"` + ChannelID discord.ChannelID `json:"channel_id"` // nullable + SelfMute bool `json:"self_mute"` + SelfDeaf bool `json:"self_deaf"` +} + +// UpdatePresenceCommand is a command for Op 3. It is sent by this client to +// indicate a presence or status update. +type UpdatePresenceCommand struct { + Since discord.UnixMsTimestamp `json:"since"` // 0 if not idle + + // Activities can be null or an empty slice. + Activities []discord.Activity `json:"activities"` + + Status discord.Status `json:"status"` + AFK bool `json:"afk"` +} + +// GuildSubscribeCommand is a command for Op 14. It is undocumented. +type GuildSubscribeCommand struct { + Typing bool `json:"typing"` + Threads bool `json:"threads"` + Activities bool `json:"activities"` + GuildID discord.GuildID `json:"guild_id"` + + // Channels is not documented. It's used to fetch the right members sidebar. + Channels map[discord.ChannelID][][2]int `json:"channels,omitempty"` +} + +// ResumedEvent is a dispatch event. It is sent by Discord whenever we've +// successfully caught up to all events after resuming. +type ResumedEvent struct{} + +// ChannelCreateEvent is a dispatch event. +// // https://discord.com/developers/docs/topics/gateway#channels -type ( - ChannelCreateEvent struct { - discord.Channel - } - ChannelUpdateEvent struct { - discord.Channel - } - ChannelDeleteEvent struct { - discord.Channel - } - ChannelPinsUpdateEvent struct { - GuildID discord.GuildID `json:"guild_id,omitempty"` - ChannelID discord.ChannelID `json:"channel_id,omitempty"` - LastPin discord.Timestamp `json:"timestamp,omitempty"` - } +type ChannelCreateEvent struct { + discord.Channel +} - ChannelUnreadUpdateEvent struct { - GuildID discord.GuildID `json:"guild_id"` +// ChannelUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#channels +type ChannelUpdateEvent struct { + discord.Channel +} - ChannelUnreadUpdates []struct { - ID discord.ChannelID `json:"id"` - LastMessageID discord.MessageID `json:"last_message_id"` - } +// ChannelDeleteEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#channels +type ChannelDeleteEvent struct { + discord.Channel +} + +// ChannelPinsUpdateEvent is a dispatch event. +type ChannelPinsUpdateEvent struct { + GuildID discord.GuildID `json:"guild_id,omitempty"` + ChannelID discord.ChannelID `json:"channel_id,omitempty"` + LastPin discord.Timestamp `json:"timestamp,omitempty"` +} + +// ChannelUnreadUpdateEvent is a dispatch event. +type ChannelUnreadUpdateEvent struct { + GuildID discord.GuildID `json:"guild_id"` + + ChannelUnreadUpdates []struct { + ID discord.ChannelID `json:"id"` + LastMessageID discord.MessageID `json:"last_message_id"` } +} - // ThreadCreateEvent is sent when a thread is created, relevant to the - // current user, or when the current user is added to a thread. - ThreadCreateEvent struct { - discord.Channel - } +// ThreadCreateEvent is a dispatch event. It is sent when a thread is created, +// relevant to the current user, or when the current user is added to a thread. +type ThreadCreateEvent struct { + discord.Channel +} - // ThreadUpdateEvent is sent when a thread is updated. - ThreadUpdateEvent struct { - discord.Channel - } +// ThreadUpdateEvent is a dispatch event. It is sent when a thread is updated. +type ThreadUpdateEvent struct { + discord.Channel +} - // ThreadDeleteEvent is sent when a thread relevant to the current user is - // deleted. - ThreadDeleteEvent struct { - // ID is the id of this channel. - ID discord.ChannelID `json:"id"` - // GuildID is the id of the guild. - GuildID discord.GuildID `json:"guild_id,omitempty"` - // Type is the type of channel. - Type discord.ChannelType `json:"type,omitempty"` - // ParentID is the id of the text channel this thread was created. - ParentID discord.ChannelID `json:"parent_id,omitempty"` - } +// ThreadDeleteEvent is a dispatch event. It is sent when a thread relevant to +// the current user is deleted. +type ThreadDeleteEvent struct { + // ID is the id of this channel. + ID discord.ChannelID `json:"id"` + // GuildID is the id of the guild. + GuildID discord.GuildID `json:"guild_id,omitempty"` + // Type is the type of channel. + Type discord.ChannelType `json:"type,omitempty"` + // ParentID is the id of the text channel this thread was created. + ParentID discord.ChannelID `json:"parent_id,omitempty"` +} - // ThreadListSyncEvent is sent when the current user gains access to a - // channel. - ThreadListSyncEvent struct { - // GuildID is the id of the guild. - GuildID discord.GuildID `json:"guild_id"` - // ChannelIDs are the parent channel ids whose threads are being - // synced. If nil, then threads were synced for the entire guild. - // This slice may contain ChannelIDs that have no active threads as - // well, so you know to clear that data. - ChannelIDs []discord.ChannelID `json:"channel_ids,omitempty"` - Threads []discord.Channel `json:"threads"` - Members []discord.ThreadMember `json:"members"` - } +// ThreadListSyncEvent is a dispatch event. It is sent when the current user +// gains access to a channel. +type ThreadListSyncEvent struct { + // GuildID is the id of the guild. + GuildID discord.GuildID `json:"guild_id"` + // ChannelIDs are the parent channel ids whose threads are being + // synced. If nil, then threads were synced for the entire guild. + // This slice may contain ChannelIDs that have no active threads as + // well, so you know to clear that data. + ChannelIDs []discord.ChannelID `json:"channel_ids,omitempty"` + Threads []discord.Channel `json:"threads"` + Members []discord.ThreadMember `json:"members"` +} - // ThreadMemberUpdateEvent is sent when the thread member object for the - // current user is updated. - ThreadMemberUpdateEvent struct { - discord.ThreadMember - } +// ThreadMemberUpdateEvent is a dispatch event. It is sent when the thread +// member object for the current user is updated. +type ThreadMemberUpdateEvent struct { + discord.ThreadMember +} - // ThreadMembersUpdateEvent is sent when anyone is added to or removed from - // a thread. If the current user does not have the GUILD_MEMBERS Gateway - // Intent, then this event will only be sent if the current user was added - // to or removed from the thread. - ThreadMembersUpdateEvent struct { - // ID is the id of the thread. - ID discord.ChannelID - // GuildID is the id of the guild. - GuildID discord.GuildID - // MemberCount is the approximate number of members in the thread, - // capped at 50. - MemberCount int - // AddedMembers are the users who were added to the thread. - AddedMembers []discord.ThreadMember - // RemovedUserIDs are the ids of the users who were removed from the - // thread. - RemovedMemberIDs []discord.UserID - } -) +// ThreadMembersUpdateEvent is a dispatch event. It is sent when anyone is added +// to or removed from a thread. If the current user does not have the +// GUILD_MEMBERS Gateway Intent, then this event will only be sent if the +// current user was added to or removed from the thread. +type ThreadMembersUpdateEvent struct { + // ID is the id of the thread. + ID discord.ChannelID + // GuildID is the id of the guild. + GuildID discord.GuildID + // MemberCount is the approximate number of members in the thread, + // capped at 50. + MemberCount int + // AddedMembers are the users who were added to the thread. + AddedMembers []discord.ThreadMember + // RemovedUserIDs are the ids of the users who were removed from the + // thread. + RemovedMemberIDs []discord.UserID +} +// GuildCreateEvent is a dispatch event. +// // https://discord.com/developers/docs/topics/gateway#guilds -type ( - GuildCreateEvent struct { - discord.Guild +type GuildCreateEvent struct { + discord.Guild - Joined discord.Timestamp `json:"joined_at,omitempty"` - Large bool `json:"large,omitempty"` - Unavailable bool `json:"unavailable,omitempty"` - MemberCount uint64 `json:"member_count,omitempty"` + Joined discord.Timestamp `json:"joined_at,omitempty"` + Large bool `json:"large,omitempty"` + Unavailable bool `json:"unavailable,omitempty"` + MemberCount uint64 `json:"member_count,omitempty"` - VoiceStates []discord.VoiceState `json:"voice_states,omitempty"` - Members []discord.Member `json:"members,omitempty"` - Channels []discord.Channel `json:"channels,omitempty"` - Threads []discord.Channel `json:"threads,omitempty"` - Presences []discord.Presence `json:"presences,omitempty"` - } - GuildUpdateEvent struct { - discord.Guild - } - GuildDeleteEvent struct { - ID discord.GuildID `json:"id"` - // Unavailable if false == removed - Unavailable bool `json:"unavailable"` - } + VoiceStates []discord.VoiceState `json:"voice_states,omitempty"` + Members []discord.Member `json:"members,omitempty"` + Channels []discord.Channel `json:"channels,omitempty"` + Threads []discord.Channel `json:"threads,omitempty"` + Presences []discord.Presence `json:"presences,omitempty"` +} - GuildBanAddEvent struct { - GuildID discord.GuildID `json:"guild_id"` - User discord.User `json:"user"` - } - GuildBanRemoveEvent struct { - GuildID discord.GuildID `json:"guild_id"` - User discord.User `json:"user"` - } +// GuildUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildUpdateEvent struct { + discord.Guild +} - GuildEmojisUpdateEvent struct { - GuildID discord.GuildID `json:"guild_id"` - Emojis []discord.Emoji `json:"emojis"` - } +// GuildDeleteEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildDeleteEvent struct { + ID discord.GuildID `json:"id"` + // Unavailable if false == removed + Unavailable bool `json:"unavailable"` +} - GuildIntegrationsUpdateEvent struct { - GuildID discord.GuildID `json:"guild_id"` - } +// GuildBanAddEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildBanAddEvent struct { + GuildID discord.GuildID `json:"guild_id"` + User discord.User `json:"user"` +} - GuildMemberAddEvent struct { - discord.Member - GuildID discord.GuildID `json:"guild_id"` - } - GuildMemberRemoveEvent struct { - GuildID discord.GuildID `json:"guild_id"` - User discord.User `json:"user"` - } - GuildMemberUpdateEvent struct { - GuildID discord.GuildID `json:"guild_id"` - RoleIDs []discord.RoleID `json:"roles"` - User discord.User `json:"user"` - Nick string `json:"nick"` - Avatar discord.Hash `json:"avatar"` - } +// GuildBanRemoveEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildBanRemoveEvent struct { + GuildID discord.GuildID `json:"guild_id"` + User discord.User `json:"user"` +} - // GuildMembersChunkEvent is sent when Guild Request Members is called. - GuildMembersChunkEvent struct { - GuildID discord.GuildID `json:"guild_id"` - Members []discord.Member `json:"members"` +// GuildEmojisUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildEmojisUpdateEvent struct { + GuildID discord.GuildID `json:"guild_id"` + Emojis []discord.Emoji `json:"emojis"` +} - ChunkIndex int `json:"chunk_index"` - ChunkCount int `json:"chunk_count"` +// GuildIntegrationsUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildIntegrationsUpdateEvent struct { + GuildID discord.GuildID `json:"guild_id"` +} - // Whatever's not found goes here - NotFound []string `json:"not_found,omitempty"` +// GuildMemberAddEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildMemberAddEvent struct { + discord.Member + GuildID discord.GuildID `json:"guild_id"` +} - // Only filled if requested - Presences []discord.Presence `json:"presences,omitempty"` - Nonce string `json:"nonce,omitempty"` - } +// GuildMemberRemoveEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildMemberRemoveEvent struct { + GuildID discord.GuildID `json:"guild_id"` + User discord.User `json:"user"` +} - // GuildMemberListUpdate is an undocumented event. It's received when the - // client sends over GuildSubscriptions with the Channels field used. - // The State package does not handle this event. - GuildMemberListUpdate struct { - ID string `json:"id"` - GuildID discord.GuildID `json:"guild_id"` - MemberCount uint64 `json:"member_count"` - OnlineCount uint64 `json:"online_count"` +// GuildMemberUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildMemberUpdateEvent struct { + GuildID discord.GuildID `json:"guild_id"` + RoleIDs []discord.RoleID `json:"roles"` + User discord.User `json:"user"` + Nick string `json:"nick"` + Avatar discord.Hash `json:"avatar"` +} - // Groups is all the visible role sections. - Groups []GuildMemberListGroup `json:"groups"` - - Ops []GuildMemberListOp `json:"ops"` - } - GuildMemberListGroup struct { - ID string `json:"id"` // either discord.RoleID, "online" or "offline" - Count uint64 `json:"count"` - } - GuildMemberListOp struct { - // Mysterious string, so far spotted to be [SYNC, INSERT, UPDATE, DELETE]. - Op string `json:"op"` - - // NON-SYNC ONLY - // Only available for Ops that aren't "SYNC". - Index int `json:"index,omitempty"` - Item GuildMemberListOpItem `json:"item,omitempty"` - - // SYNC ONLY - // Range requested in GuildSubscribeData. - Range [2]int `json:"range,omitempty"` - // Items is basically a linear list of roles and members, similarly to - // how the client renders it. No, it's not nested. - Items []GuildMemberListOpItem `json:"items,omitempty"` - } - // GuildMemberListOpItem is an enum. Either of the fields are provided, but - // never both. Refer to (*GuildMemberListUpdate).Ops for more. - GuildMemberListOpItem struct { - Group *GuildMemberListGroup `json:"group,omitempty"` - Member *struct { - discord.Member - HoistedRole string `json:"hoisted_role"` - Presence discord.Presence `json:"presence"` - } `json:"member,omitempty"` - } - - GuildRoleCreateEvent struct { - GuildID discord.GuildID `json:"guild_id"` - Role discord.Role `json:"role"` - } - GuildRoleUpdateEvent struct { - GuildID discord.GuildID `json:"guild_id"` - Role discord.Role `json:"role"` - } - GuildRoleDeleteEvent struct { - GuildID discord.GuildID `json:"guild_id"` - RoleID discord.RoleID `json:"role_id"` - } -) - -func (u GuildMemberUpdateEvent) Update(m *discord.Member) { +// UpdateMember updates the given discord.Member. +func (u *GuildMemberUpdateEvent) UpdateMember(m *discord.Member) { m.RoleIDs = u.RoleIDs m.User = u.User m.Nick = u.Nick m.Avatar = u.Avatar } +// GuildMembersChunkEvent is a dispatch event. It is sent when the Guild Request +// Members command is sent. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildMembersChunkEvent struct { + GuildID discord.GuildID `json:"guild_id"` + Members []discord.Member `json:"members"` + + ChunkIndex int `json:"chunk_index"` + ChunkCount int `json:"chunk_count"` + + // Whatever's not found goes here + NotFound []string `json:"not_found,omitempty"` + + // Only filled if requested + Presences []discord.Presence `json:"presences,omitempty"` + Nonce string `json:"nonce,omitempty"` +} + +// GuildMemberListUpdate is a dispatch event. It is an undocumented event. It's +// received when the client sends over GuildSubscriptions with the Channels +// field used. The State package does not handle this event. +type GuildMemberListUpdate struct { + ID string `json:"id"` + GuildID discord.GuildID `json:"guild_id"` + MemberCount uint64 `json:"member_count"` + OnlineCount uint64 `json:"online_count"` + + // Groups is all the visible role sections. + Groups []GuildMemberListGroup `json:"groups"` + + Ops []GuildMemberListOp `json:"ops"` +} + +// GuildMemberListGroup is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildMemberListGroup struct { + ID string `json:"id"` // either discord.RoleID, "online" or "offline" + Count uint64 `json:"count"` +} + +// GuildMemberListOp is an entry of every operation in GuildMemberListUpdate. +type GuildMemberListOp struct { + // Mysterious string, so far spotted to be [SYNC, INSERT, UPDATE, DELETE]. + Op string `json:"op"` + + // NON-SYNC ONLY + // Only available for Ops that aren't "SYNC". + Index int `json:"index,omitempty"` + Item GuildMemberListOpItem `json:"item,omitempty"` + + // SYNC ONLY + // Range requested in GuildSubscribeCommand. + Range [2]int `json:"range,omitempty"` + // Items is basically a linear list of roles and members, similarly to + // how the client renders it. No, it's not nested. + Items []GuildMemberListOpItem `json:"items,omitempty"` +} + +// GuildMemberListOpItem is a union of either Group or Member. Refer to +// (*GuildMemberListUpdate).Ops for more. +type GuildMemberListOpItem struct { + Group *GuildMemberListGroup `json:"group,omitempty"` + Member *struct { + discord.Member + HoistedRole string `json:"hoisted_role"` + Presence discord.Presence `json:"presence"` + } `json:"member,omitempty"` +} + +// GuildRoleCreateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildRoleCreateEvent struct { + GuildID discord.GuildID `json:"guild_id"` + Role discord.Role `json:"role"` +} + +// GuildRoleUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildRoleUpdateEvent struct { + GuildID discord.GuildID `json:"guild_id"` + Role discord.Role `json:"role"` +} + +// GuildRoleDeleteEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#guilds +type GuildRoleDeleteEvent struct { + GuildID discord.GuildID `json:"guild_id"` + RoleID discord.RoleID `json:"role_id"` +} + +// InviteCreateEvent is a dispatch event. +// // https://discord.com/developers/docs/topics/gateway#invites -type ( - InviteCreateEvent struct { - Code string `json:"code"` - CreatedAt discord.Timestamp `json:"created_at"` - ChannelID discord.ChannelID `json:"channel_id"` - GuildID discord.GuildID `json:"guild_id,omitempty"` +type InviteCreateEvent struct { + Code string `json:"code"` + CreatedAt discord.Timestamp `json:"created_at"` + ChannelID discord.ChannelID `json:"channel_id"` + GuildID discord.GuildID `json:"guild_id,omitempty"` - // Similar to discord.Invite - Inviter *discord.User `json:"inviter,omitempty"` - Target *discord.User `json:"target_user,omitempty"` - TargetType discord.InviteUserType `json:"target_user_type,omitempty"` + // Similar to discord.Invite + Inviter *discord.User `json:"inviter,omitempty"` + Target *discord.User `json:"target_user,omitempty"` + TargetType discord.InviteUserType `json:"target_user_type,omitempty"` - discord.InviteMetadata - } - InviteDeleteEvent struct { - Code string `json:"code"` - ChannelID discord.ChannelID `json:"channel_id"` - GuildID discord.GuildID `json:"guild_id,omitempty"` - } -) + discord.InviteMetadata +} +// InviteDeleteEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#invites +type InviteDeleteEvent struct { + Code string `json:"code"` + ChannelID discord.ChannelID `json:"channel_id"` + GuildID discord.GuildID `json:"guild_id,omitempty"` +} + +// MessageCreateEvent is a dispatch event. +// // https://discord.com/developers/docs/topics/gateway#messages -type ( - MessageCreateEvent struct { - discord.Message - Member *discord.Member `json:"member,omitempty"` - } - MessageUpdateEvent struct { - discord.Message - Member *discord.Member `json:"member,omitempty"` - } - MessageDeleteEvent struct { - ID discord.MessageID `json:"id"` - ChannelID discord.ChannelID `json:"channel_id"` - GuildID discord.GuildID `json:"guild_id,omitempty"` - } - MessageDeleteBulkEvent struct { - IDs []discord.MessageID `json:"ids"` - ChannelID discord.ChannelID `json:"channel_id"` - GuildID discord.GuildID `json:"guild_id,omitempty"` - } +type MessageCreateEvent struct { + discord.Message + Member *discord.Member `json:"member,omitempty"` +} - MessageReactionAddEvent struct { - UserID discord.UserID `json:"user_id"` - ChannelID discord.ChannelID `json:"channel_id"` - MessageID discord.MessageID `json:"message_id"` +// MessageUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#messages +type MessageUpdateEvent struct { + discord.Message + Member *discord.Member `json:"member,omitempty"` +} - Emoji discord.Emoji `json:"emoji,omitempty"` +// MessageDeleteEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#messages +type MessageDeleteEvent struct { + ID discord.MessageID `json:"id"` + ChannelID discord.ChannelID `json:"channel_id"` + GuildID discord.GuildID `json:"guild_id,omitempty"` +} - GuildID discord.GuildID `json:"guild_id,omitempty"` - Member *discord.Member `json:"member,omitempty"` - } - MessageReactionRemoveEvent struct { - UserID discord.UserID `json:"user_id"` - ChannelID discord.ChannelID `json:"channel_id"` - MessageID discord.MessageID `json:"message_id"` - Emoji discord.Emoji `json:"emoji"` - GuildID discord.GuildID `json:"guild_id,omitempty"` - } - MessageReactionRemoveAllEvent struct { - ChannelID discord.ChannelID `json:"channel_id"` - MessageID discord.MessageID `json:"message_id"` - GuildID discord.GuildID `json:"guild_id,omitempty"` - } - MessageReactionRemoveEmojiEvent struct { - ChannelID discord.ChannelID `json:"channel_id"` - MessageID discord.MessageID `json:"message_id"` - Emoji discord.Emoji `json:"emoji"` - GuildID discord.GuildID `json:"guild_id,omitempty"` - } +// MessageDeleteBulkEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#messages +type MessageDeleteBulkEvent struct { + IDs []discord.MessageID `json:"ids"` + ChannelID discord.ChannelID `json:"channel_id"` + GuildID discord.GuildID `json:"guild_id,omitempty"` +} - MessageAckEvent struct { - MessageID discord.MessageID `json:"message_id"` - ChannelID discord.ChannelID `json:"channel_id"` - } -) +// MessageReactionAddEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#messages +type MessageReactionAddEvent struct { + UserID discord.UserID `json:"user_id"` + ChannelID discord.ChannelID `json:"channel_id"` + MessageID discord.MessageID `json:"message_id"` + Emoji discord.Emoji `json:"emoji,omitempty"` + + GuildID discord.GuildID `json:"guild_id,omitempty"` + Member *discord.Member `json:"member,omitempty"` +} + +// MessageReactionRemoveEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#messages +type MessageReactionRemoveEvent struct { + UserID discord.UserID `json:"user_id"` + ChannelID discord.ChannelID `json:"channel_id"` + MessageID discord.MessageID `json:"message_id"` + Emoji discord.Emoji `json:"emoji"` + GuildID discord.GuildID `json:"guild_id,omitempty"` +} + +// MessageReactionRemoveAllEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#messages +type MessageReactionRemoveAllEvent struct { + ChannelID discord.ChannelID `json:"channel_id"` + MessageID discord.MessageID `json:"message_id"` + GuildID discord.GuildID `json:"guild_id,omitempty"` +} + +// MessageReactionRemoveEmojiEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#messages +type MessageReactionRemoveEmojiEvent struct { + ChannelID discord.ChannelID `json:"channel_id"` + MessageID discord.MessageID `json:"message_id"` + Emoji discord.Emoji `json:"emoji"` + GuildID discord.GuildID `json:"guild_id,omitempty"` +} + +// MessageAckEvent is a dispatch event. +type MessageAckEvent struct { + MessageID discord.MessageID `json:"message_id"` + ChannelID discord.ChannelID `json:"channel_id"` +} + +// PresenceUpdateEvent is a dispatch event. It represents the structure of the +// Presence Update Event object. +// +// https://discord.com/developers/docs/topics/gateway#presence-update-presence-update-event-fields +type PresenceUpdateEvent struct { + discord.Presence +} + +// PresencesReplaceEvent is a dispatch event. +// // https://discord.com/developers/docs/topics/gateway#presence -type ( - // ClientStatus is the user's platform-dependent status. - // - // https://discord.com/developers/docs/topics/gateway#client-status-object +type PresencesReplaceEvent []PresenceUpdateEvent - // PresenceUpdateEvent represents the structure of the Presence Update Event - // object. - // - // https://discord.com/developers/docs/topics/gateway#presence-update-presence-update-event-fields - PresenceUpdateEvent struct { - discord.Presence - } +// SessionsReplaceEvent is a dispatch event. It is undocumented. It's likely +// used for current user's presence updates. +type SessionsReplaceEvent []struct { + Status discord.Status `json:"status"` + SessionID string `json:"session_id"` - PresencesReplaceEvent []PresenceUpdateEvent + Activities []discord.Activity `json:"activities"` - // SessionsReplaceEvent is an undocumented user event. It's likely used for - // current user's presence updates. - SessionsReplaceEvent []struct { - Status discord.Status `json:"status"` - SessionID string `json:"session_id"` + ClientInfo struct { + Version int `json:"version"` + OS string `json:"os"` + Client string `json:"client"` + } `json:"client_info"` - Activities []discord.Activity `json:"activities"` + Active bool `json:"active"` +} - ClientInfo struct { - Version int `json:"version"` - OS string `json:"os"` - Client string `json:"client"` - } `json:"client_info"` +// TypingStartEvent is a dispatch event. +type TypingStartEvent struct { + ChannelID discord.ChannelID `json:"channel_id"` + UserID discord.UserID `json:"user_id"` + Timestamp discord.UnixTimestamp `json:"timestamp"` - Active bool `json:"active"` - } + GuildID discord.GuildID `json:"guild_id,omitempty"` + Member *discord.Member `json:"member,omitempty"` +} - TypingStartEvent struct { - ChannelID discord.ChannelID `json:"channel_id"` - UserID discord.UserID `json:"user_id"` - Timestamp discord.UnixTimestamp `json:"timestamp"` - - GuildID discord.GuildID `json:"guild_id,omitempty"` - Member *discord.Member `json:"member,omitempty"` - } - - UserUpdateEvent struct { - discord.User - } -) +// UserUpdateEvent is a dispatch event. +type UserUpdateEvent struct { + discord.User +} +// VoiceStateUpdateEvent is a dispatch event. +// // https://discord.com/developers/docs/topics/gateway#voice -type ( - VoiceStateUpdateEvent struct { - discord.VoiceState - } - VoiceServerUpdateEvent struct { - Token string `json:"token"` - GuildID discord.GuildID `json:"guild_id"` - Endpoint string `json:"endpoint"` - } -) +type VoiceStateUpdateEvent struct { + discord.VoiceState +} +// VoiceServerUpdateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#voice +type VoiceServerUpdateEvent struct { + Token string `json:"token"` + GuildID discord.GuildID `json:"guild_id"` + Endpoint string `json:"endpoint"` +} + +// WebhooksUpdateEvent is a dispatch event. +// // https://discord.com/developers/docs/topics/gateway#webhooks -type ( - WebhooksUpdateEvent struct { - GuildID discord.GuildID `json:"guild_id"` - ChannelID discord.ChannelID `json:"channel_id"` - } -) +type WebhooksUpdateEvent struct { + GuildID discord.GuildID `json:"guild_id"` + ChannelID discord.ChannelID `json:"channel_id"` +} +// InteractionCreateEvent is a dispatch event. +// +// https://discord.com/developers/docs/topics/gateway#webhooks type InteractionCreateEvent struct { discord.InteractionEvent } // Undocumented + +// UserGuildSettingsUpdateEvent is a dispatch event. It is undocumented. +type UserGuildSettingsUpdateEvent struct { + UserGuildSetting +} + +// UserSettingsUpdateEvent is a dispatch event. It is undocumented. +type UserSettingsUpdateEvent struct { + UserSettings +} + +// UserNoteUpdateEvent is a dispatch event. It is undocumented. +type UserNoteUpdateEvent struct { + ID discord.UserID `json:"id"` + Note string `json:"note"` +} + +// RelationshipAddEvent is a dispatch event. It is undocumented. +type RelationshipAddEvent struct { + discord.Relationship +} + +// RelationshipRemoveEvent is a dispatch event. It is undocumented. +type RelationshipRemoveEvent struct { + discord.Relationship +} + +// ReadyEvent is a dispatch event for READY. +type ReadyEvent struct { + Version int `json:"version"` + + User discord.User `json:"user"` + SessionID string `json:"session_id"` + + PrivateChannels []discord.Channel `json:"private_channels"` + Guilds []GuildCreateEvent `json:"guilds"` + + Shard *Shard `json:"shard,omitempty"` + + // Undocumented fields + + UserSettings *UserSettings `json:"user_settings,omitempty"` + ReadStates []ReadState `json:"read_state,omitempty"` + UserGuildSettings []UserGuildSetting `json:"user_guild_settings,omitempty"` + Relationships []discord.Relationship `json:"relationships,omitempty"` + Presences []discord.Presence `json:"presences,omitempty"` + + FriendSuggestionCount int `json:"friend_suggestion_count,omitempty"` + GeoOrderedRTCRegions []string `json:"geo_ordered_rtc_regions,omitempty"` +} + +// Ready subtypes. type ( - UserGuildSettingsUpdateEvent struct { - UserGuildSetting + // ReadState is a single ReadState entry. It is undocumented. + ReadState struct { + ChannelID discord.ChannelID `json:"id"` + LastMessageID discord.MessageID `json:"last_message_id"` + LastPinTimestamp discord.Timestamp `json:"last_pin_timestamp"` + MentionCount int `json:"mention_count"` } - UserSettingsUpdateEvent struct { - UserSettings + + // UserSettings is the struct for (almost) all user settings. It is + // undocumented. + UserSettings struct { + ShowCurrentGame bool `json:"show_current_game"` + DefaultGuildsRestricted bool `json:"default_guilds_restricted"` + InlineAttachmentMedia bool `json:"inline_attachment_media"` + InlineEmbedMedia bool `json:"inline_embed_media"` + GIFAutoPlay bool `json:"gif_auto_play"` + RenderEmbeds bool `json:"render_embeds"` + RenderReactions bool `json:"render_reactions"` + AnimateEmoji bool `json:"animate_emoji"` + AnimateStickers int `json:"animate_stickers"` + EnableTTSCommand bool `json:"enable_tts_command"` + MessageDisplayCompact bool `json:"message_display_compact"` + ConvertEmoticons bool `json:"convert_emoticons"` + ExplicitContentFilter uint8 `json:"explicit_content_filter"` // ??? + DisableGamesTab bool `json:"disable_games_tab"` + DeveloperMode bool `json:"developer_mode"` + DetectPlatformAccounts bool `json:"detect_platform_accounts"` + StreamNotification bool `json:"stream_notification_enabled"` + AccessibilityDetection bool `json:"allow_accessibility_detection"` + ContactSync bool `json:"contact_sync_enabled"` + NativePhoneIntegration bool `json:"native_phone_integration_enabled"` + + TimezoneOffset int `json:"timezone_offset"` + + Locale string `json:"locale"` + Theme string `json:"theme"` + + GuildPositions []discord.GuildID `json:"guild_positions"` + GuildFolders []GuildFolder `json:"guild_folders"` + RestrictedGuilds []discord.GuildID `json:"restricted_guilds"` + + FriendSourceFlags FriendSourceFlags `json:"friend_source_flags"` + + Status discord.Status `json:"status"` + CustomStatus *CustomUserStatus `json:"custom_status"` } - UserNoteUpdateEvent struct { - ID discord.UserID `json:"id"` - Note string `json:"note"` + + // CustomUserStatus is the custom user status that allows setting an emoji + // and a piece of text on each user. + CustomUserStatus struct { + Text string `json:"text"` + ExpiresAt discord.Timestamp `json:"expires_at,omitempty"` + EmojiID discord.EmojiID `json:"emoji_id,string"` + EmojiName string `json:"emoji_name"` + } + + // UserGuildSetting stores the settings for a single guild. It is + // undocumented. + UserGuildSetting struct { + GuildID discord.GuildID `json:"guild_id"` + + SuppressRoles bool `json:"suppress_roles"` + SuppressEveryone bool `json:"suppress_everyone"` + Muted bool `json:"muted"` + MuteConfig *UserMuteConfig `json:"mute_config"` + + MobilePush bool `json:"mobile_push"` + Notifications UserNotification `json:"message_notifications"` + + ChannelOverrides []UserChannelOverride `json:"channel_overrides"` + } + + // A UserChannelOverride struct describes a channel settings override for a + // users guild settings. + UserChannelOverride struct { + Muted bool `json:"muted"` + MuteConfig *UserMuteConfig `json:"mute_config"` + Notifications UserNotification `json:"message_notifications"` + ChannelID discord.ChannelID `json:"channel_id"` + } + + // UserMuteConfig seems to describe the mute settings. It belongs to the + // UserGuildSettingEntry and UserChannelOverride structs and is + // undocumented. + UserMuteConfig struct { + SelectedTimeWindow int `json:"selected_time_window"` + EndTime discord.Timestamp `json:"end_time"` + } + + // GuildFolder holds a single folder that you see in the left guild panel. + GuildFolder struct { + Name string `json:"name"` + ID GuildFolderID `json:"id"` + GuildIDs []discord.GuildID `json:"guild_ids"` + Color discord.Color `json:"color"` + } + + // FriendSourceFlags describes sources that friend requests could be sent + // from. It belongs to the UserSettings struct and is undocumented. + FriendSourceFlags struct { + All bool `json:"all,omitempty"` + MutualGuilds bool `json:"mutual_guilds,omitempty"` + MutualFriends bool `json:"mutual_friends,omitempty"` } ) -type ( - RelationshipAddEvent struct { - discord.Relationship +// UserNotification is the notification setting for a channel or guild. +type UserNotification uint8 + +const ( + AllNotifications UserNotification = iota + OnlyMentions + NoNotifications + GuildDefaults +) + +// GuildFolderID is possibly a snowflake. It can also be 0 (null) or a low +// number of unknown significance. +type GuildFolderID int64 + +func (g *GuildFolderID) UnmarshalJSON(b []byte) error { + var body = string(b) + if body == "null" { + return nil } - RelationshipRemoveEvent struct { - discord.Relationship + + body = strings.Trim(body, `"`) + + u, err := strconv.ParseInt(body, 10, 64) + if err != nil { + return err + } + + *g = GuildFolderID(u) + return nil +} + +func (g GuildFolderID) MarshalJSON() ([]byte, error) { + if g == 0 { + return []byte("null"), nil + } + + return []byte(strconv.FormatInt(int64(g), 10)), nil +} + +// ReadySupplementalEvent is a dispatch event for READY_SUPPLEMENTAL. It is an +// undocumented event. For now, this event is never used, and its usage have yet +// been discovered. +type ReadySupplementalEvent struct { + Guilds []GuildCreateEvent `json:"guilds"` // only have ID and VoiceStates + MergedMembers [][]SupplementalMember `json:"merged_members"` + MergedPresences MergedPresences `json:"merged_presences"` +} + +// ReadySupplemental event structs. +type ( + // SupplementalMember is the struct for a member in the MergedMembers field + // of ReadySupplementalEvent. It has slight differences to discord.Member. + SupplementalMember struct { + UserID discord.UserID `json:"user_id"` + Nick string `json:"nick,omitempty"` + RoleIDs []discord.RoleID `json:"roles"` + + GuildID discord.GuildID `json:"guild_id,omitempty"` + IsPending bool `json:"pending,omitempty"` + HoistedRole discord.RoleID `json:"hoisted_role"` + + Mute bool `json:"mute"` + Deaf bool `json:"deaf"` + + // Joined specifies when the user joined the guild. + Joined discord.Timestamp `json:"joined_at"` + + // BoostedSince specifies when the user started boosting the guild. + BoostedSince discord.Timestamp `json:"premium_since,omitempty"` + } + + // MergedPresences is the struct for presences of guilds' members and + // friends. It is undocumented. + MergedPresences struct { + Guilds [][]SupplementalPresence `json:"guilds"` + Friends []SupplementalPresence `json:"friends"` + } + + // SupplementalPresence is a single presence for either a guild member or + // friend. It is used in MergedPresences and is undocumented. + SupplementalPresence struct { + UserID discord.UserID `json:"user_id"` + + // Status is either "idle", "dnd", "online", or "offline". + Status discord.Status `json:"status"` + // Activities are the user's current activities. + Activities []discord.Activity `json:"activities"` + // ClientStaus is the user's platform-dependent status. + ClientStatus discord.ClientStatus `json:"client_status"` + + // LastModified is only present in Friends. + LastModified discord.UnixMsTimestamp `json:"last_modified,omitempty"` } ) + +// ConvertSupplementalMembers converts a SupplementalMember to a regular Member. +func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member { + members := make([]discord.Member, len(sms)) + for i, sm := range sms { + members[i] = discord.Member{ + User: discord.User{ID: sm.UserID}, + Nick: sm.Nick, + RoleIDs: sm.RoleIDs, + Joined: sm.Joined, + BoostedSince: sm.BoostedSince, + Deaf: sm.Deaf, + Mute: sm.Mute, + IsPending: sm.IsPending, + } + } + return members +} + +// ConvertSupplementalPresences converts a SupplementalPresence to a regular +// Presence with an empty GuildID. +func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence { + presences := make([]discord.Presence, len(sps)) + for i, sp := range sps { + presences[i] = discord.Presence{ + User: discord.User{ID: sp.UserID}, + Status: sp.Status, + Activities: sp.Activities, + ClientStatus: sp.ClientStatus, + } + } + return presences +} diff --git a/gateway/events_map.go b/gateway/events_map.go deleted file mode 100644 index 91341a8..0000000 --- a/gateway/events_map.go +++ /dev/null @@ -1,76 +0,0 @@ -package gateway - -// Event is any event struct. They have an "Event" suffixed to them. -type Event = interface{} - -// EventCreator maps an event type string to a constructor. -var EventCreator = map[string]func() Event{ - "HELLO": func() Event { return new(HelloEvent) }, - "READY": func() Event { return new(ReadyEvent) }, - "READY_SUPPLEMENTAL": func() Event { return new(ReadySupplementalEvent) }, - "RESUMED": func() Event { return new(ResumedEvent) }, - "INVALID_SESSION": func() Event { return new(InvalidSessionEvent) }, - - "CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) }, - "CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) }, - "CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) }, - "CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) }, - "CHANNEL_UNREAD_UPDATE": func() Event { return new(ChannelUnreadUpdateEvent) }, - - "GUILD_CREATE": func() Event { return new(GuildCreateEvent) }, - "GUILD_UPDATE": func() Event { return new(GuildUpdateEvent) }, - "GUILD_DELETE": func() Event { return new(GuildDeleteEvent) }, - - "GUILD_BAN_ADD": func() Event { return new(GuildBanAddEvent) }, - "GUILD_BAN_REMOVE": func() Event { return new(GuildBanRemoveEvent) }, - - "GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) }, - "GUILD_INTEGRATIONS_UPDATE": func() Event { return new(GuildIntegrationsUpdateEvent) }, - - "GUILD_MEMBER_ADD": func() Event { return new(GuildMemberAddEvent) }, - "GUILD_MEMBER_REMOVE": func() Event { return new(GuildMemberRemoveEvent) }, - "GUILD_MEMBER_UPDATE": func() Event { return new(GuildMemberUpdateEvent) }, - "GUILD_MEMBERS_CHUNK": func() Event { return new(GuildMembersChunkEvent) }, - - "GUILD_MEMBER_LIST_UPDATE": func() Event { return new(GuildMemberListUpdate) }, - - "GUILD_ROLE_CREATE": func() Event { return new(GuildRoleCreateEvent) }, - "GUILD_ROLE_UPDATE": func() Event { return new(GuildRoleUpdateEvent) }, - "GUILD_ROLE_DELETE": func() Event { return new(GuildRoleDeleteEvent) }, - - "INVITE_CREATE": func() Event { return new(InviteCreateEvent) }, - "INVITE_DELETE": func() Event { return new(InviteDeleteEvent) }, - - "MESSAGE_CREATE": func() Event { return new(MessageCreateEvent) }, - "MESSAGE_UPDATE": func() Event { return new(MessageUpdateEvent) }, - "MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) }, - "MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) }, - - "MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) }, - "MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) }, - "MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) }, - "MESSAGE_REACTION_REMOVE_EMOJI": func() Event { return new(MessageReactionRemoveEmojiEvent) }, - - "MESSAGE_ACK": func() Event { return new(MessageAckEvent) }, - - "PRESENCE_UPDATE": func() Event { return new(PresenceUpdateEvent) }, - "PRESENCES_REPLACE": func() Event { return new(PresencesReplaceEvent) }, - "SESSIONS_REPLACE": func() Event { return new(SessionsReplaceEvent) }, - - "TYPING_START": func() Event { return new(TypingStartEvent) }, - - "VOICE_STATE_UPDATE": func() Event { return new(VoiceStateUpdateEvent) }, - "VOICE_SERVER_UPDATE": func() Event { return new(VoiceServerUpdateEvent) }, - - "WEBHOOKS_UPDATE": func() Event { return new(WebhooksUpdateEvent) }, - - "INTERACTION_CREATE": func() Event { return new(InteractionCreateEvent) }, - - "USER_UPDATE": func() Event { return new(UserUpdateEvent) }, - "USER_SETTINGS_UPDATE": func() Event { return new(UserSettingsUpdateEvent) }, - "USER_GUILD_SETTINGS_UPDATE": func() Event { return new(UserGuildSettingsUpdateEvent) }, - "USER_NOTE_UPDATE": func() Event { return new(UserNoteUpdateEvent) }, - - "RELATIONSHIP_ADD": func() Event { return new(RelationshipAddEvent) }, - "RELATIONSHIP_REMOVE": func() Event { return new(RelationshipRemoveEvent) }, -} diff --git a/gateway/gateway.go b/gateway/gateway.go index 877b79e..4ac9112 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -9,19 +9,15 @@ package gateway import ( "context" + "math/rand" "net/url" - "strings" "sync" "time" - "github.com/gorilla/websocket" - "github.com/pkg/errors" - "github.com/diamondburned/arikawa/v3/api" - "github.com/diamondburned/arikawa/v3/internal/moreatomic" - "github.com/diamondburned/arikawa/v3/utils/json" - "github.com/diamondburned/arikawa/v3/utils/json/option" - "github.com/diamondburned/arikawa/v3/utils/wsutil" + "github.com/diamondburned/arikawa/v3/internal/lazytime" + "github.com/diamondburned/arikawa/v3/utils/ws" + "github.com/pkg/errors" ) var ( @@ -29,26 +25,24 @@ var ( Encoding = "json" ) -var ( - ErrMissingForResume = errors.New("missing session ID or sequence for resuming") - ErrWSMaxTries = errors.New( - "could not connect to the Discord gateway before reaching the timeout") - ErrClosed = errors.New("the gateway is closed and cannot reconnect") -) +// CodeInvalidSequence is the code returned by Discord to signal that the given +// sequence number is invalid. +const CodeInvalidSequence = 4007 -// see -// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes -const errCodeShardingRequired = 4011 +// CodeShardingRequired is the code returned by Discord to signal that the bot +// must reshard before proceeding. For more information, see +// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes. +const CodeShardingRequired = 4011 // URL asks Discord for a Websocket URL to the Gateway. -func URL() (string, error) { - return api.GatewayURL() +func URL(ctx context.Context) (string, error) { + return api.GatewayURL(ctx) } // BotURL fetches the Gateway URL along with extra metadata. The token // passed in will NOT be prefixed with Bot. -func BotURL(token string) (*api.BotData, error) { - return api.NewClient(token).BotURL() +func BotURL(ctx context.Context, token string) (*api.BotData, error) { + return api.NewClient(token).WithContext(ctx).BotURL() } // AddGatewayParams appends into the given URL string the gateway URL @@ -62,469 +56,371 @@ func AddGatewayParams(baseURL string) string { return baseURL + "?" + param.Encode() } -type Gateway struct { - WS *wsutil.Websocket - - // WSTimeout is a timeout for an arbitrary action. An example of this is the - // timeout for Start and the timeout for sending each Gateway command - // independently. - WSTimeout time.Duration - - // ReconnectAttempts are the amount of attempts made to Reconnect, before - // aborting. If this set to 0, unlimited attempts will be made. - ReconnectAttempts uint - - // All events sent over are pointers to Event structs (structs suffixed with - // "Event"). This shouldn't be accessed if the Gateway is created with a - // Session. - Events chan Event - - sessionMu sync.RWMutex - sessionID string - - Identifier *Identifier - Sequence *moreatomic.Int64 - - PacerLoop wsutil.PacemakerLoop - - ErrorLog func(err error) // default to log.Println - // FatalErrorCallback is called, if the Gateway exits fatally. At the point - // of calling, the gateway will be already closed. - // - // Currently this will only be called, if the ReconnectTimeout was changed - // to a definite timeout, and connection could not be established during - // that time. - // err will be ErrWSMaxTries in that case. - // - // Defaults to noop. - FatalErrorCallback func(err error) - - // AfterClose is called after each close or pause. It is used mainly for - // reconnections or any type of connection interruptions. - // - // Constructors will use a no-op function by default. - AfterClose func(err error) - - onShardingRequired func() - - waitGroup sync.WaitGroup - closed chan struct{} +// State contains the gateway state. It is a piece of data that can be shared +// across gateways during construction to be used for resuming a connection or +// starting a new one with the previous data. +// +// The data structure itself is not thread-safe, so they may only be pulled from +// the gateway after it's done and set before it's done. +type State struct { + Identifier Identifier + SessionID string + Sequence int64 } -// NewGatewayWithIntents creates a new Gateway with the given intents and the -// default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents. -func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) { - g, err := NewGateway(token) +// Gateway describes an instance that handles the Discord gateway. It is +// basically an abstracted concurrent event loop that the user could signal to +// start connecting to the Discord gateway server. +type Gateway struct { + gateway *ws.Gateway + state State + + // non-mutex-guarded states + // TODO: make lastBeat part of ws.Gateway so it can keep track of whether or + // not the websocket is dead. + beatMutex sync.Mutex + sentBeat time.Time + echoBeat time.Time + retryTimer lazytime.Timer +} + +// NewWithIntents creates a new Gateway with the given intents and the default +// stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents. +func NewWithIntents(ctx context.Context, token string, intents ...Intents) (*Gateway, error) { + var allIntents Intents + for _, intent := range intents { + allIntents |= intent + } + + g, err := New(ctx, token) if err != nil { return nil, err } - for _, intent := range intents { - g.AddIntents(intent) - } - + g.AddIntents(allIntents) return g, nil } -// NewGateway creates a new Gateway to the default Discord server. -func NewGateway(token string) (*Gateway, error) { - return NewIdentifiedGateway(DefaultIdentifier(token)) +// New creates a new Gateway to the default Discord server. +func New(ctx context.Context, token string) (*Gateway, error) { + return NewWithIdentifier(ctx, DefaultIdentifier(token)) } -// NewIdentifiedGateway creates a new Gateway with the given gateway identifier -// and the default everything. Sharded bots should prefer this function for the -// shared identifier. -func NewIdentifiedGateway(id *Identifier) (*Gateway, error) { - var gatewayURL string - var botData *api.BotData - var err error - - if strings.HasPrefix(id.Token, "Bot ") { - botData, err = BotURL(id.Token) - if err != nil { - return nil, errors.Wrap(err, "failed to get bot data") - } - gatewayURL = botData.URL - - } else { - gatewayURL, err = URL() - if err != nil { - return nil, errors.Wrap(err, "failed to get gateway endpoint") - } +// NewWithIdentifier creates a new Gateway with the given gateway identifier and +// the default everything. Sharded bots should prefer this function for the +// shared identifier. The given Identifier will be modified. +func NewWithIdentifier(ctx context.Context, id Identifier) (*Gateway, error) { + gatewayURL, err := id.QueryGateway(ctx) + if err != nil { + return nil, err } gatewayURL = AddGatewayParams(gatewayURL) - gateway := NewCustomIdentifiedGateway(gatewayURL, id) - - // Use the supplied connect rate limit, if any. - if botData != nil && botData.StartLimit != nil { - resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration()) - limiter := gateway.Identifier.IdentifyGlobalLimit - - // Update the burst to be the current given time and reset it back to - // the default when the given time is reached. - limiter.SetBurst(botData.StartLimit.Remaining) - limiter.SetBurstAt(resetAt, botData.StartLimit.Total) - - // Update the maximum number of identify requests allowed per 5s. - gateway.Identifier.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency) - } + gateway := NewCustomWithIdentifier(gatewayURL, id, nil) return gateway, nil } -// NewCustomGateway creates a new Gateway with a custom gateway URL and a new +// NewCustom creates a new Gateway with a custom gateway URL and a new // Identifier. Most bots connecting to the official server should not use these // custom functions. -func NewCustomGateway(gatewayURL, token string) *Gateway { - return NewCustomIdentifiedGateway(gatewayURL, DefaultIdentifier(token)) +func NewCustom(gatewayURL, token string) *Gateway { + return NewCustomWithIdentifier(gatewayURL, DefaultIdentifier(token), nil) } -// NewCustomIdentifiedGateway creates a new Gateway with a custom gateway URL -// and a pre-existing Identifier. Refer to NewCustomGateway. -func NewCustomIdentifiedGateway(gatewayURL string, id *Identifier) *Gateway { +// DefaultGatewayOpts contains the default options to be used for connecting to +// the gateway. +var DefaultGatewayOpts = ws.GatewayOpts{ + ReconnectDelay: func(try int) time.Duration { + // minimum 4 seconds + return time.Duration(4+(2*try)) * time.Second + }, + // FatalCloseCodes contains the default gateway close codes that will cause + // the gateway to exit. In other words, it's a list of unrecoverable close + // codes. + FatalCloseCodes: []int{ + 4003, // not authenticated + 4004, // authentication failed + 4010, // invalid shard sent + 4011, // sharding required + 4012, // invalid API version + 4013, // invalid intents + 4014, // disallowed intents + }, + DialTimeout: 0, + ReconnectAttempt: 0, + AlwaysCloseGracefully: true, +} + +// NewCustomWithIdentifier creates a new Gateway with a custom gateway URL and a +// pre-existing Identifier. If opts is nil, then DefaultGatewayOpts is used. +func NewCustomWithIdentifier(gatewayURL string, id Identifier, opts *ws.GatewayOpts) *Gateway { + return NewFromState(gatewayURL, State{Identifier: id}, opts) +} + +// NewFromState creates a new gateway from the given state and optionally +// gateway options. If opts is nil, then DefaultGatewayOpts is used. +func NewFromState(gatewayURL string, state State, opts *ws.GatewayOpts) *Gateway { + if opts == nil { + opts = &DefaultGatewayOpts + } + + gw := ws.NewGateway(ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), gatewayURL), opts) + return &Gateway{ - WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL), - WSTimeout: wsutil.WSTimeout, - - Events: make(chan Event, wsutil.WSBuffer), - Identifier: id, - Sequence: moreatomic.NewInt64(0), - - ErrorLog: wsutil.WSError, - AfterClose: func(error) {}, - PacerLoop: wsutil.PacemakerLoop{ErrorLog: wsutil.WSError}, + gateway: gw, + state: state, } } -// AddIntents adds a Gateway Intent before connecting to the Gateway. As such, -// this function will only work before Open() is called. +// State returns a copy of the gateway's internal state. It panics if the +// gateway is currently running. +func (g *Gateway) State() State { + g.gateway.AssertIsNotRunning() + return g.state +} + +// SetState sets the gateway's state. +func (g *Gateway) SetState(state State) { + g.gateway.AssertIsNotRunning() + g.state = state +} + +// AddIntents adds a Gateway Intent before connecting to the Gateway. This +// function will only work before Connect() is called. Calling it once Connect() +// is called will result in a panic. func (g *Gateway) AddIntents(i Intents) { - if g.Identifier.Intents == nil { - g.Identifier.Intents = option.NewUint(uint(i)) - } else { - *g.Identifier.Intents |= uint(i) - } + g.gateway.AssertIsNotRunning() + g.state.Identifier.AddIntents(i) } -// HasIntents reports if the Gateway has the passed Intents. +// SentBeat returns the last time that the heart was beaten. If the gateway has +// never connected, then a zero-value time is returned. +func (g *Gateway) SentBeat() time.Time { + g.beatMutex.Lock() + defer g.beatMutex.Unlock() + + return g.sentBeat +} + +// EchoBeat returns the last time that the heartbeat was acknowledged. It is +// similar to SentBeat. +func (g *Gateway) EchoBeat() time.Time { + g.beatMutex.Lock() + defer g.beatMutex.Unlock() + + return g.echoBeat +} + +// Latency is a convenient function around SentBeat and EchoBeat. It subtracts +// the EchoBeat with the SentBeat. +func (g *Gateway) Latency() time.Duration { + g.beatMutex.Lock() + defer g.beatMutex.Unlock() + + return g.echoBeat.Sub(g.sentBeat) +} + +// LastError returns the last error that the gateway has received. It only +// returns a valid error if the gateway's event loop as exited. If the event +// loop hasn't been started AND stopped, the function will panic. +func (g *Gateway) LastError() error { + return g.gateway.LastError() +} + +// Send is a function to send an Op payload to the Gateway. +func (g *Gateway) Send(ctx context.Context, data ws.Event) error { + return g.gateway.Send(ctx, data) +} + +// Connect starts the background goroutine that tries its best to maintain a +// stable connection to the Discord gateway. To the user, the gateway should +// appear to be working seamlessly. // -// If no intents are set, i.e. if using a user account HasIntents will always -// return true. -func (g *Gateway) HasIntents(intents Intents) bool { - if g.Identifier.Intents == nil { - return true - } - - return Intents(*g.Identifier.Intents).Has(intents) -} - -// Close closes the underlying Websocket connection, invalidating the session -// ID. +// Behaviors // -// It will send a closing frame before ending the connection, closing it -// gracefully. This will cause the bot to appear as offline instantly. -func (g *Gateway) Close() error { - return g.close(true) -} - -// Pause pauses the Gateway connection, by ending the connection without -// sending a closing frame. This allows the connection to be resumed at a later -// point, by calling Reconnect or ReconnectCtx. -func (g *Gateway) Pause() error { - return g.close(false) -} - -func (g *Gateway) close(graceful bool) (err error) { - wsutil.WSDebug("Trying to close. Pacemaker check skipped.") - wsutil.WSDebug("Closing the Websocket...") - - if graceful { - err = g.WS.CloseGracefully() - } else { - err = g.WS.Close() - } - - if errors.Is(err, wsutil.ErrWebsocketClosed) { - wsutil.WSDebug("Websocket already closed.") - return nil - } - - // Explicitly signal the pacemaker loop to stop. We should do this in case - // the Start function exited before it could bind the event channel into the - // loop. - g.PacerLoop.Stop() - wsutil.WSDebug("Websocket closed; error:", err) - - wsutil.WSDebug("Waiting for the Pacemaker loop to exit.") - g.waitGroup.Wait() - wsutil.WSDebug("Pacemaker loop exited.") - - g.AfterClose(err) - wsutil.WSDebug("AfterClose callback finished.") - - if graceful { - // If a Reconnect is in progress, signal to cancel. - close(g.closed) - - // Delete our session id, as we just invalidated it. - g.sessionMu.Lock() - g.sessionID = "" - g.sessionMu.Unlock() - } - - return err -} - -// SessionID returns the session ID received after Ready. This function is -// concurrently safe. -func (g *Gateway) SessionID() string { - g.sessionMu.RLock() - defer g.sessionMu.RUnlock() - - return g.sessionID -} - -// UseSessionID overrides the internal session ID for the one the user provides. -func (g *Gateway) UseSessionID(sessionID string) { - g.sessionMu.Lock() - defer g.sessionMu.Unlock() - - g.sessionID = sessionID -} - -// OnShardingRequired sets the function to be called if Discord closes with -// error code 4011 aka Sharding Required. When called, the Gateway will already -// be closed, and can (after increasing the number of shards) be reopened using -// Open. Reconnect or ReconnectCtx, however, will not be available as the -// session is invalidated. +// There are several behaviors that the gateway will overload onto the channel. // -// The gateway will completely halt what it's doing in the background when this -// callback is called. -func (g *Gateway) OnShardingRequired(fn func()) { - g.onShardingRequired = fn +// Once the gateway has exited, fatally or not, the event channel returned by +// Connect will be closed. The user should therefore know whether or not the +// gateway has exited by spinning on the channel until it is closed. +// +// If Connect is called twice, the second call will return the same exact +// channel that the first call has made without starting any new goroutines, +// except if the gateway is already closed, then a new gateway will spin up with +// the existing gateway state. +// +// If the gateway stumbles upon any background errors, it will do its best to +// recover from it, but errors will be notified to the user using the +// BackgroundErrorEvent event. The user can type-assert the Op's data field, +// like so: +// +// switch data := ev.Data.(type) { +// case *gateway.BackgroundErrorEvent: +// log.Println("gateway error:", data.Error) +// } +// +// Closing +// +// As outlined in the first paragraph, closing the gateway would involve +// cancelling the context that's given to gateway. If AlwaysCloseGracefully is +// true (which it is by default), then the gateway is closed gracefully, and the +// session ID is invalidated. +// +// To wait until the gateway has completely successfully exited, the user can +// keep spinning on the event loop: +// +// for op := range ch { +// select op.Data.(type) { +// case *gateway.ReadyEvent: +// // Close the gateway on READY. +// cancel() +// } +// } +// +// // Gateway is now completely closed. +// +// To capture the final close errors, the user can use the Error method once the +// event channel is closed, like so: +// +// var err error +// +// for op := range ch { +// switch data := op.Data.(type) { +// case *gateway.ReadyEvent: +// cancel() +// } +// } +// +// // Gateway is now completely closed. +// if gateway.LastError() != nil { +// return gateway.LastError() +// } +// +func (g *Gateway) Connect(ctx context.Context) <-chan ws.Op { + return g.gateway.Connect(ctx, &gatewayImpl{Gateway: g}) } -// Reconnect tries to reconnect to the Gateway until the ReconnectAttempts are -// reached. -func (g *Gateway) Reconnect() { - g.ReconnectCtx(context.Background()) +type gatewayImpl struct { + *Gateway + lastSentBeat time.Time } -// ReconnectCtx attempts to Reconnect until context expires. -// If the context expires FatalErrorCallback will be called with ErrWSMaxTries, -// and the last error returned by Open will be returned. -func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) { - wsutil.WSDebug("Reconnecting...") +func (g *gatewayImpl) invalidate() { + g.state.SessionID = "" + g.state.Sequence = 0 +} - // Guarantee the gateway is already closed. Ignore its error, as we're - // redialing anyway. - g.Pause() +// sendIdentify sends off the Identify command with the Gateway's IdentifyData +// with the given context for timeout. +func (g *gatewayImpl) sendIdentify(ctx context.Context) error { + if err := g.state.Identifier.Wait(ctx); err != nil { + return errors.Wrap(err, "can't wait for identify()") + } - for try := uint(1); g.ReconnectAttempts == 0 || g.ReconnectAttempts >= try; try++ { - select { - case <-g.closed: - g.ErrorLog(ErrClosed) - return ErrClosed - case <-ctx.Done(): - wsutil.WSDebug("Unable to Reconnect after", try, "attempts, aborting") - g.FatalErrorCallback(ErrWSMaxTries) - return err - default: + return g.gateway.Send(ctx, &g.state.Identifier.IdentifyCommand) +} + +func (g *gatewayImpl) sendResume(ctx context.Context) error { + return g.gateway.Send(ctx, &ResumeCommand{ + Token: g.state.Identifier.Token, + SessionID: g.state.SessionID, + Sequence: g.state.Sequence, + }) +} + +func (g *gatewayImpl) OnOp(ctx context.Context, op ws.Op) bool { + if op.Code == dispatchOp { + g.state.Sequence = op.Sequence + } + + switch data := op.Data.(type) { + case *ws.CloseEvent: + if data.Code == CodeInvalidSequence { + // Invalid sequence. + g.invalidate() } - wsutil.WSDebug("Trying to dial, attempt", try) + g.gateway.QueueReconnect() - // if we encounter an error, make sure we return it, and not nil - if oerr := g.Open(ctx); oerr != nil { - err = oerr - g.ErrorLog(oerr) + case *HelloEvent: + g.gateway.ResetHeartbeat(data.HeartbeatInterval.Duration()) - wait := time.Duration(4+2*try) * time.Second - if wait > 60*time.Second { - wait = 60 * time.Second + // Send Discord either the Identify packet (if it's a fresh + // connection), or a Resume packet (if it's a dead connection). + if g.state.SessionID == "" || g.state.Sequence == 0 { + // SessionID is empty, so this is a completely new session. + if err := g.sendIdentify(ctx); err != nil { + g.gateway.SendErrorWrap(err, "failed to send identify") + g.gateway.QueueReconnect() } - - time.Sleep(wait) - continue - } - - wsutil.WSDebug("Started after attempt:", try) - return nil - } - - wsutil.WSDebug("Unable to Reconnect after", g.ReconnectAttempts, "attempts, aborting") - return err -} - -// Open connects to the Websocket and authenticates it. You should usually use -// this function over Start(). The given context provides cancellation and -// timeout. -func (g *Gateway) Open(ctx context.Context) error { - // Reconnect to the Gateway - if err := g.WS.Dial(ctx); err != nil { - return errors.Wrap(err, "failed to Reconnect") - } - - wsutil.WSDebug("Trying to start...") - - // Try to resume the connection - if err := g.StartCtx(ctx); err != nil { - return err - } - - // Started successfully, return - return nil -} - -// Start calls StartCtx with a background context. You wouldn't usually use this -// function, but Open() instead. -func (g *Gateway) Start() error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - return g.StartCtx(ctx) -} - -// StartCtx authenticates with the websocket, or resume from a dead Websocket -// connection. You wouldn't usually use this function, but OpenCtx() instead. -func (g *Gateway) StartCtx(ctx context.Context) error { - g.closed = make(chan struct{}) - - if err := g.start(ctx); err != nil { - wsutil.WSDebug("Start failed:", err) - - // Close can be called with the mutex still acquired here, as the - // pacemaker hasn't started yet. - if err := g.Close(); err != nil { - wsutil.WSDebug("Failed to close after start fail:", err) - } - return err - } - - return nil -} - -func (g *Gateway) start(ctx context.Context) error { - // This is where we'll get our events - ch := g.WS.Listen() - - // Create a new Hello event and wait for it. - var hello HelloEvent - // Wait for an OP 10 Hello. - select { - case e, ok := <-ch: - if !ok { - return errors.New("unexpected ws close while waiting for Hello") - } - if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil { - return errors.Wrap(err, "error at Hello") - } - case <-ctx.Done(): - return errors.Wrap(ctx.Err(), "failed to wait for Hello event") - } - - wsutil.WSDebug("Hello received; duration:", hello.HeartbeatInterval) - - // Start the event handler, which also handles the pacemaker death signal. - g.waitGroup.Add(1) - - // Use the pacemaker loop. - g.PacerLoop.StartBeating(hello.HeartbeatInterval.Duration(), g, func(err error) { - g.waitGroup.Done() // mark so Close() can exit. - wsutil.WSDebug("Event loop stopped with error:", err) - - if err != nil && g.onShardingRequired != nil { - // If Discord signals us sharding is required, do not attempt to - // Reconnect, unless we don't know what to do. Instead invalidate - // our session ID, as we cannot resume, call OnShardingRequired, and - // exit. - var cerr *websocket.CloseError - if errors.As(err, &cerr) && cerr.Code == errCodeShardingRequired { - g.ErrorLog(cerr) - g.UseSessionID("") - g.onShardingRequired() - return + } else { + if err := g.sendResume(ctx); err != nil { + g.gateway.SendErrorWrap(err, "failed to send resume") + g.gateway.QueueReconnect() } } - // Bail if there is no error or if the error is an explicit close, as - // there might be an ongoing reconnection. - if err == nil || errors.Is(err, wsutil.ErrWebsocketClosed) { - return + case *InvalidSessionEvent: + // Wipe the session state. + g.invalidate() + + if !*data { + g.gateway.QueueReconnect() + break } - // Only attempt to Reconnect if we have a session ID at all. We may not - // have one if we haven't even connected successfully once. - if g.SessionID() != "" { - g.ErrorLog(err) - g.Reconnect() + // Discord expects us to wait before reconnecting. + g.retryTimer.Reset(time.Duration(rand.Intn(5)+1) * time.Second) + if err := g.retryTimer.Wait(ctx); err != nil { + g.gateway.SendErrorWrap(err, "failed to wait before identifying") + g.gateway.QueueReconnect() + break } - }) - // Send Discord either the Identify packet (if it's a fresh connection), or - // a Resume packet (if it's a dead connection). - if g.SessionID() == "" { - // SessionID is empty, so this is a completely new session. - if err := g.IdentifyCtx(ctx); err != nil { - return errors.Wrap(err, "failed to identify") - } - } else { - if err := g.ResumeCtx(ctx); err != nil { - return errors.Wrap(err, "failed to resume") + // If we fail to identify, then the gateway cannot continue with + // a bad identification, since it's likely a user error. + if err := g.sendIdentify(ctx); err != nil { + g.gateway.SendErrorWrap(err, "failed to identify") + g.gateway.QueueReconnect() + break } + + case *HeartbeatCommand: + g.SendHeartbeat(ctx) + + case *HeartbeatAckEvent: + now := time.Now() + + g.beatMutex.Lock() + g.sentBeat = g.lastSentBeat + g.echoBeat = now + g.beatMutex.Unlock() + + case *ReconnectEvent: + g.gateway.QueueReconnect() + + case *ReadyEvent: + g.state.SessionID = data.SessionID } - // Expect either READY or RESUMED before continuing. - wsutil.WSDebug("Waiting for either READY or RESUMED.") + return true +} - // WaitForEvent should until the bot becomes ready or resumes (if a - // previous ready event has already been called). - err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool { - switch op.EventName { - case "READY": - wsutil.WSDebug("Found READY event.") - return true - case "RESUMED": - wsutil.WSDebug("Found RESUMED event.") - return true - } - return false - }) +// SendHeartbeat sends a heartbeat with the gateway's current sequence. +func (g *gatewayImpl) SendHeartbeat(ctx context.Context) { + g.lastSentBeat = time.Now() - if err != nil { - return errors.Wrap(err, "first error") + sequence := HeartbeatCommand(g.state.Sequence) + if err := g.gateway.Send(ctx, &sequence); err != nil { + g.gateway.SendErrorWrap(err, "heartbeat error") + g.gateway.QueueReconnect() } +} - // Bind the event channel to the pacemaker loop. - g.PacerLoop.SetEventChannel(ch) - - wsutil.WSDebug("Started successfully.") - +// Close closes the state. +func (g *gatewayImpl) Close() error { + g.retryTimer.Stop() + g.invalidate() return nil } - -// SendCtx is a low-level function to send an OP payload to the Gateway. Most -// users shouldn't touch this, unless they know what they're doing. -func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error { - var op = wsutil.OP{ - Code: code, - } - - if v != nil { - b, err := json.Marshal(v) - if err != nil { - return errors.Wrap(err, "failed to encode v") - } - - op.Data = b - } - - b, err := json.Marshal(op) - if err != nil { - return errors.Wrap(err, "failed to encode payload") - } - - // WS should already be thread-safe. - return g.WS.SendCtx(ctx, b) -} diff --git a/gateway/gateway_example_test.go b/gateway/gateway_example_test.go new file mode 100644 index 0000000..206285f --- /dev/null +++ b/gateway/gateway_example_test.go @@ -0,0 +1,31 @@ +package gateway_test + +import ( + "context" + "log" + "os" + "os/signal" + + "github.com/diamondburned/arikawa/v3/gateway" +) + +func Example() { + token := os.Getenv("BOT_TOKEN") + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + g, err := gateway.NewWithIntents(ctx, token, gateway.IntentGuilds) + if err != nil { + log.Fatalln("failed to initialize gateway:", err) + } + + for op := range g.Connect(ctx) { + switch data := op.Data.(type) { + case *gateway.ReadyEvent: + log.Println("logged in as", data.User.Username) + case *gateway.MessageCreateEvent: + log.Println("got message", data.Content) + } + } +} diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go new file mode 100644 index 0000000..a72c791 --- /dev/null +++ b/gateway/gateway_test.go @@ -0,0 +1,197 @@ +package gateway + +import ( + "context" + "log" + "strings" + "sync" + "testing" + "time" + + "github.com/diamondburned/arikawa/v3/internal/testenv" + "github.com/diamondburned/arikawa/v3/utils/ws" +) + +var doLogOnce sync.Once + +func doLog() { + doLogOnce.Do(func() { + if testing.Verbose() { + ws.WSDebug = func(v ...interface{}) { + log.Println(append([]interface{}{"Debug:"}, v...)...) + } + } + }) +} + +func TestURL(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + t.Cleanup(cancel) + + u, err := URL(ctx) + if err != nil { + t.Fatal("failed to get gateway URL:", err) + } + + if u == "" { + t.Fatal("gateway URL is empty") + } + + if !strings.HasPrefix(u, "wss://") { + t.Fatal("gatewayURL is invalid:", u) + } +} + +func TestInvalidToken(t *testing.T) { + doLog() + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + t.Cleanup(cancel) + + g, err := New(ctx, "bad token") + if err != nil { + t.Fatal("failed to make a Gateway:", err) + } + + assertIsClose := func(err error) { + if err == nil { + t.Fatal("unexpected nil error") + } + + // 4004 Authentication Failed. + if !strings.Contains(err.Error(), "4004") { + t.Fatal("unexpected error:", err) + } + } + + for op := range g.Connect(ctx) { + if op.Data == nil { + // This shouldn't happen; the loop should've broken out. + t.Fatal("nil event received") + } + + switch data := op.Data.(type) { + case *ws.CloseEvent: + assertIsClose(data) + case *ws.BackgroundErrorEvent: + t.Error("gateway error:", data) + case *HelloEvent: + t.Log("got Hello") + case *InvalidSessionEvent: + t.Log("got InvalidSession") + default: + t.Errorf("got unexpected event %#v", data) + } + } + + assertIsClose(g.LastError()) +} + +func TestIntegration(t *testing.T) { + doLog() + + config := testenv.Must(t) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + + // NewGateway should call Start for us. + g, err := NewWithIntents(ctx, "Bot "+config.BotToken, IntentGuilds) + if err != nil { + t.Fatal("failed to make a Gateway:", err) + } + + gatewayOpenAndSpin(t, ctx, g) + cancel() +} + +func TestReuseGateway(t *testing.T) { + doLog() + + config := testenv.Must(t) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + + // NewGateway should call Start for us. + g, err := NewWithIntents(ctx, "Bot "+config.BotToken, IntentGuilds) + if err != nil { + t.Fatal("failed to make a Gateway:", err) + } + + // Reuse this 3 times. + for i := 0; i < 3; i++ { + cctx, cancel := context.WithCancel(ctx) + gatewayOpenAndSpin(t, cctx, g) + cancel() + } +} + +func gatewayOpenAndSpin(t *testing.T, ctx context.Context, g *Gateway) { + ch := g.Connect(ctx) + + var reconnected bool + reconnect := func() { + if !reconnected { + reconnected = true + g.gateway.QueueReconnect() + } + } + + for op := range ch { + if op.Data == nil { + // This shouldn't happen; the loop should've broken out. + t.Fatal("nil event received") + } + + switch data := op.Data.(type) { + case *ReadyEvent: + t.Log("got Ready") + if g.state.SessionID != data.SessionID { + t.Fatal("missing SessionID") + } + log.Println("Bot's username is", data.User.Username) + reconnect() + case *ResumedEvent: + t.Log("got Resumed, test done") + return + case *HelloEvent: + t.Log("got Hello") + case *ws.BackgroundErrorEvent: + t.Error("gateway error:", data) + default: + t.Logf("got event %T", data) + } + } +} + +func wait(t *testing.T, evCh chan interface{}) interface{} { + select { + case ev := <-evCh: + return ev + case <-time.After(20 * time.Second): + t.Fatal("timed out waiting for event") + return nil + } +} + +func gotimeout(t *testing.T, fn func(context.Context)) { + t.Helper() + + // Try and reconnect for 20 seconds maximum. + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + var done = make(chan struct{}) + go func() { + fn(ctx) + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + t.Fatal("timed out waiting for function.") + case <-done: + return + } +} diff --git a/gateway/identify.go b/gateway/identify.go index 8cb8b29..cb89a54 100644 --- a/gateway/identify.go +++ b/gateway/identify.go @@ -3,8 +3,10 @@ package gateway import ( "context" "runtime" + "strings" "time" + "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/utils/json/option" "github.com/pkg/errors" @@ -13,27 +15,27 @@ import ( // DefaultPresence is used as the default presence when initializing a new // Gateway. -var DefaultPresence *UpdateStatusData +var DefaultPresence *UpdatePresenceCommand -// Identifier is a wrapper around IdentifyData to add in appropriate rate +// Identifier is a wrapper around IdentifyCommand to add in appropriate rate // limiters. type Identifier struct { - IdentifyData + IdentifyCommand IdentifyShortLimit *rate.Limiter `json:"-"` // optional IdentifyGlobalLimit *rate.Limiter `json:"-"` // optional } // DefaultIdentifier creates a new default Identifier -func DefaultIdentifier(token string) *Identifier { - return NewIdentifier(DefaultIdentifyData(token)) +func DefaultIdentifier(token string) Identifier { + return NewIdentifier(DefaultIdentifyCommand(token)) } -// NewIdentifier creates a new identifier with the given IdentifyData and +// NewIdentifier creates a new identifier with the given IdentifyCommand and // default rate limiters. -func NewIdentifier(data IdentifyData) *Identifier { - return &Identifier{ - IdentifyData: data, +func NewIdentifier(data IdentifyCommand) Identifier { + return Identifier{ + IdentifyCommand: data, IdentifyShortLimit: rate.NewLimiter(rate.Every(5*time.Second), 1), IdentifyGlobalLimit: rate.NewLimiter(rate.Every(24*time.Hour), 1000), } @@ -41,15 +43,15 @@ func NewIdentifier(data IdentifyData) *Identifier { // Wait waits for the rate limiters to pass. If a limiter is nil, then it will // not be used to wait. This is useful -func (i *Identifier) Wait(ctx context.Context) error { - if i.IdentifyShortLimit != nil { - if err := i.IdentifyShortLimit.Wait(ctx); err != nil { +func (id *Identifier) Wait(ctx context.Context) error { + if id.IdentifyShortLimit != nil { + if err := id.IdentifyShortLimit.Wait(ctx); err != nil { return errors.Wrap(err, "can't wait for short limit") } } - if i.IdentifyGlobalLimit != nil { - if err := i.IdentifyGlobalLimit.Wait(ctx); err != nil { + if id.IdentifyGlobalLimit != nil { + if err := id.IdentifyGlobalLimit.Wait(ctx); err != nil { return errors.Wrap(err, "can't wait for global limit") } } @@ -57,6 +59,41 @@ func (i *Identifier) Wait(ctx context.Context) error { return nil } +// QueryGateway queries the gateway for the URL and updates the Identifier with +// the appropriate information. +func (id *Identifier) QueryGateway(ctx context.Context) (gatewayURL string, err error) { + var botData *api.BotData + + if strings.HasPrefix(id.Token, "Bot ") { + botData, err = BotURL(ctx, id.Token) + if err != nil { + return "", errors.Wrap(err, "failed to get bot data") + } + gatewayURL = botData.URL + } else { + gatewayURL, err = URL(ctx) + if err != nil { + return "", errors.Wrap(err, "failed to get gateway endpoint") + } + } + + // Use the supplied connect rate limit, if any. + if botData != nil && botData.StartLimit != nil { + resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration()) + limiter := id.IdentifyGlobalLimit + + // Update the burst to be the current given time and reset it back to + // the default when the given time is reached. + limiter.SetBurst(botData.StartLimit.Remaining) + limiter.SetBurstAt(resetAt, botData.StartLimit.Total) + + // Update the maximum number of identify requests allowed per 5s. + id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency) + } + + return +} + // DefaultIdentity is used as the default identity when initializing a new // Gateway. var DefaultIdentity = IdentifyProperties{ @@ -65,9 +102,9 @@ var DefaultIdentity = IdentifyProperties{ Device: "Arikawa", } -// IdentifyData is the struct for a data that's sent over in an Identify -// command. -type IdentifyData struct { +// IdentifyCommand is a command for Op 2. It is the struct for a data that's +// sent over in an Identify command. +type IdentifyCommand struct { Token string `json:"token"` Properties IdentifyProperties `json:"properties"` @@ -76,7 +113,7 @@ type IdentifyData struct { Shard *Shard `json:"shard,omitempty"` // [ shard_id, num_shards ] - Presence *UpdateStatusData `json:"presence,omitempty"` + Presence *UpdatePresenceCommand `json:"presence,omitempty"` // ClientState is the client state for a user's accuont. Bot accounts should // NOT touch this field. @@ -97,9 +134,9 @@ type IdentifyData struct { Intents option.Uint `json:"intents"` } -// DefaultIdentifyData creates a default IdentifyData with the given token. -func DefaultIdentifyData(token string) IdentifyData { - return IdentifyData{ +// DefaultIdentifyCommand creates a default IdentifyCommand with the given token. +func DefaultIdentifyCommand(token string) IdentifyCommand { + return IdentifyCommand{ Token: token, Properties: DefaultIdentity, Presence: DefaultPresence, @@ -110,14 +147,35 @@ func DefaultIdentifyData(token string) IdentifyData { } // SetShard is a helper function to set the shard configuration inside -// IdentifyData. -func (i *IdentifyData) SetShard(id, num int) { +// IdentifyCommand. +func (i *IdentifyCommand) SetShard(id, num int) { if i.Shard == nil { i.Shard = new(Shard) } i.Shard[0], i.Shard[1] = id, num } +// AddIntents adds gateway intents into the identify data. +func (i *IdentifyCommand) AddIntents(intents Intents) { + if i.Intents == nil { + i.Intents = option.NewUint(uint(intents)) + } else { + *i.Intents |= uint(intents) + } +} + +// HasIntents reports if the Gateway has the passed Intents. +// +// If no intents are set, e.g. if using a user account, HasIntents will always +// return true. +func (i *IdentifyCommand) HasIntents(intents Intents) bool { + if i.Intents == nil { + return true + } + + return Intents(*i.Intents).Has(intents) +} + type IdentifyProperties struct { // Required OS string `json:"os"` // GOOS diff --git a/gateway/integration_test.go b/gateway/integration_test.go deleted file mode 100644 index bf596a8..0000000 --- a/gateway/integration_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package gateway - -import ( - "context" - "log" - "strings" - "testing" - "time" - - "github.com/diamondburned/arikawa/v3/internal/heart" - "github.com/diamondburned/arikawa/v3/internal/testenv" - "github.com/diamondburned/arikawa/v3/utils/wsutil" -) - -func init() { - wsutil.WSDebug = func(v ...interface{}) { - log.Println(append([]interface{}{"Debug:"}, v...)...) - } - heart.Debug = func(v ...interface{}) { - log.Println(append([]interface{}{"Heart:"}, v...)...) - } -} - -func TestURL(t *testing.T) { - u, err := URL() - if err != nil { - t.Fatal("failed to get gateway URL:", err) - } - - if u == "" { - t.Fatal("gateway URL is empty") - } - - if !strings.HasPrefix(u, "wss://") { - t.Fatal("gatewayURL is invalid:", u) - } -} - -func TestInvalidToken(t *testing.T) { - g, err := NewGateway("bad token") - if err != nil { - t.Fatal("failed to make a Gateway:", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - if err = g.Open(ctx); err == nil { - t.Fatal("unexpected success while opening with a bad token.") - } - - // 4004 Authentication Failed. - if !strings.Contains(err.Error(), "4004") { - t.Fatal("unexpected error:", err) - } -} - -func TestIntegration(t *testing.T) { - config := testenv.Must(t) - - var gateway *Gateway - - // NewGateway should call Start for us. - g, err := NewGateway("Bot " + config.BotToken) - if err != nil { - t.Fatal("failed to make a Gateway:", err) - } - g.AddIntents(IntentGuilds) - g.AfterClose = func(err error) { - t.Log("closed.") - } - g.ErrorLog = func(err error) { - t.Log("gateway error:", err) - } - gateway = g - - gotimeout(t, func(ctx context.Context) { - if err := g.Open(ctx); err != nil { - t.Fatal("failed to authenticate with Discord:", err) - } - }) - - ev := wait(t, gateway.Events) - ready, ok := ev.(*ReadyEvent) - if !ok { - t.Fatal("event received is not of type Ready:", ev) - } - - if gateway.SessionID() == "" { - t.Fatal("session ID is empty") - } - - log.Println("Bot's username is", ready.User.Username) - - // Send a faster heartbeat every second for testing. - g.PacerLoop.SetPace(time.Second) - - // Sleep past the rate limiter before reconnecting: - time.Sleep(5 * time.Second) - - gotimeout(t, func(ctx context.Context) { - g.ErrorLog = func(err error) { - t.Error("unexpected error while reconnecting:", err) - } - - if err := gateway.ReconnectCtx(ctx); err != nil { - t.Error("failed to reconnect Gateway:", err) - } - }) - - g.ErrorLog = func(err error) { t.Log("warning:", err) } - - // Wait for the desired event: - gotimeout(t, func(context.Context) { - for ev := range gateway.Events { - switch ev.(type) { - // Accept only a Resumed event. - case *ResumedEvent: - return // exit - case *ReadyEvent: - t.Fatal("Ready event received instead of Resumed.") - } - } - }) - - if err := g.Close(); err != nil { - t.Fatal("failed to close Gateway:", err) - } -} - -func wait(t *testing.T, evCh chan interface{}) interface{} { - select { - case ev := <-evCh: - return ev - case <-time.After(20 * time.Second): - t.Fatal("timed out waiting for event") - return nil - } -} - -func gotimeout(t *testing.T, fn func(context.Context)) { - t.Helper() - - // Try and reconnect for 20 seconds maximum. - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - - var done = make(chan struct{}) - go func() { - fn(ctx) - done <- struct{}{} - }() - - select { - case <-ctx.Done(): - t.Fatal("timed out waiting for function.") - case <-done: - return - } -} diff --git a/gateway/intents.go b/gateway/intents.go index 234caf7..3d44749 100644 --- a/gateway/intents.go +++ b/gateway/intents.go @@ -1,6 +1,9 @@ package gateway -import "github.com/diamondburned/arikawa/v3/discord" +import ( + "github.com/diamondburned/arikawa/v3/discord" + "github.com/diamondburned/arikawa/v3/utils/ws" +) // Intents for the new Discord API feature, documented at // https://discord.com/developers/docs/topics/gateway#gateway-intents. @@ -44,7 +47,7 @@ func (i Intents) IsPrivileged() (presences, member bool) { } // EventIntents maps event types to intents. -var EventIntents = map[string]Intents{ +var EventIntents = map[ws.EventType]Intents{ "GUILD_CREATE": IntentGuilds, "GUILD_UPDATE": IntentGuilds, "GUILD_DELETE": IntentGuilds, diff --git a/gateway/op.go b/gateway/op.go deleted file mode 100644 index 01ed3b4..0000000 --- a/gateway/op.go +++ /dev/null @@ -1,118 +0,0 @@ -package gateway - -import ( - "context" - "fmt" - "math/rand" - "time" - - "github.com/diamondburned/arikawa/v3/utils/json" - "github.com/diamondburned/arikawa/v3/utils/wsutil" - "github.com/pkg/errors" -) - -type OPCode = wsutil.OPCode - -const ( - DispatchOP OPCode = 0 // recv - HeartbeatOP OPCode = 1 // send/recv - IdentifyOP OPCode = 2 // send... - StatusUpdateOP OPCode = 3 // - VoiceStateUpdateOP OPCode = 4 // - VoiceServerPingOP OPCode = 5 // - ResumeOP OPCode = 6 // - ReconnectOP OPCode = 7 // recv - RequestGuildMembersOP OPCode = 8 // send - InvalidSessionOP OPCode = 9 // recv... - HelloOP OPCode = 10 - HeartbeatAckOP OPCode = 11 - CallConnectOP OPCode = 13 - GuildSubscriptionsOP OPCode = 14 -) - -// ErrReconnectRequest is returned by HandleOP if a ReconnectOP is given. This -// is used mostly internally to signal the heartbeat loop to reconnect, if -// needed. It is not a fatal error. -var ErrReconnectRequest = errors.New("ReconnectOP received") - -func (g *Gateway) HandleOP(op *wsutil.OP) error { - switch op.Code { - case HeartbeatAckOP: - // Heartbeat from the server? - g.PacerLoop.Echo() - - case HeartbeatOP: - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - // Server requesting a heartbeat. - if err := g.PacerLoop.Pace(ctx); err != nil { - return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace")) - } - - case ReconnectOP: - // Server requests to Reconnect, die and retry. - wsutil.WSDebug("ReconnectOP received.") - - // Exit with the ReconnectOP error to force the heartbeat event loop to - // Reconnect synchronously. Not really a fatal error. - return wsutil.ErrBrokenConnection(ErrReconnectRequest) - - case InvalidSessionOP: - // Discord expects us to sleep for no reason - time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second) - - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - - // Invalid session, try and Identify. - if err := g.IdentifyCtx(ctx); err != nil { - // Can't identify, Reconnect. - return wsutil.ErrBrokenConnection(ErrReconnectRequest) - } - - return nil - - case HelloOP: - return nil - - case DispatchOP: - // Set the sequence - if op.Sequence > 0 { - g.Sequence.Set(op.Sequence) - } - - // Check if we know the event - fn, ok := EventCreator[op.EventName] - if !ok { - return &wsutil.UnknownEventError{ - Name: op.EventName, - Data: op.Data, - } - } - - // Make a new pointer to the event - var ev = fn() - - // Try and parse the event - if err := json.Unmarshal(op.Data, ev); err != nil { - return errors.Wrap(err, "failed to parse event "+op.EventName) - } - - // If the event is a ready, we'll want its sessionID - if ev, ok := ev.(*ReadyEvent); ok { - g.sessionMu.Lock() - g.sessionID = ev.SessionID - g.sessionMu.Unlock() - } - - // Throw the event into a channel; it's valid now. - g.Events <- ev - return nil - - default: - return fmt.Errorf("unknown OP code %d (event %s)", op.Code, op.EventName) - } - - return nil -} diff --git a/gateway/perseverance_test.go b/gateway/perseverance_test.go deleted file mode 100644 index b7f1995..0000000 --- a/gateway/perseverance_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// +build perseverance - -package gateway - -import ( - "testing" - "time" - - "github.com/diamondburned/arikawa/v3/internal/testenv" -) - -func TestPerseverance(t *testing.T) { - t.Parallel() - - config := testenv.Must(t) - - g, err := NewGateway("Bot " + config.BotToken) - if err != nil { - t.Fatal("failed to make the gateway:", err) - } - g.AddIntents(IntentGuilds) - - if err := g.Open(); err != nil { - t.Fatal("failed to open the gateway:", err) - } - - timeout := make(chan struct{}, 1) - - // Automatically close the gateway after set duration. - time.AfterFunc(testenv.PerseveranceTime, func() { - t.Log("Perserverence test finshed. Closing gateway.") - timeout <- struct{}{} - - if err := g.Close(); err != nil { - t.Error("failed to close gateway:", err) - } - }) - - // Spin on events. - for ev := range g.Events { - t.Logf("Received event %T.", ev) - } - - // Exit gracefully if we have not. - select { - case <-timeout: - return - default: - } - - if err := g.Close(); err != nil { - t.Fatal("failed to clean up gateway after fail:", err) - } - - t.Fatal("Test failed before timeout.") -} diff --git a/gateway/ready.go b/gateway/ready.go deleted file mode 100644 index a818fa7..0000000 --- a/gateway/ready.go +++ /dev/null @@ -1,267 +0,0 @@ -package gateway - -import ( - "strconv" - "strings" - - "github.com/diamondburned/arikawa/v3/discord" -) - -type ( - // ReadyEvent is the struct for a READY event. - ReadyEvent struct { - Version int `json:"version"` - - User discord.User `json:"user"` - SessionID string `json:"session_id"` - - PrivateChannels []discord.Channel `json:"private_channels"` - Guilds []GuildCreateEvent `json:"guilds"` - - Shard *Shard `json:"shard,omitempty"` - - // Undocumented fields - - UserSettings *UserSettings `json:"user_settings,omitempty"` - ReadStates []ReadState `json:"read_state,omitempty"` - UserGuildSettings []UserGuildSetting `json:"user_guild_settings,omitempty"` - Relationships []discord.Relationship `json:"relationships,omitempty"` - Presences []discord.Presence `json:"presences,omitempty"` - - FriendSuggestionCount int `json:"friend_suggestion_count,omitempty"` - GeoOrderedRTCRegions []string `json:"geo_ordered_rtc_regions,omitempty"` - } - - // ReadState is a single ReadState entry. It is undocumented. - ReadState struct { - ChannelID discord.ChannelID `json:"id"` - LastMessageID discord.MessageID `json:"last_message_id"` - LastPinTimestamp discord.Timestamp `json:"last_pin_timestamp"` - MentionCount int `json:"mention_count"` - } - - // UserSettings is the struct for (almost) all user settings. It is - // undocumented. - UserSettings struct { - ShowCurrentGame bool `json:"show_current_game"` - DefaultGuildsRestricted bool `json:"default_guilds_restricted"` - InlineAttachmentMedia bool `json:"inline_attachment_media"` - InlineEmbedMedia bool `json:"inline_embed_media"` - GIFAutoPlay bool `json:"gif_auto_play"` - RenderEmbeds bool `json:"render_embeds"` - RenderReactions bool `json:"render_reactions"` - AnimateEmoji bool `json:"animate_emoji"` - AnimateStickers int `json:"animate_stickers"` - EnableTTSCommand bool `json:"enable_tts_command"` - MessageDisplayCompact bool `json:"message_display_compact"` - ConvertEmoticons bool `json:"convert_emoticons"` - ExplicitContentFilter uint8 `json:"explicit_content_filter"` // ??? - DisableGamesTab bool `json:"disable_games_tab"` - DeveloperMode bool `json:"developer_mode"` - DetectPlatformAccounts bool `json:"detect_platform_accounts"` - StreamNotification bool `json:"stream_notification_enabled"` - AccessibilityDetection bool `json:"allow_accessibility_detection"` - ContactSync bool `json:"contact_sync_enabled"` - NativePhoneIntegration bool `json:"native_phone_integration_enabled"` - - TimezoneOffset int `json:"timezone_offset"` - - Locale string `json:"locale"` - Theme string `json:"theme"` - - GuildPositions []discord.GuildID `json:"guild_positions"` - GuildFolders []GuildFolder `json:"guild_folders"` - RestrictedGuilds []discord.GuildID `json:"restricted_guilds"` - - FriendSourceFlags FriendSourceFlags `json:"friend_source_flags"` - - Status discord.Status `json:"status"` - CustomStatus *CustomUserStatus `json:"custom_status"` - } - - // CustomUserStatus is the custom user status that allows setting an emoji - // and a piece of text on each user. - CustomUserStatus struct { - Text string `json:"text"` - ExpiresAt discord.Timestamp `json:"expires_at,omitempty"` - EmojiID discord.EmojiID `json:"emoji_id,string"` - EmojiName string `json:"emoji_name"` - } - - // UserGuildSetting stores the settings for a single guild. It is - // undocumented. - UserGuildSetting struct { - GuildID discord.GuildID `json:"guild_id"` - - SuppressRoles bool `json:"suppress_roles"` - SuppressEveryone bool `json:"suppress_everyone"` - Muted bool `json:"muted"` - MuteConfig *UserMuteConfig `json:"mute_config"` - - MobilePush bool `json:"mobile_push"` - Notifications UserNotification `json:"message_notifications"` - - ChannelOverrides []UserChannelOverride `json:"channel_overrides"` - } - - // A UserChannelOverride struct describes a channel settings override for a - // users guild settings. - UserChannelOverride struct { - Muted bool `json:"muted"` - MuteConfig *UserMuteConfig `json:"mute_config"` - Notifications UserNotification `json:"message_notifications"` - ChannelID discord.ChannelID `json:"channel_id"` - } - - // UserMuteConfig seems to describe the mute settings. It belongs to the - // UserGuildSettingEntry and UserChannelOverride structs and is - // undocumented. - UserMuteConfig struct { - SelectedTimeWindow int `json:"selected_time_window"` - EndTime discord.Timestamp `json:"end_time"` - } - - // GuildFolder holds a single folder that you see in the left guild panel. - GuildFolder struct { - Name string `json:"name"` - ID GuildFolderID `json:"id"` - GuildIDs []discord.GuildID `json:"guild_ids"` - Color discord.Color `json:"color"` - } - - // FriendSourceFlags describes sources that friend requests could be sent - // from. It belongs to the UserSettings struct and is undocumented. - FriendSourceFlags struct { - All bool `json:"all,omitempty"` - MutualGuilds bool `json:"mutual_guilds,omitempty"` - MutualFriends bool `json:"mutual_friends,omitempty"` - } -) - -// UserNotification is the notification setting for a channel or guild. -type UserNotification uint8 - -const ( - AllNotifications UserNotification = iota - OnlyMentions - NoNotifications - GuildDefaults -) - -// GuildFolderID is possibly a snowflake. It can also be 0 (null) or a low -// number of unknown significance. -type GuildFolderID int64 - -func (g *GuildFolderID) UnmarshalJSON(b []byte) error { - var body = string(b) - if body == "null" { - return nil - } - - body = strings.Trim(body, `"`) - - u, err := strconv.ParseInt(body, 10, 64) - if err != nil { - return err - } - - *g = GuildFolderID(u) - return nil -} - -func (g GuildFolderID) MarshalJSON() ([]byte, error) { - if g == 0 { - return []byte("null"), nil - } - - return []byte(strconv.FormatInt(int64(g), 10)), nil -} - -// ReadySupplemental event structs. For now, this event is never used, and its -// usage have yet been discovered. -type ( - // ReadySupplementalEvent is the struct for a READY_SUPPLEMENTAL event, - // which is an undocumented event. - ReadySupplementalEvent struct { - Guilds []GuildCreateEvent `json:"guilds"` // only have ID and VoiceStates - MergedMembers [][]SupplementalMember `json:"merged_members"` - MergedPresences MergedPresences `json:"merged_presences"` - } - - // SupplementalMember is the struct for a member in the MergedMembers field - // of ReadySupplementalEvent. It has slight differences to discord.Member. - SupplementalMember struct { - UserID discord.UserID `json:"user_id"` - Nick string `json:"nick,omitempty"` - RoleIDs []discord.RoleID `json:"roles"` - - GuildID discord.GuildID `json:"guild_id,omitempty"` - IsPending bool `json:"pending,omitempty"` - HoistedRole discord.RoleID `json:"hoisted_role"` - - Mute bool `json:"mute"` - Deaf bool `json:"deaf"` - - // Joined specifies when the user joined the guild. - Joined discord.Timestamp `json:"joined_at"` - - // BoostedSince specifies when the user started boosting the guild. - BoostedSince discord.Timestamp `json:"premium_since,omitempty"` - } - - // MergedPresences is the struct for presences of guilds' members and - // friends. It is undocumented. - MergedPresences struct { - Guilds [][]SupplementalPresence `json:"guilds"` - Friends []SupplementalPresence `json:"friends"` - } - - // SupplementalPresence is a single presence for either a guild member or - // friend. It is used in MergedPresences and is undocumented. - SupplementalPresence struct { - UserID discord.UserID `json:"user_id"` - - // Status is either "idle", "dnd", "online", or "offline". - Status discord.Status `json:"status"` - // Activities are the user's current activities. - Activities []discord.Activity `json:"activities"` - // ClientStaus is the user's platform-dependent status. - ClientStatus discord.ClientStatus `json:"client_status"` - - // LastModified is only present in Friends. - LastModified discord.UnixMsTimestamp `json:"last_modified,omitempty"` - } -) - -// ConvertSupplementalMembers converts a SupplementalMember to a regular Member. -func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member { - members := make([]discord.Member, len(sms)) - for i, sm := range sms { - members[i] = discord.Member{ - User: discord.User{ID: sm.UserID}, - Nick: sm.Nick, - RoleIDs: sm.RoleIDs, - Joined: sm.Joined, - BoostedSince: sm.BoostedSince, - Deaf: sm.Deaf, - Mute: sm.Mute, - IsPending: sm.IsPending, - } - } - return members -} - -// ConvertSupplementalPresences converts a SupplementalPresence to a regular -// Presence with an empty GuildID. -func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence { - presences := make([]discord.Presence, len(sps)) - for i, sp := range sps { - presences[i] = discord.Presence{ - User: discord.User{ID: sp.UserID}, - Status: sp.Status, - Activities: sp.Activities, - ClientStatus: sp.ClientStatus, - } - } - return presences -} diff --git a/internal/handleloop/handleloop.go b/internal/handleloop/handleloop.go deleted file mode 100644 index d545ed1..0000000 --- a/internal/handleloop/handleloop.go +++ /dev/null @@ -1,60 +0,0 @@ -// Package handleloop provides clean abstractions to handle listening to -// channels and passing them onto event handlers. -package handleloop - -import "github.com/diamondburned/arikawa/v3/utils/handler" - -// Loop provides a reusable event looper abstraction. It is thread-safe to use -// concurrently. -type Loop struct { - dst *handler.Handler - run chan struct{} - stop chan struct{} -} - -func NewLoop(dst *handler.Handler) *Loop { - return &Loop{ - dst: dst, - run: make(chan struct{}, 1), // intentional 1 buffer - stop: make(chan struct{}), // intentional unbuffer - } -} - -// Start starts a new event loop. It will try to stop existing loops before. -func (l *Loop) Start(src <-chan interface{}) { - // Ensure we're stopped. - l.Stop() - - // Mark that we're running. - l.run <- struct{}{} - - go func() { - for { - select { - case event := <-src: - l.dst.Call(event) - - case <-l.stop: - l.stop <- struct{}{} - return - } - } - }() -} - -// Stop tries to stop the Loop. If the Loop is not running, then it does -// nothing; thus, it can be called multiple times. -func (l *Loop) Stop() { - // Ensure that we are running before stopping. - select { - case <-l.run: - // running - default: - return - } - - // send a close request - l.stop <- struct{}{} - // wait for a reply - <-l.stop -} diff --git a/internal/heart/heart.go b/internal/heart/heart.go deleted file mode 100644 index 4fb706a..0000000 --- a/internal/heart/heart.go +++ /dev/null @@ -1,135 +0,0 @@ -// Package heart implements a general purpose pacemaker. -package heart - -import ( - "context" - "sync/atomic" - "time" - - "github.com/pkg/errors" -) - -// Debug is the default logger that Pacemaker uses. -var Debug = func(v ...interface{}) {} - -var ErrDead = errors.New("no heartbeat replied") - -// AtomicTime is a thread-safe UnixNano timestamp guarded by atomic. -type AtomicTime struct { - unixnano int64 -} - -func (t *AtomicTime) Get() int64 { - return atomic.LoadInt64(&t.unixnano) -} - -func (t *AtomicTime) Set(time time.Time) { - atomic.StoreInt64(&t.unixnano, time.UnixNano()) -} - -func (t *AtomicTime) Time() time.Time { - return time.Unix(0, t.Get()) -} - -// AtomicDuration is a thread-safe Duration guarded by atomic. -type AtomicDuration struct { - duration int64 -} - -func (d *AtomicDuration) Get() time.Duration { - return time.Duration(atomic.LoadInt64(&d.duration)) -} - -func (d *AtomicDuration) Set(dura time.Duration) { - atomic.StoreInt64(&d.duration, int64(dura)) -} - -// Pacemaker is the internal pacemaker state. All fields are not thread-safe -// unless they're atomic. -type Pacemaker struct { - // Heartrate is the received duration between heartbeats. - Heartrate AtomicDuration - - ticker time.Ticker - Ticks <-chan time.Time - - // Time in nanoseconds, guarded by atomic read/writes. - SentBeat AtomicTime - EchoBeat AtomicTime - - // Any callback that returns an error will stop the pacer. - Pacer func(context.Context) error -} - -func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker { - p := Pacemaker{ - Heartrate: AtomicDuration{int64(heartrate)}, - Pacer: pacer, - ticker: *time.NewTicker(heartrate), - } - p.Ticks = p.ticker.C - // Reset states to its old position. - now := time.Now() - p.EchoBeat.Set(now) - p.SentBeat.Set(now) - - return p -} - -func (p *Pacemaker) Echo() { - // Swap our received heartbeats - p.EchoBeat.Set(time.Now()) -} - -// Dead, if true, will have Pace return an ErrDead. -func (p *Pacemaker) Dead() bool { - var ( - echo = p.EchoBeat.Get() - sent = p.SentBeat.Get() - ) - - if echo == 0 || sent == 0 { - return false - } - - return sent-echo > int64(p.Heartrate.Get())*2 -} - -// SetHeartRate sets the ticker's heart rate. -func (p *Pacemaker) SetPace(heartrate time.Duration) { - p.Heartrate.Set(heartrate) - - // To uncomment when 1.16 releases and we drop support for 1.14. - // p.ticker.Reset(heartrate) - - p.ticker.Stop() - p.ticker = *time.NewTicker(heartrate) - p.Ticks = p.ticker.C -} - -// Stop stops the pacemaker, or it does nothing if the pacemaker is not started. -func (p *Pacemaker) StopTicker() { - p.ticker.Stop() -} - -// pace sends a heartbeat with the appropriate timeout for the context. -func (p *Pacemaker) Pace() error { - ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate.Get()) - defer cancel() - - return p.PaceCtx(ctx) -} - -func (p *Pacemaker) PaceCtx(ctx context.Context) error { - if err := p.Pacer(ctx); err != nil { - return err - } - - p.SentBeat.Set(time.Now()) - - if p.Dead() { - return ErrDead - } - - return nil -} diff --git a/internal/lazytime/ticker.go b/internal/lazytime/ticker.go new file mode 100644 index 0000000..465420a --- /dev/null +++ b/internal/lazytime/ticker.go @@ -0,0 +1,30 @@ +package lazytime + +import "time" + +type Ticker struct { + C <-chan time.Time + + ticker *time.Ticker +} + +// Reset resets the ticker. If this is the first time calling, then a new timer +// is created. +func (t *Ticker) Reset(d time.Duration) { + if t.ticker == nil { + t.ticker = time.NewTicker(d) + t.C = t.ticker.C + } else { + t.ticker.Reset(d) + } +} + +// Stop stops the ticker. If the ticker has never been used, then it does +// nothing. +func (t *Ticker) Stop() { + if t.ticker == nil { + return + } + + t.ticker.Stop() +} diff --git a/internal/lazytime/timer.go b/internal/lazytime/timer.go new file mode 100644 index 0000000..676ca16 --- /dev/null +++ b/internal/lazytime/timer.go @@ -0,0 +1,50 @@ +package lazytime + +import ( + "context" + "time" +) + +type Timer struct { + C <-chan time.Time + + timer *time.Timer +} + +// Reset resets the timer by draining it and resetting the internal channel. If +// this is the first time calling, then a new timer is created. +func (t *Timer) Reset(d time.Duration) { + if t.timer == nil { + t.timer = time.NewTimer(d) + t.C = t.timer.C + return + } + + t.Stop() + t.timer.Reset(d) +} + +// Stop stops the timer and drains it. If the timer has never been used, then it +// does nothing. +func (t *Timer) Stop() { + if t.timer == nil { + return + } + + if !t.timer.Stop() { + select { + case <-t.timer.C: + default: + } + } +} + +// Wait blocks until the timer fires or until the context expires. +func (t *Timer) Wait(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return nil + } +} diff --git a/internal/moreatomic/bool.go b/internal/moreatomic/bool.go index c98a583..874aaee 100644 --- a/internal/moreatomic/bool.go +++ b/internal/moreatomic/bool.go @@ -17,3 +17,17 @@ func (b *Bool) Set(val bool) { } atomic.StoreUint32(&b.val, x) } + +func (b *Bool) SetTrue() { + atomic.StoreUint32(&b.val, 1) +} + +func (b *Bool) SetFalse() { + atomic.StoreUint32(&b.val, 0) +} + +// Acquire sets bool to true if it's false and returns true, otherwise returns +// false. +func (b *Bool) Acquire() bool { + return atomic.CompareAndSwapUint32(&b.val, 0, 1) +} diff --git a/session/session.go b/session/session.go index a50cf35..415539f 100644 --- a/session/session.go +++ b/session/session.go @@ -5,89 +5,64 @@ package session import ( "context" + "sync" "github.com/pkg/errors" "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/gateway/shard" - "github.com/diamondburned/arikawa/v3/internal/handleloop" "github.com/diamondburned/arikawa/v3/utils/handler" + "github.com/diamondburned/arikawa/v3/utils/json/option" + "github.com/diamondburned/arikawa/v3/utils/ws/ophandler" ) +// ErrMFA is returned if the account requires a 2FA code to log in. var ErrMFA = errors.New("account has 2FA enabled") -// Closed is an event that's sent to Session's command handler. This works by -// using (*Gateway).AfterClose. If the user sets this callback, no Closed events -// would be sent. -// -// Usage -// -// ses.AddHandler(func(*session.Closed) {}) -// -type Closed struct { - Error error -} - -// NewShardFunc creates a shard constructor for a session. -func NewShardFunc(f func(m *shard.Manager, s *Session)) shard.NewShardFunc { - return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { - s := NewCustomShard(m, id) - if f != nil { - f(m, s) - } - return s, nil - } -} - -// NewCustomShard creates a new session from the given shard manager and other -// parameters. -func NewCustomShard(m *shard.Manager, id *gateway.Identifier) *Session { - return NewCustomSession( - shard.NewGatewayShard(m, id), - api.NewClient(id.Token), - handler.New(), - ) -} - // Session manages both the API and Gateway. As such, Session inherits all of // API's methods, as well has the Handler used for Gateway. type Session struct { *api.Client - *gateway.Gateway - - // Command handler with inherited methods. *handler.Handler // internal state to not be copied around. - looper *handleloop.Loop + state *sessionState } -func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) { - g, err := gateway.NewGatewayWithIntents(token, intents...) - if err != nil { - return nil, errors.Wrap(err, "failed to connect to Gateway") +type sessionState struct { + sync.Mutex + id gateway.Identifier + gateway *gateway.Gateway + + ctx context.Context + cancel context.CancelFunc + doneCh <-chan struct{} +} + +// NewWithIntents is similar to New but adds the given intents in during +// construction. +func NewWithIntents(token string, intents ...gateway.Intents) *Session { + var allIntent gateway.Intents + for _, intent := range intents { + allIntent |= intent } - return NewWithGateway(g), nil + id := gateway.DefaultIdentifier(token) + id.Intents = option.NewUint(uint(allIntent)) + + return NewWithIdentifier(id) } // New creates a new session from a given token. Most bots should be using // NewWithIntents instead. -func New(token string) (*Session, error) { - // Create a gateway - g, err := gateway.NewGateway(token) - if err != nil { - return nil, errors.Wrap(err, "failed to connect to Gateway") - } - - return NewWithGateway(g), nil +func New(token string) *Session { + return NewWithIdentifier(gateway.DefaultIdentifier(token)) } // Login tries to log in as a normal user account; MFA is optional. -func Login(email, password, mfa string) (*Session, error) { +func Login(ctx context.Context, email, password, mfa string) (*Session, error) { // Make a scratch HTTP client without a token - client := api.NewClient("") + client := api.NewClient("").WithContext(ctx) // Try to login without TOTP l, err := client.Login(email, password) @@ -97,7 +72,7 @@ func Login(email, password, mfa string) (*Session, error) { if l.Token != "" && !l.MFA { // We got the token, return with a new Session. - return New(l.Token) + return New(l.Token), nil } // Discord requests MFA, so we need the MFA token. @@ -111,40 +86,117 @@ func Login(email, password, mfa string) (*Session, error) { return nil, errors.Wrap(err, "failed to login with 2FA") } - return New(l.Token) + return New(l.Token), nil } -// NewWithGateway creates a new Session with the given Gateway. -func NewWithGateway(gw *gateway.Gateway) *Session { - return NewCustomSession(gw, api.NewClient(gw.Identifier.Token), handler.New()) +// NewWithIdentifier creates a bare Session with the given identifier. +func NewWithIdentifier(id gateway.Identifier) *Session { + return NewCustom(id, api.NewClient(id.Token), handler.New()) } -// NewCustomSession constructs a bare Session from the given parameters. -func NewCustomSession(gw *gateway.Gateway, cl *api.Client, h *handler.Handler) *Session { +// NewWithGateway constructs a bare Session from the given UNOPENED gateway. +func NewWithGateway(g *gateway.Gateway, h *handler.Handler) *Session { + state := g.State() + return &Session{ + Client: api.NewClient(state.Identifier.Token), + Handler: h, + state: &sessionState{ + gateway: g, + id: state.Identifier, + }, + } +} + +// NewCustom constructs a bare Session from the given parameters. +func NewCustom(id gateway.Identifier, cl *api.Client, h *handler.Handler) *Session { return &Session{ - Gateway: gw, Client: cl, Handler: h, - looper: handleloop.NewLoop(h), + state: &sessionState{id: id}, } } +// AddIntents adds the given intents into the gateway. Calling it after Open has +// already been called will result in a panic. +func (s *Session) AddIntents(intents gateway.Intents) { + s.state.Lock() + + s.state.id.AddIntents(intents) + + if s.state.gateway != nil { + s.state.gateway.AddIntents(intents) + } + + s.state.Unlock() +} + +// HasIntents reports if the Gateway has the passed Intents. +// +// If no intents are set, e.g. if using a user account, HasIntents will always +// return true. +func (s *Session) HasIntents(intents gateway.Intents) bool { + return s.state.id.HasIntents(intents) +} + +// Gateway returns the current session's gateway. If Open has never been called +// or Session was never constructed with a gateway, then nil is returned. +func (s *Session) Gateway() *gateway.Gateway { + s.state.Lock() + g := s.state.gateway + s.state.Unlock() + return g +} + +// Open opens the Discord gateway and its handler, then waits until either the +// Ready or Resumed event gets through. func (s *Session) Open(ctx context.Context) error { - // Start the handler beforehand so no events are missed. - s.looper.Start(s.Gateway.Events) + evCh := make(chan interface{}) - // Set the AfterClose's handler. - s.Gateway.AfterClose = func(err error) { - s.Handler.Call(&Closed{ - Error: err, - }) + s.state.Lock() + defer s.state.Unlock() + + if s.state.cancel != nil { + if err := s.close(ctx); err != nil { + return err + } } - if err := s.Gateway.Open(ctx); err != nil { - return errors.Wrap(err, "failed to start gateway") + if s.state.gateway == nil { + g, err := gateway.NewWithIdentifier(ctx, s.state.id) + if err != nil { + return err + } + s.state.gateway = g } - return nil + ctx, cancel := context.WithCancel(context.Background()) + s.state.ctx = ctx + s.state.cancel = cancel + + // TODO: change this to AddSyncHandler. + rm := s.AddHandler(evCh) + defer rm() + + opCh := s.state.gateway.Connect(s.state.ctx) + s.state.doneCh = ophandler.Loop(opCh, s.Handler) + + for { + select { + case <-ctx.Done(): + s.close(ctx) + return ctx.Err() + + case <-s.state.doneCh: + // Event loop died. + return s.state.gateway.LastError() + + case ev := <-evCh: + switch ev.(type) { + case *gateway.ReadyEvent, *gateway.ResumedEvent: + return nil + } + } + } } // WithContext returns a shallow copy of Session with the context replaced in @@ -160,21 +212,34 @@ func (s *Session) WithContext(ctx context.Context) *Session { } // Close closes the underlying Websocket connection, invalidating the session -// ID. -// -// It will send a closing frame before ending the connection, closing it -// gracefully. This will cause the bot to appear as offline instantly. +// ID. It will send a closing frame before ending the connection, closing it +// gracefully. This will cause the bot to appear as offline instantly. To +// prevent this behavior, change Gateway.AlwaysCloseGracefully. func (s *Session) Close() error { - // Stop the event handler - s.looper.Stop() - return s.Gateway.Close() + s.state.Lock() + defer s.state.Unlock() + + return s.close(context.Background()) } -// Pause pauses the Gateway connection, by ending the connection without -// sending a closing frame. This allows the connection to be resumed at a later -// point. -func (s *Session) Pause() error { - // Stop the event handler - s.looper.Stop() - return s.Gateway.Pause() +func (s *Session) close(ctx context.Context) error { + if s.state.cancel == nil { + return errors.New("Session is already closed") + } + + s.state.cancel() + s.state.cancel = nil + s.state.ctx = nil + + // Wait until we've successfully disconnected. + select { + case <-ctx.Done(): + return errors.Wrap(ctx.Err(), "cannot wait for gateway exit") + case <-s.state.doneCh: + // ok + } + + s.state.doneCh = nil + + return s.state.gateway.LastError() } diff --git a/session/session_test.go b/session/session_test.go new file mode 100644 index 0000000..a7b9ceb --- /dev/null +++ b/session/session_test.go @@ -0,0 +1,50 @@ +package session + +import ( + "context" + "testing" + "time" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/internal/testenv" +) + +func TestSession(t *testing.T) { + attempts := 1 + timeout := 15 * time.Second + + if !testing.Short() { + attempts = 5 + timeout = time.Minute // 5s-10s each reconnection + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + + env := testenv.Must(t) + + readyCh := make(chan *gateway.ReadyEvent, 1) + + s := NewWithIntents(env.BotToken, gateway.IntentGuilds) + s.AddHandler(readyCh) + + for i := 0; i < attempts; i++ { + if err := s.Open(ctx); err != nil { + t.Fatal("failed to open:", err) + } + + if ready, ok := <-readyCh; !ok { + t.Fatal("ready not received") + } else { + now := time.Now() + t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username) + } + + if err := s.Close(); err != nil { + t.Fatal("failed to close:", err) + } + + // Hold for an additional one second. + time.Sleep(time.Second) + } +} diff --git a/gateway/shard/manager.go b/session/shard/manager.go similarity index 92% rename from gateway/shard/manager.go rename to session/shard/manager.go index 6fcf7f9..0527637 100644 --- a/gateway/shard/manager.go +++ b/session/shard/manager.go @@ -61,7 +61,7 @@ type rescalingState struct { func NewManager(token string, fn NewShardFunc) (*Manager, error) { id := gateway.DefaultIdentifier(token) - url, err := updateIdentifier(context.Background(), id) + url, err := updateIdentifier(context.Background(), &id) if err != nil { return nil, errors.Wrap(err, "failed to get gateway info") } @@ -76,20 +76,20 @@ func NewManager(token string, fn NewShardFunc) (*Manager, error) { // // This function should rarely be used, since the shard information will be // queried from Discord if it's required to shard anyway. -func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager, error) { +func NewIdentifiedManager(idData gateway.IdentifyCommand, fn NewShardFunc) (*Manager, error) { // Ensure id.Shard is never nil. - if data.Shard == nil { - data.Shard = gateway.DefaultShard + if idData.Shard == nil { + idData.Shard = gateway.DefaultShard } - id := gateway.NewIdentifier(data) + id := gateway.NewIdentifier(idData) - url, err := updateIdentifier(context.Background(), id) + url, err := updateIdentifier(context.Background(), &id) if err != nil { return nil, errors.Wrap(err, "failed to get gateway info") } - id.Shard = data.Shard + id.Shard = idData.Shard return NewIdentifiedManagerWithURL(url, id, fn) } @@ -97,7 +97,7 @@ func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager, // NewIdentifiedManagerWithURL creates a new Manager with the given Identifier // and gateway URL. It behaves similarly to NewIdentifiedManager. func NewIdentifiedManagerWithURL( - url string, id *gateway.Identifier, fn NewShardFunc) (*Manager, error) { + url string, id gateway.Identifier, fn NewShardFunc) (*Manager, error) { m := Manager{ gatewayURL: gateway.AddGatewayParams(url), @@ -108,12 +108,12 @@ func NewIdentifiedManagerWithURL( var err error for i := range m.shards { - data := id.IdentifyData + data := id.IdentifyCommand data.Shard = &gateway.Shard{i, len(m.shards)} m.shards[i] = ShardState{ ID: gateway.Identifier{ - IdentifyData: data, + IdentifyCommand: data, IdentifyShortLimit: id.IdentifyShortLimit, IdentifyGlobalLimit: id.IdentifyGlobalLimit, }, @@ -263,10 +263,10 @@ func (m *Manager) rescale() { func (m *Manager) tryRescale(ctx context.Context) bool { m.mutex.Lock() - data := m.shards[0].ID.IdentifyData + data := m.shards[0].ID.IdentifyCommand newID := gateway.NewIdentifier(data) - url, err := updateIdentifier(ctx, newID) + url, err := updateIdentifier(ctx, &newID) if err != nil { m.mutex.Unlock() return false @@ -282,12 +282,12 @@ func (m *Manager) tryRescale(ctx context.Context) bool { newShards := make([]ShardState, numShards) for i := 0; i < numShards; i++ { - data := newID.IdentifyData + data := newID.IdentifyCommand data.Shard = &gateway.Shard{i, len(m.shards)} newShards[i] = ShardState{ ID: gateway.Identifier{ - IdentifyData: data, + IdentifyCommand: data, IdentifyShortLimit: newID.IdentifyShortLimit, IdentifyGlobalLimit: newID.IdentifyGlobalLimit, }, diff --git a/gateway/shard/shard.go b/session/shard/shard.go similarity index 79% rename from gateway/shard/shard.go rename to session/shard/shard.go index 958ce9e..dc89cc0 100644 --- a/gateway/shard/shard.go +++ b/session/shard/shard.go @@ -3,7 +3,10 @@ package shard import ( "context" + "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/session" + "github.com/diamondburned/arikawa/v3/utils/handler" "github.com/pkg/errors" ) @@ -23,17 +26,14 @@ type Shard interface { // methods without deadlocking. type NewShardFunc func(m *Manager, id *gateway.Identifier) (Shard, error) -// NewGatewayShardFunc wraps around NewGatewayShard to be compatible with -// NewShardFunc. -var NewGatewayShardFunc NewShardFunc = func(m *Manager, id *gateway.Identifier) (Shard, error) { - return NewGatewayShard(m, id), nil -} - -// NewGatewayShard creates a new gateway that's plugged into the shard manager. -func NewGatewayShard(m *Manager, id *gateway.Identifier) *gateway.Gateway { - gw := gateway.NewCustomIdentifiedGateway(m.GatewayURL(), id) - gw.OnShardingRequired(m.Rescale) - return gw +// NewSessionShard creates a shard constructor for a session. +// Accessing any shard and adding a handler will add a handler for all shards. +func NewSessionShard(f func(m *Manager, s *session.Session)) NewShardFunc { + return func(m *Manager, id *gateway.Identifier) (Shard, error) { + s := session.NewCustom(*id, api.NewClient(id.Token), handler.New()) + f(m, s) + return s, nil + } } // ShardState wraps around the Gateway interface to provide additional state. diff --git a/session/session_shard_test.go b/session/shard/shard_test.go similarity index 78% rename from session/session_shard_test.go rename to session/shard/shard_test.go index 1b74856..a5219b8 100644 --- a/session/session_shard_test.go +++ b/session/shard/shard_test.go @@ -1,4 +1,4 @@ -package session +package shard import ( "context" @@ -6,29 +6,28 @@ import ( "time" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/gateway/shard" "github.com/diamondburned/arikawa/v3/internal/testenv" + "github.com/diamondburned/arikawa/v3/session" ) func TestSharding(t *testing.T) { env := testenv.Must(t) - data := gateway.DefaultIdentifyData("Bot " + env.BotToken) + data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken) data.Shard = &gateway.Shard{0, env.ShardCount} readyCh := make(chan *gateway.ReadyEvent) - m, err := shard.NewIdentifiedManager(data, NewShardFunc( - func(m *shard.Manager, s *Session) { + m, err := NewIdentifiedManager(data, NewSessionShard( + func(m *Manager, s *session.Session) { now := time.Now().Format(time.StampMilli) t.Log(now, "initializing shard") - s.Gateway.ErrorLog = func(err error) { - t.Error("gateway error:", err) - } - s.AddIntents(gateway.IntentGuilds) s.AddHandler(readyCh) + s.AddHandler(func(err error) { + t.Error("unexpected error:", err) + }) }, )) if err != nil { diff --git a/state/state.go b/state/state.go index ce50443..f0cce6f 100644 --- a/state/state.go +++ b/state/state.go @@ -6,10 +6,11 @@ import ( "context" "sync" + "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/gateway/shard" "github.com/diamondburned/arikawa/v3/session" + "github.com/diamondburned/arikawa/v3/session/shard" "github.com/diamondburned/arikawa/v3/state/store" "github.com/diamondburned/arikawa/v3/state/store/defaultstore" "github.com/diamondburned/arikawa/v3/utils/handler" @@ -27,7 +28,8 @@ var ( // The user should initialize handlers and intents in the opts function. func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc { return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { - state := NewFromSession(session.NewCustomShard(m, id), defaultstore.New()) + sessn := session.NewCustom(*id, api.NewClient(id.Token), handler.New()) + state := NewFromSession(sessn, defaultstore.New()) opts(m, state) return state, nil } @@ -110,29 +112,21 @@ type State struct { } // New creates a new state. -func New(token string) (*State, error) { +func New(token string) *State { return NewWithStore(token, defaultstore.New()) } // NewWithIntents creates a new state with the given gateway intents. For more // information, refer to gateway.Intents. -func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) { - s, err := session.NewWithIntents(token, intents...) - if err != nil { - return nil, err - } - - return NewFromSession(s, defaultstore.New()), nil +func NewWithIntents(token string, intents ...gateway.Intents) *State { + s := session.NewWithIntents(token, intents...) + return NewFromSession(s, defaultstore.New()) } // NewWithStore creates a new state with the given store cabinet. -func NewWithStore(token string, cabinet *store.Cabinet) (*State, error) { - s, err := session.New(token) - if err != nil { - return nil, err - } - - return NewFromSession(s, cabinet), nil +func NewWithStore(token string, cabinet *store.Cabinet) *State { + s := session.New(token) + return NewFromSession(s, cabinet) } // NewFromSession creates a new State from the passed Session and Cabinet. @@ -239,11 +233,11 @@ func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (dis merr = store.ErrNotFound ) - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { g, gerr = s.Cabinet.Guild(guildID) } - if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + if s.HasIntents(gateway.IntentGuildMembers) { m, merr = s.Cabinet.Member(guildID, userID) } @@ -300,11 +294,11 @@ func (s *State) Permissions( merr = store.ErrNotFound ) - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { g, gerr = s.Cabinet.Guild(ch.GuildID) } - if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + if s.HasIntents(gateway.IntentGuildMembers) { m, merr = s.Cabinet.Member(ch.GuildID, userID) } @@ -372,7 +366,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) { } func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err error) { - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { cs, err = s.Cabinet.Channels(guildID) if err == nil { return @@ -384,7 +378,7 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err return } - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { for i := range cs { if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil { return @@ -436,7 +430,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) { func (s *State) Emoji( guildID discord.GuildID, emojiID discord.EmojiID) (e *discord.Emoji, err error) { - if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { + if s.HasIntents(gateway.IntentGuildEmojis) { e, err = s.Cabinet.Emoji(guildID, emojiID) if err == nil { return @@ -464,7 +458,7 @@ func (s *State) Emoji( } func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) { - if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { + if s.HasIntents(gateway.IntentGuildEmojis) { es, err = s.Cabinet.Emojis(guildID) if err == nil { return @@ -476,7 +470,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) return } - if s.Gateway.HasIntents(gateway.IntentGuildEmojis) { + if s.HasIntents(gateway.IntentGuildEmojis) { err = s.Cabinet.EmojiSet(guildID, es, false) } @@ -486,7 +480,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) //// func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) { - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { c, err := s.Cabinet.Guild(id) if err == nil { return c, nil @@ -498,7 +492,7 @@ func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) { // Guilds will only fill a maximum of 100 guilds from the API. func (s *State) Guilds() (gs []discord.Guild, err error) { - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { gs, err = s.Cabinet.Guilds() if err == nil { return @@ -510,7 +504,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) { return } - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { for i := range gs { if err = s.Cabinet.GuildSet(&gs[i], false); err != nil { return @@ -524,7 +518,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) { //// func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) { - if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + if s.HasIntents(gateway.IntentGuildMembers) { m, err := s.Cabinet.Member(guildID, userID) if err == nil { return m, nil @@ -535,7 +529,7 @@ func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord } func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error) { - if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + if s.HasIntents(gateway.IntentGuildMembers) { ms, err = s.Cabinet.Members(guildID) if err == nil { return @@ -547,7 +541,7 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error return } - if s.Gateway.HasIntents(gateway.IntentGuildMembers) { + if s.HasIntents(gateway.IntentGuildMembers) { for i := range ms { if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil { return @@ -580,7 +574,7 @@ func (s *State) Message( wg.Add(1) go func() { c, cerr = s.Session.Channel(channelID) - if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { + if cerr == nil && s.HasIntents(gateway.IntentGuilds) { cerr = s.Cabinet.ChannelSet(c, false) } @@ -706,13 +700,13 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes // Presence checks the state for user presences. If no guildID is given, it // will look for the presence in all cached guilds. func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*discord.Presence, error) { - if !s.Gateway.HasIntents(gateway.IntentGuildPresences) { + if !s.HasIntents(gateway.IntentGuildPresences) { return nil, store.ErrNotFound } // If there's no guild ID, look in all guilds if !gID.IsValid() { - if !s.Gateway.HasIntents(gateway.IntentGuilds) { + if !s.HasIntents(gateway.IntentGuilds) { return nil, store.ErrNotFound } @@ -736,7 +730,7 @@ func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*discord.Pres //// func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *discord.Role, err error) { - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { target, err = s.Cabinet.Role(guildID, roleID) if err == nil { return @@ -754,7 +748,7 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di target = &r } - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { if err = s.RoleSet(guildID, &rs[i], false); err != nil { return } @@ -779,7 +773,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { return nil, err } - if s.Gateway.HasIntents(gateway.IntentGuilds) { + if s.HasIntents(gateway.IntentGuilds) { for i := range rs { if err := s.RoleSet(guildID, &rs[i], false); err != nil { return rs, err @@ -792,7 +786,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) { func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { g, err = s.Session.Guild(id) - if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) { + if err == nil && s.HasIntents(gateway.IntentGuilds) { err = s.Cabinet.GuildSet(g, false) } @@ -801,7 +795,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) { func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) { m, err = s.Session.Member(gID, uID) - if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) { + if err == nil && s.HasIntents(gateway.IntentGuildMembers) { err = s.Cabinet.MemberSet(gID, m, false) } @@ -811,14 +805,12 @@ func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord // tracksMessage reports whether the state would track the passed message and // messages from the same channel. func (s *State) tracksMessage(m *discord.Message) bool { - return s.Gateway.Identifier.Intents == nil || - (m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) || - (!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages)) + return (m.GuildID.IsValid() && s.HasIntents(gateway.IntentGuildMessages)) || + (!m.GuildID.IsValid() && s.HasIntents(gateway.IntentDirectMessages)) } // tracksChannel reports whether the state would track the passed channel. func (s *State) tracksChannel(c *discord.Channel) bool { - return s.Gateway.Identifier.Intents == nil || - (c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) || + return (c.GuildID.IsValid() && s.HasIntents(gateway.IntentGuilds)) || !c.GuildID.IsValid() } diff --git a/state/state_events.go b/state/state_events.go index 036bb36..fdb6cd6 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -152,7 +152,7 @@ func (s *State) onEvent(iface interface{}) { } // Update available fields from ev into m - ev.Update(m) + ev.UpdateMember(m) if err := s.Cabinet.MemberSet(ev.GuildID, m, true); err != nil { s.stateErr(err, "failed to update a member in state") diff --git a/state/state_example_test.go b/state/state_example_test.go new file mode 100644 index 0000000..da59ccb --- /dev/null +++ b/state/state_example_test.go @@ -0,0 +1,32 @@ +package state_test + +import ( + "context" + "log" + "os" + "os/signal" + + "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/state" +) + +func Example() { + s := state.New("Bot " + os.Getenv("DISCORD_TOKEN")) + s.AddIntents(gateway.IntentGuilds | gateway.IntentGuildMessages) + s.AddHandler(func(m *gateway.MessageCreateEvent) { + log.Printf("%s: %s", m.Author.Username, m.Content) + }) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err := s.Open(ctx); err != nil { + log.Println("cannot open:", err) + } + + <-ctx.Done() // block until Ctrl+C + + if err := s.Close(); err != nil { + log.Println("cannot close:", err) + } +} diff --git a/state/state_shard_test.go b/state/state_shard_test.go index 08cd46d..d2f59a4 100644 --- a/state/state_shard_test.go +++ b/state/state_shard_test.go @@ -5,16 +5,24 @@ import ( "testing" "time" + "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/gateway/shard" "github.com/diamondburned/arikawa/v3/internal/testenv" + "github.com/diamondburned/arikawa/v3/session/shard" ) func TestSharding(t *testing.T) { env := testenv.Must(t) - data := gateway.DefaultIdentifyData("Bot " + env.BotToken) + data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken) data.Shard = &gateway.Shard{0, env.ShardCount} + data.Presence = &gateway.UpdatePresenceCommand{ + Status: discord.DoNotDisturbStatus, + Activities: []discord.Activity{{ + Name: "Testing shards...", + Type: discord.CustomActivity, + }}, + } readyCh := make(chan *gateway.ReadyEvent) @@ -23,19 +31,18 @@ func TestSharding(t *testing.T) { now := time.Now().Format(time.StampMilli) t.Log(now, "initializing shard") - s.Gateway.ErrorLog = func(err error) { - t.Error("gateway error:", err) - } - s.AddIntents(gateway.IntentGuilds) - s.AddHandler(readyCh) + s.AddSyncHandler(readyCh) + s.AddSyncHandler(func(err error) { + t.Log("background error:", err) + }) }, )) if err != nil { t.Fatal("failed to make shard manager:", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() go func() { diff --git a/utils/bot/command.go b/utils/bot/command.go index b7f33a7..7970e51 100644 --- a/utils/bot/command.go +++ b/utils/bot/command.go @@ -2,22 +2,29 @@ package bot import ( "reflect" + "sync" "github.com/diamondburned/arikawa/v3/gateway" + "github.com/diamondburned/arikawa/v3/utils/ws" ) -// eventIntents maps event pointer types to intents. -var eventIntents = map[reflect.Type]gateway.Intents{} +var ( + // eventIntents maps event pointer types to intents. + eventIntents map[reflect.Type]gateway.Intents + eventIntentsOnce sync.Once +) -func init() { - for event, intent := range gateway.EventIntents { - fn, ok := gateway.EventCreator[event] - if !ok { - continue - } - - eventIntents[reflect.TypeOf(fn())] = intent - } +func ensureEventIntents() { + eventIntentsOnce.Do(func() { + eventIntents = map[reflect.Type]gateway.Intents{} + gateway.OpUnmarshalers.Each(func(_ ws.OpCode, t ws.EventType, f ws.OpFunc) bool { + intent, ok := gateway.EventIntents[t] + if ok { + eventIntents[reflect.TypeOf(f())] = intent + } + return false + }) + }) } type command struct { @@ -44,6 +51,8 @@ func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, er // intents returns the command's intents from the event. func (c *command) intents() gateway.Intents { + ensureEventIntents() + intents, ok := eventIntents[c.event] if !ok { return 0 diff --git a/utils/bot/ctx.go b/utils/bot/ctx.go index c969f11..1e0071e 100644 --- a/utils/bot/ctx.go +++ b/utils/bot/ctx.go @@ -14,13 +14,13 @@ import ( "github.com/pkg/errors" "github.com/diamondburned/arikawa/v3/api" - "github.com/diamondburned/arikawa/v3/utils/bot/extras/shellwords" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/gateway/shard" "github.com/diamondburned/arikawa/v3/session" + "github.com/diamondburned/arikawa/v3/session/shard" "github.com/diamondburned/arikawa/v3/state" - "github.com/diamondburned/arikawa/v3/state/store" "github.com/diamondburned/arikawa/v3/state/store/defaultstore" + "github.com/diamondburned/arikawa/v3/utils/bot/extras/shellwords" + "github.com/diamondburned/arikawa/v3/utils/handler" ) // Prefixer checks a message if it starts with the desired prefix. By default, @@ -55,22 +55,15 @@ func NewShardFunc(fn func(*state.State) (*Context, error)) shard.NewShardFunc { panic("bot.NewShardFunc missing fn") } - var once sync.Once - var cab *store.Cabinet - return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) { - state := state.NewFromSession(session.NewCustomShard(m, id), nil) + sessn := session.NewCustom(*id, api.NewClient(id.Token), handler.New()) + state := state.NewFromSession(sessn, defaultstore.New()) bot, err := fn(state) if err != nil { return nil, errors.Wrap(err, "failed to create bot instance") } - if state.Cabinet == nil { - once.Do(func() { cab = defaultstore.New() }) - state.Cabinet = cab - } - return bot, nil } } @@ -206,10 +199,6 @@ func Start( // fail api request if they (will) take up more than 5 minutes ctx.Client.Client.Timeout = 5 * time.Minute - ctx.Gateway.ErrorLog = func(err error) { - ctx.ErrorLogger(err) - } - if opts != nil { if err := opts(ctx); err != nil { return nil, err @@ -317,7 +306,7 @@ func New(s *state.State, cmd interface{}) (*Context, error) { // AddIntents adds the given Gateway Intent into the Gateway. This is a // convenient function that calls Gateway's AddIntent. func (ctx *Context) AddIntents(i gateway.Intents) { - ctx.Gateway.AddIntents(i) + ctx.Session.AddIntents(i) } // Subcommands returns the slice of subcommands. To add subcommands, use diff --git a/utils/bot/ctx_shard_test.go b/utils/bot/ctx_shard_test.go index 6e0cb84..4d81246 100644 --- a/utils/bot/ctx_shard_test.go +++ b/utils/bot/ctx_shard_test.go @@ -6,8 +6,8 @@ import ( "time" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/gateway/shard" "github.com/diamondburned/arikawa/v3/internal/testenv" + "github.com/diamondburned/arikawa/v3/session/shard" "github.com/diamondburned/arikawa/v3/state" ) @@ -24,7 +24,7 @@ func (bot *shardedBot) OnReady(r *gateway.ReadyEvent) { func TestSharding(t *testing.T) { env := testenv.Must(t) - data := gateway.DefaultIdentifyData("Bot " + env.BotToken) + data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken) data.Shard = &gateway.Shard{0, env.ShardCount} readyCh := make(chan *gateway.ReadyEvent) diff --git a/utils/bot/extras/middlewares/middlewares_test.go b/utils/bot/extras/middlewares/middlewares_test.go index a323c30..fc4f8d3 100644 --- a/utils/bot/extras/middlewares/middlewares_test.go +++ b/utils/bot/extras/middlewares/middlewares_test.go @@ -4,27 +4,18 @@ import ( "errors" "testing" - "github.com/diamondburned/arikawa/v3/utils/bot" "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" "github.com/diamondburned/arikawa/v3/session" "github.com/diamondburned/arikawa/v3/state" "github.com/diamondburned/arikawa/v3/state/store" - "github.com/diamondburned/arikawa/v3/utils/json/option" + "github.com/diamondburned/arikawa/v3/utils/bot" ) func TestAdminOnly(t *testing.T) { var ctx = &bot.Context{ State: &state.State{ - Session: &session.Session{ - Gateway: &gateway.Gateway{ - Identifier: &gateway.Identifier{ - IdentifyData: gateway.IdentifyData{ - Intents: option.NewUint(uint(gateway.IntentGuilds | gateway.IntentGuildMembers)), - }, - }, - }, - }, + Session: session.New(""), Cabinet: mockCabinet(), }, } @@ -62,15 +53,7 @@ func TestAdminOnly(t *testing.T) { func TestGuildOnly(t *testing.T) { var ctx = &bot.Context{ State: &state.State{ - Session: &session.Session{ - Gateway: &gateway.Gateway{ - Identifier: &gateway.Identifier{ - IdentifyData: gateway.IdentifyData{ - Intents: option.NewUint(uint(gateway.IntentGuilds)), - }, - }, - }, - }, + Session: session.New(""), Cabinet: mockCabinet(), }, } diff --git a/utils/cmd/genevent/genevent.go b/utils/cmd/genevent/genevent.go new file mode 100644 index 0000000..59e5b5f --- /dev/null +++ b/utils/cmd/genevent/genevent.go @@ -0,0 +1,172 @@ +package main + +import ( + "bytes" + "flag" + "go/format" + "log" + "os" + "regexp" + "strconv" + "strings" + "text/template" + "unicode" + + _ "embed" + + "github.com/pkg/errors" +) + +var ( + pkg = "gateway" + out = "-" +) + +type registry struct { + PackageName string + EventTypes []EventType +} + +type EventType struct { + StructName string + EventName string + IsDispatch bool + OpCode int +} + +func (t *EventType) MethodRecv() string { + if len(t.StructName) == 0 { + return "e" + } + return string(unicode.ToLower([]rune(t.StructName)[0])) +} + +//go:embed template.tmpl +var packageTmpl string + +var tmpl = template.Must(template.New("").Parse(packageTmpl)) + +const eventStructRegex = "(?m)" + + `^// ([A-Za-z]+(?:Event|Command)) is (a dispatch event|an event|a command)` + + `(?:` + + ` for ([A-Z_]+)` + "|" + + ` for Op (\d+)` + + `)?` + + `\.(?:.|\n)*?\ntype ([A-Za-z]+(?:Event|Command)) .*` + +func main() { + flag.StringVar(&pkg, "p", pkg, "the package name to use") + flag.StringVar(&out, "o", out, "output file, - for stdout") + flag.Parse() + + log.Println(eventStructRegex) + + r := registry{ + PackageName: pkg, + } + + files, err := os.ReadDir(".") + if err != nil { + log.Fatalln("failed to read current directory:", err) + } + + for _, file := range files { + if file.IsDir() || !strings.HasSuffix(file.Name(), ".go") { + continue + } + if err := r.CrawlFile(file.Name()); err != nil { + log.Fatalln("failed to crawl file:", err) + } + } + + buf := bytes.Buffer{} + if err := tmpl.Execute(&buf, &r); err != nil { + log.Fatalln("failed to execute template:", err) + } + + b, err := format.Source(buf.Bytes()) + if err != nil { + log.Fatalln("failed to fmt:", err) + } + + output := os.Stdout + if out != "-" { + f, err := os.Create(out) + if err != nil { + log.Fatalln("failed to create output:", err) + } + defer f.Close() + + output = f + } + + if _, err := output.Write(b); err != nil { + log.Fatalln("failed to write rendered:", err) + } +} + +var reEventStruct = regexp.MustCompile(eventStructRegex) + +func (r *registry) CrawlFile(name string) error { + f, err := os.ReadFile(name) + if err != nil { + return errors.Wrap(err, "failed to read file") + } + + for _, match := range reEventStruct.FindAllSubmatch(f, -1) { + // Validity check. + if string(match[1]) != string(match[5]) { + continue + } + + if strings.HasSuffix(string(match[1]), "Command") && string(match[2]) != "a command" { + log.Println(string(match[1]), "has invalid comment %q", string(match[2])) + continue + } + + t := EventType{ + StructName: string(match[1]), + EventName: string(match[3]), + IsDispatch: string(match[2]) == "a dispatch event", + OpCode: -1, + } + + if op := string(match[4]); op != "" && !t.IsDispatch { + i, err := strconv.Atoi(op) + if err != nil { + log.Printf("error at struct %s: error parsing Op %v", t.StructName, err) + } + t.OpCode = i + } + + if t.IsDispatch && t.EventName == "" { + t.EventName = guessEventName(t.StructName) + } + + r.EventTypes = append(r.EventTypes, t) + } + + return nil +} + +func guessEventName(structName string) string { + name := strings.TrimSuffix(structName, "Event") + + var newName strings.Builder + newName.Grow(len(name) * 2) + + for i, r := range name { + if unicode.IsLower(r) { + newName.WriteRune(unicode.ToUpper(r)) + continue + } + + if i > 0 { + newName.WriteByte('_') + } + + newName.WriteRune(r) + } + + return newName.String() +} diff --git a/utils/cmd/genevent/template.tmpl b/utils/cmd/genevent/template.tmpl new file mode 100644 index 0000000..31af7bb --- /dev/null +++ b/utils/cmd/genevent/template.tmpl @@ -0,0 +1,27 @@ +// Code generated by genevent. DO NOT EDIT. + +package {{ .PackageName }} + +import "github.com/diamondburned/arikawa/v3/utils/ws" + +func init() { + OpUnmarshalers.Add( + {{ range .EventTypes -}} + func() ws.Event { return new({{ .StructName }}) }, + {{ end -}} + ) +} + +{{ range .EventTypes }} + +{{ if .IsDispatch }} +// Op implements Event. It always returns 0. +func (*{{ .StructName }}) Op() ws.OpCode { return dispatchOp } +{{ else if (gt .OpCode -1) }} +// Op implements Event. It always returns Op {{ .OpCode }}. +func (*{{ .StructName }}) Op() ws.OpCode { return {{ .OpCode }} } +{{ end }} + +// EventType implements Event. +func (*{{ .StructName }}) EventType() ws.EventType { return "{{ .EventName }}" } +{{ end }} diff --git a/utils/gensnowflake/main.go b/utils/cmd/gensnowflake/main.go similarity index 100% rename from utils/gensnowflake/main.go rename to utils/cmd/gensnowflake/main.go diff --git a/utils/gensnowflake/template.tmpl b/utils/cmd/gensnowflake/template.tmpl similarity index 100% rename from utils/gensnowflake/template.tmpl rename to utils/cmd/gensnowflake/template.tmpl diff --git a/utils/handler/handler.go b/utils/handler/handler.go index c56b17a..47ab8e8 100644 --- a/utils/handler/handler.go +++ b/utils/handler/handler.go @@ -239,6 +239,7 @@ type handler struct { chanclose reflect.Value // IsValid() if chan isIface bool isSync bool + isOnce bool } // newHandler reflects either a channel or a function into a handler. A function diff --git a/utils/ws/codec.go b/utils/ws/codec.go new file mode 100644 index 0000000..0cf1cd3 --- /dev/null +++ b/utils/ws/codec.go @@ -0,0 +1,101 @@ +package ws + +import ( + "io" + "net/http" + + "github.com/diamondburned/arikawa/v3/utils/json" + "github.com/pkg/errors" +) + +// Codec holds the codec states for Websocket implementations to share with the +// manager. It is used internally in the Websocket and the Connection +// implementation. +type Codec struct { + Unmarshalers OpUnmarshalers + Headers http.Header +} + +// NewCodec creates a new default Codec instance. +func NewCodec(unmarshalers OpUnmarshalers) Codec { + return Codec{ + Unmarshalers: unmarshalers, + Headers: http.Header{ + "Accept-Encoding": {"zlib"}, + }, + } +} + +type codecOp struct { + Op + Data json.Raw `json:"d,omitempty"` +} + +const maxSharedBufferSize = 1 << 15 // 32KB + +// DecodeBuffer boxes a byte slice to provide a shared and thread-unsafe buffer. +// It is used internally and should only be handled around as an opaque thing. +type DecodeBuffer struct { + buf []byte +} + +// NewDecodeBuffer creates a new preallocated DecodeBuffer. +func NewDecodeBuffer(cap int) DecodeBuffer { + if cap > maxSharedBufferSize { + cap = maxSharedBufferSize + } + + return DecodeBuffer{ + buf: make([]byte, 0, cap), + } +} + +// DecodeFrom reads the given reader and decodes it into an Op. +// +// buf is optional. +func (c Codec) DecodeFrom(r io.Reader, buf *DecodeBuffer) Op { + var op codecOp + op.Data = json.Raw(buf.buf) + + if err := json.DecodeStream(r, &op); err != nil { + return newErrOp(err, "cannot read JSON stream") + } + + // buf isn't grown from here out. Set it back right now. If Data hasn't been + // grown, then this will just set buf back to what it was. + if cap(op.Data) < maxSharedBufferSize { + buf.buf = op.Data[:0] + } + + fn := c.Unmarshalers.Lookup(op.Code, op.Type) + if fn == nil { + err := UnknownEventError{ + Op: op.Code, + Type: op.Type, + } + return newErrOp(err, "") + } + + op.Op.Data = fn() + if err := op.Data.UnmarshalTo(op.Op.Data); err != nil { + return newErrOp(err, "cannot unmarshal JSON data from gateway") + } + + return op.Op +} + +func newErrOp(err error, wrap string) Op { + if wrap != "" { + err = errors.Wrap(err, wrap) + } + + ev := &BackgroundErrorEvent{ + Err: err, + } + + return Op{ + Code: ev.Op(), + Type: ev.EventType(), + Data: ev, + } +} diff --git a/utils/ws/conn.go b/utils/ws/conn.go new file mode 100644 index 0000000..0027bca --- /dev/null +++ b/utils/ws/conn.go @@ -0,0 +1,278 @@ +package ws + +import ( + "compress/zlib" + "context" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" +) + +const rwBufferSize = 1 << 15 // 32KB + +// ErrWebsocketClosed is returned if the websocket is already closed. +var ErrWebsocketClosed = errors.New("websocket is closed") + +// Connection is an interface that abstracts around a generic Websocket driver. +// This connection expects the driver to handle compression by itself, including +// modifying the connection URL. The implementation doesn't have to be safe for +// concurrent use. +type Connection interface { + // Dial dials the address (string). Context needs to be passed in for + // timeout. This method should also be re-usable after Close is called. + Dial(context.Context, string) (<-chan Op, error) + + // Send allows the caller to send bytes. + Send(context.Context, []byte) error + + // Close should close the websocket connection. The underlying connection + // may be reused, but this Connection instance will be reused with Dial. The + // Connection must still be reusable even if Close returns an error. If + // gracefully is true, then the implementation must send a close frame + // prior. + Close(gracefully bool) error +} + +// Conn is the default Websocket connection. It tries to compresses all payloads +// using zlib. +type Conn struct { + dialer websocket.Dialer + codec Codec + + // conn is used for synchronizing the conn instance itself. Any use of conn + // must copy conn out. + conn *connMutex + // mut is used for synchronizing the conn field. + mut sync.Mutex + + // CloseTimeout is the timeout for graceful closing. It's defaulted to 5s. + CloseTimeout time.Duration +} + +type connMutex struct { + wrmut chan struct{} + *websocket.Conn +} + +var _ Connection = (*Conn)(nil) + +// NewConn creates a new default websocket connection with a default dialer. +func NewConn(codec Codec) *Conn { + return NewConnWithDialer(codec, websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 10 * time.Second, + ReadBufferSize: rwBufferSize, + WriteBufferSize: rwBufferSize, + EnableCompression: true, + }) +} + +// NewConnWithDialer creates a new default websocket connection with a custom +// dialer. +func NewConnWithDialer(codec Codec, dialer websocket.Dialer) *Conn { + return &Conn{ + dialer: dialer, + codec: codec, + CloseTimeout: 5 * time.Second, + } +} + +// Dial starts a new connection and returns the listening channel for it. If the +// websocket is already dialed, then the connection is closed first. +func (c *Conn) Dial(ctx context.Context, addr string) (<-chan Op, error) { + // BUG which prevents stream compression. + // See https://github.com/golang/go/issues/31514. + + c.mut.Lock() + defer c.mut.Unlock() + + // Ensure that the connection is already closed. + if c.conn != nil { + c.conn.close(c.CloseTimeout, false) + } + + conn, _, err := c.dialer.DialContext(ctx, addr, c.codec.Headers) + if err != nil { + return nil, errors.Wrap(err, "failed to dial WS") + } + + events := make(chan Op, 1) + go readLoop(conn, c.codec, events) + + c.conn = &connMutex{ + wrmut: make(chan struct{}, 1), + Conn: conn, + } + + return events, err +} + +// Close implements Connection. +func (c *Conn) Close(gracefully bool) error { + c.mut.Lock() + defer c.mut.Unlock() + + return c.conn.close(c.CloseTimeout, gracefully) +} + +func (c *connMutex) close(timeout time.Duration, gracefully bool) error { + if c == nil || c.Conn == nil { + WSDebug("Conn: Close is called on already closed connection") + return ErrWebsocketClosed + } + + WSDebug("Conn: Close is called; shutting down the Websocket connection.") + + if gracefully { + // Have a deadline before closing. + deadline := time.Now().Add(timeout) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + select { + case c.wrmut <- struct{}{}: + // Lock acquired. We can now safely set the deadline and write. + c.SetWriteDeadline(deadline) + + WSDebug("Conn: Graceful closing requested, sending close frame.") + + if err := c.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ); err != nil { + WSError(err) + } + + // Release the lock. + <-c.wrmut + + case <-ctx.Done(): + // We couldn't acquire the lock. Resort to just closing the + // connection directly. + } + } + + // Close the WS. + err := c.Conn.Close() + + if err != nil { + WSDebug("Conn: Websocket closed; error:", err) + } else { + WSDebug("Conn: Websocket closed successfully") + } + + c.Conn = nil + return err +} + +// resetDeadline is used to reset the write deadline after using the context's. +var resetDeadline = time.Time{} + +// Send implements Connection. +func (c *Conn) Send(ctx context.Context, b []byte) error { + c.mut.Lock() + conn := c.conn + c.mut.Unlock() + + select { + case conn.wrmut <- struct{}{}: + defer func() { <-conn.wrmut }() + + if ctx != context.Background() { + d, ok := ctx.Deadline() + if ok { + conn.SetWriteDeadline(d) + defer conn.SetWriteDeadline(resetDeadline) + } + } + + return conn.WriteMessage(websocket.TextMessage, b) + case <-ctx.Done(): + return ctx.Err() + } +} + +// loopState is a thread-unsafe disposable state container for the read loop. +// It's made to completely separate the read loop of any synchronization that +// doesn't involve the websocket connection itself. +type loopState struct { + conn *websocket.Conn + codec Codec + zlib io.ReadCloser + buf DecodeBuffer +} + +func readLoop(conn *websocket.Conn, codec Codec, opCh chan<- Op) { + // Clean up the events channel in the end. + defer close(opCh) + + // Allocate the read loop its own private resources. + state := loopState{ + conn: conn, + codec: codec, + buf: NewDecodeBuffer(1 << 14), // 16KB + } + + for { + b, err := state.handle() + if err != nil { + WSDebug("Conn: fatal Conn error:", err) + + closeEv := &CloseEvent{ + Err: err, + Code: -1, + } + + var closeErr *websocket.CloseError + if errors.As(err, &closeErr) { + closeEv.Code = closeErr.Code + closeEv.Err = fmt.Errorf("%d %s", closeErr.Code, closeErr.Text) + } + + opCh <- Op{ + Code: closeEv.Op(), + Type: closeEv.EventType(), + Data: closeEv, + } + + return + } + + opCh <- b + } +} + +func (state *loopState) handle() (Op, error) { + // skip message type + t, r, err := state.conn.NextReader() + if err != nil { + return Op{}, err + } + + if t == websocket.BinaryMessage { + // Probably a zlib payload. + + if state.zlib == nil { + z, err := zlib.NewReader(r) + if err != nil { + return Op{}, errors.Wrap(err, "failed to create a zlib reader") + } + state.zlib = z + } else { + if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil { + return Op{}, errors.Wrap(err, "failed to reset zlib reader") + } + } + + defer state.zlib.Close() + r = state.zlib + } + + return state.codec.DecodeFrom(r, &state.buf), nil +} diff --git a/utils/ws/gateway.go b/utils/ws/gateway.go new file mode 100644 index 0000000..d86c02a --- /dev/null +++ b/utils/ws/gateway.go @@ -0,0 +1,364 @@ +package ws + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/diamondburned/arikawa/v3/internal/lazytime" + "github.com/diamondburned/arikawa/v3/utils/json" + "github.com/pkg/errors" +) + +// ConnectionError is given to the user if the gateway fails to connect to the +// gateway for any reason, including during an initial connection or a +// reconnection. To check for this error, use the errors.As function. +type ConnectionError struct { + Err error +} + +// Unwrap unwraps the ConnectionError. +func (err ConnectionError) Unwrap() error { return err.Err } + +// Error formats the error. +func (err ConnectionError) Error() string { + return fmt.Sprintf("error reconnecting: %s", err.Err) +} + +// BackgroundErrorEvent describes an error that the gateway event loop might +// stumble upon while it's running. See Gateway's documentation for possible +// usages. +type BackgroundErrorEvent struct { + Err error +} + +var _ Event = (*BackgroundErrorEvent)(nil) + +// Unwrap returns err.Err. +func (err *BackgroundErrorEvent) Unwrap() error { return err.Err } + +// Error formats the BackgroundErrorEvent. +func (err *BackgroundErrorEvent) Error() string { + return "background gateway error: " + err.Err.Error() +} + +// Op implements Op. It returns -1. +func (err *BackgroundErrorEvent) Op() OpCode { return -1 } + +// EventType implements Op. It returns an opaque unique string. +func (err *BackgroundErrorEvent) EventType() EventType { + return "__ws.BackgroundErrorEvent" +} + +// GatewayOpts describes the gateway event loop options. +type GatewayOpts struct { + // ReconnectDelay determines the duration to idle after each failed retry. + // This can be used to implement exponential backoff. The default is already + // sane, so this field rarely needs to be changed. + ReconnectDelay func(try int) time.Duration + + // FatalCloseCodes is a list of close codes that will cause the gateway to + // exit out if it stumbles on one of these. It is a copy of FatalCloseCodes + // (the global variable) by default. + FatalCloseCodes []int + + // DialTimeout is the timeout to wait for each websocket dial before failing + // it and retrying. Default is 0. + DialTimeout time.Duration + + // ReconnectAttempt is the maximum number of attempts made to Reconnect + // before aborting the whole gateway. If this set to 0, unlimited attempts + // will be made. Default is 0. + ReconnectAttempt int + + // AlwaysCloseGracefully, if true, will always make the Gateway close + // gracefully once the context given to Open is cancelled. It governs the + // Close behavior. The default is true. + AlwaysCloseGracefully bool +} + +// DefaultGatewayOpts is the default event loop options. +var DefaultGatewayOpts = GatewayOpts{ + ReconnectDelay: func(try int) time.Duration { + // minimum 4 seconds + return time.Duration(4+(2*try)) * time.Second + }, + DialTimeout: 0, + ReconnectAttempt: 0, + AlwaysCloseGracefully: true, +} + +// Gateway describes an instance that handles the Discord gateway. It is +// basically an abstracted concurrent event loop that the user could signal to +// start connecting to the Discord gateway server. +type Gateway struct { + ws *Websocket + + reconnect chan struct{} + heart lazytime.Ticker + srcOp <-chan Op // from WS + outer outerState + lastError error + + opts GatewayOpts +} + +// outerState holds gateway state that the caller may change concurrently. As +// such, it holds a mutex to allow that. The main purpose of this +// synchronization is to allow the caller to use the gateway while the event +// loop is still running without having the event loop muddle in without locking +// properly. For example, opCh is given to the event loop as a copy; the event +// loop must never access the outerState directly. +type outerState struct { + sync.Mutex + ch chan Op + started bool +} + +// Handler describes a gateway handler. It describes the core that governs the +// behavior of the gateway event loop. +type Handler interface { + // OnOp is called by the gateway event loop on every new Op. If the returned + // boolean is false, then the loop fatally exits. + OnOp(context.Context, Op) (canContinue bool) + // SendHeartbeat is called by the gateway event loop everytime a heartbeat + // needs to be sent over. + SendHeartbeat(context.Context) + // Close closes the handler. + Close() error +} + +// NewGateway creates a new Gateway with a custom gateway URL and a pre-existing +// Identifier. If opts is nil, then DefaultOpts is used. +func NewGateway(ws *Websocket, opts *GatewayOpts) *Gateway { + if opts == nil { + opts = &DefaultGatewayOpts + } + + return &Gateway{ + ws: ws, + opts: *opts, + } +} + +// Send is a function to send an Op payload to the Gateway. +func (g *Gateway) Send(ctx context.Context, data Event) error { + op := Op{ + Code: data.Op(), + Type: data.EventType(), + Data: data, + } + + b, err := json.Marshal(op) + if err != nil { + return errors.Wrap(err, "failed to encode payload") + } + + // WS should already be thread-safe. + return g.ws.Send(ctx, b) +} + +// HasStarted returns true if the gateway event loop is currently spinning. +func (g *Gateway) HasStarted() bool { + g.outer.Lock() + defer g.outer.Unlock() + + return g.outer.started +} + +// AssertIsNotRunning asserts that the gateway is currently not running. If the +// gateway is running, the method will panic. Since a gateway cannot be started +// back up, this method can be used to detect whether or not the caller in a +// single goroutine can read the state safely. +func (g *Gateway) AssertIsNotRunning() { + g.outer.Lock() + defer g.outer.Unlock() + + if !g.outer.started { + return + } + + // Hack to ensure that the event channel is closed. + select { + case _, ok := <-g.outer.ch: + if !ok { + return + } + // The panic behavior is a must, because if this branch is hit, then + // we've actually stolen an event from the channel unexpectedly, putting + // the event loop under a weird state. + // + // An alternative solution to this bug would be to mutex-guard the error + // field, but the purpose of this method isn't to be called before the + // gateway has been stopped. + panic("ws: Error called while Gateway is still running") + default: + panic("ws: Error called while Gateway is still running") + } +} + +// Connect starts the background goroutine that tries its best to maintain a +// stable connection to the Websocket gateway. To the user, the gateway should +// appear to be working seamlessly. +// +// For more documentation, refer to (*gateway.Gateway).Connect. +func (g *Gateway) Connect(ctx context.Context, h Handler) <-chan Op { + g.outer.Lock() + defer g.outer.Unlock() + + if !g.outer.started { + g.outer.started = true + g.outer.ch = make(chan Op, 1) + go g.spin(ctx, h) + } + + return g.outer.ch +} + +// LastError returns the last error that the gateway has received. +func (g *Gateway) LastError() error { + g.AssertIsNotRunning() + return g.lastError +} + +// finalize closes the gateway permanently. +func (g *Gateway) finalize(h Handler) { + var err error + + if g.opts.AlwaysCloseGracefully { + err = g.ws.CloseGracefully() + } else { + err = g.ws.Close() + } + + if err != nil { + g.SendErrorWrap(err, "failed to finalize websocket") + } + + if err := h.Close(); err != nil { + g.SendError(err) + } + + g.outer.Lock() + close(g.outer.ch) + g.outer.started = false + g.outer.Unlock() +} + +// QueueReconnect queues a reconnection in the gateway loop. This method should +// only be called in the event loop ONCE; calling more than once will deadlock +// the loop. +func (g *Gateway) QueueReconnect() { + select { + case g.reconnect <- struct{}{}: + default: + } + + g.heart.Stop() +} + +// ResetHeartbeat resets the heartbeat to be the given duration. +func (g *Gateway) ResetHeartbeat(d time.Duration) { + g.heart.Reset(d) +} + +// SendError sends the given error wrapped in a BackgroundErrorEvent into the +// event channel. +func (g *Gateway) SendError(err error) { + event := &BackgroundErrorEvent{err} + + g.outer.ch <- Op{ + Code: event.Op(), + Type: event.EventType(), + Data: event, + } + g.lastError = err +} + +// SendErrorWrap is a convenient function over SendError. +func (g *Gateway) SendErrorWrap(err error, message string) { + g.SendError(errors.Wrap(err, message)) +} + +func (g *Gateway) spin(ctx context.Context, h Handler) { + // Always close the event channel once we exit. + defer g.finalize(h) + + var retryTimer lazytime.Timer + defer retryTimer.Stop() + + g.reconnect = make(chan struct{}, 1) + g.reconnect <- struct{}{} + + for { + select { + case <-ctx.Done(): + return + + case op, ok := <-g.srcOp: + if !ok { + // Skip zero-value Ops that may happen on gateway closure. + continue + } + + switch data := op.Data.(type) { + case *CloseEvent: + for _, code := range g.opts.FatalCloseCodes { + if code == data.Code { + // Don't wrap the error, but instead, just pipe it as-is + // through the channel. + g.outer.ch <- op + g.lastError = data + return + } + } + } + + ok = h.OnOp(ctx, op) + g.outer.ch <- op + if !ok { + return + } + + case <-g.heart.C: + h.SendHeartbeat(ctx) + + case <-g.reconnect: + // Close the previous connection if it's not already. Ignore the + // already closed error. + if err := g.ws.Close(); err != nil && !errors.Is(err, ErrWebsocketClosed) { + g.SendErrorWrap(err, "error closing before reconnecting") + } + + // Invalidate our srcOp. + g.srcOp = nil + + // Keep track of the last error for notifying. + var err error + + for try := 0; g.opts.ReconnectAttempt == 0 || try < g.opts.ReconnectAttempt; try++ { + g.srcOp, err = g.ws.Dial(ctx) + if err == nil { + break + } + + // Signal an error before retrying. + g.SendError(ConnectionError{err}) + + retryTimer.Reset(g.opts.ReconnectDelay(try)) + if err := retryTimer.Wait(ctx); err != nil { + g.SendError(ConnectionError{ctx.Err()}) + return + } + } + + // Ensure that we've reconnected successfully. Exit otherwise. + if g.srcOp == nil { + err = errors.Wrap(err, "failed to reconnect after max attempts") + g.SendError(ConnectionError{err}) + return + } + } + } +} diff --git a/utils/ws/op.go b/utils/ws/op.go new file mode 100644 index 0000000..08c99a8 --- /dev/null +++ b/utils/ws/op.go @@ -0,0 +1,212 @@ +package ws + +import ( + "context" + "fmt" + "sync" + + "github.com/pkg/errors" +) + +// OpCode is the type for websocket Op codes. Op codes less than 0 are +// internal Op codes and should usually be ignored. +type OpCode int + +// CloseEvent is an event that is given from wsutil when the websocket is +// closed. +type CloseEvent struct { + // Err is the underlying error. + Err error + // Code is the websocket close code, if any. + Code int +} + +// Unwrap returns err.Err. +func (e *CloseEvent) Unwrap() error { return e.Err } + +// Error formats the CloseEvent. A CloseEvent is also an error. +func (e *CloseEvent) Error() string { + return fmt.Sprintf("websocket closed, reason: %s", e.Err) +} + +// Op implements Event. It returns -1. +func (e *CloseEvent) Op() OpCode { return -1 } + +// EventType implements Event. It returns an emty string. +func (e *CloseEvent) EventType() EventType { return "__ws.CloseEvent" } + +// EventType is a type for event types, which is the "t" field in the payload. +type EventType string + +// Event describes an Event data that comes from a gateway Operation. +type Event interface { + Op() OpCode + EventType() EventType +} + +// OpFunc is a constructor function for an Operation. +type OpFunc func() Event + +// OpUnmarshalers contains a map of event constructor function. +type OpUnmarshalers struct { + r map[opFuncID]OpFunc +} + +type opFuncID struct { + Op OpCode `json:"op"` + T EventType `json:"t"` +} + +// NewOpUnmarshalers creates a nwe OpUnmarshalers instance from the given +// constructor functions. +func NewOpUnmarshalers(funcs ...OpFunc) OpUnmarshalers { + m := OpUnmarshalers{r: make(map[opFuncID]OpFunc)} + m.Add(funcs...) + return m +} + +// Each iterates over the marshaler map. +func (m OpUnmarshalers) Each(f func(OpCode, EventType, OpFunc) (done bool)) { + for id, fn := range m.r { + if f(id.Op, id.T, fn) { + return + } + } +} + +// Add adds the given functions into the unmarshaler registry. +func (m OpUnmarshalers) Add(funcs ...OpFunc) { + for _, fn := range funcs { + ev := fn() + id := opFuncID{ + Op: ev.Op(), + T: ev.EventType(), + } + + m.r[id] = fn + } +} + +// Lookup searches the OpMarshalers map for the given constructor function. +func (m OpUnmarshalers) Lookup(op OpCode, t EventType) OpFunc { + return m.r[opFuncID{op, t}] +} + +// Op is a gateway Operation. +type Op struct { + Code OpCode `json:"op"` + Data Event `json:"d,omitempty"` + + // Type is only for gateway dispatch events. + Type EventType `json:"t,omitempty"` + // Sequence is only for gateway dispatch events (Op 0). + Sequence int64 `json:"s,omitempty"` +} + +// UnknownEventError is required by HandleOp if an event is encountered that is +// not known. Internally, unknown events are logged and ignored. It is not a +// fatal error. +type UnknownEventError struct { + Op OpCode + Type EventType +} + +// Error formats the unknown event error to with the event name and payload +func (err UnknownEventError) Error() string { + return fmt.Sprintf("unknown op %d, event %s", err.Op, err.Type) +} + +// IsBrokenConnection returns true if the error is a broken connection error. +func IsUnknownEvent(err error) bool { + var uevent *UnknownEventError + return errors.As(err, &uevent) +} + +// ReadOps reads maximum n Ops and accumulate them into a slice. +func ReadOps(ctx context.Context, ch <-chan Op, n int) ([]Op, error) { + ops := make([]Op, 0, n) + for { + select { + case <-ctx.Done(): + return ops, ctx.Err() + case op := <-ch: + ops = append(ops, op) + if len(ops) == n { + return ops, nil + } + } + } +} + +// ReadOp reads a single Op. +func ReadOp(ctx context.Context, ch <-chan Op) (Op, error) { + select { + case <-ctx.Done(): + return Op{}, ctx.Err() + case op := <-ch: + return op, nil + } +} + +// Broadcaster is primarily used for debugging. +type Broadcaster struct { + src <-chan Op + dst map[chan<- Op]struct{} + mut sync.Mutex + void bool +} + +// NewBroadcaster creates a new broadcaster. +func NewBroadcaster(src <-chan Op) *Broadcaster { + return &Broadcaster{ + src: src, + dst: make(map[chan<- Op]struct{}), + } +} + +// Start starts the broadcasting loop. +func (b *Broadcaster) Start() { + b.mut.Lock() + if b.void { + panic("Start called on voided Broadcaster") + } + b.mut.Unlock() + + go func() { + for op := range b.src { + b.mut.Lock() + + for ch := range b.dst { + ch <- op + } + + b.mut.Unlock() + } + + b.mut.Lock() + b.void = true + + for ch := range b.dst { + close(ch) + } + + b.mut.Unlock() + }() +} + +// Subscribe subscribes the given channel +func (b *Broadcaster) Subscribe(ch chan<- Op) { + b.mut.Lock() + if b.void { + panic("Subscribe called on voided Broadcaster") + } + b.dst[ch] = struct{}{} + b.mut.Unlock() +} + +// NewSubscribed creates a newly subscribed Op channel. +func (b *Broadcaster) NewSubscribed() <-chan Op { + ch := make(chan Op, 1) + b.Subscribe(ch) + return ch +} diff --git a/utils/ws/ophandler/ophandler.go b/utils/ws/ophandler/ophandler.go new file mode 100644 index 0000000..920d508 --- /dev/null +++ b/utils/ws/ophandler/ophandler.go @@ -0,0 +1,35 @@ +// Package ophandler provides an Op channel reader that redistributes the events +// into handlers. +package ophandler + +import ( + "context" + + "github.com/diamondburned/arikawa/v3/utils/handler" + "github.com/diamondburned/arikawa/v3/utils/ws" +) + +// Loop starts a background goroutine that starts reading from src and +// distributes received events into the given handler. It's stopped once src is +// closed. The returned channel will be closed once src is closed. +func Loop(src <-chan ws.Op, dst *handler.Handler) <-chan struct{} { + done := make(chan struct{}) + go func() { + for op := range src { + dst.Call(op.Data) + } + close(done) + }() + return done +} + +// WaitForDone waits for the done channel returned by Loop until the channel is +// closed or the context expires. +func WaitForDone(ctx context.Context, done <-chan struct{}) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} diff --git a/utils/wsutil/throttler.go b/utils/ws/throttler.go similarity index 96% rename from utils/wsutil/throttler.go rename to utils/ws/throttler.go index 936553c..2e93f21 100644 --- a/utils/wsutil/throttler.go +++ b/utils/ws/throttler.go @@ -1,4 +1,4 @@ -package wsutil +package ws import ( "time" diff --git a/utils/ws/ws.go b/utils/ws/ws.go new file mode 100644 index 0000000..774dd0e --- /dev/null +++ b/utils/ws/ws.go @@ -0,0 +1,118 @@ +// Package wsutil provides abstractions around the Websocket, including rate +// limits. +package ws + +import ( + "context" + "log" + "sync" + + "github.com/pkg/errors" + "golang.org/x/time/rate" +) + +var ( + // WSError is the default error handler + WSError = func(err error) { log.Println("Gateway error:", err) } + // WSDebug is used for extra debug logging. This is expected to behave + // similarly to log.Println(). + WSDebug = func(v ...interface{}) {} +) + +// Websocket is a wrapper around a websocket Conn with thread safety and rate +// limiting for sending and throttling. +type Websocket struct { + mutex sync.Mutex + conn Connection + addr string + + // If you ever need access to these fields from outside the package, please + // open an issue. It might be worth it to refactor these out for distributed + // sharding. + + sendLimiter *rate.Limiter + dialLimiter *rate.Limiter +} + +// NewWebsocket creates a default Websocket with the given address. +func NewWebsocket(c Codec, addr string) *Websocket { + return NewCustomWebsocket(NewConn(c), addr) +} + +// NewCustomWebsocket creates a new undialed Websocket. +func NewCustomWebsocket(conn Connection, addr string) *Websocket { + return &Websocket{ + conn: conn, + addr: addr, + + sendLimiter: NewSendLimiter(), + dialLimiter: NewDialLimiter(), + } +} + +// Dial waits until the rate limiter allows then dials the websocket. +func (ws *Websocket) Dial(ctx context.Context) (<-chan Op, error) { + if err := ws.dialLimiter.Wait(ctx); err != nil { + // Expired, fatal error + return nil, errors.Wrap(err, "failed to wait for dial rate limiter") + } + + ws.mutex.Lock() + defer ws.mutex.Unlock() + + // Reset the send limiter. + // TODO: see if each limit only applies to one connection or not. + ws.sendLimiter = NewSendLimiter() + + return ws.conn.Dial(ctx, ws.addr) +} + +// Send sends b over the Websocket with a deadline. It closes the internal +// Websocket if the Send method errors out. +func (ws *Websocket) Send(ctx context.Context, b []byte) error { + WSDebug("Acquiring the websoccket mutex for sending.") + + ws.mutex.Lock() + WSDebug("Mutex lock acquired.") + sendLimiter := ws.sendLimiter + conn := ws.conn + ws.mutex.Unlock() + + WSDebug("Waiting for the send rate limiter...") + + if err := sendLimiter.Wait(ctx); err != nil { + WSDebug("Send rate limiter timed out.") + return errors.Wrap(err, "SendLimiter failed") + } + + WSDebug("Send has passed the rate limiting. Waiting on mutex.") + + return conn.Send(ctx, b) +} + +// Close closes the websocket connection. It assumes that the Websocket is +// closed even when it returns an error. If the Websocket was already closed +// before, ErrWebsocketClosed will be returned. +func (ws *Websocket) Close() error { + WSDebug("Conn: Acquiring mutex lock to close...") + + ws.mutex.Lock() + defer ws.mutex.Unlock() + + WSDebug("Conn: Write mutex acquired") + + return ws.conn.Close(false) +} + +// CloseGracefully is similar to Close, but a proper close frame is sent to +// Discord, invalidating the internal session ID and voiding resumes. +func (ws *Websocket) CloseGracefully() error { + WSDebug("Conn: Acquiring mutex lock to close...") + + ws.mutex.Lock() + defer ws.mutex.Unlock() + + WSDebug("Conn: Write mutex acquired") + + return ws.conn.Close(true) +} diff --git a/utils/wsutil/conn.go b/utils/wsutil/conn.go deleted file mode 100644 index d6e773b..0000000 --- a/utils/wsutil/conn.go +++ /dev/null @@ -1,270 +0,0 @@ -package wsutil - -import ( - "bytes" - "compress/zlib" - "context" - "io" - "net/http" - "strings" - "time" - - "github.com/gorilla/websocket" - "github.com/pkg/errors" -) - -// CopyBufferSize is used for the initial size of the internal WS' buffer. Its -// size is 4KB. -var CopyBufferSize = 4096 - -// MaxCapUntilReset determines the maximum capacity before the bytes buffer is -// re-allocated. It is roughly 16KB, quadruple CopyBufferSize. -var MaxCapUntilReset = CopyBufferSize * 4 - -// CloseDeadline controls the deadline to wait for sending the Close frame. -var CloseDeadline = time.Second - -// ErrWebsocketClosed is returned if the websocket is already closed. -var ErrWebsocketClosed = errors.New("websocket is closed") - -// Connection is an interface that abstracts around a generic Websocket driver. -// This connection expects the driver to handle compression by itself, including -// modifying the connection URL. The implementation doesn't have to be safe for -// concurrent use. -type Connection interface { - // Dial dials the address (string). Context needs to be passed in for - // timeout. This method should also be re-usable after Close is called. - Dial(context.Context, string) error - - // Listen returns an event channel that sends over events constantly. It can - // return nil if there isn't an ongoing connection. - Listen() <-chan Event - - // Send allows the caller to send bytes. It does not need to clean itself - // up on errors, as the Websocket wrapper will do that. - // - // If the data is nil, it should send a close frame - Send(context.Context, []byte) error - - // Close should close the websocket connection. The underlying connection - // may be reused, but this Connection instance will be reused with Dial. The - // Connection must still be reusable even if Close returns an error. - Close() error - // CloseGracefully sends a close frame and then closes the websocket - // connection. - CloseGracefully() error -} - -// Conn is the default Websocket connection. It tries to compresses all payloads -// using zlib. -type Conn struct { - Dialer websocket.Dialer - Header http.Header - Conn *websocket.Conn - events chan Event -} - -var _ Connection = (*Conn)(nil) - -// NewConn creates a new default websocket connection with a default dialer. -func NewConn() *Conn { - return NewConnWithDialer(websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: WSTimeout, - ReadBufferSize: CopyBufferSize, - WriteBufferSize: CopyBufferSize, - EnableCompression: true, - }) -} - -// NewConnWithDialer creates a new default websocket connection with a custom -// dialer. -func NewConnWithDialer(dialer websocket.Dialer) *Conn { - return &Conn{ - Dialer: dialer, - Header: http.Header{ - "Accept-Encoding": {"zlib"}, - }, - } -} - -func (c *Conn) Dial(ctx context.Context, addr string) (err error) { - // BUG which prevents stream compression. - // See https://github.com/golang/go/issues/31514. - - c.Conn, _, err = c.Dialer.DialContext(ctx, addr, c.Header) - if err != nil { - return errors.Wrap(err, "failed to dial WS") - } - - // Reset the deadline. - c.Conn.SetWriteDeadline(resetDeadline) - - c.events = make(chan Event, WSBuffer) - go startReadLoop(c.Conn, c.events) - - return err -} - -// Listen returns an event channel if there is a connection associated with it. -// It returns nil if there is none. -func (c *Conn) Listen() <-chan Event { - return c.events -} - -// resetDeadline is used to reset the write deadline after using the context's. -var resetDeadline = time.Time{} - -func (c *Conn) Send(ctx context.Context, b []byte) error { - d, ok := ctx.Deadline() - if ok { - c.Conn.SetWriteDeadline(d) - defer c.Conn.SetWriteDeadline(resetDeadline) - } - - if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil { - return err - } - - return nil -} - -func (c *Conn) Close() error { - WSDebug("Conn: Close is called; shutting down the Websocket connection.") - - // Have a deadline before closing. - var deadline = time.Now().Add(5 * time.Second) - c.Conn.SetWriteDeadline(deadline) - - // Close the WS. - err := c.Conn.Close() - - c.Conn.SetWriteDeadline(resetDeadline) - - WSDebug("Conn: Websocket closed; error:", err) - WSDebug("Conn: Flushing events...") - - // Flush all events before closing the channel. This will return as soon as - // c.events is closed, or after closed. - for range c.events { - } - - WSDebug("Flushed events.") - - return err -} - -func (c *Conn) CloseGracefully() error { - WSDebug("Conn: CloseGracefully is called; sending close frame.") - - c.Conn.SetWriteDeadline(time.Now().Add(CloseDeadline)) - - err := c.Conn.WriteMessage(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - WSError(err) - } - - WSDebug("Conn: Close frame sent; error:", err) - - return c.Close() -} - -// loopState is a thread-unsafe disposable state container for the read loop. -// It's made to completely separate the read loop of any synchronization that -// doesn't involve the websocket connection itself. -type loopState struct { - conn *websocket.Conn - zlib io.ReadCloser - buf bytes.Buffer -} - -func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) { - // Clean up the events channel in the end. - defer close(eventCh) - - // Allocate the read loop its own private resources. - state := loopState{conn: conn} - state.buf.Grow(CopyBufferSize) - - for { - b, err := state.handle() - if err != nil { - WSDebug("Conn: Read error:", err) - - // Is the error an EOF? - if errors.Is(err, io.EOF) { - // Yes it is, exit. - return - } - - // Is the error an intentional close call? Go 1.16 exposes - // ErrClosing, but we have to do this for now. - if strings.HasSuffix(err.Error(), "use of closed network connection") { - return - } - - // Unusual error; log and exit: - eventCh <- Event{nil, errors.Wrap(err, "WS error")} - return - } - - // If the payload length is 0, skip it. - if len(b) == 0 { - continue - } - - eventCh <- Event{b, nil} - } -} - -func (state *loopState) handle() ([]byte, error) { - // skip message type - t, r, err := state.conn.NextReader() - if err != nil { - return nil, err - } - - if t == websocket.BinaryMessage { - // Probably a zlib payload. - - if state.zlib == nil { - z, err := zlib.NewReader(r) - if err != nil { - return nil, errors.Wrap(err, "failed to create a zlib reader") - } - state.zlib = z - } else { - if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil { - return nil, errors.Wrap(err, "failed to reset zlib reader") - } - } - - defer state.zlib.Close() - r = state.zlib - } - - return state.readAll(r) -} - -// readAll reads bytes into an existing buffer, copy it over, then wipe the old -// buffer. -func (state *loopState) readAll(r io.Reader) ([]byte, error) { - defer state.buf.Reset() - - if _, err := state.buf.ReadFrom(r); err != nil { - return nil, err - } - - // Copy the bytes so we could empty the buffer for reuse. - cpy := make([]byte, state.buf.Len()) - copy(cpy, state.buf.Bytes()) - - // If the buffer's capacity is over the limit, then re-allocate a new one. - if state.buf.Cap() > MaxCapUntilReset { - state.buf = bytes.Buffer{} - state.buf.Grow(CopyBufferSize) - } - - return cpy, nil -} diff --git a/utils/wsutil/heart.go b/utils/wsutil/heart.go deleted file mode 100644 index d2e5be0..0000000 --- a/utils/wsutil/heart.go +++ /dev/null @@ -1,151 +0,0 @@ -package wsutil - -import ( - "context" - "time" - - "github.com/pkg/errors" - - "github.com/diamondburned/arikawa/v3/internal/heart" -) - -type brokenConnectionError struct { - underneath error -} - -// Error formats the broken connection error with the message "explicit -// connection break." -func (err brokenConnectionError) Error() string { - return "explicit connection break: " + err.underneath.Error() -} - -// Unwrap returns the underlying error. -func (err brokenConnectionError) Unwrap() error { - return err.underneath -} - -// ErrBrokenConnection marks the given error as a broken connection error. This -// error will cause the pacemaker loop to break and return the error. The error, -// when stringified, will say "explicit connection break." -func ErrBrokenConnection(err error) error { - return brokenConnectionError{underneath: err} -} - -// IsBrokenConnection returns true if the error is a broken connection error. -func IsBrokenConnection(err error) bool { - var broken *brokenConnectionError - return errors.As(err, &broken) -} - -// TODO API -type EventLoopHandler interface { - EventHandler - HeartbeatCtx(context.Context) error -} - -// PacemakerLoop provides an event loop with a pacemaker. A zero-value instance -// is a valid instance only when RunAsync is called first. -type PacemakerLoop struct { - heart.Pacemaker - Extras ExtraHandlers - ErrorLog func(error) - - events <-chan Event - control chan func() - handler func(*OP) error -} - -func (p *PacemakerLoop) errorLog(err error) { - if p.ErrorLog == nil { - WSDebug("Uncaught error:", err) - return - } - - p.ErrorLog(err) -} - -// Pace calls the pacemaker's Pace function. -func (p *PacemakerLoop) Pace(ctx context.Context) error { - return p.Pacemaker.PaceCtx(ctx) -} - -// StartBeating asynchronously starts the pacemaker loop. -func (p *PacemakerLoop) StartBeating(pace time.Duration, evl EventLoopHandler, exit func(error)) { - WSDebug("Starting the pacemaker loop.") - - p.Pacemaker = heart.NewPacemaker(pace, evl.HeartbeatCtx) - p.control = make(chan func()) - p.handler = evl.HandleOP - p.events = nil // block forever - - go func() { exit(p.startLoop()) }() -} - -// Stop signals the pacemaker to stop. It does not wait for the pacer to stop. -// The pacer will call the given callback with a nil error. -func (p *PacemakerLoop) Stop() { - close(p.control) -} - -// SetEventChannel sets the event channel inside the event loop. There is no -// guarantee that the channel is set when the function returns. This function is -// concurrently safe. -func (p *PacemakerLoop) SetEventChannel(evCh <-chan Event) { - p.control <- func() { p.events = evCh } -} - -// SetPace (re)sets the pace duration. As with SetEventChannel, there is no -// guarantee that the pacer is reset when the function returns. This function is -// concurrently safe. -func (p *PacemakerLoop) SetPace(pace time.Duration) { - p.control <- func() { p.Pacemaker.SetPace(pace) } -} - -func (p *PacemakerLoop) startLoop() error { - defer WSDebug("Pacemaker loop has exited.") - defer p.Pacemaker.StopTicker() - - for { - select { - case <-p.Pacemaker.Ticks: - if err := p.Pacemaker.Pace(); err != nil { - return errors.Wrap(err, "pace failed, reconnecting") - } - - case fn, ok := <-p.control: - if !ok { // Intentional stop at p.Close(). - WSDebug("Pacemaker intentionally stopped using p.control.") - return nil - } - - fn() - - case ev, ok := <-p.events: - if !ok { - WSDebug("Events channel closed, stopping pacemaker.") - return nil - } - - if ev.Error != nil { - return errors.Wrap(ev.Error, "event returned error") - } - - o, err := DecodeOP(ev) - if err != nil { - return errors.Wrap(err, "failed to decode OP") - } - - // Check the events before handling. - p.Extras.Check(o) - - // Handle the event - if err := p.handler(o); err != nil { - if IsBrokenConnection(err) { - return errors.Wrap(err, "handler failed") - } - - p.errorLog(err) - } - } - } -} diff --git a/utils/wsutil/op.go b/utils/wsutil/op.go deleted file mode 100644 index 117ce13..0000000 --- a/utils/wsutil/op.go +++ /dev/null @@ -1,202 +0,0 @@ -package wsutil - -import ( - "context" - "fmt" - "sync" - - "github.com/pkg/errors" - - "github.com/diamondburned/arikawa/v3/internal/moreatomic" - "github.com/diamondburned/arikawa/v3/utils/json" -) - -var ErrEmptyPayload = errors.New("empty payload") - -// OPCode is a generic type for websocket OP codes. -type OPCode uint8 - -type OP struct { - Code OPCode `json:"op"` - Data json.Raw `json:"d,omitempty"` - - // Only for Gateway Dispatch (op 0) - Sequence int64 `json:"s,omitempty"` - EventName string `json:"t,omitempty"` -} - -func (op *OP) UnmarshalData(v interface{}) error { - return json.Unmarshal(op.Data, v) -} - -func DecodeOP(ev Event) (*OP, error) { - if ev.Error != nil { - return nil, ev.Error - } - - if len(ev.Data) == 0 { - return nil, ErrEmptyPayload - } - - var op *OP - if err := json.Unmarshal(ev.Data, &op); err != nil { - return nil, errors.Wrap(err, "OP error: "+string(ev.Data)) - } - - return op, nil -} - -func AssertEvent(ev Event, code OPCode, v interface{}) (*OP, error) { - op, err := DecodeOP(ev) - if err != nil { - return nil, err - } - - if op.Code != code { - return op, fmt.Errorf( - "unexpected OP Code: %d, expected %d (%s)", - op.Code, code, op.Data, - ) - } - - if err := json.Unmarshal(op.Data, v); err != nil { - return op, errors.Wrap(err, "failed to decode data") - } - - return op, nil -} - -// UnknownEventError is required by HandleOP if an event is encountered that is -// not known. Internally, unknown events are logged and ignored. It is not a -// fatal error. -type UnknownEventError struct { - Name string - Data json.Raw -} - -// Error formats the unknown event error to with the event name and payload -func (err UnknownEventError) Error() string { - return fmt.Sprintf("unknown event %s: %s", err.Name, string(err.Data)) -} - -// IsBrokenConnection returns true if the error is a broken connection error. -func IsUnknownEvent(err error) bool { - var uevent *UnknownEventError - return errors.As(err, &uevent) -} - -type EventHandler interface { - HandleOP(op *OP) error -} - -func HandleEvent(h EventHandler, ev Event) error { - o, err := DecodeOP(ev) - if err != nil { - return err - } - - return h.HandleOP(o) -} - -// WaitForEvent blocks until fn() returns true. All incoming events are handled -// regardless. -func WaitForEvent(ctx context.Context, h EventHandler, ch <-chan Event, fn func(*OP) bool) error { - for { - select { - case e, ok := <-ch: - if !ok { - return errors.New("event not found and event channel is closed") - } - - o, err := DecodeOP(e) - if err != nil { - return err - } - - // Handle the *OP first, in case it's an Invalid Session. This should - // also prevent a race condition with things that need Ready after - // Open(). - if err := h.HandleOP(o); err != nil { - // Explicitly ignore events we don't know. - if IsUnknownEvent(err) { - WSError(err) - continue - } - return err - } - - // Are these events what we're looking for? If we've found the event, - // return. - if fn(o) { - return nil - } - - case <-ctx.Done(): - return ctx.Err() - } - } -} - -type ExtraHandlers struct { - mutex sync.Mutex - handlers map[uint32]*ExtraHandler - serial uint32 -} - -type ExtraHandler struct { - Check func(*OP) bool - send chan *OP - - closed moreatomic.Bool -} - -func (ex *ExtraHandlers) Add(check func(*OP) bool) (<-chan *OP, func()) { - handler := &ExtraHandler{ - Check: check, - send: make(chan *OP), - } - - ex.mutex.Lock() - defer ex.mutex.Unlock() - - if ex.handlers == nil { - ex.handlers = make(map[uint32]*ExtraHandler, 1) - } - - i := ex.serial - ex.serial++ - - ex.handlers[i] = handler - - return handler.send, func() { - // Check the atomic bool before acquiring the mutex. Might help a bit in - // performance. - if handler.closed.Get() { - return - } - - ex.mutex.Lock() - defer ex.mutex.Unlock() - - delete(ex.handlers, i) - } -} - -// Check runs and sends OP data. It is not thread-safe. -func (ex *ExtraHandlers) Check(op *OP) { - ex.mutex.Lock() - defer ex.mutex.Unlock() - - for i, handler := range ex.handlers { - if handler.Check(op) { - // Attempt to send. - handler.send <- op - - // Mark the handler as closed. - handler.closed.Set(true) - - // Delete the handler. - delete(ex.handlers, i) - } - } -} diff --git a/utils/wsutil/ws.go b/utils/wsutil/ws.go deleted file mode 100644 index cfa3467..0000000 --- a/utils/wsutil/ws.go +++ /dev/null @@ -1,200 +0,0 @@ -// Package wsutil provides abstractions around the Websocket, including rate -// limits. -package wsutil - -import ( - "context" - "log" - "sync" - "time" - - "github.com/pkg/errors" - "golang.org/x/time/rate" -) - -var ( - // WSTimeout is the timeout for connecting and writing to the Websocket, - // before Gateway cancels and fails. - WSTimeout = 30 * time.Second - // WSBuffer is the size of the Event channel. This has to be at least 1 to - // make space for the first Event: Ready or Resumed. - WSBuffer = 10 - // WSError is the default error handler - WSError = func(err error) { log.Println("Gateway error:", err) } - // WSDebug is used for extra debug logging. This is expected to behave - // similarly to log.Println(). - WSDebug = func(v ...interface{}) {} -) - -type Event struct { - Data []byte - - // Error is non-nil if Data is nil. - Error error -} - -// Websocket is a wrapper around a websocket Conn with thread safety and rate -// limiting for sending and throttling. -type Websocket struct { - mutex sync.Mutex - conn Connection - addr string - closed bool - - sendLimiter *rate.Limiter - dialLimiter *rate.Limiter - - // Timeout is the default timeout used if a context with no deadline is - // given to Dial. - // - // It must not be changed after the Websocket is used once. - Timeout time.Duration -} - -// New creates a default Websocket with the given address. -func New(addr string) *Websocket { - return NewCustom(NewConn(), addr) -} - -// NewCustom creates a new undialed Websocket. -func NewCustom(conn Connection, addr string) *Websocket { - return &Websocket{ - conn: conn, - addr: addr, - closed: true, - - sendLimiter: NewSendLimiter(), - dialLimiter: NewDialLimiter(), - - Timeout: WSTimeout, - } -} - -// Dial waits until the rate limiter allows then dials the websocket. -// -// If the passed context has no deadline, Dial will wrap it in a -// context.WithTimeout using ws.Timeout as timeout. -func (ws *Websocket) Dial(ctx context.Context) error { - if _, ok := ctx.Deadline(); !ok && ws.Timeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, ws.Timeout) - defer cancel() - } - - if err := ws.dialLimiter.Wait(ctx); err != nil { - // Expired, fatal error - return errors.Wrap(err, "failed to wait") - } - - ws.mutex.Lock() - defer ws.mutex.Unlock() - - if !ws.closed { - WSDebug("Old connection not yet closed while dialog; closing it.") - ws.conn.Close() - } - - if err := ws.conn.Dial(ctx, ws.addr); err != nil { - return errors.Wrap(err, "failed to dial") - } - - ws.closed = false - - // Reset the send limiter. - ws.sendLimiter = NewSendLimiter() - - return nil -} - -// Listen returns the inner event channel or nil if the Websocket connection is -// not alive. -func (ws *Websocket) Listen() <-chan Event { - ws.mutex.Lock() - defer ws.mutex.Unlock() - - if ws.closed { - return nil - } - - return ws.conn.Listen() -} - -// Send sends b over the Websocket without a timeout. -func (ws *Websocket) Send(b []byte) error { - return ws.SendCtx(context.Background(), b) -} - -// SendCtx sends b over the Websocket with a deadline. It closes the internal -// Websocket if the Send method errors out. -func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error { - WSDebug("Waiting for the send rate limiter...") - - if err := ws.sendLimiter.Wait(ctx); err != nil { - WSDebug("Send rate limiter timed out.") - return errors.Wrap(err, "SendLimiter failed") - } - - WSDebug("Send has passed the rate limiting. Waiting on mutex.") - - ws.mutex.Lock() - defer ws.mutex.Unlock() - - WSDebug("Mutex lock acquired.") - - if ws.closed { - return ErrWebsocketClosed - } - - if err := ws.conn.Send(ctx, b); err != nil { - // We need to clean up ourselves if things are erroring out. - WSDebug("Conn: Error while sending; closing the connection. Error:", err) - ws.close(false) - return err - } - - return nil -} - -// Close closes the websocket connection. It assumes that the Websocket is -// closed even when it returns an error. If the Websocket was already closed -// before, ErrWebsocketClosed will be returned. -func (ws *Websocket) Close() error { - WSDebug("Conn: Acquiring mutex lock to close...") - - ws.mutex.Lock() - defer ws.mutex.Unlock() - - WSDebug("Conn: Write mutex acquired") - - return ws.close(false) -} - -func (ws *Websocket) CloseGracefully() error { - WSDebug("Conn: Acquiring mutex lock to close...") - - ws.mutex.Lock() - defer ws.mutex.Unlock() - - WSDebug("Conn: Write mutex acquired") - - return ws.close(true) -} - -// close closes the Websocket without acquiring the mutex. Refer to Close for -// more information. -func (ws *Websocket) close(graceful bool) error { - if ws.closed { - WSDebug("Conn: Websocket is already closed.") - return ErrWebsocketClosed - } - - ws.closed = true - - if graceful { - WSDebug("Conn: Closing gracefully") - return ws.conn.CloseGracefully() - } - - WSDebug("Conn: Closing") - return ws.conn.Close() -} diff --git a/voice/session.go b/voice/session.go index 26c0dcc..6a0e7db 100644 --- a/voice/session.go +++ b/voice/session.go @@ -2,20 +2,22 @@ package voice import ( "context" + "net" "sync" "time" "github.com/diamondburned/arikawa/v3/state" "github.com/diamondburned/arikawa/v3/utils/handler" + "github.com/diamondburned/arikawa/v3/utils/ws/ophandler" "github.com/pkg/errors" "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" - "github.com/diamondburned/arikawa/v3/internal/handleloop" + "github.com/diamondburned/arikawa/v3/internal/lazytime" "github.com/diamondburned/arikawa/v3/internal/moreatomic" "github.com/diamondburned/arikawa/v3/session" - "github.com/diamondburned/arikawa/v3/utils/wsutil" + "github.com/diamondburned/arikawa/v3/utils/ws" "github.com/diamondburned/arikawa/v3/voice/udp" "github.com/diamondburned/arikawa/v3/voice/voicegateway" ) @@ -32,24 +34,65 @@ var ErrCannotSend = errors.New("cannot send audio to closed channel") // WSTimeout is the duration to wait for a gateway operation including Session // to complete before erroring out. This only applies to functions that don't // take in a context already. -var WSTimeout = 10 * time.Second +const WSTimeout = 25 * time.Second + +// ReconnectError is emitted into Session.Handler everytime the voice gateway +// fails to be reconnected. It implements the error interface. +type ReconnectError struct { + Err error +} + +// Error implements error. +func (e ReconnectError) Error() string { + return "voice reconnect error: " + e.Err.Error() +} + +// Unwrap returns e.Err. +func (e ReconnectError) Unwrap() error { return e.Err } + +// MainSession abstracts both session.Session and state.State. +type MainSession interface { + // AddHandler describes the method in handler.Handler. + AddHandler(handler interface{}) (rm func()) + // Gateway returns the session's main Discord gateway. + Gateway() *gateway.Gateway + // Me returns the current user. + Me() (*discord.User, error) + // Channel queries for the channel with the given ID. + Channel(discord.ChannelID) (*discord.Channel, error) +} + +var ( + _ MainSession = (*session.Session)(nil) + _ MainSession = (*state.State)(nil) +) + +// UDPDialer is the UDP dialer function type. It's the function signature for +// udp.DialConnection. +type UDPDialer = func(ctx context.Context, addr string, ssrc uint32) (*udp.Connection, error) // Session is a single voice session that wraps around the voice gateway and UDP // connection. type Session struct { *handler.Handler - ErrorLog func(err error) + session MainSession - session *session.Session - looper *handleloop.Loop - detach func() + mut sync.RWMutex + // connected is a non-nil blocking channel after Join is called and is + // closed once Leave is called. + disconnected chan struct{} - mut sync.RWMutex state voicegateway.State // guarded except UserID - // TODO: expose getters mutex-guarded. - gateway *voicegateway.Gateway + + detachReconnect []func() + voiceUDP *udp.Connection - // end of mutex + gateway *voicegateway.Gateway + gwCancel context.CancelFunc + gwDone <-chan struct{} + + // DialUDP is the custom function for dialing up a UDP connection. + DialUDP UDPDialer WSTimeout time.Duration // global WSTimeout WSMaxRetry int // 2 @@ -59,76 +102,102 @@ type Session struct { // joining determines the behavior of incoming event callbacks (Update). // If this is true, incoming events will just send into Updated channels. If // false, events will trigger a reconnection. - joining moreatomic.Bool - connected bool + joining moreatomic.Bool + // disconnectClosed is true if connected is already closed. It is only used + // to keep track of closing connected. + disconnectClosed bool } // NewSession creates a new voice session for the current user. -func NewSession(state *state.State) (*Session, error) { +func NewSession(state MainSession) (*Session, error) { u, err := state.Me() if err != nil { return nil, errors.Wrap(err, "failed to get me") } - return NewSessionCustom(state.Session, u.ID), nil + return NewSessionCustom(state, u.ID), nil } // NewSessionCustom creates a new voice session from the given session and user // ID. -func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session { - handler := handler.New() - hlooper := handleloop.NewLoop(handler) +func NewSessionCustom(ses MainSession, userID discord.UserID) *Session { + closed := make(chan struct{}) + close(closed) + session := &Session{ - Handler: handler, - looper: hlooper, + Handler: handler.New(), session: ses, state: voicegateway.State{ UserID: userID, }, - ErrorLog: func(err error) {}, + DialUDP: udp.DialConnection, WSTimeout: WSTimeout, WSMaxRetry: 2, WSRetryDelay: 2 * time.Second, WSWaitDuration: 5 * time.Second, + + // Set this pair of value in so we never have to nil-check the channel. + // We can just assume that it's either closed or connected. + disconnected: closed, + disconnectClosed: true, } return session } -// updateServer is specifically used to monitor for reconnects. -func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) { +func (s *Session) acquireUpdate(f func()) bool { if s.joining.Get() { - return + return false } s.mut.Lock() defer s.mut.Unlock() // Ignore if we haven't connected yet or we're still joining. - if !s.connected || s.state.GuildID != ev.GuildID { - return + select { + case <-s.disconnected: + return false + default: + // ok } - // Reconnect. - - s.state.Endpoint = ev.Endpoint - s.state.Token = ev.Token - - ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) - defer cancel() - - if err := s.reconnectCtx(ctx); err != nil { - s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update")) - } + f() + return true } -// JoinChannel joins a voice channel with the default WS timeout. See -// JoinChannelCtx for more information. -func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error { - ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) - defer cancel() +// updateServer is specifically used to monitor for reconnects. +func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) { + s.acquireUpdate(func() { + if s.state.GuildID != ev.GuildID { + return + } - return s.JoinChannelCtx(ctx, gID, cID, mute, deaf) + s.state.Endpoint = ev.Endpoint + s.state.Token = ev.Token + + ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) + defer cancel() + + s.reconnectCtx(ctx) + }) +} + +// updateState is specifically used after connecting to monitor when the bot is +// forced across channels. +func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) { + s.acquireUpdate(func() { + if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID { + return + } + + s.state.ChannelID = ev.ChannelID + s.state.SessionID = ev.SessionID + + ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) + defer cancel() + + s.reconnectCtx(ctx) + }) } type waitEventChs struct { @@ -136,11 +205,17 @@ type waitEventChs struct { stateUpdate chan *gateway.VoiceStateUpdateEvent } -// JoinChannelCtx joins a voice channel. Callers shouldn't use this method -// directly, but rather Voice's. This method shouldn't ever be called -// concurrently. -func (s *Session) JoinChannelCtx( - ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error { +// JoinChannel joins the given voice channel with the default timeout. +func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error { + var ch *discord.Channel + + if chID.IsValid() { + var err error + ch, err = s.session.Channel(chID) + if err != nil { + return errors.Wrap(err, "invalid channel ID") + } + } s.mut.Lock() defer s.mut.Unlock() @@ -155,14 +230,20 @@ func (s *Session) JoinChannelCtx( defer s.joining.Set(false) // Set the state. - s.state.ChannelID = cID - s.state.GuildID = gID - s.detach = s.session.AddHandler(s.updateServer) + if ch != nil { + s.state.ChannelID = ch.ID + s.state.GuildID = ch.GuildID + } else { + s.state.GuildID = 0 + // Ensure that if `cID` is zero that it passes null to the update event. + s.state.ChannelID = discord.NullChannelID + } - // Ensure that if `cID` is zero that it passes null to the update event. - channelID := discord.NullChannelID - if cID.IsValid() { - channelID = cID + if s.detachReconnect == nil { + s.detachReconnect = []func(){ + s.session.AddHandler(s.updateServer), + s.session.AddHandler(s.updateState), + } } chs := waitEventChs{ @@ -185,15 +266,16 @@ func (s *Session) JoinChannelCtx( // Ensure gateway and voiceUDP are already closed. s.ensureClosed() - data := gateway.UpdateVoiceStateData{ - GuildID: gID, - ChannelID: channelID, + // https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information + // Send a Voice State Update event to the gateway. + data := &gateway.UpdateVoiceStateCommand{ + GuildID: s.state.GuildID, + ChannelID: s.state.ChannelID, SelfMute: mute, SelfDeaf: deaf, } var err error - var timer *time.Timer // Retry 3 times maximum. @@ -230,17 +312,17 @@ func (s *Session) JoinChannelCtx( // Mark the session as connected and move on. This allows one of the // connected handlers to reconnect on its own. - s.connected = true + s.disconnected = make(chan struct{}) return s.reconnectCtx(ctx) } func (s *Session) askDiscord( - ctx context.Context, data gateway.UpdateVoiceStateData, chs waitEventChs) error { + ctx context.Context, data *gateway.UpdateVoiceStateCommand, chs waitEventChs) error { // https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information // Send a Voice State Update event to the gateway. - if err := s.session.Gateway.UpdateVoiceStateCtx(ctx, data); err != nil { + if err := s.session.Gateway().Send(ctx, data); err != nil { return errors.Wrap(err, "failed to send Voice State Update event") } @@ -290,118 +372,183 @@ func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error { // reconnect uses the current state to reconnect to a new gateway and UDP // connection. -func (s *Session) reconnectCtx(ctx context.Context) (err error) { - wsutil.WSDebug("Sending stop handle.") - s.looper.Stop() +func (s *Session) reconnectCtx(ctx context.Context) error { + ws.WSDebug("Sending stop handle.") - wsutil.WSDebug("Start gateway.") + s.ensureClosed() + + ws.WSDebug("Start gateway.") s.gateway = voicegateway.New(s.state) // Open the voice gateway. The function will block until Ready is received. - if err := s.gateway.OpenCtx(ctx); err != nil { - return errors.Wrap(err, "failed to open voice gateway") + gwctx, gwcancel := context.WithCancel(context.Background()) + s.gwCancel = gwcancel + + gwch := s.gateway.Connect(gwctx) + + if err := s.spinGateway(ctx, gwch); err != nil { + // Early cancel the gateway. + gwcancel() + // Nil this so future reconnects don't use the invalid gwDone. + s.gwCancel = nil + // Emit the error. It's fine to do this here since this is the only + // place that can error out. + s.Handler.Call(&ReconnectError{err}) + return err } - // Start the handler dispatching - s.looper.Start(s.gateway.Events) - - // Get the Ready event. - voiceReady := s.gateway.Ready() - - // Prepare the UDP voice connection. - s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC) - if err != nil { - return errors.Wrap(err, "failed to open voice UDP connection") - } - - // Get the session description from the voice gateway. - d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{ - Protocol: "udp", - Data: voicegateway.SelectProtocolData{ - Address: s.voiceUDP.GatewayIP, - Port: s.voiceUDP.GatewayPort, - Mode: Protocol, - }, - }) - if err != nil { - return errors.Wrap(err, "failed to select protocol") - } - - s.voiceUDP.UseSecret(d.SecretKey) + // Start dispatching. + s.gwDone = ophandler.Loop(gwch, s.Handler) return nil } +func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error { + var err error + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case ev, ok := <-gwch: + if !ok { + return s.gateway.LastError() + } + + switch data := ev.Data.(type) { + case *ws.CloseEvent: + return errors.Wrap(err, "voice gateway error") + + case *voicegateway.ReadyEvent: + // Prepare the UDP voice connection. + s.voiceUDP, err = s.DialUDP(ctx, data.Addr(), data.SSRC) + if err != nil { + return errors.Wrap(err, "failed to open voice UDP connection") + } + + if err := s.gateway.Send(ctx, &voicegateway.SelectProtocolCommand{ + Protocol: "udp", + Data: voicegateway.SelectProtocolData{ + Address: s.voiceUDP.GatewayIP, + Port: s.voiceUDP.GatewayPort, + Mode: Protocol, + }, + }); err != nil { + return errors.Wrap(err, "failed to send SelectProtocolCommand") + } + + case *voicegateway.SessionDescriptionEvent: + // We're done. + s.voiceUDP.UseSecret(data.SecretKey) + return nil + } + + // Dispatch this event to the handler. + s.Handler.Call(ev.Data) + } + } +} + // Speaking tells Discord we're speaking. This method should not be called // concurrently. -func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error { - s.mut.RLock() - gateway := s.gateway - s.mut.RUnlock() - - return gateway.Speaking(flag) -} - -// UseContext tells the UDP voice connection to write with the given context. -func (s *Session) UseContext(ctx context.Context) error { +func (s *Session) Speaking(ctx context.Context, flag voicegateway.SpeakingFlag) error { s.mut.Lock() - defer s.mut.Unlock() + gateway := s.gateway + s.mut.Unlock() - if s.voiceUDP == nil { - return ErrCannotSend + return gateway.Speaking(ctx, flag) +} + +func (s *Session) useUDP(f func(c *udp.Connection) error) (err error) { + const maxAttempts = 5 + const retryDelay = 250 * time.Millisecond // adds up to about 1.25s + + var lazyWait lazytime.Timer + + // Hack: loop until we no longer get an error closed or until the connection + // is dead. This is a workaround for when the session is trying to reconnect + // itself in the background, which would drop the UDP connection. + for i := 0; i < maxAttempts; i++ { + s.mut.RLock() + voiceUDP := s.voiceUDP + disconnected := s.disconnected + s.mut.RUnlock() + + select { + case <-disconnected: + return net.ErrClosed + default: + if voiceUDP == nil { + // Session is still connected, but our voice UDP connection is + // nil, so we're probably in the process of reconnecting + // already. + goto retry + } + } + + if err = f(voiceUDP); err != nil && errors.Is(err, net.ErrClosed) { + // Session is still connected, but our UDP connection is somehow + // closed, so we're probably waiting for the server to ask us to + // reconnect with a new session. + goto retry + } + + // Unknown error or none at all; exit. + return err + + retry: + // Wait a slight bit. We can probably make the caller wait a couple + // milliseconds without a wait. + lazyWait.Reset(retryDelay) + select { + case <-lazyWait.C: + continue + case <-disconnected: + return net.ErrClosed + } } - return s.voiceUDP.UseContext(ctx) + return } -// VoiceUDPConn gets a voice UDP connection. The caller could use this method to -// circumvent the rapid mutex-read-lock acquire inside Write. -func (s *Session) VoiceUDPConn() *udp.Connection { - s.mut.RLock() - defer s.mut.RUnlock() - - return s.voiceUDP -} - -// Write writes into the UDP voice connection WITHOUT a timeout. Refer to -// WriteCtx for more information. +// Write writes into the UDP voice connection. This method is thread safe as far +// as calling other methods of Session goes; HOWEVER it is not thread safe to +// call Write itself concurrently. func (s *Session) Write(b []byte) (int, error) { - return s.WriteCtx(context.Background(), b) + var n int + err := s.useUDP(func(c *udp.Connection) (err error) { + n, err = c.Write(b) + return + }) + return n, err } -// WriteCtx writes into the UDP voice connection with a context for timeout. -// This method is thread safe as far as calling other methods of Session goes; -// HOWEVER it is not thread safe to call Write itself concurrently. -func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) { - voiceUDP := s.VoiceUDPConn() - - if voiceUDP == nil { - return 0, ErrCannotSend - } - - return voiceUDP.WriteCtx(ctx, b) +// ReadPacket reads a single packet from the UDP connection. This is NOT at all +// thread safe, and must be used very carefully. The backing buffer is always +// reused. +func (s *Session) ReadPacket() (*udp.Packet, error) { + var p *udp.Packet + err := s.useUDP(func(c *udp.Connection) (err error) { + p, err = c.ReadPacket() + return + }) + return p, err } // Leave disconnects the current voice session from the currently connected // channel. -func (s *Session) Leave() error { - ctx, cancel := context.WithTimeout(context.Background(), WSTimeout) - defer cancel() - - return s.LeaveCtx(ctx) -} - -// LeaveCtx disconencts with a context. Refer to Leave for more information. -func (s *Session) LeaveCtx(ctx context.Context) error { +func (s *Session) Leave(ctx context.Context) error { s.mut.Lock() defer s.mut.Unlock() - s.connected = false + s.ensureClosed() // Unbind the handlers. - if s.detach != nil { - s.detach() - s.detach = nil + if s.detachReconnect != nil { + for _, detach := range s.detachReconnect { + detach() + } + s.detachReconnect = nil } // If we're already closed. @@ -409,46 +556,59 @@ func (s *Session) LeaveCtx(ctx context.Context) error { return nil } - s.looper.Stop() - - // Notify Discord that we're leaving. This will send a - // VoiceStateUpdateEvent, in which our handler will promptly remove the - // session from the map. - - err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{ + // Notify Discord that we're leaving. + err := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{ GuildID: s.state.GuildID, ChannelID: discord.ChannelID(discord.NullSnowflake), SelfMute: true, SelfDeaf: true, }) - s.ensureClosed() - // wrap returns nil if err is nil - return errors.Wrap(err, "failed to update voice state") + // Wait for the gateway to exit first before we tell the user of the gateway + // send error. + if err := s.cancelGateway(ctx); err != nil { + return err + } + + if err != nil { + return errors.Wrap(err, "failed to update voice state") + } + + return nil +} + +func (s *Session) cancelGateway(ctx context.Context) error { + if s.gwCancel != nil { + s.gwCancel() + s.gwCancel = nil + + // Wait for the previous gateway to finish closing up, but make sure to + // bail if the context expires. + if err := ophandler.WaitForDone(ctx, s.gwDone); err != nil { + return errors.Wrap(err, "cannot wait for gateway to close") + } + } + + return nil } // close ensures everything is closed. It does not acquire the mutex. func (s *Session) ensureClosed() { - s.looper.Stop() - // Disconnect the UDP connection. if s.voiceUDP != nil { s.voiceUDP.Close() s.voiceUDP = nil } - // Disconnect the voice gateway, ignoring the error. - if s.gateway != nil { - if err := s.gateway.Close(); err != nil { - wsutil.WSDebug("Uncaught voice gateway close error:", err) - } - s.gateway = nil + if !s.disconnectClosed { + close(s.disconnected) + s.disconnectClosed = true + } + + if s.gwCancel != nil { + s.gwCancel() + // Don't actually clear this field, because we still want the caller to + // be able to wait for the gateway to completely exit using + // cancelGateway. } } - -// ReadPacket reads a single packet from the UDP connection. This is NOT at all -// thread safe, and must be used very carefully. The backing buffer is always -// reused. -func (s *Session) ReadPacket() (*udp.Packet, error) { - return s.VoiceUDPConn().ReadPacket() -} diff --git a/voice/session_example_test.go b/voice/session_example_test.go index 207aa82..e96bbd6 100644 --- a/voice/session_example_test.go +++ b/voice/session_example_test.go @@ -34,33 +34,25 @@ func TestNoop(t *testing.T) { } func ExampleSession() { - s, err := state.New("Bot " + token) - if err != nil { - log.Fatalln("failed to make state:", err) - } + s := state.New("Bot " + token) // This is required for bots. - voice.AddIntents(s.Gateway) + voice.AddIntents(s) if err := s.Open(context.TODO()); err != nil { log.Fatalln("failed to open gateway:", err) } defer s.Close() - c, err := s.Channel(channelID) - if err != nil { - log.Fatalln("failed to get channel:", err) - } - v, err := voice.NewSession(s) if err != nil { log.Fatalln("failed to create voice session:", err) } - if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil { + if err := v.JoinChannel(context.TODO(), channelID, false, false); err != nil { log.Fatalln("failed to join voice channel:", err) } - defer v.Leave() + defer v.Leave(context.TODO()) // Start writing Opus frames. for { diff --git a/voice/session_test.go b/voice/session_test.go index a3ddc94..7ea8a64 100644 --- a/voice/session_test.go +++ b/voice/session_test.go @@ -15,7 +15,7 @@ import ( "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/internal/testenv" "github.com/diamondburned/arikawa/v3/state" - "github.com/diamondburned/arikawa/v3/utils/wsutil" + "github.com/diamondburned/arikawa/v3/utils/ws" "github.com/diamondburned/arikawa/v3/voice/voicegateway" "github.com/pkg/errors" ) @@ -23,17 +23,14 @@ import ( func TestIntegration(t *testing.T) { config := testenv.Must(t) - wsutil.WSDebug = func(v ...interface{}) { + ws.WSDebug = func(v ...interface{}) { _, file, line, _ := runtime.Caller(1) caller := file + ":" + strconv.Itoa(line) log.Println(append([]interface{}{caller}, v...)...) } - s, err := state.New("Bot " + config.BotToken) - if err != nil { - t.Fatal("failed to create a new state:", err) - } - AddIntents(s.Gateway) + s := state.New("Bot " + config.BotToken) + AddIntents(s) func() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -70,7 +67,6 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) { if err != nil { t.Fatal("failed to create a new voice session:", err) } - v.ErrorLog = func(err error) { t.Error(err) } // Grab a timer to benchmark things. finish := timer() @@ -80,7 +76,10 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) { finish("receiving voice speaking event") }) - if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + t.Cleanup(cancel) + + if err := v.JoinChannel(ctx, c.ID, false, false); err != nil { t.Fatal("failed to join voice:", err) } @@ -88,28 +87,19 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) { log.Println("Leaving the voice channel concurrently.") raceMe(t, "failed to leave voice channel", func() (interface{}, error) { - return nil, v.Leave() + return nil, v.Leave(ctx) }) }) finish("joining the voice channel") - // Create a context and only cancel it AFTER we're done sending silence - // frames. - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - t.Cleanup(cancel) - // Trigger speaking. - if err := v.Speaking(voicegateway.Microphone); err != nil { + if err := v.Speaking(ctx, voicegateway.Microphone); err != nil { t.Fatal("failed to start speaking:", err) } finish("sending the speaking command") - if err := v.UseContext(ctx); err != nil { - t.Fatal("failed to set ctx into vs:", err) - } - f, err := os.Open("testdata/nico.dca") if err != nil { t.Fatal("failed to open nico.dca:", err) diff --git a/voice/udp/udp.go b/voice/udp/udp.go index da0137a..00f0034 100644 --- a/voice/udp/udp.go +++ b/voice/udp/udp.go @@ -8,6 +8,7 @@ import ( "net" "time" + "github.com/diamondburned/arikawa/v3/internal/moreatomic" "github.com/pkg/errors" "golang.org/x/crypto/nacl/secretbox" ) @@ -39,13 +40,13 @@ type Connection struct { GatewayIP string GatewayPort uint16 - context context.Context - conn net.Conn - ssrc uint32 + conn net.Conn + ssrc uint32 // frequency rate.Limiter frequency *time.Ticker timeIncr uint32 + stopFreq chan struct{} packet [12]byte secret [32]byte @@ -59,9 +60,12 @@ type Connection struct { recvBuf []byte // len 1400 recvOpus []byte // len 1400 recvPacket *Packet // uses recvOpus' backing array + + closed moreatomic.Bool } -func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { +// DialConnection dials a UDP connection. +func DialConnection(ctx context.Context, addr string, ssrc uint32) (*Connection, error) { // Create a new UDP connection. conn, err := Dialer.DialContext(ctx, "udp", addr) if err != nil { @@ -114,7 +118,7 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti GatewayPort: port, frequency: time.NewTicker(20 * time.Millisecond), timeIncr: 960, - context: context.Background(), + stopFreq: make(chan struct{}), packet: packet, ssrc: ssrc, conn: conn, @@ -159,49 +163,29 @@ func (c *Connection) UseSecret(secret [32]byte) { c.secret = secret } -// UseContext lets the connection use the given context for its Write method. -// WriteCtx will override this context. -func (c *Connection) UseContext(ctx context.Context) error { - return c.useContext(ctx) +// SetWriteDeadline sets the UDP connection's write deadline. +func (c *Connection) SetWriteDeadline(deadline time.Time) { + c.conn.SetWriteDeadline(deadline) } -func (c *Connection) useContext(ctx context.Context) error { - if c.context == ctx { - return nil - } - - c.context = ctx - - if deadline, ok := c.context.Deadline(); ok { - return c.conn.SetWriteDeadline(deadline) - } else { - return c.conn.SetWriteDeadline(time.Time{}) - } +// SetReadDeadline sets the UDP connection's read deadline. +func (c *Connection) SetReadDeadline(deadline time.Time) { + c.conn.SetReadDeadline(deadline) } func (c *Connection) Close() error { - c.frequency.Stop() + if c.closed.Acquire() { + // Be sure to only run this ONCE. + c.frequency.Stop() + close(c.stopFreq) + } + return c.conn.Close() } -// Write sends bytes into the voice UDP connection using the preset context. +// Write sends a packet of audio into the voice UDP connection using the preset +// context. func (c *Connection) Write(b []byte) (int, error) { - return c.write(b) -} - -// WriteCtx sends bytes into the voice UDP connection with a timeout using the -// given context. It ignores the context inside the connection, but will restore -// the deadline after this call is done. -func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) { - oldCtx := c.context - - c.useContext(ctx) - defer c.useContext(oldCtx) - - return c.write(b) -} - -func (c *Connection) write(b []byte) (int, error) { // Write a new sequence. binary.BigEndian.PutUint16(c.packet[2:4], c.sequence) c.sequence++ @@ -215,9 +199,9 @@ func (c *Connection) write(b []byte) (int, error) { select { case <-c.frequency.C: - - case <-c.context.Done(): - return 0, c.context.Err() + // ok + case <-c.stopFreq: + return 0, net.ErrClosed } n, err := c.conn.Write(toSend) diff --git a/voice/voice.go b/voice/voice.go index 3a3a60c..4141213 100644 --- a/voice/voice.go +++ b/voice/voice.go @@ -8,7 +8,7 @@ import "github.com/diamondburned/arikawa/v3/gateway" // AddIntents adds the needed voice intents into gw. Bots should always call // this before Open if voice is required. -func AddIntents(gw *gateway.Gateway) { +func AddIntents(gw interface{ AddIntents(gateway.Intents) }) { gw.AddIntents(gateway.IntentGuilds) gw.AddIntents(gateway.IntentGuildVoiceStates) } diff --git a/voice/voicegateway/commands.go b/voice/voicegateway/commands.go deleted file mode 100644 index 98c2781..0000000 --- a/voice/voicegateway/commands.go +++ /dev/null @@ -1,167 +0,0 @@ -package voicegateway - -import ( - "context" - "time" - - "github.com/diamondburned/arikawa/v3/discord" - "github.com/pkg/errors" -) - -var ( - // ErrMissingForIdentify is an error when we are missing information to identify. - ErrMissingForIdentify = errors.New("missing GuildID, UserID, SessionID, or Token for identify") - - // ErrMissingForResume is an error when we are missing information to resume. - ErrMissingForResume = errors.New("missing GuildID, SessionID, or Token for resuming") -) - -// OPCode 0 -// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload -type IdentifyData struct { - GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id" - UserID discord.UserID `json:"user_id"` - SessionID string `json:"session_id"` - Token string `json:"token"` -} - -// Identify sends an Identify operation (opcode 0) to the Gateway Gateway. -func (c *Gateway) Identify() error { - ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) - defer cancel() - - return c.IdentifyCtx(ctx) -} - -// IdentifyCtx sends an Identify operation (opcode 0) to the Gateway Gateway. -func (c *Gateway) IdentifyCtx(ctx context.Context) error { - guildID := c.state.GuildID - userID := c.state.UserID - sessionID := c.state.SessionID - token := c.state.Token - - if guildID == 0 || userID == 0 || sessionID == "" || token == "" { - return ErrMissingForIdentify - } - - return c.SendCtx(ctx, IdentifyOP, IdentifyData{ - GuildID: guildID, - UserID: userID, - SessionID: sessionID, - Token: token, - }) -} - -// OPCode 1 -// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload -type SelectProtocol struct { - Protocol string `json:"protocol"` - Data SelectProtocolData `json:"data"` -} - -type SelectProtocolData struct { - Address string `json:"address"` - Port uint16 `json:"port"` - Mode string `json:"mode"` -} - -// SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway. -func (c *Gateway) SelectProtocol(data SelectProtocol) error { - ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) - defer cancel() - - return c.SelectProtocolCtx(ctx, data) -} - -// SelectProtocolCtx sends a Select Protocol operation (opcode 1) to the Gateway Gateway. -func (c *Gateway) SelectProtocolCtx(ctx context.Context, data SelectProtocol) error { - return c.SendCtx(ctx, SelectProtocolOP, data) -} - -// OPCode 3 -// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload -// type Heartbeat uint64 - -// Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway. -func (c *Gateway) Heartbeat() error { - ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) - defer cancel() - - return c.HeartbeatCtx(ctx) -} - -// HeartbeatCtx sends a Heartbeat operation (opcode 3) to the Gateway Gateway. -func (c *Gateway) HeartbeatCtx(ctx context.Context) error { - return c.SendCtx(ctx, HeartbeatOP, time.Now().UnixNano()) -} - -// https://discord.com/developers/docs/topics/voice-connections#speaking -type SpeakingFlag uint64 - -const ( - NotSpeaking SpeakingFlag = 0 - Microphone SpeakingFlag = 1 << iota - Soundshare - Priority -) - -// OPCode 5 -// https://discord.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload -type SpeakingData struct { - Speaking SpeakingFlag `json:"speaking"` - Delay int `json:"delay"` - SSRC uint32 `json:"ssrc"` - UserID discord.UserID `json:"user_id,omitempty"` -} - -// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway. -func (c *Gateway) Speaking(flag SpeakingFlag) error { - ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) - defer cancel() - - return c.SpeakingCtx(ctx, flag) -} - -// SpeakingCtx sends a Speaking operation (opcode 5) to the Gateway Gateway. -func (c *Gateway) SpeakingCtx(ctx context.Context, flag SpeakingFlag) error { - // How do we allow a user to stop speaking? - // Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation - - return c.SendCtx(ctx, SpeakingOP, SpeakingData{ - Speaking: flag, - Delay: 0, - SSRC: c.ready.SSRC, - }) -} - -// OPCode 7 -// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload -type ResumeData struct { - GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id" - SessionID string `json:"session_id"` - Token string `json:"token"` -} - -// Resume sends a Resume operation (opcode 7) to the Gateway Gateway. -func (c *Gateway) Resume() error { - ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) - defer cancel() - return c.ResumeCtx(ctx) -} - -// ResumeCtx sends a Resume operation (opcode 7) to the Gateway Gateway. -func (c *Gateway) ResumeCtx(ctx context.Context) error { - guildID := c.state.GuildID - sessionID := c.state.SessionID - token := c.state.Token - - if !guildID.IsValid() || sessionID == "" || token == "" { - return ErrMissingForResume - } - - return c.SendCtx(ctx, ResumeOP, ResumeData{ - GuildID: guildID, - SessionID: sessionID, - Token: token, - }) -} diff --git a/voice/voicegateway/event_methods.go b/voice/voicegateway/event_methods.go new file mode 100644 index 0000000..9265c2a --- /dev/null +++ b/voice/voicegateway/event_methods.go @@ -0,0 +1,94 @@ +// Code generated by genevent. DO NOT EDIT. + +package voicegateway + +import "github.com/diamondburned/arikawa/v3/utils/ws" + +func init() { + OpUnmarshalers.Add( + func() ws.Event { return new(IdentifyCommand) }, + func() ws.Event { return new(SelectProtocolCommand) }, + func() ws.Event { return new(ReadyEvent) }, + func() ws.Event { return new(HeartbeatCommand) }, + func() ws.Event { return new(SessionDescriptionEvent) }, + func() ws.Event { return new(SpeakingEvent) }, + func() ws.Event { return new(HeartbeatAckEvent) }, + func() ws.Event { return new(ResumeCommand) }, + func() ws.Event { return new(HelloEvent) }, + func() ws.Event { return new(ResumedEvent) }, + func() ws.Event { return new(ClientConnectEvent) }, + func() ws.Event { return new(ClientDisconnectEvent) }, + ) +} + +// Op implements Event. It always returns Op 0. +func (*IdentifyCommand) Op() ws.OpCode { return 0 } + +// EventType implements Event. +func (*IdentifyCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 1. +func (*SelectProtocolCommand) Op() ws.OpCode { return 1 } + +// EventType implements Event. +func (*SelectProtocolCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 2. +func (*ReadyEvent) Op() ws.OpCode { return 2 } + +// EventType implements Event. +func (*ReadyEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 3. +func (*HeartbeatCommand) Op() ws.OpCode { return 3 } + +// EventType implements Event. +func (*HeartbeatCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 4. +func (*SessionDescriptionEvent) Op() ws.OpCode { return 4 } + +// EventType implements Event. +func (*SessionDescriptionEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 5. +func (*SpeakingEvent) Op() ws.OpCode { return 5 } + +// EventType implements Event. +func (*SpeakingEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 6. +func (*HeartbeatAckEvent) Op() ws.OpCode { return 6 } + +// EventType implements Event. +func (*HeartbeatAckEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 7. +func (*ResumeCommand) Op() ws.OpCode { return 7 } + +// EventType implements Event. +func (*ResumeCommand) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 8. +func (*HelloEvent) Op() ws.OpCode { return 8 } + +// EventType implements Event. +func (*HelloEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 9. +func (*ResumedEvent) Op() ws.OpCode { return 9 } + +// EventType implements Event. +func (*ResumedEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 12. +func (*ClientConnectEvent) Op() ws.OpCode { return 12 } + +// EventType implements Event. +func (*ClientConnectEvent) EventType() ws.EventType { return "" } + +// Op implements Event. It always returns Op 13. +func (*ClientDisconnectEvent) Op() ws.OpCode { return 13 } + +// EventType implements Event. +func (*ClientDisconnectEvent) EventType() ws.EventType { return "" } diff --git a/voice/voicegateway/events.go b/voice/voicegateway/events.go index 726ba1a..eef6377 100644 --- a/voice/voicegateway/events.go +++ b/voice/voicegateway/events.go @@ -4,9 +4,41 @@ import ( "strconv" "github.com/diamondburned/arikawa/v3/discord" + "github.com/diamondburned/arikawa/v3/utils/ws" ) -// OPCode 2 +//go:generate go run ../../utils/cmd/genevent -p voicegateway -o event_methods.go + +// OpUnmarshalers contains the Op unmarshalers for the voice gateway events. +var OpUnmarshalers = ws.NewOpUnmarshalers() + +// IdentifyCommand is a command for Op 0. +// +// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload +type IdentifyCommand struct { + GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id" + UserID discord.UserID `json:"user_id"` + SessionID string `json:"session_id"` + Token string `json:"token"` +} + +// SelectProtocolCommand is a command for Op 1. +// +// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload +type SelectProtocolCommand struct { + Protocol string `json:"protocol"` + Data SelectProtocolData `json:"data"` +} + +// SelectProtocolData is the data inside a SelectProtocolCommand. +type SelectProtocolData struct { + Address string `json:"address"` + Port uint16 `json:"port"` + Mode string `json:"mode"` +} + +// ReadyEvent is an event for Op 2. +// // https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-ready-payload type ReadyEvent struct { SSRC uint32 `json:"ssrc"` @@ -23,45 +55,79 @@ type ReadyEvent struct { // HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"` } +// Addr formats the URL inside Ready to be of format "host:port". func (r ReadyEvent) Addr() string { return r.IP + ":" + strconv.Itoa(r.Port) } -// OPCode 4 +// HeartbeatCommand is a command for Op 3. +// +// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload +type HeartbeatCommand uint64 + +// SessionDescriptionEvent is an event for Op 4. +// // https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-session-description-payload type SessionDescriptionEvent struct { Mode string `json:"mode"` SecretKey [32]byte `json:"secret_key"` } -// OPCode 5 -type SpeakingEvent SpeakingData +// https://discord.com/developers/docs/topics/voice-connections#speaking +type SpeakingFlag uint64 -// OPCode 6 +const ( + NotSpeaking SpeakingFlag = 0 + Microphone SpeakingFlag = 1 << iota + Soundshare + Priority +) + +// SpeakingEvent is an event for Op 5. It is also a command. +// +// https://discord.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload +type SpeakingEvent struct { + Speaking SpeakingFlag `json:"speaking"` + Delay int `json:"delay"` + SSRC uint32 `json:"ssrc"` + UserID discord.UserID `json:"user_id,omitempty"` +} + +// HeartbeatAckEvent is an event for Op 6. +// // https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-ack-payload -type HeartbeatACKEvent uint64 +type HeartbeatAckEvent uint64 -// OPCode 8 +// ResumeCommand is a command for Op 7. +// +// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload +type ResumeCommand struct { + GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id" + SessionID string `json:"session_id"` + Token string `json:"token"` +} + +// HelloEvent is an event for Op 8. +// // https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-hello-payload-since-v3 type HelloEvent struct { HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"` } -// OPCode 9 +// ResumedEvent is an event for Op 9. // https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resumed-payload type ResumedEvent struct{} -// OPCode 12 -// (undocumented) +// ClientConnectEvent is an event for Op 12. It is undocumented. type ClientConnectEvent struct { UserID discord.UserID `json:"user_id"` AudioSSRC uint32 `json:"audio_ssrc"` VideoSSRC uint32 `json:"video_ssrc"` } -// OPCode 13 -// Undocumented, existence mentioned in below issue -// https://github.com/discord/discord-api-docs/issues/510 +// ClientDisconnectEvent is an event for Op 13. It is undocumented, but its +// existence is mentioned in this issue: +// https://github.com/discord/discord-api-docs/issues/510. type ClientDisconnectEvent struct { UserID discord.UserID `json:"user_id"` } diff --git a/voice/voicegateway/gateway.go b/voice/voicegateway/gateway.go index 6050dfb..cba21ee 100644 --- a/voice/voicegateway/gateway.go +++ b/voice/voicegateway/gateway.go @@ -18,9 +18,7 @@ import ( "github.com/pkg/errors" "github.com/diamondburned/arikawa/v3/discord" - "github.com/diamondburned/arikawa/v3/internal/moreatomic" - "github.com/diamondburned/arikawa/v3/utils/json" - "github.com/diamondburned/arikawa/v3/utils/wsutil" + "github.com/diamondburned/arikawa/v3/utils/ws" ) const ( @@ -37,9 +35,9 @@ type Event = interface{} // State contains state information of a voice gateway. type State struct { - GuildID discord.GuildID + UserID discord.UserID // constant + GuildID discord.GuildID // constant ChannelID discord.ChannelID - UserID discord.UserID SessionID string Token string @@ -48,282 +46,164 @@ type State struct { // Gateway represents a Discord Gateway Gateway connection. type Gateway struct { - state State // constant + gateway *ws.Gateway + state State // constant mutex sync.RWMutex - ready ReadyEvent - - WS *wsutil.Websocket - - Timeout time.Duration - reconnect moreatomic.Bool - - EventLoop wsutil.PacemakerLoop - Events chan Event - - // ErrorLog will be called when an error occurs (defaults to log.Println) - ErrorLog func(err error) - // AfterClose is called after each close. Error can be non-nil, as this is - // called even when the Gateway is gracefully closed. It's used mainly for - // reconnections or any type of connection interruptions. (defaults to noop) - AfterClose func(err error) - - // Filled by methods, internal use - waitGroup *sync.WaitGroup + ready *ReadyEvent } +// DefaultGatewayOpts contains the default options to be used for connecting to +// the gateway. +var DefaultGatewayOpts = ws.GatewayOpts{ + ReconnectDelay: func(try int) time.Duration { + // minimum 4 seconds + return time.Duration(4+(2*try)) * time.Second + }, + // FatalCloseCodes contains the default gateway close codes that will cause + // the gateway to exit. In other words, it's a list of unrecoverable close + // codes. + FatalCloseCodes: []int{ + 4003, // not authenticated + 4004, // authentication failed + 4006, // session invalid + 4009, // session timed out + 4011, // server not found + 4012, // unknown protocol + 4014, // disconnected + 4016, // unknown encryption mode + }, + DialTimeout: 0, + ReconnectAttempt: 0, + AlwaysCloseGracefully: true, +} + +// New creates a new voice gateway. func New(state State) *Gateway { // https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection var endpoint = "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version + gw := ws.NewGateway( + ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), endpoint), + &DefaultGatewayOpts, + ) + return &Gateway{ - state: state, - WS: wsutil.New(endpoint), - Timeout: wsutil.WSTimeout, - Events: make(chan Event, wsutil.WSBuffer), - ErrorLog: wsutil.WSError, - AfterClose: func(error) {}, + gateway: gw, + state: state, } } -// TODO: get rid of -func (c *Gateway) Ready() ReadyEvent { - c.mutex.RLock() - defer c.mutex.RUnlock() +// Ready returns the ready event. +func (g *Gateway) Ready() *ReadyEvent { + g.mutex.RLock() + defer g.mutex.RUnlock() - return c.ready + return g.ready } -// OpenCtx shouldn't be used, but JoinServer instead. -func (c *Gateway) OpenCtx(ctx context.Context) error { - if c.state.Endpoint == "" { - return errors.New("missing endpoint in state") - } - - // https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection - var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version - - wsutil.WSDebug("VoiceGateway: Connecting to voice endpoint (endpoint=" + endpoint + ")") - - // Create a new context with a timeout for the connection. - ctx, cancel := context.WithTimeout(ctx, c.Timeout) - defer cancel() - - // Connect to the Gateway Gateway. - if err := c.WS.Dial(ctx); err != nil { - return errors.Wrap(err, "failed to connect to voice gateway") - } - - wsutil.WSDebug("VoiceGateway: Trying to start...") - - // Try to start or resume the connection. - if err := c.start(ctx); err != nil { - return err - } - - return nil +// LastError returns the last error that the gateway has received. It only +// returns a valid error if the gateway's event loop as exited. If the event +// loop hasn't been started AND stopped, the function will panic. +func (g *Gateway) LastError() error { + return g.gateway.LastError() } -// Start . -func (c *Gateway) start(ctx context.Context) error { - if err := c.__start(ctx); err != nil { - wsutil.WSDebug("VoiceGateway: Start failed: ", err) - - // Close can be called with the mutex still acquired here, as the - // pacemaker hasn't started yet. - if err := c.Close(); err != nil { - wsutil.WSDebug("VoiceGateway: Failed to close after start fail: ", err) - } - return err - } - - return nil +// Send is a function to send an Op payload to the Gateway. +func (g *Gateway) Send(ctx context.Context, data ws.Event) error { + return g.gateway.Send(ctx, data) } -// this function blocks until READY. -func (c *Gateway) __start(ctx context.Context) error { - // Make a new WaitGroup for use in background loops: - c.waitGroup = new(sync.WaitGroup) +// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway. +func (g *Gateway) Speaking(ctx context.Context, flag SpeakingFlag) error { + g.mutex.RLock() + ssrc := g.ready.SSRC + g.mutex.RUnlock() - ch := c.WS.Listen() + return g.gateway.Send(ctx, &SpeakingEvent{ + Speaking: flag, + Delay: 0, + SSRC: ssrc, + }) +} - // Wait for hello. - wsutil.WSDebug("VoiceGateway: Waiting for Hello..") +func (g *Gateway) Connect(ctx context.Context) <-chan ws.Op { + return g.gateway.Connect(ctx, (*gatewayImpl)(g)) +} - var hello *HelloEvent - // Wait for the Hello event; return if it times out. - select { - case e, ok := <-ch: - if !ok { - return errors.New("unexpected ws close while waiting for Hello") - } - if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil { - return errors.Wrap(err, "error at Hello") - } - case <-ctx.Done(): - return errors.Wrap(ctx.Err(), "failed to wait for Hello event") +var ( + // ErrMissingForIdentify is an error when we are missing information to + // identify. + ErrMissingForIdentify = errors.New("missing GuildID, UserID, SessionID, or Token for identify") + // ErrMissingForResume is an error when we are missing information to + // resume. + ErrMissingForResume = errors.New("missing GuildID, SessionID, or Token for resuming") +) + +type gatewayImpl Gateway + +func (g *gatewayImpl) sendIdentify(ctx context.Context) error { + id := IdentifyCommand{ + GuildID: g.state.GuildID, + UserID: g.state.UserID, + SessionID: g.state.SessionID, + Token: g.state.Token, + } + if !id.GuildID.IsValid() || id == (IdentifyCommand{}) { + return ErrMissingForIdentify } - wsutil.WSDebug("VoiceGateway: Received Hello") + return g.gateway.Send(ctx, &id) +} - // Start the event handler, which also handles the pacemaker death signal. - c.waitGroup.Add(1) +func (g *gatewayImpl) sendResume(ctx context.Context) error { + if !g.state.GuildID.IsValid() || g.state.SessionID == "" || g.state.Token == "" { + return ErrMissingForResume + } - c.EventLoop.StartBeating(hello.HeartbeatInterval.Duration(), c, func(err error) { - c.waitGroup.Done() // mark so Close() can exit. - wsutil.WSDebug("VoiceGateway: Event loop stopped.") + return g.gateway.Send(ctx, &ResumeCommand{ + GuildID: g.state.GuildID, + SessionID: g.state.SessionID, + Token: g.state.Token, + }) +} - if err != nil { - c.ErrorLog(err) +func (g *gatewayImpl) OnOp(ctx context.Context, op ws.Op) bool { + switch data := op.Data.(type) { + case *HelloEvent: + g.gateway.ResetHeartbeat(data.HeartbeatInterval.Duration()) - if err := c.Reconnect(); err != nil { - c.ErrorLog(errors.Wrap(err, "failed to reconnect voice")) + // Send Discord either the Identify packet (if it's a fresh + // connection), or a Resume packet (if it's a dead connection). + if g.ready == nil { + // SessionID is empty, so this is a completely new session. + if err := g.sendIdentify(ctx); err != nil { + g.gateway.SendErrorWrap(err, "failed to send identify") + g.gateway.QueueReconnect() + } + } else { + if err := g.sendResume(ctx); err != nil { + g.gateway.SendErrorWrap(err, "failed to send resume") + g.gateway.QueueReconnect() } - - // Reconnect should spawn another eventLoop in its Start function. } - }) - - // https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection - // Turns out Hello is sent right away on connection start. - if !c.reconnect.Get() { - if err := c.IdentifyCtx(ctx); err != nil { - return errors.Wrap(err, "failed to identify") - } - } else { - if err := c.ResumeCtx(ctx); err != nil { - return errors.Wrap(err, "failed to resume") - } - } - // This bool is because we should only try and Resume once. - c.reconnect.Set(false) - - // Wait for either Ready or Resumed. - err := wsutil.WaitForEvent(ctx, c, ch, func(op *wsutil.OP) bool { - return op.Code == ReadyOP || op.Code == ResumedOP - }) - if err != nil { - return errors.Wrap(err, "failed to wait for Ready or Resumed") + case *ReadyEvent: + g.mutex.Lock() + g.ready = data + g.mutex.Unlock() } - // Bind the event channel away. - c.EventLoop.SetEventChannel(ch) + return true +} - wsutil.WSDebug("VoiceGateway: Started successfully.") +func (g *gatewayImpl) SendHeartbeat(ctx context.Context) { + heartbeat := HeartbeatCommand(time.Now().UnixNano()) + if err := g.gateway.Send(ctx, &heartbeat); err != nil { + g.gateway.SendErrorWrap(err, "heartbeat error") + g.gateway.QueueReconnect() + } +} +func (g *gatewayImpl) Close() error { return nil } - -// Close closes the underlying Websocket connection. -func (g *Gateway) Close() error { - wsutil.WSDebug("VoiceGateway: Trying to close. Pacemaker check skipped.") - - wsutil.WSDebug("VoiceGateway: Closing the Websocket...") - err := g.WS.Close() - - if errors.Is(err, wsutil.ErrWebsocketClosed) { - wsutil.WSDebug("VoiceGateway: Websocket already closed.") - return nil - } - - wsutil.WSDebug("VoiceGateway: Websocket closed; error:", err) - - wsutil.WSDebug("VoiceGateway: Waiting for the Pacemaker loop to exit.") - g.waitGroup.Wait() - wsutil.WSDebug("VoiceGateway: Pacemaker loop exited.") - - g.AfterClose(err) - wsutil.WSDebug("VoiceGateway: AfterClose callback finished.") - - return err -} - -func (c *Gateway) Reconnect() error { - return c.ReconnectCtx(context.Background()) -} - -func (c *Gateway) ReconnectCtx(ctx context.Context) error { - wsutil.WSDebug("VoiceGateway: Reconnecting...") - - // TODO: implement a reconnect loop - - // Guarantee the gateway is already closed. Ignore its error, as we're - // redialing anyway. - c.Close() - - c.reconnect.Set(true) - - // Condition: err == ErrInvalidSession: - // If the connection is rate limited (documented behavior): - // https://discord.com/developers/docs/topics/gateway#rate-limiting - - if err := c.OpenCtx(ctx); err != nil { - return errors.Wrap(err, "failed to reopen gateway") - } - - wsutil.WSDebug("VoiceGateway: Reconnected successfully.") - - return nil -} - -func (c *Gateway) SessionDescriptionCtx( - ctx context.Context, sp SelectProtocol) (*SessionDescriptionEvent, error) { - - // Add the handler first. - ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool { - return op.Code == SessionDescriptionOP - }) - defer cancel() - - if err := c.SelectProtocolCtx(ctx, sp); err != nil { - return nil, err - } - - var sesdesc *SessionDescriptionEvent - - // Wait for SessionDescriptionOP packet. - select { - case e, ok := <-ch: - if !ok { - return nil, errors.New("unexpected close waiting for session description") - } - if err := e.UnmarshalData(&sesdesc); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal session description") - } - case <-ctx.Done(): - return nil, errors.Wrap(ctx.Err(), "failed to wait for session description") - } - - return sesdesc, nil -} - -// Send sends a payload to the Gateway with the default timeout. -func (c *Gateway) Send(code OPCode, v interface{}) error { - ctx, cancel := context.WithTimeout(context.Background(), c.Timeout) - defer cancel() - - return c.SendCtx(ctx, code, v) -} - -func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error { - var op = wsutil.OP{ - Code: code, - } - - if v != nil { - b, err := json.Marshal(v) - if err != nil { - return errors.Wrap(err, "failed to encode v") - } - - op.Data = b - } - - b, err := json.Marshal(op) - if err != nil { - return errors.Wrap(err, "failed to encode payload") - } - - // WS should already be thread-safe. - return c.WS.SendCtx(ctx, b) -} diff --git a/voice/voicegateway/op.go b/voice/voicegateway/op.go deleted file mode 100644 index 45ffd97..0000000 --- a/voice/voicegateway/op.go +++ /dev/null @@ -1,99 +0,0 @@ -package voicegateway - -import ( - "fmt" - "sync" - - "github.com/diamondburned/arikawa/v3/utils/json" - "github.com/diamondburned/arikawa/v3/utils/wsutil" - "github.com/pkg/errors" -) - -// OPCode represents a Discord Gateway Gateway operation code. -type OPCode = wsutil.OPCode - -const ( - IdentifyOP OPCode = 0 // send - SelectProtocolOP OPCode = 1 // send - ReadyOP OPCode = 2 // receive - HeartbeatOP OPCode = 3 // send - SessionDescriptionOP OPCode = 4 // receive - SpeakingOP OPCode = 5 // send/receive - HeartbeatAckOP OPCode = 6 // receive - ResumeOP OPCode = 7 // send - HelloOP OPCode = 8 // receive - ResumedOP OPCode = 9 // receive - ClientConnectOP OPCode = 12 // receive - ClientDisconnectOP OPCode = 13 // receive -) - -func (c *Gateway) HandleOP(op *wsutil.OP) error { - wsutil.WSDebug("Handle OP", op.Code) - switch op.Code { - // Gives information required to make a UDP connection - case ReadyOP: - if err := unmarshalMutex(op.Data, &c.ready, &c.mutex); err != nil { - return errors.Wrap(err, "failed to parse READY event") - } - - c.Events <- &c.ready - - // Gives information about the encryption mode and secret key for sending voice packets - case SessionDescriptionOP: - // ? - // Already handled by Session. - - // Someone started or stopped speaking. - case SpeakingOP: - ev := new(SpeakingEvent) - - if err := json.Unmarshal(op.Data, ev); err != nil { - return errors.Wrap(err, "failed to parse Speaking event") - } - - c.Events <- ev - - // Heartbeat response from the server - case HeartbeatAckOP: - c.EventLoop.Echo() - - // Hello server, we hear you! :) - case HelloOP: - // ? - // Already handled on initial connection. - - // Server is saying the connection was resumed, no data here. - case ResumedOP: - wsutil.WSDebug("Gateway connection has been resumed.") - - case ClientConnectOP: - ev := new(ClientConnectEvent) - - if err := json.Unmarshal(op.Data, ev); err != nil { - return errors.Wrap(err, "failed to parse Speaking event") - } - - c.Events <- ev - - case ClientDisconnectOP: - ev := new(ClientDisconnectEvent) - - if err := json.Unmarshal(op.Data, ev); err != nil { - return errors.Wrap(err, "failed to parse Speaking event") - } - - c.Events <- ev - - default: - return fmt.Errorf("unknown OP code %d", op.Code) - } - - return nil -} - -func unmarshalMutex(d []byte, v interface{}, m *sync.RWMutex) error { - m.Lock() - err := json.Unmarshal(d, v) - m.Unlock() - return err -}