mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-10-31 20:14:21 +00:00
gateway: Refactor for a better concurrent API
This commit refactors the whole package gateway as well as utils/ws (formerly utils/wsutil) and voice/voicegateway. The new refactor utilizes a design pattern involving a concurrent loop and an arriving event channel. An additional change was made to the way gateway events are typed. Before, pretty much any type will satisfy a gateway event type, since the actual type was just interface{}. The new refactor defines a concrete interface that events can implement: type Event interface { Op() OpCode EventType() EventType } Using this interface, the user can easily add custom gateway events independently of the library without relying on string maps. This adds a lot of type safety into the library and makes type-switching on Event types much more reasonable. Gateway error callbacks are also almost entirely removed in favor of custom gateway events. A catch-all can easily be added like this: s.AddHandler(func(err error) { log.Println("gateway error:, err") })
This commit is contained in:
parent
5c88317130
commit
17b9c73ce3
12
.build.yml
12
.build.yml
|
@ -12,15 +12,25 @@ environment:
|
|||
GO111MODULE: "on"
|
||||
CGO_ENABLED: "1"
|
||||
# Integration test variables.
|
||||
SHARD_COUNT: "3"
|
||||
SHARD_COUNT: "2"
|
||||
tested: "./api,./gateway,./bot,./discord"
|
||||
cov_file: "/tmp/cov_results"
|
||||
dismock: "github.com/mavolin/dismock/v2/pkg/dismock"
|
||||
dismock_v: "259685b84e4b6ab364b0fd858aac2aa2dfa42502"
|
||||
|
||||
tasks:
|
||||
- generate: |-
|
||||
cd arikawa
|
||||
go generate ./...
|
||||
|
||||
if [[ "$(git status --porcelain)" ]]; then
|
||||
echo "Repository differ after regeneration."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- build: cd arikawa && go build ./...
|
||||
- unit: cd arikawa && go test -tags unitonly -race ./...
|
||||
|
||||
- integration: |-
|
||||
sh -c '
|
||||
test -f ~/.env || {
|
||||
|
|
|
@ -23,11 +23,7 @@ func main() {
|
|||
log.Fatalln("No $BOT_TOKEN given.")
|
||||
}
|
||||
|
||||
s, err := state.New("Bot " + token)
|
||||
if err != nil {
|
||||
log.Fatalln("Session failed:", err)
|
||||
return
|
||||
}
|
||||
s := state.New("Bot " + token)
|
||||
|
||||
app, err := s.CurrentApplication()
|
||||
if err != nil {
|
||||
|
|
|
@ -22,11 +22,7 @@ func main() {
|
|||
log.Fatalln("no $BOT_TOKEN given")
|
||||
}
|
||||
|
||||
s, err := session.New("Bot " + token)
|
||||
if err != nil {
|
||||
log.Fatalln("session failed:", err)
|
||||
return
|
||||
}
|
||||
s := session.New("Bot " + token)
|
||||
|
||||
app, err := s.CurrentApplication()
|
||||
if err != nil {
|
||||
|
|
|
@ -22,11 +22,7 @@ func main() {
|
|||
log.Fatalln("No $BOT_TOKEN given.")
|
||||
}
|
||||
|
||||
s, err := state.New("Bot " + token)
|
||||
if err != nil {
|
||||
log.Fatalln("Session failed:", err)
|
||||
return
|
||||
}
|
||||
s := state.New("Bot " + token)
|
||||
|
||||
app, err := s.CurrentApplication()
|
||||
if err != nil {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/session/shard"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
)
|
||||
|
||||
|
|
|
@ -19,11 +19,7 @@ func main() {
|
|||
log.Fatalln("No $BOT_TOKEN given.")
|
||||
}
|
||||
|
||||
s, err := session.New("Bot " + token)
|
||||
if err != nil {
|
||||
log.Fatalln("Session failed:", err)
|
||||
}
|
||||
|
||||
s := session.New("Bot " + token)
|
||||
s.AddHandler(func(c *gateway.MessageCreateEvent) {
|
||||
log.Println(c.Author.Username, "sent", c.Content)
|
||||
})
|
||||
|
|
|
@ -19,11 +19,7 @@ func main() {
|
|||
log.Fatalln("No $BOT_TOKEN given.")
|
||||
}
|
||||
|
||||
s, err := state.New("Bot " + token)
|
||||
if err != nil {
|
||||
log.Fatalln("Session failed:", err)
|
||||
}
|
||||
|
||||
s := state.New("Bot " + token)
|
||||
// Make a pre-handler
|
||||
s.PreHandler = handler.New()
|
||||
s.PreHandler.AddSyncHandler(func(c *gateway.MessageDeleteEvent) {
|
||||
|
|
40
README.md
40
README.md
|
@ -47,6 +47,46 @@ This example demonstrates the PreHandler feature of the state library.
|
|||
PreHandler calls all handlers that are registered (separately from the session),
|
||||
calling them before the state is updated.
|
||||
|
||||
### Bare Minimum Print Example
|
||||
|
||||
The least amount of code recommended to have a bot that logs all messages to
|
||||
console.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := state.New("Bot " + os.Getenv("DISCORD_TOKEN"))
|
||||
s.AddIntents(gateway.IntentGuilds | gateway.IntentGuildMessages)
|
||||
s.AddHandler(func(m *gateway.MessageCreateEvent) {
|
||||
log.Printf("%s: %s", m.Author.Username, m.Content)
|
||||
})
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Open(ctx); err != nil {
|
||||
log.Println("cannot open:", err)
|
||||
}
|
||||
|
||||
<-ctx.Done() // block until Ctrl+C
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
log.Println("cannot close:", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Bare Minimum Bot
|
||||
|
||||
The least amount of code for a basic ping-pong bot. It's similar to Serenity's
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/utils/httputil"
|
||||
)
|
||||
|
@ -30,7 +32,8 @@ func (c *Client) BotURL() (*BotData, error) {
|
|||
}
|
||||
|
||||
// GatewayURL asks Discord for a Websocket URL to the Gateway.
|
||||
func GatewayURL() (string, error) {
|
||||
func GatewayURL(ctx context.Context) (string, error) {
|
||||
var g BotData
|
||||
return g.URL, httputil.NewClient().RequestJSON(&g, "GET", EndpointGateway)
|
||||
err := httputil.NewClient().WithContext(ctx).RequestJSON(&g, "GET", EndpointGateway)
|
||||
return g.URL, err
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ func DurationSinceEpoch(t time.Time) time.Duration {
|
|||
return time.Duration(t.UnixNano()) - Epoch
|
||||
}
|
||||
|
||||
//go:generate go run ../utils/gensnowflake -o snowflake_types.go AppID AttachmentID AuditLogEntryID ChannelID CommandID EmojiID GuildID IntegrationID InteractionID MessageID RoleID StageID StickerID StickerPackID TeamID UserID WebhookID
|
||||
//go:generate go run ../utils/cmd/gensnowflake -o snowflake_types.go AppID AttachmentID AuditLogEntryID ChannelID CommandID EmojiID GuildID IntegrationID InteractionID MessageID RoleID StageID StickerID StickerPackID TeamID UserID WebhookID
|
||||
|
||||
// Mention generates the mention syntax for this channel ID.
|
||||
func (s ChannelID) Mention() string { return "<#" + s.String() + ">" }
|
||||
|
|
|
@ -1,167 +0,0 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
)
|
||||
|
||||
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
|
||||
|
||||
// Identify structure is at identify.go
|
||||
|
||||
// Identify sends off the Identify command with the Gateway's IdentifyData.
|
||||
func (g *Gateway) Identify() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.IdentifyCtx(ctx)
|
||||
}
|
||||
|
||||
// IdentifyCtx sends off the Identify command with the Gateway's IdentifyData
|
||||
// with the given context for time out.
|
||||
func (g *Gateway) IdentifyCtx(ctx context.Context) error {
|
||||
if err := g.Identifier.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "can't wait for identify()")
|
||||
}
|
||||
|
||||
return g.SendCtx(ctx, IdentifyOP, g.Identifier)
|
||||
}
|
||||
|
||||
type ResumeData struct {
|
||||
Token string `json:"token"`
|
||||
SessionID string `json:"session_id"`
|
||||
Sequence int64 `json:"seq"`
|
||||
}
|
||||
|
||||
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
|
||||
// from a dead connection. Start() resumes from a dead connection.
|
||||
func (g *Gateway) Resume() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.ResumeCtx(ctx)
|
||||
}
|
||||
|
||||
// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume
|
||||
// from a dead connection. Start() resumes from a dead connection.
|
||||
func (g *Gateway) ResumeCtx(ctx context.Context) error {
|
||||
var (
|
||||
ses = g.SessionID()
|
||||
seq = g.Sequence.Get()
|
||||
)
|
||||
|
||||
if ses == "" || seq == 0 {
|
||||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
return g.SendCtx(ctx, ResumeOP, ResumeData{
|
||||
Token: g.Identifier.Token,
|
||||
SessionID: ses,
|
||||
Sequence: seq,
|
||||
})
|
||||
}
|
||||
|
||||
// HeartbeatData is the last sequence number to be sent.
|
||||
type HeartbeatData int
|
||||
|
||||
func (g *Gateway) Heartbeat() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.HeartbeatCtx(ctx)
|
||||
}
|
||||
|
||||
func (g *Gateway) HeartbeatCtx(ctx context.Context) error {
|
||||
return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get())
|
||||
}
|
||||
|
||||
type RequestGuildMembersData struct {
|
||||
// GuildIDs contains the ids of the guilds to request data from. Multiple
|
||||
// guilds can only be requested when using user accounts.
|
||||
GuildIDs []discord.GuildID `json:"guild_id"`
|
||||
UserIDs []discord.UserID `json:"user_ids,omitempty"`
|
||||
|
||||
Query string `json:"query"`
|
||||
Limit uint `json:"limit"`
|
||||
Presences bool `json:"presences,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.RequestGuildMembersCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) RequestGuildMembersCtx(
|
||||
ctx context.Context, data RequestGuildMembersData) error {
|
||||
|
||||
return g.SendCtx(ctx, RequestGuildMembersOP, data)
|
||||
}
|
||||
|
||||
type UpdateVoiceStateData struct {
|
||||
GuildID discord.GuildID `json:"guild_id"`
|
||||
ChannelID discord.ChannelID `json:"channel_id"` // nullable
|
||||
SelfMute bool `json:"self_mute"`
|
||||
SelfDeaf bool `json:"self_deaf"`
|
||||
}
|
||||
|
||||
func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.UpdateVoiceStateCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) UpdateVoiceStateCtx(ctx context.Context, data UpdateVoiceStateData) error {
|
||||
return g.SendCtx(ctx, VoiceStateUpdateOP, data)
|
||||
}
|
||||
|
||||
// UpdateStatusData is sent by this client to indicate a presence or status
|
||||
// update.
|
||||
type UpdateStatusData struct {
|
||||
Since discord.UnixMsTimestamp `json:"since"` // 0 if not idle
|
||||
|
||||
// Activities can be null or an empty slice.
|
||||
Activities []discord.Activity `json:"activities"`
|
||||
|
||||
Status discord.Status `json:"status"`
|
||||
AFK bool `json:"afk"`
|
||||
}
|
||||
|
||||
func (g *Gateway) UpdateStatus(data UpdateStatusData) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.UpdateStatusCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error {
|
||||
return g.SendCtx(ctx, StatusUpdateOP, data)
|
||||
}
|
||||
|
||||
// Undocumented
|
||||
type GuildSubscribeData struct {
|
||||
Typing bool `json:"typing"`
|
||||
Threads bool `json:"threads"`
|
||||
Activities bool `json:"activities"`
|
||||
GuildID discord.GuildID `json:"guild_id"`
|
||||
|
||||
// Channels is not documented. It's used to fetch the right members sidebar.
|
||||
Channels map[discord.ChannelID][][2]int `json:"channels,omitempty"`
|
||||
}
|
||||
|
||||
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.GuildSubscribeCtx(ctx, data)
|
||||
}
|
||||
|
||||
func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error {
|
||||
return g.SendCtx(ctx, GuildSubscriptionsOP, data)
|
||||
}
|
460
gateway/event_methods.go
Normal file
460
gateway/event_methods.go
Normal file
|
@ -0,0 +1,460 @@
|
|||
// Code generated by genevent. DO NOT EDIT.
|
||||
|
||||
package gateway
|
||||
|
||||
import "github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
|
||||
func init() {
|
||||
OpUnmarshalers.Add(
|
||||
func() ws.Event { return new(HeartbeatCommand) },
|
||||
func() ws.Event { return new(HeartbeatAckEvent) },
|
||||
func() ws.Event { return new(ReconnectEvent) },
|
||||
func() ws.Event { return new(HelloEvent) },
|
||||
func() ws.Event { return new(ResumeCommand) },
|
||||
func() ws.Event { return new(InvalidSessionEvent) },
|
||||
func() ws.Event { return new(RequestGuildMembersCommand) },
|
||||
func() ws.Event { return new(UpdateVoiceStateCommand) },
|
||||
func() ws.Event { return new(UpdatePresenceCommand) },
|
||||
func() ws.Event { return new(GuildSubscribeCommand) },
|
||||
func() ws.Event { return new(ResumedEvent) },
|
||||
func() ws.Event { return new(ChannelCreateEvent) },
|
||||
func() ws.Event { return new(ChannelUpdateEvent) },
|
||||
func() ws.Event { return new(ChannelDeleteEvent) },
|
||||
func() ws.Event { return new(ChannelPinsUpdateEvent) },
|
||||
func() ws.Event { return new(ChannelUnreadUpdateEvent) },
|
||||
func() ws.Event { return new(ThreadCreateEvent) },
|
||||
func() ws.Event { return new(ThreadUpdateEvent) },
|
||||
func() ws.Event { return new(ThreadDeleteEvent) },
|
||||
func() ws.Event { return new(ThreadListSyncEvent) },
|
||||
func() ws.Event { return new(ThreadMemberUpdateEvent) },
|
||||
func() ws.Event { return new(ThreadMembersUpdateEvent) },
|
||||
func() ws.Event { return new(GuildCreateEvent) },
|
||||
func() ws.Event { return new(GuildUpdateEvent) },
|
||||
func() ws.Event { return new(GuildDeleteEvent) },
|
||||
func() ws.Event { return new(GuildBanAddEvent) },
|
||||
func() ws.Event { return new(GuildBanRemoveEvent) },
|
||||
func() ws.Event { return new(GuildEmojisUpdateEvent) },
|
||||
func() ws.Event { return new(GuildIntegrationsUpdateEvent) },
|
||||
func() ws.Event { return new(GuildMemberAddEvent) },
|
||||
func() ws.Event { return new(GuildMemberRemoveEvent) },
|
||||
func() ws.Event { return new(GuildMemberUpdateEvent) },
|
||||
func() ws.Event { return new(GuildMembersChunkEvent) },
|
||||
func() ws.Event { return new(GuildRoleCreateEvent) },
|
||||
func() ws.Event { return new(GuildRoleUpdateEvent) },
|
||||
func() ws.Event { return new(GuildRoleDeleteEvent) },
|
||||
func() ws.Event { return new(InviteCreateEvent) },
|
||||
func() ws.Event { return new(InviteDeleteEvent) },
|
||||
func() ws.Event { return new(MessageCreateEvent) },
|
||||
func() ws.Event { return new(MessageUpdateEvent) },
|
||||
func() ws.Event { return new(MessageDeleteEvent) },
|
||||
func() ws.Event { return new(MessageDeleteBulkEvent) },
|
||||
func() ws.Event { return new(MessageReactionAddEvent) },
|
||||
func() ws.Event { return new(MessageReactionRemoveEvent) },
|
||||
func() ws.Event { return new(MessageReactionRemoveAllEvent) },
|
||||
func() ws.Event { return new(MessageReactionRemoveEmojiEvent) },
|
||||
func() ws.Event { return new(MessageAckEvent) },
|
||||
func() ws.Event { return new(PresenceUpdateEvent) },
|
||||
func() ws.Event { return new(PresencesReplaceEvent) },
|
||||
func() ws.Event { return new(SessionsReplaceEvent) },
|
||||
func() ws.Event { return new(TypingStartEvent) },
|
||||
func() ws.Event { return new(UserUpdateEvent) },
|
||||
func() ws.Event { return new(VoiceStateUpdateEvent) },
|
||||
func() ws.Event { return new(VoiceServerUpdateEvent) },
|
||||
func() ws.Event { return new(WebhooksUpdateEvent) },
|
||||
func() ws.Event { return new(InteractionCreateEvent) },
|
||||
func() ws.Event { return new(UserGuildSettingsUpdateEvent) },
|
||||
func() ws.Event { return new(UserSettingsUpdateEvent) },
|
||||
func() ws.Event { return new(UserNoteUpdateEvent) },
|
||||
func() ws.Event { return new(RelationshipAddEvent) },
|
||||
func() ws.Event { return new(RelationshipRemoveEvent) },
|
||||
func() ws.Event { return new(ReadyEvent) },
|
||||
func() ws.Event { return new(ReadySupplementalEvent) },
|
||||
func() ws.Event { return new(IdentifyCommand) },
|
||||
)
|
||||
}
|
||||
|
||||
// Op implements Event. It always returns Op 1.
|
||||
func (*HeartbeatCommand) Op() ws.OpCode { return 1 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*HeartbeatCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 11.
|
||||
func (*HeartbeatAckEvent) Op() ws.OpCode { return 11 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*HeartbeatAckEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 7.
|
||||
func (*ReconnectEvent) Op() ws.OpCode { return 7 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ReconnectEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 10.
|
||||
func (*HelloEvent) Op() ws.OpCode { return 10 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*HelloEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 6.
|
||||
func (*ResumeCommand) Op() ws.OpCode { return 6 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ResumeCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 9.
|
||||
func (*InvalidSessionEvent) Op() ws.OpCode { return 9 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*InvalidSessionEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 8.
|
||||
func (*RequestGuildMembersCommand) Op() ws.OpCode { return 8 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*RequestGuildMembersCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 4.
|
||||
func (*UpdateVoiceStateCommand) Op() ws.OpCode { return 4 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*UpdateVoiceStateCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 3.
|
||||
func (*UpdatePresenceCommand) Op() ws.OpCode { return 3 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*UpdatePresenceCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 14.
|
||||
func (*GuildSubscribeCommand) Op() ws.OpCode { return 14 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildSubscribeCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ResumedEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ResumedEvent) EventType() ws.EventType { return "RESUMED" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ChannelCreateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ChannelCreateEvent) EventType() ws.EventType { return "CHANNEL_CREATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ChannelUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ChannelUpdateEvent) EventType() ws.EventType { return "CHANNEL_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ChannelDeleteEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ChannelDeleteEvent) EventType() ws.EventType { return "CHANNEL_DELETE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ChannelPinsUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ChannelPinsUpdateEvent) EventType() ws.EventType { return "CHANNEL_PINS_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ChannelUnreadUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ChannelUnreadUpdateEvent) EventType() ws.EventType { return "CHANNEL_UNREAD_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ThreadCreateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ThreadCreateEvent) EventType() ws.EventType { return "THREAD_CREATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ThreadUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ThreadUpdateEvent) EventType() ws.EventType { return "THREAD_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ThreadDeleteEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ThreadDeleteEvent) EventType() ws.EventType { return "THREAD_DELETE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ThreadListSyncEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ThreadListSyncEvent) EventType() ws.EventType { return "THREAD_LIST_SYNC" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ThreadMemberUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ThreadMemberUpdateEvent) EventType() ws.EventType { return "THREAD_MEMBER_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ThreadMembersUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ThreadMembersUpdateEvent) EventType() ws.EventType { return "THREAD_MEMBERS_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildCreateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildCreateEvent) EventType() ws.EventType { return "GUILD_CREATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildUpdateEvent) EventType() ws.EventType { return "GUILD_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildDeleteEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildDeleteEvent) EventType() ws.EventType { return "GUILD_DELETE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildBanAddEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildBanAddEvent) EventType() ws.EventType { return "GUILD_BAN_ADD" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildBanRemoveEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildBanRemoveEvent) EventType() ws.EventType { return "GUILD_BAN_REMOVE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildEmojisUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildEmojisUpdateEvent) EventType() ws.EventType { return "GUILD_EMOJIS_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildIntegrationsUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildIntegrationsUpdateEvent) EventType() ws.EventType { return "GUILD_INTEGRATIONS_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildMemberAddEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildMemberAddEvent) EventType() ws.EventType { return "GUILD_MEMBER_ADD" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildMemberRemoveEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildMemberRemoveEvent) EventType() ws.EventType { return "GUILD_MEMBER_REMOVE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildMemberUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildMemberUpdateEvent) EventType() ws.EventType { return "GUILD_MEMBER_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildMembersChunkEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildMembersChunkEvent) EventType() ws.EventType { return "GUILD_MEMBERS_CHUNK" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildRoleCreateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildRoleCreateEvent) EventType() ws.EventType { return "GUILD_ROLE_CREATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildRoleUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildRoleUpdateEvent) EventType() ws.EventType { return "GUILD_ROLE_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*GuildRoleDeleteEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*GuildRoleDeleteEvent) EventType() ws.EventType { return "GUILD_ROLE_DELETE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*InviteCreateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*InviteCreateEvent) EventType() ws.EventType { return "INVITE_CREATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*InviteDeleteEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*InviteDeleteEvent) EventType() ws.EventType { return "INVITE_DELETE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageCreateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageCreateEvent) EventType() ws.EventType { return "MESSAGE_CREATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageUpdateEvent) EventType() ws.EventType { return "MESSAGE_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageDeleteEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageDeleteEvent) EventType() ws.EventType { return "MESSAGE_DELETE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageDeleteBulkEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageDeleteBulkEvent) EventType() ws.EventType { return "MESSAGE_DELETE_BULK" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageReactionAddEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageReactionAddEvent) EventType() ws.EventType { return "MESSAGE_REACTION_ADD" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageReactionRemoveEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageReactionRemoveEvent) EventType() ws.EventType { return "MESSAGE_REACTION_REMOVE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageReactionRemoveAllEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageReactionRemoveAllEvent) EventType() ws.EventType { return "MESSAGE_REACTION_REMOVE_ALL" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageReactionRemoveEmojiEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageReactionRemoveEmojiEvent) EventType() ws.EventType {
|
||||
return "MESSAGE_REACTION_REMOVE_EMOJI"
|
||||
}
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*MessageAckEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*MessageAckEvent) EventType() ws.EventType { return "MESSAGE_ACK" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*PresenceUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*PresenceUpdateEvent) EventType() ws.EventType { return "PRESENCE_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*PresencesReplaceEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*PresencesReplaceEvent) EventType() ws.EventType { return "PRESENCES_REPLACE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*SessionsReplaceEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*SessionsReplaceEvent) EventType() ws.EventType { return "SESSIONS_REPLACE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*TypingStartEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*TypingStartEvent) EventType() ws.EventType { return "TYPING_START" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*UserUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*UserUpdateEvent) EventType() ws.EventType { return "USER_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*VoiceStateUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*VoiceStateUpdateEvent) EventType() ws.EventType { return "VOICE_STATE_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*VoiceServerUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*VoiceServerUpdateEvent) EventType() ws.EventType { return "VOICE_SERVER_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*WebhooksUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*WebhooksUpdateEvent) EventType() ws.EventType { return "WEBHOOKS_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*InteractionCreateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*InteractionCreateEvent) EventType() ws.EventType { return "INTERACTION_CREATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*UserGuildSettingsUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*UserGuildSettingsUpdateEvent) EventType() ws.EventType { return "USER_GUILD_SETTINGS_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*UserSettingsUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*UserSettingsUpdateEvent) EventType() ws.EventType { return "USER_SETTINGS_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*UserNoteUpdateEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*UserNoteUpdateEvent) EventType() ws.EventType { return "USER_NOTE_UPDATE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*RelationshipAddEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*RelationshipAddEvent) EventType() ws.EventType { return "RELATIONSHIP_ADD" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*RelationshipRemoveEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*RelationshipRemoveEvent) EventType() ws.EventType { return "RELATIONSHIP_REMOVE" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ReadyEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ReadyEvent) EventType() ws.EventType { return "READY" }
|
||||
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*ReadySupplementalEvent) Op() ws.OpCode { return dispatchOp }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ReadySupplementalEvent) EventType() ws.EventType { return "READY_SUPPLEMENTAL" }
|
||||
|
||||
// Op implements Event. It always returns Op 2.
|
||||
func (*IdentifyCommand) Op() ws.OpCode { return 2 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*IdentifyCommand) EventType() ws.EventType { return "" }
|
1150
gateway/events.go
1150
gateway/events.go
File diff suppressed because it is too large
Load diff
|
@ -1,76 +0,0 @@
|
|||
package gateway
|
||||
|
||||
// Event is any event struct. They have an "Event" suffixed to them.
|
||||
type Event = interface{}
|
||||
|
||||
// EventCreator maps an event type string to a constructor.
|
||||
var EventCreator = map[string]func() Event{
|
||||
"HELLO": func() Event { return new(HelloEvent) },
|
||||
"READY": func() Event { return new(ReadyEvent) },
|
||||
"READY_SUPPLEMENTAL": func() Event { return new(ReadySupplementalEvent) },
|
||||
"RESUMED": func() Event { return new(ResumedEvent) },
|
||||
"INVALID_SESSION": func() Event { return new(InvalidSessionEvent) },
|
||||
|
||||
"CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) },
|
||||
"CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) },
|
||||
"CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) },
|
||||
"CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) },
|
||||
"CHANNEL_UNREAD_UPDATE": func() Event { return new(ChannelUnreadUpdateEvent) },
|
||||
|
||||
"GUILD_CREATE": func() Event { return new(GuildCreateEvent) },
|
||||
"GUILD_UPDATE": func() Event { return new(GuildUpdateEvent) },
|
||||
"GUILD_DELETE": func() Event { return new(GuildDeleteEvent) },
|
||||
|
||||
"GUILD_BAN_ADD": func() Event { return new(GuildBanAddEvent) },
|
||||
"GUILD_BAN_REMOVE": func() Event { return new(GuildBanRemoveEvent) },
|
||||
|
||||
"GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) },
|
||||
"GUILD_INTEGRATIONS_UPDATE": func() Event { return new(GuildIntegrationsUpdateEvent) },
|
||||
|
||||
"GUILD_MEMBER_ADD": func() Event { return new(GuildMemberAddEvent) },
|
||||
"GUILD_MEMBER_REMOVE": func() Event { return new(GuildMemberRemoveEvent) },
|
||||
"GUILD_MEMBER_UPDATE": func() Event { return new(GuildMemberUpdateEvent) },
|
||||
"GUILD_MEMBERS_CHUNK": func() Event { return new(GuildMembersChunkEvent) },
|
||||
|
||||
"GUILD_MEMBER_LIST_UPDATE": func() Event { return new(GuildMemberListUpdate) },
|
||||
|
||||
"GUILD_ROLE_CREATE": func() Event { return new(GuildRoleCreateEvent) },
|
||||
"GUILD_ROLE_UPDATE": func() Event { return new(GuildRoleUpdateEvent) },
|
||||
"GUILD_ROLE_DELETE": func() Event { return new(GuildRoleDeleteEvent) },
|
||||
|
||||
"INVITE_CREATE": func() Event { return new(InviteCreateEvent) },
|
||||
"INVITE_DELETE": func() Event { return new(InviteDeleteEvent) },
|
||||
|
||||
"MESSAGE_CREATE": func() Event { return new(MessageCreateEvent) },
|
||||
"MESSAGE_UPDATE": func() Event { return new(MessageUpdateEvent) },
|
||||
"MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) },
|
||||
"MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) },
|
||||
|
||||
"MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) },
|
||||
"MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) },
|
||||
"MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) },
|
||||
"MESSAGE_REACTION_REMOVE_EMOJI": func() Event { return new(MessageReactionRemoveEmojiEvent) },
|
||||
|
||||
"MESSAGE_ACK": func() Event { return new(MessageAckEvent) },
|
||||
|
||||
"PRESENCE_UPDATE": func() Event { return new(PresenceUpdateEvent) },
|
||||
"PRESENCES_REPLACE": func() Event { return new(PresencesReplaceEvent) },
|
||||
"SESSIONS_REPLACE": func() Event { return new(SessionsReplaceEvent) },
|
||||
|
||||
"TYPING_START": func() Event { return new(TypingStartEvent) },
|
||||
|
||||
"VOICE_STATE_UPDATE": func() Event { return new(VoiceStateUpdateEvent) },
|
||||
"VOICE_SERVER_UPDATE": func() Event { return new(VoiceServerUpdateEvent) },
|
||||
|
||||
"WEBHOOKS_UPDATE": func() Event { return new(WebhooksUpdateEvent) },
|
||||
|
||||
"INTERACTION_CREATE": func() Event { return new(InteractionCreateEvent) },
|
||||
|
||||
"USER_UPDATE": func() Event { return new(UserUpdateEvent) },
|
||||
"USER_SETTINGS_UPDATE": func() Event { return new(UserSettingsUpdateEvent) },
|
||||
"USER_GUILD_SETTINGS_UPDATE": func() Event { return new(UserGuildSettingsUpdateEvent) },
|
||||
"USER_NOTE_UPDATE": func() Event { return new(UserNoteUpdateEvent) },
|
||||
|
||||
"RELATIONSHIP_ADD": func() Event { return new(RelationshipAddEvent) },
|
||||
"RELATIONSHIP_REMOVE": func() Event { return new(RelationshipRemoveEvent) },
|
||||
}
|
|
@ -9,19 +9,15 @@ package gateway
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json/option"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/v3/internal/lazytime"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -29,26 +25,24 @@ var (
|
|||
Encoding = "json"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
|
||||
ErrWSMaxTries = errors.New(
|
||||
"could not connect to the Discord gateway before reaching the timeout")
|
||||
ErrClosed = errors.New("the gateway is closed and cannot reconnect")
|
||||
)
|
||||
// CodeInvalidSequence is the code returned by Discord to signal that the given
|
||||
// sequence number is invalid.
|
||||
const CodeInvalidSequence = 4007
|
||||
|
||||
// see
|
||||
// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes
|
||||
const errCodeShardingRequired = 4011
|
||||
// CodeShardingRequired is the code returned by Discord to signal that the bot
|
||||
// must reshard before proceeding. For more information, see
|
||||
// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes.
|
||||
const CodeShardingRequired = 4011
|
||||
|
||||
// URL asks Discord for a Websocket URL to the Gateway.
|
||||
func URL() (string, error) {
|
||||
return api.GatewayURL()
|
||||
func URL(ctx context.Context) (string, error) {
|
||||
return api.GatewayURL(ctx)
|
||||
}
|
||||
|
||||
// BotURL fetches the Gateway URL along with extra metadata. The token
|
||||
// passed in will NOT be prefixed with Bot.
|
||||
func BotURL(token string) (*api.BotData, error) {
|
||||
return api.NewClient(token).BotURL()
|
||||
func BotURL(ctx context.Context, token string) (*api.BotData, error) {
|
||||
return api.NewClient(token).WithContext(ctx).BotURL()
|
||||
}
|
||||
|
||||
// AddGatewayParams appends into the given URL string the gateway URL
|
||||
|
@ -62,469 +56,371 @@ func AddGatewayParams(baseURL string) string {
|
|||
return baseURL + "?" + param.Encode()
|
||||
}
|
||||
|
||||
type Gateway struct {
|
||||
WS *wsutil.Websocket
|
||||
|
||||
// WSTimeout is a timeout for an arbitrary action. An example of this is the
|
||||
// timeout for Start and the timeout for sending each Gateway command
|
||||
// independently.
|
||||
WSTimeout time.Duration
|
||||
|
||||
// ReconnectAttempts are the amount of attempts made to Reconnect, before
|
||||
// aborting. If this set to 0, unlimited attempts will be made.
|
||||
ReconnectAttempts uint
|
||||
|
||||
// All events sent over are pointers to Event structs (structs suffixed with
|
||||
// "Event"). This shouldn't be accessed if the Gateway is created with a
|
||||
// Session.
|
||||
Events chan Event
|
||||
|
||||
sessionMu sync.RWMutex
|
||||
sessionID string
|
||||
|
||||
Identifier *Identifier
|
||||
Sequence *moreatomic.Int64
|
||||
|
||||
PacerLoop wsutil.PacemakerLoop
|
||||
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
// FatalErrorCallback is called, if the Gateway exits fatally. At the point
|
||||
// of calling, the gateway will be already closed.
|
||||
//
|
||||
// Currently this will only be called, if the ReconnectTimeout was changed
|
||||
// to a definite timeout, and connection could not be established during
|
||||
// that time.
|
||||
// err will be ErrWSMaxTries in that case.
|
||||
//
|
||||
// Defaults to noop.
|
||||
FatalErrorCallback func(err error)
|
||||
|
||||
// AfterClose is called after each close or pause. It is used mainly for
|
||||
// reconnections or any type of connection interruptions.
|
||||
//
|
||||
// Constructors will use a no-op function by default.
|
||||
AfterClose func(err error)
|
||||
|
||||
onShardingRequired func()
|
||||
|
||||
waitGroup sync.WaitGroup
|
||||
closed chan struct{}
|
||||
// State contains the gateway state. It is a piece of data that can be shared
|
||||
// across gateways during construction to be used for resuming a connection or
|
||||
// starting a new one with the previous data.
|
||||
//
|
||||
// The data structure itself is not thread-safe, so they may only be pulled from
|
||||
// the gateway after it's done and set before it's done.
|
||||
type State struct {
|
||||
Identifier Identifier
|
||||
SessionID string
|
||||
Sequence int64
|
||||
}
|
||||
|
||||
// NewGatewayWithIntents creates a new Gateway with the given intents and the
|
||||
// default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
|
||||
func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
|
||||
g, err := NewGateway(token)
|
||||
// Gateway describes an instance that handles the Discord gateway. It is
|
||||
// basically an abstracted concurrent event loop that the user could signal to
|
||||
// start connecting to the Discord gateway server.
|
||||
type Gateway struct {
|
||||
gateway *ws.Gateway
|
||||
state State
|
||||
|
||||
// non-mutex-guarded states
|
||||
// TODO: make lastBeat part of ws.Gateway so it can keep track of whether or
|
||||
// not the websocket is dead.
|
||||
beatMutex sync.Mutex
|
||||
sentBeat time.Time
|
||||
echoBeat time.Time
|
||||
retryTimer lazytime.Timer
|
||||
}
|
||||
|
||||
// NewWithIntents creates a new Gateway with the given intents and the default
|
||||
// stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
|
||||
func NewWithIntents(ctx context.Context, token string, intents ...Intents) (*Gateway, error) {
|
||||
var allIntents Intents
|
||||
for _, intent := range intents {
|
||||
allIntents |= intent
|
||||
}
|
||||
|
||||
g, err := New(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, intent := range intents {
|
||||
g.AddIntents(intent)
|
||||
}
|
||||
|
||||
g.AddIntents(allIntents)
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// NewGateway creates a new Gateway to the default Discord server.
|
||||
func NewGateway(token string) (*Gateway, error) {
|
||||
return NewIdentifiedGateway(DefaultIdentifier(token))
|
||||
// New creates a new Gateway to the default Discord server.
|
||||
func New(ctx context.Context, token string) (*Gateway, error) {
|
||||
return NewWithIdentifier(ctx, DefaultIdentifier(token))
|
||||
}
|
||||
|
||||
// NewIdentifiedGateway creates a new Gateway with the given gateway identifier
|
||||
// and the default everything. Sharded bots should prefer this function for the
|
||||
// shared identifier.
|
||||
func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
|
||||
var gatewayURL string
|
||||
var botData *api.BotData
|
||||
var err error
|
||||
|
||||
if strings.HasPrefix(id.Token, "Bot ") {
|
||||
botData, err = BotURL(id.Token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get bot data")
|
||||
}
|
||||
gatewayURL = botData.URL
|
||||
|
||||
} else {
|
||||
gatewayURL, err = URL()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get gateway endpoint")
|
||||
}
|
||||
// NewWithIdentifier creates a new Gateway with the given gateway identifier and
|
||||
// the default everything. Sharded bots should prefer this function for the
|
||||
// shared identifier. The given Identifier will be modified.
|
||||
func NewWithIdentifier(ctx context.Context, id Identifier) (*Gateway, error) {
|
||||
gatewayURL, err := id.QueryGateway(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gatewayURL = AddGatewayParams(gatewayURL)
|
||||
gateway := NewCustomIdentifiedGateway(gatewayURL, id)
|
||||
|
||||
// Use the supplied connect rate limit, if any.
|
||||
if botData != nil && botData.StartLimit != nil {
|
||||
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
|
||||
limiter := gateway.Identifier.IdentifyGlobalLimit
|
||||
|
||||
// Update the burst to be the current given time and reset it back to
|
||||
// the default when the given time is reached.
|
||||
limiter.SetBurst(botData.StartLimit.Remaining)
|
||||
limiter.SetBurstAt(resetAt, botData.StartLimit.Total)
|
||||
|
||||
// Update the maximum number of identify requests allowed per 5s.
|
||||
gateway.Identifier.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
|
||||
}
|
||||
gateway := NewCustomWithIdentifier(gatewayURL, id, nil)
|
||||
|
||||
return gateway, nil
|
||||
}
|
||||
|
||||
// NewCustomGateway creates a new Gateway with a custom gateway URL and a new
|
||||
// NewCustom creates a new Gateway with a custom gateway URL and a new
|
||||
// Identifier. Most bots connecting to the official server should not use these
|
||||
// custom functions.
|
||||
func NewCustomGateway(gatewayURL, token string) *Gateway {
|
||||
return NewCustomIdentifiedGateway(gatewayURL, DefaultIdentifier(token))
|
||||
func NewCustom(gatewayURL, token string) *Gateway {
|
||||
return NewCustomWithIdentifier(gatewayURL, DefaultIdentifier(token), nil)
|
||||
}
|
||||
|
||||
// NewCustomIdentifiedGateway creates a new Gateway with a custom gateway URL
|
||||
// and a pre-existing Identifier. Refer to NewCustomGateway.
|
||||
func NewCustomIdentifiedGateway(gatewayURL string, id *Identifier) *Gateway {
|
||||
// DefaultGatewayOpts contains the default options to be used for connecting to
|
||||
// the gateway.
|
||||
var DefaultGatewayOpts = ws.GatewayOpts{
|
||||
ReconnectDelay: func(try int) time.Duration {
|
||||
// minimum 4 seconds
|
||||
return time.Duration(4+(2*try)) * time.Second
|
||||
},
|
||||
// FatalCloseCodes contains the default gateway close codes that will cause
|
||||
// the gateway to exit. In other words, it's a list of unrecoverable close
|
||||
// codes.
|
||||
FatalCloseCodes: []int{
|
||||
4003, // not authenticated
|
||||
4004, // authentication failed
|
||||
4010, // invalid shard sent
|
||||
4011, // sharding required
|
||||
4012, // invalid API version
|
||||
4013, // invalid intents
|
||||
4014, // disallowed intents
|
||||
},
|
||||
DialTimeout: 0,
|
||||
ReconnectAttempt: 0,
|
||||
AlwaysCloseGracefully: true,
|
||||
}
|
||||
|
||||
// NewCustomWithIdentifier creates a new Gateway with a custom gateway URL and a
|
||||
// pre-existing Identifier. If opts is nil, then DefaultGatewayOpts is used.
|
||||
func NewCustomWithIdentifier(gatewayURL string, id Identifier, opts *ws.GatewayOpts) *Gateway {
|
||||
return NewFromState(gatewayURL, State{Identifier: id}, opts)
|
||||
}
|
||||
|
||||
// NewFromState creates a new gateway from the given state and optionally
|
||||
// gateway options. If opts is nil, then DefaultGatewayOpts is used.
|
||||
func NewFromState(gatewayURL string, state State, opts *ws.GatewayOpts) *Gateway {
|
||||
if opts == nil {
|
||||
opts = &DefaultGatewayOpts
|
||||
}
|
||||
|
||||
gw := ws.NewGateway(ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), gatewayURL), opts)
|
||||
|
||||
return &Gateway{
|
||||
WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
|
||||
WSTimeout: wsutil.WSTimeout,
|
||||
|
||||
Events: make(chan Event, wsutil.WSBuffer),
|
||||
Identifier: id,
|
||||
Sequence: moreatomic.NewInt64(0),
|
||||
|
||||
ErrorLog: wsutil.WSError,
|
||||
AfterClose: func(error) {},
|
||||
PacerLoop: wsutil.PacemakerLoop{ErrorLog: wsutil.WSError},
|
||||
gateway: gw,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
// AddIntents adds a Gateway Intent before connecting to the Gateway. As such,
|
||||
// this function will only work before Open() is called.
|
||||
// State returns a copy of the gateway's internal state. It panics if the
|
||||
// gateway is currently running.
|
||||
func (g *Gateway) State() State {
|
||||
g.gateway.AssertIsNotRunning()
|
||||
return g.state
|
||||
}
|
||||
|
||||
// SetState sets the gateway's state.
|
||||
func (g *Gateway) SetState(state State) {
|
||||
g.gateway.AssertIsNotRunning()
|
||||
g.state = state
|
||||
}
|
||||
|
||||
// AddIntents adds a Gateway Intent before connecting to the Gateway. This
|
||||
// function will only work before Connect() is called. Calling it once Connect()
|
||||
// is called will result in a panic.
|
||||
func (g *Gateway) AddIntents(i Intents) {
|
||||
if g.Identifier.Intents == nil {
|
||||
g.Identifier.Intents = option.NewUint(uint(i))
|
||||
} else {
|
||||
*g.Identifier.Intents |= uint(i)
|
||||
}
|
||||
g.gateway.AssertIsNotRunning()
|
||||
g.state.Identifier.AddIntents(i)
|
||||
}
|
||||
|
||||
// HasIntents reports if the Gateway has the passed Intents.
|
||||
// SentBeat returns the last time that the heart was beaten. If the gateway has
|
||||
// never connected, then a zero-value time is returned.
|
||||
func (g *Gateway) SentBeat() time.Time {
|
||||
g.beatMutex.Lock()
|
||||
defer g.beatMutex.Unlock()
|
||||
|
||||
return g.sentBeat
|
||||
}
|
||||
|
||||
// EchoBeat returns the last time that the heartbeat was acknowledged. It is
|
||||
// similar to SentBeat.
|
||||
func (g *Gateway) EchoBeat() time.Time {
|
||||
g.beatMutex.Lock()
|
||||
defer g.beatMutex.Unlock()
|
||||
|
||||
return g.echoBeat
|
||||
}
|
||||
|
||||
// Latency is a convenient function around SentBeat and EchoBeat. It subtracts
|
||||
// the EchoBeat with the SentBeat.
|
||||
func (g *Gateway) Latency() time.Duration {
|
||||
g.beatMutex.Lock()
|
||||
defer g.beatMutex.Unlock()
|
||||
|
||||
return g.echoBeat.Sub(g.sentBeat)
|
||||
}
|
||||
|
||||
// LastError returns the last error that the gateway has received. It only
|
||||
// returns a valid error if the gateway's event loop as exited. If the event
|
||||
// loop hasn't been started AND stopped, the function will panic.
|
||||
func (g *Gateway) LastError() error {
|
||||
return g.gateway.LastError()
|
||||
}
|
||||
|
||||
// Send is a function to send an Op payload to the Gateway.
|
||||
func (g *Gateway) Send(ctx context.Context, data ws.Event) error {
|
||||
return g.gateway.Send(ctx, data)
|
||||
}
|
||||
|
||||
// Connect starts the background goroutine that tries its best to maintain a
|
||||
// stable connection to the Discord gateway. To the user, the gateway should
|
||||
// appear to be working seamlessly.
|
||||
//
|
||||
// If no intents are set, i.e. if using a user account HasIntents will always
|
||||
// return true.
|
||||
func (g *Gateway) HasIntents(intents Intents) bool {
|
||||
if g.Identifier.Intents == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return Intents(*g.Identifier.Intents).Has(intents)
|
||||
}
|
||||
|
||||
// Close closes the underlying Websocket connection, invalidating the session
|
||||
// ID.
|
||||
// Behaviors
|
||||
//
|
||||
// It will send a closing frame before ending the connection, closing it
|
||||
// gracefully. This will cause the bot to appear as offline instantly.
|
||||
func (g *Gateway) Close() error {
|
||||
return g.close(true)
|
||||
}
|
||||
|
||||
// Pause pauses the Gateway connection, by ending the connection without
|
||||
// sending a closing frame. This allows the connection to be resumed at a later
|
||||
// point, by calling Reconnect or ReconnectCtx.
|
||||
func (g *Gateway) Pause() error {
|
||||
return g.close(false)
|
||||
}
|
||||
|
||||
func (g *Gateway) close(graceful bool) (err error) {
|
||||
wsutil.WSDebug("Trying to close. Pacemaker check skipped.")
|
||||
wsutil.WSDebug("Closing the Websocket...")
|
||||
|
||||
if graceful {
|
||||
err = g.WS.CloseGracefully()
|
||||
} else {
|
||||
err = g.WS.Close()
|
||||
}
|
||||
|
||||
if errors.Is(err, wsutil.ErrWebsocketClosed) {
|
||||
wsutil.WSDebug("Websocket already closed.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Explicitly signal the pacemaker loop to stop. We should do this in case
|
||||
// the Start function exited before it could bind the event channel into the
|
||||
// loop.
|
||||
g.PacerLoop.Stop()
|
||||
wsutil.WSDebug("Websocket closed; error:", err)
|
||||
|
||||
wsutil.WSDebug("Waiting for the Pacemaker loop to exit.")
|
||||
g.waitGroup.Wait()
|
||||
wsutil.WSDebug("Pacemaker loop exited.")
|
||||
|
||||
g.AfterClose(err)
|
||||
wsutil.WSDebug("AfterClose callback finished.")
|
||||
|
||||
if graceful {
|
||||
// If a Reconnect is in progress, signal to cancel.
|
||||
close(g.closed)
|
||||
|
||||
// Delete our session id, as we just invalidated it.
|
||||
g.sessionMu.Lock()
|
||||
g.sessionID = ""
|
||||
g.sessionMu.Unlock()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SessionID returns the session ID received after Ready. This function is
|
||||
// concurrently safe.
|
||||
func (g *Gateway) SessionID() string {
|
||||
g.sessionMu.RLock()
|
||||
defer g.sessionMu.RUnlock()
|
||||
|
||||
return g.sessionID
|
||||
}
|
||||
|
||||
// UseSessionID overrides the internal session ID for the one the user provides.
|
||||
func (g *Gateway) UseSessionID(sessionID string) {
|
||||
g.sessionMu.Lock()
|
||||
defer g.sessionMu.Unlock()
|
||||
|
||||
g.sessionID = sessionID
|
||||
}
|
||||
|
||||
// OnShardingRequired sets the function to be called if Discord closes with
|
||||
// error code 4011 aka Sharding Required. When called, the Gateway will already
|
||||
// be closed, and can (after increasing the number of shards) be reopened using
|
||||
// Open. Reconnect or ReconnectCtx, however, will not be available as the
|
||||
// session is invalidated.
|
||||
// There are several behaviors that the gateway will overload onto the channel.
|
||||
//
|
||||
// The gateway will completely halt what it's doing in the background when this
|
||||
// callback is called.
|
||||
func (g *Gateway) OnShardingRequired(fn func()) {
|
||||
g.onShardingRequired = fn
|
||||
// Once the gateway has exited, fatally or not, the event channel returned by
|
||||
// Connect will be closed. The user should therefore know whether or not the
|
||||
// gateway has exited by spinning on the channel until it is closed.
|
||||
//
|
||||
// If Connect is called twice, the second call will return the same exact
|
||||
// channel that the first call has made without starting any new goroutines,
|
||||
// except if the gateway is already closed, then a new gateway will spin up with
|
||||
// the existing gateway state.
|
||||
//
|
||||
// If the gateway stumbles upon any background errors, it will do its best to
|
||||
// recover from it, but errors will be notified to the user using the
|
||||
// BackgroundErrorEvent event. The user can type-assert the Op's data field,
|
||||
// like so:
|
||||
//
|
||||
// switch data := ev.Data.(type) {
|
||||
// case *gateway.BackgroundErrorEvent:
|
||||
// log.Println("gateway error:", data.Error)
|
||||
// }
|
||||
//
|
||||
// Closing
|
||||
//
|
||||
// As outlined in the first paragraph, closing the gateway would involve
|
||||
// cancelling the context that's given to gateway. If AlwaysCloseGracefully is
|
||||
// true (which it is by default), then the gateway is closed gracefully, and the
|
||||
// session ID is invalidated.
|
||||
//
|
||||
// To wait until the gateway has completely successfully exited, the user can
|
||||
// keep spinning on the event loop:
|
||||
//
|
||||
// for op := range ch {
|
||||
// select op.Data.(type) {
|
||||
// case *gateway.ReadyEvent:
|
||||
// // Close the gateway on READY.
|
||||
// cancel()
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Gateway is now completely closed.
|
||||
//
|
||||
// To capture the final close errors, the user can use the Error method once the
|
||||
// event channel is closed, like so:
|
||||
//
|
||||
// var err error
|
||||
//
|
||||
// for op := range ch {
|
||||
// switch data := op.Data.(type) {
|
||||
// case *gateway.ReadyEvent:
|
||||
// cancel()
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Gateway is now completely closed.
|
||||
// if gateway.LastError() != nil {
|
||||
// return gateway.LastError()
|
||||
// }
|
||||
//
|
||||
func (g *Gateway) Connect(ctx context.Context) <-chan ws.Op {
|
||||
return g.gateway.Connect(ctx, &gatewayImpl{Gateway: g})
|
||||
}
|
||||
|
||||
// Reconnect tries to reconnect to the Gateway until the ReconnectAttempts are
|
||||
// reached.
|
||||
func (g *Gateway) Reconnect() {
|
||||
g.ReconnectCtx(context.Background())
|
||||
type gatewayImpl struct {
|
||||
*Gateway
|
||||
lastSentBeat time.Time
|
||||
}
|
||||
|
||||
// ReconnectCtx attempts to Reconnect until context expires.
|
||||
// If the context expires FatalErrorCallback will be called with ErrWSMaxTries,
|
||||
// and the last error returned by Open will be returned.
|
||||
func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
|
||||
wsutil.WSDebug("Reconnecting...")
|
||||
func (g *gatewayImpl) invalidate() {
|
||||
g.state.SessionID = ""
|
||||
g.state.Sequence = 0
|
||||
}
|
||||
|
||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||
// redialing anyway.
|
||||
g.Pause()
|
||||
// sendIdentify sends off the Identify command with the Gateway's IdentifyData
|
||||
// with the given context for timeout.
|
||||
func (g *gatewayImpl) sendIdentify(ctx context.Context) error {
|
||||
if err := g.state.Identifier.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "can't wait for identify()")
|
||||
}
|
||||
|
||||
for try := uint(1); g.ReconnectAttempts == 0 || g.ReconnectAttempts >= try; try++ {
|
||||
select {
|
||||
case <-g.closed:
|
||||
g.ErrorLog(ErrClosed)
|
||||
return ErrClosed
|
||||
case <-ctx.Done():
|
||||
wsutil.WSDebug("Unable to Reconnect after", try, "attempts, aborting")
|
||||
g.FatalErrorCallback(ErrWSMaxTries)
|
||||
return err
|
||||
default:
|
||||
return g.gateway.Send(ctx, &g.state.Identifier.IdentifyCommand)
|
||||
}
|
||||
|
||||
func (g *gatewayImpl) sendResume(ctx context.Context) error {
|
||||
return g.gateway.Send(ctx, &ResumeCommand{
|
||||
Token: g.state.Identifier.Token,
|
||||
SessionID: g.state.SessionID,
|
||||
Sequence: g.state.Sequence,
|
||||
})
|
||||
}
|
||||
|
||||
func (g *gatewayImpl) OnOp(ctx context.Context, op ws.Op) bool {
|
||||
if op.Code == dispatchOp {
|
||||
g.state.Sequence = op.Sequence
|
||||
}
|
||||
|
||||
switch data := op.Data.(type) {
|
||||
case *ws.CloseEvent:
|
||||
if data.Code == CodeInvalidSequence {
|
||||
// Invalid sequence.
|
||||
g.invalidate()
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Trying to dial, attempt", try)
|
||||
g.gateway.QueueReconnect()
|
||||
|
||||
// if we encounter an error, make sure we return it, and not nil
|
||||
if oerr := g.Open(ctx); oerr != nil {
|
||||
err = oerr
|
||||
g.ErrorLog(oerr)
|
||||
case *HelloEvent:
|
||||
g.gateway.ResetHeartbeat(data.HeartbeatInterval.Duration())
|
||||
|
||||
wait := time.Duration(4+2*try) * time.Second
|
||||
if wait > 60*time.Second {
|
||||
wait = 60 * time.Second
|
||||
// Send Discord either the Identify packet (if it's a fresh
|
||||
// connection), or a Resume packet (if it's a dead connection).
|
||||
if g.state.SessionID == "" || g.state.Sequence == 0 {
|
||||
// SessionID is empty, so this is a completely new session.
|
||||
if err := g.sendIdentify(ctx); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "failed to send identify")
|
||||
g.gateway.QueueReconnect()
|
||||
}
|
||||
|
||||
time.Sleep(wait)
|
||||
continue
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Started after attempt:", try)
|
||||
return nil
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Unable to Reconnect after", g.ReconnectAttempts, "attempts, aborting")
|
||||
return err
|
||||
}
|
||||
|
||||
// Open connects to the Websocket and authenticates it. You should usually use
|
||||
// this function over Start(). The given context provides cancellation and
|
||||
// timeout.
|
||||
func (g *Gateway) Open(ctx context.Context) error {
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to Reconnect")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Trying to start...")
|
||||
|
||||
// Try to resume the connection
|
||||
if err := g.StartCtx(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Started successfully, return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start calls StartCtx with a background context. You wouldn't usually use this
|
||||
// function, but Open() instead.
|
||||
func (g *Gateway) Start() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.StartCtx(ctx)
|
||||
}
|
||||
|
||||
// StartCtx authenticates with the websocket, or resume from a dead Websocket
|
||||
// connection. You wouldn't usually use this function, but OpenCtx() instead.
|
||||
func (g *Gateway) StartCtx(ctx context.Context) error {
|
||||
g.closed = make(chan struct{})
|
||||
|
||||
if err := g.start(ctx); err != nil {
|
||||
wsutil.WSDebug("Start failed:", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
// pacemaker hasn't started yet.
|
||||
if err := g.Close(); err != nil {
|
||||
wsutil.WSDebug("Failed to close after start fail:", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) start(ctx context.Context) error {
|
||||
// This is where we'll get our events
|
||||
ch := g.WS.Listen()
|
||||
|
||||
// Create a new Hello event and wait for it.
|
||||
var hello HelloEvent
|
||||
// Wait for an OP 10 Hello.
|
||||
select {
|
||||
case e, ok := <-ch:
|
||||
if !ok {
|
||||
return errors.New("unexpected ws close while waiting for Hello")
|
||||
}
|
||||
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "error at Hello")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("Hello received; duration:", hello.HeartbeatInterval)
|
||||
|
||||
// Start the event handler, which also handles the pacemaker death signal.
|
||||
g.waitGroup.Add(1)
|
||||
|
||||
// Use the pacemaker loop.
|
||||
g.PacerLoop.StartBeating(hello.HeartbeatInterval.Duration(), g, func(err error) {
|
||||
g.waitGroup.Done() // mark so Close() can exit.
|
||||
wsutil.WSDebug("Event loop stopped with error:", err)
|
||||
|
||||
if err != nil && g.onShardingRequired != nil {
|
||||
// If Discord signals us sharding is required, do not attempt to
|
||||
// Reconnect, unless we don't know what to do. Instead invalidate
|
||||
// our session ID, as we cannot resume, call OnShardingRequired, and
|
||||
// exit.
|
||||
var cerr *websocket.CloseError
|
||||
if errors.As(err, &cerr) && cerr.Code == errCodeShardingRequired {
|
||||
g.ErrorLog(cerr)
|
||||
g.UseSessionID("")
|
||||
g.onShardingRequired()
|
||||
return
|
||||
} else {
|
||||
if err := g.sendResume(ctx); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "failed to send resume")
|
||||
g.gateway.QueueReconnect()
|
||||
}
|
||||
}
|
||||
|
||||
// Bail if there is no error or if the error is an explicit close, as
|
||||
// there might be an ongoing reconnection.
|
||||
if err == nil || errors.Is(err, wsutil.ErrWebsocketClosed) {
|
||||
return
|
||||
case *InvalidSessionEvent:
|
||||
// Wipe the session state.
|
||||
g.invalidate()
|
||||
|
||||
if !*data {
|
||||
g.gateway.QueueReconnect()
|
||||
break
|
||||
}
|
||||
|
||||
// Only attempt to Reconnect if we have a session ID at all. We may not
|
||||
// have one if we haven't even connected successfully once.
|
||||
if g.SessionID() != "" {
|
||||
g.ErrorLog(err)
|
||||
g.Reconnect()
|
||||
// Discord expects us to wait before reconnecting.
|
||||
g.retryTimer.Reset(time.Duration(rand.Intn(5)+1) * time.Second)
|
||||
if err := g.retryTimer.Wait(ctx); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "failed to wait before identifying")
|
||||
g.gateway.QueueReconnect()
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
// Send Discord either the Identify packet (if it's a fresh connection), or
|
||||
// a Resume packet (if it's a dead connection).
|
||||
if g.SessionID() == "" {
|
||||
// SessionID is empty, so this is a completely new session.
|
||||
if err := g.IdentifyCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to identify")
|
||||
}
|
||||
} else {
|
||||
if err := g.ResumeCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to resume")
|
||||
// If we fail to identify, then the gateway cannot continue with
|
||||
// a bad identification, since it's likely a user error.
|
||||
if err := g.sendIdentify(ctx); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "failed to identify")
|
||||
g.gateway.QueueReconnect()
|
||||
break
|
||||
}
|
||||
|
||||
case *HeartbeatCommand:
|
||||
g.SendHeartbeat(ctx)
|
||||
|
||||
case *HeartbeatAckEvent:
|
||||
now := time.Now()
|
||||
|
||||
g.beatMutex.Lock()
|
||||
g.sentBeat = g.lastSentBeat
|
||||
g.echoBeat = now
|
||||
g.beatMutex.Unlock()
|
||||
|
||||
case *ReconnectEvent:
|
||||
g.gateway.QueueReconnect()
|
||||
|
||||
case *ReadyEvent:
|
||||
g.state.SessionID = data.SessionID
|
||||
}
|
||||
|
||||
// Expect either READY or RESUMED before continuing.
|
||||
wsutil.WSDebug("Waiting for either READY or RESUMED.")
|
||||
return true
|
||||
}
|
||||
|
||||
// WaitForEvent should until the bot becomes ready or resumes (if a
|
||||
// previous ready event has already been called).
|
||||
err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
|
||||
switch op.EventName {
|
||||
case "READY":
|
||||
wsutil.WSDebug("Found READY event.")
|
||||
return true
|
||||
case "RESUMED":
|
||||
wsutil.WSDebug("Found RESUMED event.")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
// SendHeartbeat sends a heartbeat with the gateway's current sequence.
|
||||
func (g *gatewayImpl) SendHeartbeat(ctx context.Context) {
|
||||
g.lastSentBeat = time.Now()
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "first error")
|
||||
sequence := HeartbeatCommand(g.state.Sequence)
|
||||
if err := g.gateway.Send(ctx, &sequence); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "heartbeat error")
|
||||
g.gateway.QueueReconnect()
|
||||
}
|
||||
}
|
||||
|
||||
// Bind the event channel to the pacemaker loop.
|
||||
g.PacerLoop.SetEventChannel(ch)
|
||||
|
||||
wsutil.WSDebug("Started successfully.")
|
||||
|
||||
// Close closes the state.
|
||||
func (g *gatewayImpl) Close() error {
|
||||
g.retryTimer.Stop()
|
||||
g.invalidate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendCtx is a low-level function to send an OP payload to the Gateway. Most
|
||||
// users shouldn't touch this, unless they know what they're doing.
|
||||
func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
|
||||
var op = wsutil.OP{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
if v != nil {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to encode v")
|
||||
}
|
||||
|
||||
op.Data = b
|
||||
}
|
||||
|
||||
b, err := json.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to encode payload")
|
||||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return g.WS.SendCtx(ctx, b)
|
||||
}
|
||||
|
|
31
gateway/gateway_example_test.go
Normal file
31
gateway/gateway_example_test.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package gateway_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
token := os.Getenv("BOT_TOKEN")
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
g, err := gateway.NewWithIntents(ctx, token, gateway.IntentGuilds)
|
||||
if err != nil {
|
||||
log.Fatalln("failed to initialize gateway:", err)
|
||||
}
|
||||
|
||||
for op := range g.Connect(ctx) {
|
||||
switch data := op.Data.(type) {
|
||||
case *gateway.ReadyEvent:
|
||||
log.Println("logged in as", data.User.Username)
|
||||
case *gateway.MessageCreateEvent:
|
||||
log.Println("got message", data.Content)
|
||||
}
|
||||
}
|
||||
}
|
197
gateway/gateway_test.go
Normal file
197
gateway/gateway_test.go
Normal file
|
@ -0,0 +1,197 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
)
|
||||
|
||||
var doLogOnce sync.Once
|
||||
|
||||
func doLog() {
|
||||
doLogOnce.Do(func() {
|
||||
if testing.Verbose() {
|
||||
ws.WSDebug = func(v ...interface{}) {
|
||||
log.Println(append([]interface{}{"Debug:"}, v...)...)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestURL(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
u, err := URL(ctx)
|
||||
if err != nil {
|
||||
t.Fatal("failed to get gateway URL:", err)
|
||||
}
|
||||
|
||||
if u == "" {
|
||||
t.Fatal("gateway URL is empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(u, "wss://") {
|
||||
t.Fatal("gatewayURL is invalid:", u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidToken(t *testing.T) {
|
||||
doLog()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
g, err := New(ctx, "bad token")
|
||||
if err != nil {
|
||||
t.Fatal("failed to make a Gateway:", err)
|
||||
}
|
||||
|
||||
assertIsClose := func(err error) {
|
||||
if err == nil {
|
||||
t.Fatal("unexpected nil error")
|
||||
}
|
||||
|
||||
// 4004 Authentication Failed.
|
||||
if !strings.Contains(err.Error(), "4004") {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
for op := range g.Connect(ctx) {
|
||||
if op.Data == nil {
|
||||
// This shouldn't happen; the loop should've broken out.
|
||||
t.Fatal("nil event received")
|
||||
}
|
||||
|
||||
switch data := op.Data.(type) {
|
||||
case *ws.CloseEvent:
|
||||
assertIsClose(data)
|
||||
case *ws.BackgroundErrorEvent:
|
||||
t.Error("gateway error:", data)
|
||||
case *HelloEvent:
|
||||
t.Log("got Hello")
|
||||
case *InvalidSessionEvent:
|
||||
t.Log("got InvalidSession")
|
||||
default:
|
||||
t.Errorf("got unexpected event %#v", data)
|
||||
}
|
||||
}
|
||||
|
||||
assertIsClose(g.LastError())
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
doLog()
|
||||
|
||||
config := testenv.Must(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// NewGateway should call Start for us.
|
||||
g, err := NewWithIntents(ctx, "Bot "+config.BotToken, IntentGuilds)
|
||||
if err != nil {
|
||||
t.Fatal("failed to make a Gateway:", err)
|
||||
}
|
||||
|
||||
gatewayOpenAndSpin(t, ctx, g)
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestReuseGateway(t *testing.T) {
|
||||
doLog()
|
||||
|
||||
config := testenv.Must(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// NewGateway should call Start for us.
|
||||
g, err := NewWithIntents(ctx, "Bot "+config.BotToken, IntentGuilds)
|
||||
if err != nil {
|
||||
t.Fatal("failed to make a Gateway:", err)
|
||||
}
|
||||
|
||||
// Reuse this 3 times.
|
||||
for i := 0; i < 3; i++ {
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
gatewayOpenAndSpin(t, cctx, g)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func gatewayOpenAndSpin(t *testing.T, ctx context.Context, g *Gateway) {
|
||||
ch := g.Connect(ctx)
|
||||
|
||||
var reconnected bool
|
||||
reconnect := func() {
|
||||
if !reconnected {
|
||||
reconnected = true
|
||||
g.gateway.QueueReconnect()
|
||||
}
|
||||
}
|
||||
|
||||
for op := range ch {
|
||||
if op.Data == nil {
|
||||
// This shouldn't happen; the loop should've broken out.
|
||||
t.Fatal("nil event received")
|
||||
}
|
||||
|
||||
switch data := op.Data.(type) {
|
||||
case *ReadyEvent:
|
||||
t.Log("got Ready")
|
||||
if g.state.SessionID != data.SessionID {
|
||||
t.Fatal("missing SessionID")
|
||||
}
|
||||
log.Println("Bot's username is", data.User.Username)
|
||||
reconnect()
|
||||
case *ResumedEvent:
|
||||
t.Log("got Resumed, test done")
|
||||
return
|
||||
case *HelloEvent:
|
||||
t.Log("got Hello")
|
||||
case *ws.BackgroundErrorEvent:
|
||||
t.Error("gateway error:", data)
|
||||
default:
|
||||
t.Logf("got event %T", data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wait(t *testing.T, evCh chan interface{}) interface{} {
|
||||
select {
|
||||
case ev := <-evCh:
|
||||
return ev
|
||||
case <-time.After(20 * time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func gotimeout(t *testing.T, fn func(context.Context)) {
|
||||
t.Helper()
|
||||
|
||||
// Try and reconnect for 20 seconds maximum.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var done = make(chan struct{})
|
||||
go func() {
|
||||
fn(ctx)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for function.")
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
|
@ -3,8 +3,10 @@ package gateway
|
|||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json/option"
|
||||
"github.com/pkg/errors"
|
||||
|
@ -13,27 +15,27 @@ import (
|
|||
|
||||
// DefaultPresence is used as the default presence when initializing a new
|
||||
// Gateway.
|
||||
var DefaultPresence *UpdateStatusData
|
||||
var DefaultPresence *UpdatePresenceCommand
|
||||
|
||||
// Identifier is a wrapper around IdentifyData to add in appropriate rate
|
||||
// Identifier is a wrapper around IdentifyCommand to add in appropriate rate
|
||||
// limiters.
|
||||
type Identifier struct {
|
||||
IdentifyData
|
||||
IdentifyCommand
|
||||
|
||||
IdentifyShortLimit *rate.Limiter `json:"-"` // optional
|
||||
IdentifyGlobalLimit *rate.Limiter `json:"-"` // optional
|
||||
}
|
||||
|
||||
// DefaultIdentifier creates a new default Identifier
|
||||
func DefaultIdentifier(token string) *Identifier {
|
||||
return NewIdentifier(DefaultIdentifyData(token))
|
||||
func DefaultIdentifier(token string) Identifier {
|
||||
return NewIdentifier(DefaultIdentifyCommand(token))
|
||||
}
|
||||
|
||||
// NewIdentifier creates a new identifier with the given IdentifyData and
|
||||
// NewIdentifier creates a new identifier with the given IdentifyCommand and
|
||||
// default rate limiters.
|
||||
func NewIdentifier(data IdentifyData) *Identifier {
|
||||
return &Identifier{
|
||||
IdentifyData: data,
|
||||
func NewIdentifier(data IdentifyCommand) Identifier {
|
||||
return Identifier{
|
||||
IdentifyCommand: data,
|
||||
IdentifyShortLimit: rate.NewLimiter(rate.Every(5*time.Second), 1),
|
||||
IdentifyGlobalLimit: rate.NewLimiter(rate.Every(24*time.Hour), 1000),
|
||||
}
|
||||
|
@ -41,15 +43,15 @@ func NewIdentifier(data IdentifyData) *Identifier {
|
|||
|
||||
// Wait waits for the rate limiters to pass. If a limiter is nil, then it will
|
||||
// not be used to wait. This is useful
|
||||
func (i *Identifier) Wait(ctx context.Context) error {
|
||||
if i.IdentifyShortLimit != nil {
|
||||
if err := i.IdentifyShortLimit.Wait(ctx); err != nil {
|
||||
func (id *Identifier) Wait(ctx context.Context) error {
|
||||
if id.IdentifyShortLimit != nil {
|
||||
if err := id.IdentifyShortLimit.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "can't wait for short limit")
|
||||
}
|
||||
}
|
||||
|
||||
if i.IdentifyGlobalLimit != nil {
|
||||
if err := i.IdentifyGlobalLimit.Wait(ctx); err != nil {
|
||||
if id.IdentifyGlobalLimit != nil {
|
||||
if err := id.IdentifyGlobalLimit.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "can't wait for global limit")
|
||||
}
|
||||
}
|
||||
|
@ -57,6 +59,41 @@ func (i *Identifier) Wait(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// QueryGateway queries the gateway for the URL and updates the Identifier with
|
||||
// the appropriate information.
|
||||
func (id *Identifier) QueryGateway(ctx context.Context) (gatewayURL string, err error) {
|
||||
var botData *api.BotData
|
||||
|
||||
if strings.HasPrefix(id.Token, "Bot ") {
|
||||
botData, err = BotURL(ctx, id.Token)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to get bot data")
|
||||
}
|
||||
gatewayURL = botData.URL
|
||||
} else {
|
||||
gatewayURL, err = URL(ctx)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to get gateway endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
// Use the supplied connect rate limit, if any.
|
||||
if botData != nil && botData.StartLimit != nil {
|
||||
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
|
||||
limiter := id.IdentifyGlobalLimit
|
||||
|
||||
// Update the burst to be the current given time and reset it back to
|
||||
// the default when the given time is reached.
|
||||
limiter.SetBurst(botData.StartLimit.Remaining)
|
||||
limiter.SetBurstAt(resetAt, botData.StartLimit.Total)
|
||||
|
||||
// Update the maximum number of identify requests allowed per 5s.
|
||||
id.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// DefaultIdentity is used as the default identity when initializing a new
|
||||
// Gateway.
|
||||
var DefaultIdentity = IdentifyProperties{
|
||||
|
@ -65,9 +102,9 @@ var DefaultIdentity = IdentifyProperties{
|
|||
Device: "Arikawa",
|
||||
}
|
||||
|
||||
// IdentifyData is the struct for a data that's sent over in an Identify
|
||||
// command.
|
||||
type IdentifyData struct {
|
||||
// IdentifyCommand is a command for Op 2. It is the struct for a data that's
|
||||
// sent over in an Identify command.
|
||||
type IdentifyCommand struct {
|
||||
Token string `json:"token"`
|
||||
Properties IdentifyProperties `json:"properties"`
|
||||
|
||||
|
@ -76,7 +113,7 @@ type IdentifyData struct {
|
|||
|
||||
Shard *Shard `json:"shard,omitempty"` // [ shard_id, num_shards ]
|
||||
|
||||
Presence *UpdateStatusData `json:"presence,omitempty"`
|
||||
Presence *UpdatePresenceCommand `json:"presence,omitempty"`
|
||||
|
||||
// ClientState is the client state for a user's accuont. Bot accounts should
|
||||
// NOT touch this field.
|
||||
|
@ -97,9 +134,9 @@ type IdentifyData struct {
|
|||
Intents option.Uint `json:"intents"`
|
||||
}
|
||||
|
||||
// DefaultIdentifyData creates a default IdentifyData with the given token.
|
||||
func DefaultIdentifyData(token string) IdentifyData {
|
||||
return IdentifyData{
|
||||
// DefaultIdentifyCommand creates a default IdentifyCommand with the given token.
|
||||
func DefaultIdentifyCommand(token string) IdentifyCommand {
|
||||
return IdentifyCommand{
|
||||
Token: token,
|
||||
Properties: DefaultIdentity,
|
||||
Presence: DefaultPresence,
|
||||
|
@ -110,14 +147,35 @@ func DefaultIdentifyData(token string) IdentifyData {
|
|||
}
|
||||
|
||||
// SetShard is a helper function to set the shard configuration inside
|
||||
// IdentifyData.
|
||||
func (i *IdentifyData) SetShard(id, num int) {
|
||||
// IdentifyCommand.
|
||||
func (i *IdentifyCommand) SetShard(id, num int) {
|
||||
if i.Shard == nil {
|
||||
i.Shard = new(Shard)
|
||||
}
|
||||
i.Shard[0], i.Shard[1] = id, num
|
||||
}
|
||||
|
||||
// AddIntents adds gateway intents into the identify data.
|
||||
func (i *IdentifyCommand) AddIntents(intents Intents) {
|
||||
if i.Intents == nil {
|
||||
i.Intents = option.NewUint(uint(intents))
|
||||
} else {
|
||||
*i.Intents |= uint(intents)
|
||||
}
|
||||
}
|
||||
|
||||
// HasIntents reports if the Gateway has the passed Intents.
|
||||
//
|
||||
// If no intents are set, e.g. if using a user account, HasIntents will always
|
||||
// return true.
|
||||
func (i *IdentifyCommand) HasIntents(intents Intents) bool {
|
||||
if i.Intents == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return Intents(*i.Intents).Has(intents)
|
||||
}
|
||||
|
||||
type IdentifyProperties struct {
|
||||
// Required
|
||||
OS string `json:"os"` // GOOS
|
||||
|
|
|
@ -1,160 +0,0 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/internal/heart"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
wsutil.WSDebug = func(v ...interface{}) {
|
||||
log.Println(append([]interface{}{"Debug:"}, v...)...)
|
||||
}
|
||||
heart.Debug = func(v ...interface{}) {
|
||||
log.Println(append([]interface{}{"Heart:"}, v...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func TestURL(t *testing.T) {
|
||||
u, err := URL()
|
||||
if err != nil {
|
||||
t.Fatal("failed to get gateway URL:", err)
|
||||
}
|
||||
|
||||
if u == "" {
|
||||
t.Fatal("gateway URL is empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(u, "wss://") {
|
||||
t.Fatal("gatewayURL is invalid:", u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidToken(t *testing.T) {
|
||||
g, err := NewGateway("bad token")
|
||||
if err != nil {
|
||||
t.Fatal("failed to make a Gateway:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err = g.Open(ctx); err == nil {
|
||||
t.Fatal("unexpected success while opening with a bad token.")
|
||||
}
|
||||
|
||||
// 4004 Authentication Failed.
|
||||
if !strings.Contains(err.Error(), "4004") {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
config := testenv.Must(t)
|
||||
|
||||
var gateway *Gateway
|
||||
|
||||
// NewGateway should call Start for us.
|
||||
g, err := NewGateway("Bot " + config.BotToken)
|
||||
if err != nil {
|
||||
t.Fatal("failed to make a Gateway:", err)
|
||||
}
|
||||
g.AddIntents(IntentGuilds)
|
||||
g.AfterClose = func(err error) {
|
||||
t.Log("closed.")
|
||||
}
|
||||
g.ErrorLog = func(err error) {
|
||||
t.Log("gateway error:", err)
|
||||
}
|
||||
gateway = g
|
||||
|
||||
gotimeout(t, func(ctx context.Context) {
|
||||
if err := g.Open(ctx); err != nil {
|
||||
t.Fatal("failed to authenticate with Discord:", err)
|
||||
}
|
||||
})
|
||||
|
||||
ev := wait(t, gateway.Events)
|
||||
ready, ok := ev.(*ReadyEvent)
|
||||
if !ok {
|
||||
t.Fatal("event received is not of type Ready:", ev)
|
||||
}
|
||||
|
||||
if gateway.SessionID() == "" {
|
||||
t.Fatal("session ID is empty")
|
||||
}
|
||||
|
||||
log.Println("Bot's username is", ready.User.Username)
|
||||
|
||||
// Send a faster heartbeat every second for testing.
|
||||
g.PacerLoop.SetPace(time.Second)
|
||||
|
||||
// Sleep past the rate limiter before reconnecting:
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
gotimeout(t, func(ctx context.Context) {
|
||||
g.ErrorLog = func(err error) {
|
||||
t.Error("unexpected error while reconnecting:", err)
|
||||
}
|
||||
|
||||
if err := gateway.ReconnectCtx(ctx); err != nil {
|
||||
t.Error("failed to reconnect Gateway:", err)
|
||||
}
|
||||
})
|
||||
|
||||
g.ErrorLog = func(err error) { t.Log("warning:", err) }
|
||||
|
||||
// Wait for the desired event:
|
||||
gotimeout(t, func(context.Context) {
|
||||
for ev := range gateway.Events {
|
||||
switch ev.(type) {
|
||||
// Accept only a Resumed event.
|
||||
case *ResumedEvent:
|
||||
return // exit
|
||||
case *ReadyEvent:
|
||||
t.Fatal("Ready event received instead of Resumed.")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if err := g.Close(); err != nil {
|
||||
t.Fatal("failed to close Gateway:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func wait(t *testing.T, evCh chan interface{}) interface{} {
|
||||
select {
|
||||
case ev := <-evCh:
|
||||
return ev
|
||||
case <-time.After(20 * time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func gotimeout(t *testing.T, fn func(context.Context)) {
|
||||
t.Helper()
|
||||
|
||||
// Try and reconnect for 20 seconds maximum.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var done = make(chan struct{})
|
||||
go func() {
|
||||
fn(ctx)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for function.")
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
|
@ -1,6 +1,9 @@
|
|||
package gateway
|
||||
|
||||
import "github.com/diamondburned/arikawa/v3/discord"
|
||||
import (
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
)
|
||||
|
||||
// Intents for the new Discord API feature, documented at
|
||||
// https://discord.com/developers/docs/topics/gateway#gateway-intents.
|
||||
|
@ -44,7 +47,7 @@ func (i Intents) IsPrivileged() (presences, member bool) {
|
|||
}
|
||||
|
||||
// EventIntents maps event types to intents.
|
||||
var EventIntents = map[string]Intents{
|
||||
var EventIntents = map[ws.EventType]Intents{
|
||||
"GUILD_CREATE": IntentGuilds,
|
||||
"GUILD_UPDATE": IntentGuilds,
|
||||
"GUILD_DELETE": IntentGuilds,
|
||||
|
|
118
gateway/op.go
118
gateway/op.go
|
@ -1,118 +0,0 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OPCode = wsutil.OPCode
|
||||
|
||||
const (
|
||||
DispatchOP OPCode = 0 // recv
|
||||
HeartbeatOP OPCode = 1 // send/recv
|
||||
IdentifyOP OPCode = 2 // send...
|
||||
StatusUpdateOP OPCode = 3 //
|
||||
VoiceStateUpdateOP OPCode = 4 //
|
||||
VoiceServerPingOP OPCode = 5 //
|
||||
ResumeOP OPCode = 6 //
|
||||
ReconnectOP OPCode = 7 // recv
|
||||
RequestGuildMembersOP OPCode = 8 // send
|
||||
InvalidSessionOP OPCode = 9 // recv...
|
||||
HelloOP OPCode = 10
|
||||
HeartbeatAckOP OPCode = 11
|
||||
CallConnectOP OPCode = 13
|
||||
GuildSubscriptionsOP OPCode = 14
|
||||
)
|
||||
|
||||
// ErrReconnectRequest is returned by HandleOP if a ReconnectOP is given. This
|
||||
// is used mostly internally to signal the heartbeat loop to reconnect, if
|
||||
// needed. It is not a fatal error.
|
||||
var ErrReconnectRequest = errors.New("ReconnectOP received")
|
||||
|
||||
func (g *Gateway) HandleOP(op *wsutil.OP) error {
|
||||
switch op.Code {
|
||||
case HeartbeatAckOP:
|
||||
// Heartbeat from the server?
|
||||
g.PacerLoop.Echo()
|
||||
|
||||
case HeartbeatOP:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Server requesting a heartbeat.
|
||||
if err := g.PacerLoop.Pace(ctx); err != nil {
|
||||
return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace"))
|
||||
}
|
||||
|
||||
case ReconnectOP:
|
||||
// Server requests to Reconnect, die and retry.
|
||||
wsutil.WSDebug("ReconnectOP received.")
|
||||
|
||||
// Exit with the ReconnectOP error to force the heartbeat event loop to
|
||||
// Reconnect synchronously. Not really a fatal error.
|
||||
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
|
||||
|
||||
case InvalidSessionOP:
|
||||
// Discord expects us to sleep for no reason
|
||||
time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Invalid session, try and Identify.
|
||||
if err := g.IdentifyCtx(ctx); err != nil {
|
||||
// Can't identify, Reconnect.
|
||||
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
case HelloOP:
|
||||
return nil
|
||||
|
||||
case DispatchOP:
|
||||
// Set the sequence
|
||||
if op.Sequence > 0 {
|
||||
g.Sequence.Set(op.Sequence)
|
||||
}
|
||||
|
||||
// Check if we know the event
|
||||
fn, ok := EventCreator[op.EventName]
|
||||
if !ok {
|
||||
return &wsutil.UnknownEventError{
|
||||
Name: op.EventName,
|
||||
Data: op.Data,
|
||||
}
|
||||
}
|
||||
|
||||
// Make a new pointer to the event
|
||||
var ev = fn()
|
||||
|
||||
// Try and parse the event
|
||||
if err := json.Unmarshal(op.Data, ev); err != nil {
|
||||
return errors.Wrap(err, "failed to parse event "+op.EventName)
|
||||
}
|
||||
|
||||
// If the event is a ready, we'll want its sessionID
|
||||
if ev, ok := ev.(*ReadyEvent); ok {
|
||||
g.sessionMu.Lock()
|
||||
g.sessionID = ev.SessionID
|
||||
g.sessionMu.Unlock()
|
||||
}
|
||||
|
||||
// Throw the event into a channel; it's valid now.
|
||||
g.Events <- ev
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown OP code %d (event %s)", op.Code, op.EventName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,56 +0,0 @@
|
|||
// +build perseverance
|
||||
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
)
|
||||
|
||||
func TestPerseverance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := testenv.Must(t)
|
||||
|
||||
g, err := NewGateway("Bot " + config.BotToken)
|
||||
if err != nil {
|
||||
t.Fatal("failed to make the gateway:", err)
|
||||
}
|
||||
g.AddIntents(IntentGuilds)
|
||||
|
||||
if err := g.Open(); err != nil {
|
||||
t.Fatal("failed to open the gateway:", err)
|
||||
}
|
||||
|
||||
timeout := make(chan struct{}, 1)
|
||||
|
||||
// Automatically close the gateway after set duration.
|
||||
time.AfterFunc(testenv.PerseveranceTime, func() {
|
||||
t.Log("Perserverence test finshed. Closing gateway.")
|
||||
timeout <- struct{}{}
|
||||
|
||||
if err := g.Close(); err != nil {
|
||||
t.Error("failed to close gateway:", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Spin on events.
|
||||
for ev := range g.Events {
|
||||
t.Logf("Received event %T.", ev)
|
||||
}
|
||||
|
||||
// Exit gracefully if we have not.
|
||||
select {
|
||||
case <-timeout:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err := g.Close(); err != nil {
|
||||
t.Fatal("failed to clean up gateway after fail:", err)
|
||||
}
|
||||
|
||||
t.Fatal("Test failed before timeout.")
|
||||
}
|
267
gateway/ready.go
267
gateway/ready.go
|
@ -1,267 +0,0 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
)
|
||||
|
||||
type (
|
||||
// ReadyEvent is the struct for a READY event.
|
||||
ReadyEvent struct {
|
||||
Version int `json:"version"`
|
||||
|
||||
User discord.User `json:"user"`
|
||||
SessionID string `json:"session_id"`
|
||||
|
||||
PrivateChannels []discord.Channel `json:"private_channels"`
|
||||
Guilds []GuildCreateEvent `json:"guilds"`
|
||||
|
||||
Shard *Shard `json:"shard,omitempty"`
|
||||
|
||||
// Undocumented fields
|
||||
|
||||
UserSettings *UserSettings `json:"user_settings,omitempty"`
|
||||
ReadStates []ReadState `json:"read_state,omitempty"`
|
||||
UserGuildSettings []UserGuildSetting `json:"user_guild_settings,omitempty"`
|
||||
Relationships []discord.Relationship `json:"relationships,omitempty"`
|
||||
Presences []discord.Presence `json:"presences,omitempty"`
|
||||
|
||||
FriendSuggestionCount int `json:"friend_suggestion_count,omitempty"`
|
||||
GeoOrderedRTCRegions []string `json:"geo_ordered_rtc_regions,omitempty"`
|
||||
}
|
||||
|
||||
// ReadState is a single ReadState entry. It is undocumented.
|
||||
ReadState struct {
|
||||
ChannelID discord.ChannelID `json:"id"`
|
||||
LastMessageID discord.MessageID `json:"last_message_id"`
|
||||
LastPinTimestamp discord.Timestamp `json:"last_pin_timestamp"`
|
||||
MentionCount int `json:"mention_count"`
|
||||
}
|
||||
|
||||
// UserSettings is the struct for (almost) all user settings. It is
|
||||
// undocumented.
|
||||
UserSettings struct {
|
||||
ShowCurrentGame bool `json:"show_current_game"`
|
||||
DefaultGuildsRestricted bool `json:"default_guilds_restricted"`
|
||||
InlineAttachmentMedia bool `json:"inline_attachment_media"`
|
||||
InlineEmbedMedia bool `json:"inline_embed_media"`
|
||||
GIFAutoPlay bool `json:"gif_auto_play"`
|
||||
RenderEmbeds bool `json:"render_embeds"`
|
||||
RenderReactions bool `json:"render_reactions"`
|
||||
AnimateEmoji bool `json:"animate_emoji"`
|
||||
AnimateStickers int `json:"animate_stickers"`
|
||||
EnableTTSCommand bool `json:"enable_tts_command"`
|
||||
MessageDisplayCompact bool `json:"message_display_compact"`
|
||||
ConvertEmoticons bool `json:"convert_emoticons"`
|
||||
ExplicitContentFilter uint8 `json:"explicit_content_filter"` // ???
|
||||
DisableGamesTab bool `json:"disable_games_tab"`
|
||||
DeveloperMode bool `json:"developer_mode"`
|
||||
DetectPlatformAccounts bool `json:"detect_platform_accounts"`
|
||||
StreamNotification bool `json:"stream_notification_enabled"`
|
||||
AccessibilityDetection bool `json:"allow_accessibility_detection"`
|
||||
ContactSync bool `json:"contact_sync_enabled"`
|
||||
NativePhoneIntegration bool `json:"native_phone_integration_enabled"`
|
||||
|
||||
TimezoneOffset int `json:"timezone_offset"`
|
||||
|
||||
Locale string `json:"locale"`
|
||||
Theme string `json:"theme"`
|
||||
|
||||
GuildPositions []discord.GuildID `json:"guild_positions"`
|
||||
GuildFolders []GuildFolder `json:"guild_folders"`
|
||||
RestrictedGuilds []discord.GuildID `json:"restricted_guilds"`
|
||||
|
||||
FriendSourceFlags FriendSourceFlags `json:"friend_source_flags"`
|
||||
|
||||
Status discord.Status `json:"status"`
|
||||
CustomStatus *CustomUserStatus `json:"custom_status"`
|
||||
}
|
||||
|
||||
// CustomUserStatus is the custom user status that allows setting an emoji
|
||||
// and a piece of text on each user.
|
||||
CustomUserStatus struct {
|
||||
Text string `json:"text"`
|
||||
ExpiresAt discord.Timestamp `json:"expires_at,omitempty"`
|
||||
EmojiID discord.EmojiID `json:"emoji_id,string"`
|
||||
EmojiName string `json:"emoji_name"`
|
||||
}
|
||||
|
||||
// UserGuildSetting stores the settings for a single guild. It is
|
||||
// undocumented.
|
||||
UserGuildSetting struct {
|
||||
GuildID discord.GuildID `json:"guild_id"`
|
||||
|
||||
SuppressRoles bool `json:"suppress_roles"`
|
||||
SuppressEveryone bool `json:"suppress_everyone"`
|
||||
Muted bool `json:"muted"`
|
||||
MuteConfig *UserMuteConfig `json:"mute_config"`
|
||||
|
||||
MobilePush bool `json:"mobile_push"`
|
||||
Notifications UserNotification `json:"message_notifications"`
|
||||
|
||||
ChannelOverrides []UserChannelOverride `json:"channel_overrides"`
|
||||
}
|
||||
|
||||
// A UserChannelOverride struct describes a channel settings override for a
|
||||
// users guild settings.
|
||||
UserChannelOverride struct {
|
||||
Muted bool `json:"muted"`
|
||||
MuteConfig *UserMuteConfig `json:"mute_config"`
|
||||
Notifications UserNotification `json:"message_notifications"`
|
||||
ChannelID discord.ChannelID `json:"channel_id"`
|
||||
}
|
||||
|
||||
// UserMuteConfig seems to describe the mute settings. It belongs to the
|
||||
// UserGuildSettingEntry and UserChannelOverride structs and is
|
||||
// undocumented.
|
||||
UserMuteConfig struct {
|
||||
SelectedTimeWindow int `json:"selected_time_window"`
|
||||
EndTime discord.Timestamp `json:"end_time"`
|
||||
}
|
||||
|
||||
// GuildFolder holds a single folder that you see in the left guild panel.
|
||||
GuildFolder struct {
|
||||
Name string `json:"name"`
|
||||
ID GuildFolderID `json:"id"`
|
||||
GuildIDs []discord.GuildID `json:"guild_ids"`
|
||||
Color discord.Color `json:"color"`
|
||||
}
|
||||
|
||||
// FriendSourceFlags describes sources that friend requests could be sent
|
||||
// from. It belongs to the UserSettings struct and is undocumented.
|
||||
FriendSourceFlags struct {
|
||||
All bool `json:"all,omitempty"`
|
||||
MutualGuilds bool `json:"mutual_guilds,omitempty"`
|
||||
MutualFriends bool `json:"mutual_friends,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
// UserNotification is the notification setting for a channel or guild.
|
||||
type UserNotification uint8
|
||||
|
||||
const (
|
||||
AllNotifications UserNotification = iota
|
||||
OnlyMentions
|
||||
NoNotifications
|
||||
GuildDefaults
|
||||
)
|
||||
|
||||
// GuildFolderID is possibly a snowflake. It can also be 0 (null) or a low
|
||||
// number of unknown significance.
|
||||
type GuildFolderID int64
|
||||
|
||||
func (g *GuildFolderID) UnmarshalJSON(b []byte) error {
|
||||
var body = string(b)
|
||||
if body == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
body = strings.Trim(body, `"`)
|
||||
|
||||
u, err := strconv.ParseInt(body, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*g = GuildFolderID(u)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g GuildFolderID) MarshalJSON() ([]byte, error) {
|
||||
if g == 0 {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return []byte(strconv.FormatInt(int64(g), 10)), nil
|
||||
}
|
||||
|
||||
// ReadySupplemental event structs. For now, this event is never used, and its
|
||||
// usage have yet been discovered.
|
||||
type (
|
||||
// ReadySupplementalEvent is the struct for a READY_SUPPLEMENTAL event,
|
||||
// which is an undocumented event.
|
||||
ReadySupplementalEvent struct {
|
||||
Guilds []GuildCreateEvent `json:"guilds"` // only have ID and VoiceStates
|
||||
MergedMembers [][]SupplementalMember `json:"merged_members"`
|
||||
MergedPresences MergedPresences `json:"merged_presences"`
|
||||
}
|
||||
|
||||
// SupplementalMember is the struct for a member in the MergedMembers field
|
||||
// of ReadySupplementalEvent. It has slight differences to discord.Member.
|
||||
SupplementalMember struct {
|
||||
UserID discord.UserID `json:"user_id"`
|
||||
Nick string `json:"nick,omitempty"`
|
||||
RoleIDs []discord.RoleID `json:"roles"`
|
||||
|
||||
GuildID discord.GuildID `json:"guild_id,omitempty"`
|
||||
IsPending bool `json:"pending,omitempty"`
|
||||
HoistedRole discord.RoleID `json:"hoisted_role"`
|
||||
|
||||
Mute bool `json:"mute"`
|
||||
Deaf bool `json:"deaf"`
|
||||
|
||||
// Joined specifies when the user joined the guild.
|
||||
Joined discord.Timestamp `json:"joined_at"`
|
||||
|
||||
// BoostedSince specifies when the user started boosting the guild.
|
||||
BoostedSince discord.Timestamp `json:"premium_since,omitempty"`
|
||||
}
|
||||
|
||||
// MergedPresences is the struct for presences of guilds' members and
|
||||
// friends. It is undocumented.
|
||||
MergedPresences struct {
|
||||
Guilds [][]SupplementalPresence `json:"guilds"`
|
||||
Friends []SupplementalPresence `json:"friends"`
|
||||
}
|
||||
|
||||
// SupplementalPresence is a single presence for either a guild member or
|
||||
// friend. It is used in MergedPresences and is undocumented.
|
||||
SupplementalPresence struct {
|
||||
UserID discord.UserID `json:"user_id"`
|
||||
|
||||
// Status is either "idle", "dnd", "online", or "offline".
|
||||
Status discord.Status `json:"status"`
|
||||
// Activities are the user's current activities.
|
||||
Activities []discord.Activity `json:"activities"`
|
||||
// ClientStaus is the user's platform-dependent status.
|
||||
ClientStatus discord.ClientStatus `json:"client_status"`
|
||||
|
||||
// LastModified is only present in Friends.
|
||||
LastModified discord.UnixMsTimestamp `json:"last_modified,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
// ConvertSupplementalMembers converts a SupplementalMember to a regular Member.
|
||||
func ConvertSupplementalMembers(sms []SupplementalMember) []discord.Member {
|
||||
members := make([]discord.Member, len(sms))
|
||||
for i, sm := range sms {
|
||||
members[i] = discord.Member{
|
||||
User: discord.User{ID: sm.UserID},
|
||||
Nick: sm.Nick,
|
||||
RoleIDs: sm.RoleIDs,
|
||||
Joined: sm.Joined,
|
||||
BoostedSince: sm.BoostedSince,
|
||||
Deaf: sm.Deaf,
|
||||
Mute: sm.Mute,
|
||||
IsPending: sm.IsPending,
|
||||
}
|
||||
}
|
||||
return members
|
||||
}
|
||||
|
||||
// ConvertSupplementalPresences converts a SupplementalPresence to a regular
|
||||
// Presence with an empty GuildID.
|
||||
func ConvertSupplementalPresences(sps []SupplementalPresence) []discord.Presence {
|
||||
presences := make([]discord.Presence, len(sps))
|
||||
for i, sp := range sps {
|
||||
presences[i] = discord.Presence{
|
||||
User: discord.User{ID: sp.UserID},
|
||||
Status: sp.Status,
|
||||
Activities: sp.Activities,
|
||||
ClientStatus: sp.ClientStatus,
|
||||
}
|
||||
}
|
||||
return presences
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
// Package handleloop provides clean abstractions to handle listening to
|
||||
// channels and passing them onto event handlers.
|
||||
package handleloop
|
||||
|
||||
import "github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
|
||||
// Loop provides a reusable event looper abstraction. It is thread-safe to use
|
||||
// concurrently.
|
||||
type Loop struct {
|
||||
dst *handler.Handler
|
||||
run chan struct{}
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
func NewLoop(dst *handler.Handler) *Loop {
|
||||
return &Loop{
|
||||
dst: dst,
|
||||
run: make(chan struct{}, 1), // intentional 1 buffer
|
||||
stop: make(chan struct{}), // intentional unbuffer
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts a new event loop. It will try to stop existing loops before.
|
||||
func (l *Loop) Start(src <-chan interface{}) {
|
||||
// Ensure we're stopped.
|
||||
l.Stop()
|
||||
|
||||
// Mark that we're running.
|
||||
l.run <- struct{}{}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case event := <-src:
|
||||
l.dst.Call(event)
|
||||
|
||||
case <-l.stop:
|
||||
l.stop <- struct{}{}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop tries to stop the Loop. If the Loop is not running, then it does
|
||||
// nothing; thus, it can be called multiple times.
|
||||
func (l *Loop) Stop() {
|
||||
// Ensure that we are running before stopping.
|
||||
select {
|
||||
case <-l.run:
|
||||
// running
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
// send a close request
|
||||
l.stop <- struct{}{}
|
||||
// wait for a reply
|
||||
<-l.stop
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
// Package heart implements a general purpose pacemaker.
|
||||
package heart
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Debug is the default logger that Pacemaker uses.
|
||||
var Debug = func(v ...interface{}) {}
|
||||
|
||||
var ErrDead = errors.New("no heartbeat replied")
|
||||
|
||||
// AtomicTime is a thread-safe UnixNano timestamp guarded by atomic.
|
||||
type AtomicTime struct {
|
||||
unixnano int64
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Get() int64 {
|
||||
return atomic.LoadInt64(&t.unixnano)
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Set(time time.Time) {
|
||||
atomic.StoreInt64(&t.unixnano, time.UnixNano())
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Time() time.Time {
|
||||
return time.Unix(0, t.Get())
|
||||
}
|
||||
|
||||
// AtomicDuration is a thread-safe Duration guarded by atomic.
|
||||
type AtomicDuration struct {
|
||||
duration int64
|
||||
}
|
||||
|
||||
func (d *AtomicDuration) Get() time.Duration {
|
||||
return time.Duration(atomic.LoadInt64(&d.duration))
|
||||
}
|
||||
|
||||
func (d *AtomicDuration) Set(dura time.Duration) {
|
||||
atomic.StoreInt64(&d.duration, int64(dura))
|
||||
}
|
||||
|
||||
// Pacemaker is the internal pacemaker state. All fields are not thread-safe
|
||||
// unless they're atomic.
|
||||
type Pacemaker struct {
|
||||
// Heartrate is the received duration between heartbeats.
|
||||
Heartrate AtomicDuration
|
||||
|
||||
ticker time.Ticker
|
||||
Ticks <-chan time.Time
|
||||
|
||||
// Time in nanoseconds, guarded by atomic read/writes.
|
||||
SentBeat AtomicTime
|
||||
EchoBeat AtomicTime
|
||||
|
||||
// Any callback that returns an error will stop the pacer.
|
||||
Pacer func(context.Context) error
|
||||
}
|
||||
|
||||
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker {
|
||||
p := Pacemaker{
|
||||
Heartrate: AtomicDuration{int64(heartrate)},
|
||||
Pacer: pacer,
|
||||
ticker: *time.NewTicker(heartrate),
|
||||
}
|
||||
p.Ticks = p.ticker.C
|
||||
// Reset states to its old position.
|
||||
now := time.Now()
|
||||
p.EchoBeat.Set(now)
|
||||
p.SentBeat.Set(now)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pacemaker) Echo() {
|
||||
// Swap our received heartbeats
|
||||
p.EchoBeat.Set(time.Now())
|
||||
}
|
||||
|
||||
// Dead, if true, will have Pace return an ErrDead.
|
||||
func (p *Pacemaker) Dead() bool {
|
||||
var (
|
||||
echo = p.EchoBeat.Get()
|
||||
sent = p.SentBeat.Get()
|
||||
)
|
||||
|
||||
if echo == 0 || sent == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return sent-echo > int64(p.Heartrate.Get())*2
|
||||
}
|
||||
|
||||
// SetHeartRate sets the ticker's heart rate.
|
||||
func (p *Pacemaker) SetPace(heartrate time.Duration) {
|
||||
p.Heartrate.Set(heartrate)
|
||||
|
||||
// To uncomment when 1.16 releases and we drop support for 1.14.
|
||||
// p.ticker.Reset(heartrate)
|
||||
|
||||
p.ticker.Stop()
|
||||
p.ticker = *time.NewTicker(heartrate)
|
||||
p.Ticks = p.ticker.C
|
||||
}
|
||||
|
||||
// Stop stops the pacemaker, or it does nothing if the pacemaker is not started.
|
||||
func (p *Pacemaker) StopTicker() {
|
||||
p.ticker.Stop()
|
||||
}
|
||||
|
||||
// pace sends a heartbeat with the appropriate timeout for the context.
|
||||
func (p *Pacemaker) Pace() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate.Get())
|
||||
defer cancel()
|
||||
|
||||
return p.PaceCtx(ctx)
|
||||
}
|
||||
|
||||
func (p *Pacemaker) PaceCtx(ctx context.Context) error {
|
||||
if err := p.Pacer(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.SentBeat.Set(time.Now())
|
||||
|
||||
if p.Dead() {
|
||||
return ErrDead
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
30
internal/lazytime/ticker.go
Normal file
30
internal/lazytime/ticker.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package lazytime
|
||||
|
||||
import "time"
|
||||
|
||||
type Ticker struct {
|
||||
C <-chan time.Time
|
||||
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
// Reset resets the ticker. If this is the first time calling, then a new timer
|
||||
// is created.
|
||||
func (t *Ticker) Reset(d time.Duration) {
|
||||
if t.ticker == nil {
|
||||
t.ticker = time.NewTicker(d)
|
||||
t.C = t.ticker.C
|
||||
} else {
|
||||
t.ticker.Reset(d)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the ticker. If the ticker has never been used, then it does
|
||||
// nothing.
|
||||
func (t *Ticker) Stop() {
|
||||
if t.ticker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.ticker.Stop()
|
||||
}
|
50
internal/lazytime/timer.go
Normal file
50
internal/lazytime/timer.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package lazytime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Timer struct {
|
||||
C <-chan time.Time
|
||||
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// Reset resets the timer by draining it and resetting the internal channel. If
|
||||
// this is the first time calling, then a new timer is created.
|
||||
func (t *Timer) Reset(d time.Duration) {
|
||||
if t.timer == nil {
|
||||
t.timer = time.NewTimer(d)
|
||||
t.C = t.timer.C
|
||||
return
|
||||
}
|
||||
|
||||
t.Stop()
|
||||
t.timer.Reset(d)
|
||||
}
|
||||
|
||||
// Stop stops the timer and drains it. If the timer has never been used, then it
|
||||
// does nothing.
|
||||
func (t *Timer) Stop() {
|
||||
if t.timer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.timer.Stop() {
|
||||
select {
|
||||
case <-t.timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait blocks until the timer fires or until the context expires.
|
||||
func (t *Timer) Wait(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -17,3 +17,17 @@ func (b *Bool) Set(val bool) {
|
|||
}
|
||||
atomic.StoreUint32(&b.val, x)
|
||||
}
|
||||
|
||||
func (b *Bool) SetTrue() {
|
||||
atomic.StoreUint32(&b.val, 1)
|
||||
}
|
||||
|
||||
func (b *Bool) SetFalse() {
|
||||
atomic.StoreUint32(&b.val, 0)
|
||||
}
|
||||
|
||||
// Acquire sets bool to true if it's false and returns true, otherwise returns
|
||||
// false.
|
||||
func (b *Bool) Acquire() bool {
|
||||
return atomic.CompareAndSwapUint32(&b.val, 0, 1)
|
||||
}
|
||||
|
|
|
@ -5,89 +5,64 @@ package session
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/handleloop"
|
||||
"github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json/option"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws/ophandler"
|
||||
)
|
||||
|
||||
// ErrMFA is returned if the account requires a 2FA code to log in.
|
||||
var ErrMFA = errors.New("account has 2FA enabled")
|
||||
|
||||
// Closed is an event that's sent to Session's command handler. This works by
|
||||
// using (*Gateway).AfterClose. If the user sets this callback, no Closed events
|
||||
// would be sent.
|
||||
//
|
||||
// Usage
|
||||
//
|
||||
// ses.AddHandler(func(*session.Closed) {})
|
||||
//
|
||||
type Closed struct {
|
||||
Error error
|
||||
}
|
||||
|
||||
// NewShardFunc creates a shard constructor for a session.
|
||||
func NewShardFunc(f func(m *shard.Manager, s *Session)) shard.NewShardFunc {
|
||||
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
|
||||
s := NewCustomShard(m, id)
|
||||
if f != nil {
|
||||
f(m, s)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewCustomShard creates a new session from the given shard manager and other
|
||||
// parameters.
|
||||
func NewCustomShard(m *shard.Manager, id *gateway.Identifier) *Session {
|
||||
return NewCustomSession(
|
||||
shard.NewGatewayShard(m, id),
|
||||
api.NewClient(id.Token),
|
||||
handler.New(),
|
||||
)
|
||||
}
|
||||
|
||||
// Session manages both the API and Gateway. As such, Session inherits all of
|
||||
// API's methods, as well has the Handler used for Gateway.
|
||||
type Session struct {
|
||||
*api.Client
|
||||
*gateway.Gateway
|
||||
|
||||
// Command handler with inherited methods.
|
||||
*handler.Handler
|
||||
|
||||
// internal state to not be copied around.
|
||||
looper *handleloop.Loop
|
||||
state *sessionState
|
||||
}
|
||||
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
|
||||
g, err := gateway.NewGatewayWithIntents(token, intents...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to connect to Gateway")
|
||||
type sessionState struct {
|
||||
sync.Mutex
|
||||
id gateway.Identifier
|
||||
gateway *gateway.Gateway
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
doneCh <-chan struct{}
|
||||
}
|
||||
|
||||
// NewWithIntents is similar to New but adds the given intents in during
|
||||
// construction.
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) *Session {
|
||||
var allIntent gateway.Intents
|
||||
for _, intent := range intents {
|
||||
allIntent |= intent
|
||||
}
|
||||
|
||||
return NewWithGateway(g), nil
|
||||
id := gateway.DefaultIdentifier(token)
|
||||
id.Intents = option.NewUint(uint(allIntent))
|
||||
|
||||
return NewWithIdentifier(id)
|
||||
}
|
||||
|
||||
// New creates a new session from a given token. Most bots should be using
|
||||
// NewWithIntents instead.
|
||||
func New(token string) (*Session, error) {
|
||||
// Create a gateway
|
||||
g, err := gateway.NewGateway(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to connect to Gateway")
|
||||
}
|
||||
|
||||
return NewWithGateway(g), nil
|
||||
func New(token string) *Session {
|
||||
return NewWithIdentifier(gateway.DefaultIdentifier(token))
|
||||
}
|
||||
|
||||
// Login tries to log in as a normal user account; MFA is optional.
|
||||
func Login(email, password, mfa string) (*Session, error) {
|
||||
func Login(ctx context.Context, email, password, mfa string) (*Session, error) {
|
||||
// Make a scratch HTTP client without a token
|
||||
client := api.NewClient("")
|
||||
client := api.NewClient("").WithContext(ctx)
|
||||
|
||||
// Try to login without TOTP
|
||||
l, err := client.Login(email, password)
|
||||
|
@ -97,7 +72,7 @@ func Login(email, password, mfa string) (*Session, error) {
|
|||
|
||||
if l.Token != "" && !l.MFA {
|
||||
// We got the token, return with a new Session.
|
||||
return New(l.Token)
|
||||
return New(l.Token), nil
|
||||
}
|
||||
|
||||
// Discord requests MFA, so we need the MFA token.
|
||||
|
@ -111,40 +86,117 @@ func Login(email, password, mfa string) (*Session, error) {
|
|||
return nil, errors.Wrap(err, "failed to login with 2FA")
|
||||
}
|
||||
|
||||
return New(l.Token)
|
||||
return New(l.Token), nil
|
||||
}
|
||||
|
||||
// NewWithGateway creates a new Session with the given Gateway.
|
||||
func NewWithGateway(gw *gateway.Gateway) *Session {
|
||||
return NewCustomSession(gw, api.NewClient(gw.Identifier.Token), handler.New())
|
||||
// NewWithIdentifier creates a bare Session with the given identifier.
|
||||
func NewWithIdentifier(id gateway.Identifier) *Session {
|
||||
return NewCustom(id, api.NewClient(id.Token), handler.New())
|
||||
}
|
||||
|
||||
// NewCustomSession constructs a bare Session from the given parameters.
|
||||
func NewCustomSession(gw *gateway.Gateway, cl *api.Client, h *handler.Handler) *Session {
|
||||
// NewWithGateway constructs a bare Session from the given UNOPENED gateway.
|
||||
func NewWithGateway(g *gateway.Gateway, h *handler.Handler) *Session {
|
||||
state := g.State()
|
||||
return &Session{
|
||||
Client: api.NewClient(state.Identifier.Token),
|
||||
Handler: h,
|
||||
state: &sessionState{
|
||||
gateway: g,
|
||||
id: state.Identifier,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewCustom constructs a bare Session from the given parameters.
|
||||
func NewCustom(id gateway.Identifier, cl *api.Client, h *handler.Handler) *Session {
|
||||
return &Session{
|
||||
Gateway: gw,
|
||||
Client: cl,
|
||||
Handler: h,
|
||||
looper: handleloop.NewLoop(h),
|
||||
state: &sessionState{id: id},
|
||||
}
|
||||
}
|
||||
|
||||
// AddIntents adds the given intents into the gateway. Calling it after Open has
|
||||
// already been called will result in a panic.
|
||||
func (s *Session) AddIntents(intents gateway.Intents) {
|
||||
s.state.Lock()
|
||||
|
||||
s.state.id.AddIntents(intents)
|
||||
|
||||
if s.state.gateway != nil {
|
||||
s.state.gateway.AddIntents(intents)
|
||||
}
|
||||
|
||||
s.state.Unlock()
|
||||
}
|
||||
|
||||
// HasIntents reports if the Gateway has the passed Intents.
|
||||
//
|
||||
// If no intents are set, e.g. if using a user account, HasIntents will always
|
||||
// return true.
|
||||
func (s *Session) HasIntents(intents gateway.Intents) bool {
|
||||
return s.state.id.HasIntents(intents)
|
||||
}
|
||||
|
||||
// Gateway returns the current session's gateway. If Open has never been called
|
||||
// or Session was never constructed with a gateway, then nil is returned.
|
||||
func (s *Session) Gateway() *gateway.Gateway {
|
||||
s.state.Lock()
|
||||
g := s.state.gateway
|
||||
s.state.Unlock()
|
||||
return g
|
||||
}
|
||||
|
||||
// Open opens the Discord gateway and its handler, then waits until either the
|
||||
// Ready or Resumed event gets through.
|
||||
func (s *Session) Open(ctx context.Context) error {
|
||||
// Start the handler beforehand so no events are missed.
|
||||
s.looper.Start(s.Gateway.Events)
|
||||
evCh := make(chan interface{})
|
||||
|
||||
// Set the AfterClose's handler.
|
||||
s.Gateway.AfterClose = func(err error) {
|
||||
s.Handler.Call(&Closed{
|
||||
Error: err,
|
||||
})
|
||||
s.state.Lock()
|
||||
defer s.state.Unlock()
|
||||
|
||||
if s.state.cancel != nil {
|
||||
if err := s.close(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.Gateway.Open(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to start gateway")
|
||||
if s.state.gateway == nil {
|
||||
g, err := gateway.NewWithIdentifier(ctx, s.state.id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.state.gateway = g
|
||||
}
|
||||
|
||||
return nil
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.state.ctx = ctx
|
||||
s.state.cancel = cancel
|
||||
|
||||
// TODO: change this to AddSyncHandler.
|
||||
rm := s.AddHandler(evCh)
|
||||
defer rm()
|
||||
|
||||
opCh := s.state.gateway.Connect(s.state.ctx)
|
||||
s.state.doneCh = ophandler.Loop(opCh, s.Handler)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.close(ctx)
|
||||
return ctx.Err()
|
||||
|
||||
case <-s.state.doneCh:
|
||||
// Event loop died.
|
||||
return s.state.gateway.LastError()
|
||||
|
||||
case ev := <-evCh:
|
||||
switch ev.(type) {
|
||||
case *gateway.ReadyEvent, *gateway.ResumedEvent:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext returns a shallow copy of Session with the context replaced in
|
||||
|
@ -160,21 +212,34 @@ func (s *Session) WithContext(ctx context.Context) *Session {
|
|||
}
|
||||
|
||||
// Close closes the underlying Websocket connection, invalidating the session
|
||||
// ID.
|
||||
//
|
||||
// It will send a closing frame before ending the connection, closing it
|
||||
// gracefully. This will cause the bot to appear as offline instantly.
|
||||
// ID. It will send a closing frame before ending the connection, closing it
|
||||
// gracefully. This will cause the bot to appear as offline instantly. To
|
||||
// prevent this behavior, change Gateway.AlwaysCloseGracefully.
|
||||
func (s *Session) Close() error {
|
||||
// Stop the event handler
|
||||
s.looper.Stop()
|
||||
return s.Gateway.Close()
|
||||
s.state.Lock()
|
||||
defer s.state.Unlock()
|
||||
|
||||
return s.close(context.Background())
|
||||
}
|
||||
|
||||
// Pause pauses the Gateway connection, by ending the connection without
|
||||
// sending a closing frame. This allows the connection to be resumed at a later
|
||||
// point.
|
||||
func (s *Session) Pause() error {
|
||||
// Stop the event handler
|
||||
s.looper.Stop()
|
||||
return s.Gateway.Pause()
|
||||
func (s *Session) close(ctx context.Context) error {
|
||||
if s.state.cancel == nil {
|
||||
return errors.New("Session is already closed")
|
||||
}
|
||||
|
||||
s.state.cancel()
|
||||
s.state.cancel = nil
|
||||
s.state.ctx = nil
|
||||
|
||||
// Wait until we've successfully disconnected.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "cannot wait for gateway exit")
|
||||
case <-s.state.doneCh:
|
||||
// ok
|
||||
}
|
||||
|
||||
s.state.doneCh = nil
|
||||
|
||||
return s.state.gateway.LastError()
|
||||
}
|
||||
|
|
50
session/session_test.go
Normal file
50
session/session_test.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
)
|
||||
|
||||
func TestSession(t *testing.T) {
|
||||
attempts := 1
|
||||
timeout := 15 * time.Second
|
||||
|
||||
if !testing.Short() {
|
||||
attempts = 5
|
||||
timeout = time.Minute // 5s-10s each reconnection
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
env := testenv.Must(t)
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent, 1)
|
||||
|
||||
s := NewWithIntents(env.BotToken, gateway.IntentGuilds)
|
||||
s.AddHandler(readyCh)
|
||||
|
||||
for i := 0; i < attempts; i++ {
|
||||
if err := s.Open(ctx); err != nil {
|
||||
t.Fatal("failed to open:", err)
|
||||
}
|
||||
|
||||
if ready, ok := <-readyCh; !ok {
|
||||
t.Fatal("ready not received")
|
||||
} else {
|
||||
now := time.Now()
|
||||
t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username)
|
||||
}
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatal("failed to close:", err)
|
||||
}
|
||||
|
||||
// Hold for an additional one second.
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
|
@ -61,7 +61,7 @@ type rescalingState struct {
|
|||
func NewManager(token string, fn NewShardFunc) (*Manager, error) {
|
||||
id := gateway.DefaultIdentifier(token)
|
||||
|
||||
url, err := updateIdentifier(context.Background(), id)
|
||||
url, err := updateIdentifier(context.Background(), &id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get gateway info")
|
||||
}
|
||||
|
@ -76,20 +76,20 @@ func NewManager(token string, fn NewShardFunc) (*Manager, error) {
|
|||
//
|
||||
// This function should rarely be used, since the shard information will be
|
||||
// queried from Discord if it's required to shard anyway.
|
||||
func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager, error) {
|
||||
func NewIdentifiedManager(idData gateway.IdentifyCommand, fn NewShardFunc) (*Manager, error) {
|
||||
// Ensure id.Shard is never nil.
|
||||
if data.Shard == nil {
|
||||
data.Shard = gateway.DefaultShard
|
||||
if idData.Shard == nil {
|
||||
idData.Shard = gateway.DefaultShard
|
||||
}
|
||||
|
||||
id := gateway.NewIdentifier(data)
|
||||
id := gateway.NewIdentifier(idData)
|
||||
|
||||
url, err := updateIdentifier(context.Background(), id)
|
||||
url, err := updateIdentifier(context.Background(), &id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get gateway info")
|
||||
}
|
||||
|
||||
id.Shard = data.Shard
|
||||
id.Shard = idData.Shard
|
||||
|
||||
return NewIdentifiedManagerWithURL(url, id, fn)
|
||||
}
|
||||
|
@ -97,7 +97,7 @@ func NewIdentifiedManager(data gateway.IdentifyData, fn NewShardFunc) (*Manager,
|
|||
// NewIdentifiedManagerWithURL creates a new Manager with the given Identifier
|
||||
// and gateway URL. It behaves similarly to NewIdentifiedManager.
|
||||
func NewIdentifiedManagerWithURL(
|
||||
url string, id *gateway.Identifier, fn NewShardFunc) (*Manager, error) {
|
||||
url string, id gateway.Identifier, fn NewShardFunc) (*Manager, error) {
|
||||
|
||||
m := Manager{
|
||||
gatewayURL: gateway.AddGatewayParams(url),
|
||||
|
@ -108,12 +108,12 @@ func NewIdentifiedManagerWithURL(
|
|||
var err error
|
||||
|
||||
for i := range m.shards {
|
||||
data := id.IdentifyData
|
||||
data := id.IdentifyCommand
|
||||
data.Shard = &gateway.Shard{i, len(m.shards)}
|
||||
|
||||
m.shards[i] = ShardState{
|
||||
ID: gateway.Identifier{
|
||||
IdentifyData: data,
|
||||
IdentifyCommand: data,
|
||||
IdentifyShortLimit: id.IdentifyShortLimit,
|
||||
IdentifyGlobalLimit: id.IdentifyGlobalLimit,
|
||||
},
|
||||
|
@ -263,10 +263,10 @@ func (m *Manager) rescale() {
|
|||
func (m *Manager) tryRescale(ctx context.Context) bool {
|
||||
m.mutex.Lock()
|
||||
|
||||
data := m.shards[0].ID.IdentifyData
|
||||
data := m.shards[0].ID.IdentifyCommand
|
||||
newID := gateway.NewIdentifier(data)
|
||||
|
||||
url, err := updateIdentifier(ctx, newID)
|
||||
url, err := updateIdentifier(ctx, &newID)
|
||||
if err != nil {
|
||||
m.mutex.Unlock()
|
||||
return false
|
||||
|
@ -282,12 +282,12 @@ func (m *Manager) tryRescale(ctx context.Context) bool {
|
|||
newShards := make([]ShardState, numShards)
|
||||
|
||||
for i := 0; i < numShards; i++ {
|
||||
data := newID.IdentifyData
|
||||
data := newID.IdentifyCommand
|
||||
data.Shard = &gateway.Shard{i, len(m.shards)}
|
||||
|
||||
newShards[i] = ShardState{
|
||||
ID: gateway.Identifier{
|
||||
IdentifyData: data,
|
||||
IdentifyCommand: data,
|
||||
IdentifyShortLimit: newID.IdentifyShortLimit,
|
||||
IdentifyGlobalLimit: newID.IdentifyGlobalLimit,
|
||||
},
|
|
@ -3,7 +3,10 @@ package shard
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -23,17 +26,14 @@ type Shard interface {
|
|||
// methods without deadlocking.
|
||||
type NewShardFunc func(m *Manager, id *gateway.Identifier) (Shard, error)
|
||||
|
||||
// NewGatewayShardFunc wraps around NewGatewayShard to be compatible with
|
||||
// NewShardFunc.
|
||||
var NewGatewayShardFunc NewShardFunc = func(m *Manager, id *gateway.Identifier) (Shard, error) {
|
||||
return NewGatewayShard(m, id), nil
|
||||
}
|
||||
|
||||
// NewGatewayShard creates a new gateway that's plugged into the shard manager.
|
||||
func NewGatewayShard(m *Manager, id *gateway.Identifier) *gateway.Gateway {
|
||||
gw := gateway.NewCustomIdentifiedGateway(m.GatewayURL(), id)
|
||||
gw.OnShardingRequired(m.Rescale)
|
||||
return gw
|
||||
// NewSessionShard creates a shard constructor for a session.
|
||||
// Accessing any shard and adding a handler will add a handler for all shards.
|
||||
func NewSessionShard(f func(m *Manager, s *session.Session)) NewShardFunc {
|
||||
return func(m *Manager, id *gateway.Identifier) (Shard, error) {
|
||||
s := session.NewCustom(*id, api.NewClient(id.Token), handler.New())
|
||||
f(m, s)
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ShardState wraps around the Gateway interface to provide additional state.
|
|
@ -1,4 +1,4 @@
|
|||
package session
|
||||
package shard
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -6,29 +6,28 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
)
|
||||
|
||||
func TestSharding(t *testing.T) {
|
||||
env := testenv.Must(t)
|
||||
|
||||
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
|
||||
data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken)
|
||||
data.Shard = &gateway.Shard{0, env.ShardCount}
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent)
|
||||
|
||||
m, err := shard.NewIdentifiedManager(data, NewShardFunc(
|
||||
func(m *shard.Manager, s *Session) {
|
||||
m, err := NewIdentifiedManager(data, NewSessionShard(
|
||||
func(m *Manager, s *session.Session) {
|
||||
now := time.Now().Format(time.StampMilli)
|
||||
t.Log(now, "initializing shard")
|
||||
|
||||
s.Gateway.ErrorLog = func(err error) {
|
||||
t.Error("gateway error:", err)
|
||||
}
|
||||
|
||||
s.AddIntents(gateway.IntentGuilds)
|
||||
s.AddHandler(readyCh)
|
||||
s.AddHandler(func(err error) {
|
||||
t.Error("unexpected error:", err)
|
||||
})
|
||||
},
|
||||
))
|
||||
if err != nil {
|
|
@ -6,10 +6,11 @@ import (
|
|||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/session/shard"
|
||||
"github.com/diamondburned/arikawa/v3/state/store"
|
||||
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
|
||||
"github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
|
@ -27,7 +28,8 @@ var (
|
|||
// The user should initialize handlers and intents in the opts function.
|
||||
func NewShardFunc(opts func(*shard.Manager, *State)) shard.NewShardFunc {
|
||||
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
|
||||
state := NewFromSession(session.NewCustomShard(m, id), defaultstore.New())
|
||||
sessn := session.NewCustom(*id, api.NewClient(id.Token), handler.New())
|
||||
state := NewFromSession(sessn, defaultstore.New())
|
||||
opts(m, state)
|
||||
return state, nil
|
||||
}
|
||||
|
@ -110,29 +112,21 @@ type State struct {
|
|||
}
|
||||
|
||||
// New creates a new state.
|
||||
func New(token string) (*State, error) {
|
||||
func New(token string) *State {
|
||||
return NewWithStore(token, defaultstore.New())
|
||||
}
|
||||
|
||||
// NewWithIntents creates a new state with the given gateway intents. For more
|
||||
// information, refer to gateway.Intents.
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
|
||||
s, err := session.NewWithIntents(token, intents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewFromSession(s, defaultstore.New()), nil
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) *State {
|
||||
s := session.NewWithIntents(token, intents...)
|
||||
return NewFromSession(s, defaultstore.New())
|
||||
}
|
||||
|
||||
// NewWithStore creates a new state with the given store cabinet.
|
||||
func NewWithStore(token string, cabinet *store.Cabinet) (*State, error) {
|
||||
s, err := session.New(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewFromSession(s, cabinet), nil
|
||||
func NewWithStore(token string, cabinet *store.Cabinet) *State {
|
||||
s := session.New(token)
|
||||
return NewFromSession(s, cabinet)
|
||||
}
|
||||
|
||||
// NewFromSession creates a new State from the passed Session and Cabinet.
|
||||
|
@ -239,11 +233,11 @@ func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (dis
|
|||
merr = store.ErrNotFound
|
||||
)
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
g, gerr = s.Cabinet.Guild(guildID)
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
|
||||
if s.HasIntents(gateway.IntentGuildMembers) {
|
||||
m, merr = s.Cabinet.Member(guildID, userID)
|
||||
}
|
||||
|
||||
|
@ -300,11 +294,11 @@ func (s *State) Permissions(
|
|||
merr = store.ErrNotFound
|
||||
)
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
g, gerr = s.Cabinet.Guild(ch.GuildID)
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
|
||||
if s.HasIntents(gateway.IntentGuildMembers) {
|
||||
m, merr = s.Cabinet.Member(ch.GuildID, userID)
|
||||
}
|
||||
|
||||
|
@ -372,7 +366,7 @@ func (s *State) Channel(id discord.ChannelID) (c *discord.Channel, err error) {
|
|||
}
|
||||
|
||||
func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err error) {
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
cs, err = s.Cabinet.Channels(guildID)
|
||||
if err == nil {
|
||||
return
|
||||
|
@ -384,7 +378,7 @@ func (s *State) Channels(guildID discord.GuildID) (cs []discord.Channel, err err
|
|||
return
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
for i := range cs {
|
||||
if err = s.Cabinet.ChannelSet(&cs[i], false); err != nil {
|
||||
return
|
||||
|
@ -436,7 +430,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) {
|
|||
func (s *State) Emoji(
|
||||
guildID discord.GuildID, emojiID discord.EmojiID) (e *discord.Emoji, err error) {
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
|
||||
if s.HasIntents(gateway.IntentGuildEmojis) {
|
||||
e, err = s.Cabinet.Emoji(guildID, emojiID)
|
||||
if err == nil {
|
||||
return
|
||||
|
@ -464,7 +458,7 @@ func (s *State) Emoji(
|
|||
}
|
||||
|
||||
func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error) {
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
|
||||
if s.HasIntents(gateway.IntentGuildEmojis) {
|
||||
es, err = s.Cabinet.Emojis(guildID)
|
||||
if err == nil {
|
||||
return
|
||||
|
@ -476,7 +470,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error)
|
|||
return
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildEmojis) {
|
||||
if s.HasIntents(gateway.IntentGuildEmojis) {
|
||||
err = s.Cabinet.EmojiSet(guildID, es, false)
|
||||
}
|
||||
|
||||
|
@ -486,7 +480,7 @@ func (s *State) Emojis(guildID discord.GuildID) (es []discord.Emoji, err error)
|
|||
////
|
||||
|
||||
func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
c, err := s.Cabinet.Guild(id)
|
||||
if err == nil {
|
||||
return c, nil
|
||||
|
@ -498,7 +492,7 @@ func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
|
|||
|
||||
// Guilds will only fill a maximum of 100 guilds from the API.
|
||||
func (s *State) Guilds() (gs []discord.Guild, err error) {
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
gs, err = s.Cabinet.Guilds()
|
||||
if err == nil {
|
||||
return
|
||||
|
@ -510,7 +504,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
for i := range gs {
|
||||
if err = s.Cabinet.GuildSet(&gs[i], false); err != nil {
|
||||
return
|
||||
|
@ -524,7 +518,7 @@ func (s *State) Guilds() (gs []discord.Guild, err error) {
|
|||
////
|
||||
|
||||
func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
|
||||
if s.HasIntents(gateway.IntentGuildMembers) {
|
||||
m, err := s.Cabinet.Member(guildID, userID)
|
||||
if err == nil {
|
||||
return m, nil
|
||||
|
@ -535,7 +529,7 @@ func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord
|
|||
}
|
||||
|
||||
func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error) {
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
|
||||
if s.HasIntents(gateway.IntentGuildMembers) {
|
||||
ms, err = s.Cabinet.Members(guildID)
|
||||
if err == nil {
|
||||
return
|
||||
|
@ -547,7 +541,7 @@ func (s *State) Members(guildID discord.GuildID) (ms []discord.Member, err error
|
|||
return
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuildMembers) {
|
||||
if s.HasIntents(gateway.IntentGuildMembers) {
|
||||
for i := range ms {
|
||||
if err = s.Cabinet.MemberSet(guildID, &ms[i], false); err != nil {
|
||||
return
|
||||
|
@ -580,7 +574,7 @@ func (s *State) Message(
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
c, cerr = s.Session.Channel(channelID)
|
||||
if cerr == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if cerr == nil && s.HasIntents(gateway.IntentGuilds) {
|
||||
cerr = s.Cabinet.ChannelSet(c, false)
|
||||
}
|
||||
|
||||
|
@ -706,13 +700,13 @@ func (s *State) Messages(channelID discord.ChannelID, limit uint) ([]discord.Mes
|
|||
// Presence checks the state for user presences. If no guildID is given, it
|
||||
// will look for the presence in all cached guilds.
|
||||
func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*discord.Presence, error) {
|
||||
if !s.Gateway.HasIntents(gateway.IntentGuildPresences) {
|
||||
if !s.HasIntents(gateway.IntentGuildPresences) {
|
||||
return nil, store.ErrNotFound
|
||||
}
|
||||
|
||||
// If there's no guild ID, look in all guilds
|
||||
if !gID.IsValid() {
|
||||
if !s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if !s.HasIntents(gateway.IntentGuilds) {
|
||||
return nil, store.ErrNotFound
|
||||
}
|
||||
|
||||
|
@ -736,7 +730,7 @@ func (s *State) Presence(gID discord.GuildID, uID discord.UserID) (*discord.Pres
|
|||
////
|
||||
|
||||
func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *discord.Role, err error) {
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
target, err = s.Cabinet.Role(guildID, roleID)
|
||||
if err == nil {
|
||||
return
|
||||
|
@ -754,7 +748,7 @@ func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (target *di
|
|||
target = &r
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
if err = s.RoleSet(guildID, &rs[i], false); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -779,7 +773,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if s.HasIntents(gateway.IntentGuilds) {
|
||||
for i := range rs {
|
||||
if err := s.RoleSet(guildID, &rs[i], false); err != nil {
|
||||
return rs, err
|
||||
|
@ -792,7 +786,7 @@ func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
|
|||
|
||||
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
|
||||
g, err = s.Session.Guild(id)
|
||||
if err == nil && s.Gateway.HasIntents(gateway.IntentGuilds) {
|
||||
if err == nil && s.HasIntents(gateway.IntentGuilds) {
|
||||
err = s.Cabinet.GuildSet(g, false)
|
||||
}
|
||||
|
||||
|
@ -801,7 +795,7 @@ func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
|
|||
|
||||
func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord.Member, err error) {
|
||||
m, err = s.Session.Member(gID, uID)
|
||||
if err == nil && s.Gateway.HasIntents(gateway.IntentGuildMembers) {
|
||||
if err == nil && s.HasIntents(gateway.IntentGuildMembers) {
|
||||
err = s.Cabinet.MemberSet(gID, m, false)
|
||||
}
|
||||
|
||||
|
@ -811,14 +805,12 @@ func (s *State) fetchMember(gID discord.GuildID, uID discord.UserID) (m *discord
|
|||
// tracksMessage reports whether the state would track the passed message and
|
||||
// messages from the same channel.
|
||||
func (s *State) tracksMessage(m *discord.Message) bool {
|
||||
return s.Gateway.Identifier.Intents == nil ||
|
||||
(m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuildMessages)) ||
|
||||
(!m.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentDirectMessages))
|
||||
return (m.GuildID.IsValid() && s.HasIntents(gateway.IntentGuildMessages)) ||
|
||||
(!m.GuildID.IsValid() && s.HasIntents(gateway.IntentDirectMessages))
|
||||
}
|
||||
|
||||
// tracksChannel reports whether the state would track the passed channel.
|
||||
func (s *State) tracksChannel(c *discord.Channel) bool {
|
||||
return s.Gateway.Identifier.Intents == nil ||
|
||||
(c.GuildID.IsValid() && s.Gateway.HasIntents(gateway.IntentGuilds)) ||
|
||||
return (c.GuildID.IsValid() && s.HasIntents(gateway.IntentGuilds)) ||
|
||||
!c.GuildID.IsValid()
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ func (s *State) onEvent(iface interface{}) {
|
|||
}
|
||||
|
||||
// Update available fields from ev into m
|
||||
ev.Update(m)
|
||||
ev.UpdateMember(m)
|
||||
|
||||
if err := s.Cabinet.MemberSet(ev.GuildID, m, true); err != nil {
|
||||
s.stateErr(err, "failed to update a member in state")
|
||||
|
|
32
state/state_example_test.go
Normal file
32
state/state_example_test.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package state_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
s := state.New("Bot " + os.Getenv("DISCORD_TOKEN"))
|
||||
s.AddIntents(gateway.IntentGuilds | gateway.IntentGuildMessages)
|
||||
s.AddHandler(func(m *gateway.MessageCreateEvent) {
|
||||
log.Printf("%s: %s", m.Author.Username, m.Content)
|
||||
})
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
if err := s.Open(ctx); err != nil {
|
||||
log.Println("cannot open:", err)
|
||||
}
|
||||
|
||||
<-ctx.Done() // block until Ctrl+C
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
log.Println("cannot close:", err)
|
||||
}
|
||||
}
|
|
@ -5,16 +5,24 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
"github.com/diamondburned/arikawa/v3/session/shard"
|
||||
)
|
||||
|
||||
func TestSharding(t *testing.T) {
|
||||
env := testenv.Must(t)
|
||||
|
||||
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
|
||||
data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken)
|
||||
data.Shard = &gateway.Shard{0, env.ShardCount}
|
||||
data.Presence = &gateway.UpdatePresenceCommand{
|
||||
Status: discord.DoNotDisturbStatus,
|
||||
Activities: []discord.Activity{{
|
||||
Name: "Testing shards...",
|
||||
Type: discord.CustomActivity,
|
||||
}},
|
||||
}
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent)
|
||||
|
||||
|
@ -23,19 +31,18 @@ func TestSharding(t *testing.T) {
|
|||
now := time.Now().Format(time.StampMilli)
|
||||
t.Log(now, "initializing shard")
|
||||
|
||||
s.Gateway.ErrorLog = func(err error) {
|
||||
t.Error("gateway error:", err)
|
||||
}
|
||||
|
||||
s.AddIntents(gateway.IntentGuilds)
|
||||
s.AddHandler(readyCh)
|
||||
s.AddSyncHandler(readyCh)
|
||||
s.AddSyncHandler(func(err error) {
|
||||
t.Log("background error:", err)
|
||||
})
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal("failed to make shard manager:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
|
|
|
@ -2,22 +2,29 @@ package bot
|
|||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
)
|
||||
|
||||
// eventIntents maps event pointer types to intents.
|
||||
var eventIntents = map[reflect.Type]gateway.Intents{}
|
||||
var (
|
||||
// eventIntents maps event pointer types to intents.
|
||||
eventIntents map[reflect.Type]gateway.Intents
|
||||
eventIntentsOnce sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
for event, intent := range gateway.EventIntents {
|
||||
fn, ok := gateway.EventCreator[event]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
eventIntents[reflect.TypeOf(fn())] = intent
|
||||
}
|
||||
func ensureEventIntents() {
|
||||
eventIntentsOnce.Do(func() {
|
||||
eventIntents = map[reflect.Type]gateway.Intents{}
|
||||
gateway.OpUnmarshalers.Each(func(_ ws.OpCode, t ws.EventType, f ws.OpFunc) bool {
|
||||
intent, ok := gateway.EventIntents[t]
|
||||
if ok {
|
||||
eventIntents[reflect.TypeOf(f())] = intent
|
||||
}
|
||||
return false
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type command struct {
|
||||
|
@ -44,6 +51,8 @@ func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, er
|
|||
|
||||
// intents returns the command's intents from the event.
|
||||
func (c *command) intents() gateway.Intents {
|
||||
ensureEventIntents()
|
||||
|
||||
intents, ok := eventIntents[c.event]
|
||||
if !ok {
|
||||
return 0
|
||||
|
|
|
@ -14,13 +14,13 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/api"
|
||||
"github.com/diamondburned/arikawa/v3/utils/bot/extras/shellwords"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/session/shard"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
"github.com/diamondburned/arikawa/v3/state/store"
|
||||
"github.com/diamondburned/arikawa/v3/state/store/defaultstore"
|
||||
"github.com/diamondburned/arikawa/v3/utils/bot/extras/shellwords"
|
||||
"github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
)
|
||||
|
||||
// Prefixer checks a message if it starts with the desired prefix. By default,
|
||||
|
@ -55,22 +55,15 @@ func NewShardFunc(fn func(*state.State) (*Context, error)) shard.NewShardFunc {
|
|||
panic("bot.NewShardFunc missing fn")
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
var cab *store.Cabinet
|
||||
|
||||
return func(m *shard.Manager, id *gateway.Identifier) (shard.Shard, error) {
|
||||
state := state.NewFromSession(session.NewCustomShard(m, id), nil)
|
||||
sessn := session.NewCustom(*id, api.NewClient(id.Token), handler.New())
|
||||
state := state.NewFromSession(sessn, defaultstore.New())
|
||||
|
||||
bot, err := fn(state)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create bot instance")
|
||||
}
|
||||
|
||||
if state.Cabinet == nil {
|
||||
once.Do(func() { cab = defaultstore.New() })
|
||||
state.Cabinet = cab
|
||||
}
|
||||
|
||||
return bot, nil
|
||||
}
|
||||
}
|
||||
|
@ -206,10 +199,6 @@ func Start(
|
|||
// fail api request if they (will) take up more than 5 minutes
|
||||
ctx.Client.Client.Timeout = 5 * time.Minute
|
||||
|
||||
ctx.Gateway.ErrorLog = func(err error) {
|
||||
ctx.ErrorLogger(err)
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
if err := opts(ctx); err != nil {
|
||||
return nil, err
|
||||
|
@ -317,7 +306,7 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
|
|||
// AddIntents adds the given Gateway Intent into the Gateway. This is a
|
||||
// convenient function that calls Gateway's AddIntent.
|
||||
func (ctx *Context) AddIntents(i gateway.Intents) {
|
||||
ctx.Gateway.AddIntents(i)
|
||||
ctx.Session.AddIntents(i)
|
||||
}
|
||||
|
||||
// Subcommands returns the slice of subcommands. To add subcommands, use
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/gateway/shard"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
"github.com/diamondburned/arikawa/v3/session/shard"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,7 @@ func (bot *shardedBot) OnReady(r *gateway.ReadyEvent) {
|
|||
func TestSharding(t *testing.T) {
|
||||
env := testenv.Must(t)
|
||||
|
||||
data := gateway.DefaultIdentifyData("Bot " + env.BotToken)
|
||||
data := gateway.DefaultIdentifyCommand("Bot " + env.BotToken)
|
||||
data.Shard = &gateway.Shard{0, env.ShardCount}
|
||||
|
||||
readyCh := make(chan *gateway.ReadyEvent)
|
||||
|
|
|
@ -4,27 +4,18 @@ import (
|
|||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/utils/bot"
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
"github.com/diamondburned/arikawa/v3/state/store"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json/option"
|
||||
"github.com/diamondburned/arikawa/v3/utils/bot"
|
||||
)
|
||||
|
||||
func TestAdminOnly(t *testing.T) {
|
||||
var ctx = &bot.Context{
|
||||
State: &state.State{
|
||||
Session: &session.Session{
|
||||
Gateway: &gateway.Gateway{
|
||||
Identifier: &gateway.Identifier{
|
||||
IdentifyData: gateway.IdentifyData{
|
||||
Intents: option.NewUint(uint(gateway.IntentGuilds | gateway.IntentGuildMembers)),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Session: session.New(""),
|
||||
Cabinet: mockCabinet(),
|
||||
},
|
||||
}
|
||||
|
@ -62,15 +53,7 @@ func TestAdminOnly(t *testing.T) {
|
|||
func TestGuildOnly(t *testing.T) {
|
||||
var ctx = &bot.Context{
|
||||
State: &state.State{
|
||||
Session: &session.Session{
|
||||
Gateway: &gateway.Gateway{
|
||||
Identifier: &gateway.Identifier{
|
||||
IdentifyData: gateway.IdentifyData{
|
||||
Intents: option.NewUint(uint(gateway.IntentGuilds)),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Session: session.New(""),
|
||||
Cabinet: mockCabinet(),
|
||||
},
|
||||
}
|
||||
|
|
172
utils/cmd/genevent/genevent.go
Normal file
172
utils/cmd/genevent/genevent.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"go/format"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
"unicode"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
pkg = "gateway"
|
||||
out = "-"
|
||||
)
|
||||
|
||||
type registry struct {
|
||||
PackageName string
|
||||
EventTypes []EventType
|
||||
}
|
||||
|
||||
type EventType struct {
|
||||
StructName string
|
||||
EventName string
|
||||
IsDispatch bool
|
||||
OpCode int
|
||||
}
|
||||
|
||||
func (t *EventType) MethodRecv() string {
|
||||
if len(t.StructName) == 0 {
|
||||
return "e"
|
||||
}
|
||||
return string(unicode.ToLower([]rune(t.StructName)[0]))
|
||||
}
|
||||
|
||||
//go:embed template.tmpl
|
||||
var packageTmpl string
|
||||
|
||||
var tmpl = template.Must(template.New("").Parse(packageTmpl))
|
||||
|
||||
const eventStructRegex = "(?m)" +
|
||||
`^// ([A-Za-z]+(?:Event|Command)) is (a dispatch event|an event|a command)` +
|
||||
`(?:` +
|
||||
` for ([A-Z_]+)` + "|" +
|
||||
` for Op (\d+)` +
|
||||
`)?` +
|
||||
`\.(?:.|\n)*?\ntype ([A-Za-z]+(?:Event|Command)) .*`
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&pkg, "p", pkg, "the package name to use")
|
||||
flag.StringVar(&out, "o", out, "output file, - for stdout")
|
||||
flag.Parse()
|
||||
|
||||
log.Println(eventStructRegex)
|
||||
|
||||
r := registry{
|
||||
PackageName: pkg,
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(".")
|
||||
if err != nil {
|
||||
log.Fatalln("failed to read current directory:", err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() || !strings.HasSuffix(file.Name(), ".go") {
|
||||
continue
|
||||
}
|
||||
if err := r.CrawlFile(file.Name()); err != nil {
|
||||
log.Fatalln("failed to crawl file:", err)
|
||||
}
|
||||
}
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
if err := tmpl.Execute(&buf, &r); err != nil {
|
||||
log.Fatalln("failed to execute template:", err)
|
||||
}
|
||||
|
||||
b, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
log.Fatalln("failed to fmt:", err)
|
||||
}
|
||||
|
||||
output := os.Stdout
|
||||
if out != "-" {
|
||||
f, err := os.Create(out)
|
||||
if err != nil {
|
||||
log.Fatalln("failed to create output:", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
output = f
|
||||
}
|
||||
|
||||
if _, err := output.Write(b); err != nil {
|
||||
log.Fatalln("failed to write rendered:", err)
|
||||
}
|
||||
}
|
||||
|
||||
var reEventStruct = regexp.MustCompile(eventStructRegex)
|
||||
|
||||
func (r *registry) CrawlFile(name string) error {
|
||||
f, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read file")
|
||||
}
|
||||
|
||||
for _, match := range reEventStruct.FindAllSubmatch(f, -1) {
|
||||
// Validity check.
|
||||
if string(match[1]) != string(match[5]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(string(match[1]), "Command") && string(match[2]) != "a command" {
|
||||
log.Println(string(match[1]), "has invalid comment %q", string(match[2]))
|
||||
continue
|
||||
}
|
||||
|
||||
t := EventType{
|
||||
StructName: string(match[1]),
|
||||
EventName: string(match[3]),
|
||||
IsDispatch: string(match[2]) == "a dispatch event",
|
||||
OpCode: -1,
|
||||
}
|
||||
|
||||
if op := string(match[4]); op != "" && !t.IsDispatch {
|
||||
i, err := strconv.Atoi(op)
|
||||
if err != nil {
|
||||
log.Printf("error at struct %s: error parsing Op %v", t.StructName, err)
|
||||
}
|
||||
t.OpCode = i
|
||||
}
|
||||
|
||||
if t.IsDispatch && t.EventName == "" {
|
||||
t.EventName = guessEventName(t.StructName)
|
||||
}
|
||||
|
||||
r.EventTypes = append(r.EventTypes, t)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func guessEventName(structName string) string {
|
||||
name := strings.TrimSuffix(structName, "Event")
|
||||
|
||||
var newName strings.Builder
|
||||
newName.Grow(len(name) * 2)
|
||||
|
||||
for i, r := range name {
|
||||
if unicode.IsLower(r) {
|
||||
newName.WriteRune(unicode.ToUpper(r))
|
||||
continue
|
||||
}
|
||||
|
||||
if i > 0 {
|
||||
newName.WriteByte('_')
|
||||
}
|
||||
|
||||
newName.WriteRune(r)
|
||||
}
|
||||
|
||||
return newName.String()
|
||||
}
|
27
utils/cmd/genevent/template.tmpl
Normal file
27
utils/cmd/genevent/template.tmpl
Normal file
|
@ -0,0 +1,27 @@
|
|||
// Code generated by genevent. DO NOT EDIT.
|
||||
|
||||
package {{ .PackageName }}
|
||||
|
||||
import "github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
|
||||
func init() {
|
||||
OpUnmarshalers.Add(
|
||||
{{ range .EventTypes -}}
|
||||
func() ws.Event { return new({{ .StructName }}) },
|
||||
{{ end -}}
|
||||
)
|
||||
}
|
||||
|
||||
{{ range .EventTypes }}
|
||||
|
||||
{{ if .IsDispatch }}
|
||||
// Op implements Event. It always returns 0.
|
||||
func (*{{ .StructName }}) Op() ws.OpCode { return dispatchOp }
|
||||
{{ else if (gt .OpCode -1) }}
|
||||
// Op implements Event. It always returns Op {{ .OpCode }}.
|
||||
func (*{{ .StructName }}) Op() ws.OpCode { return {{ .OpCode }} }
|
||||
{{ end }}
|
||||
|
||||
// EventType implements Event.
|
||||
func (*{{ .StructName }}) EventType() ws.EventType { return "{{ .EventName }}" }
|
||||
{{ end }}
|
|
@ -239,6 +239,7 @@ type handler struct {
|
|||
chanclose reflect.Value // IsValid() if chan
|
||||
isIface bool
|
||||
isSync bool
|
||||
isOnce bool
|
||||
}
|
||||
|
||||
// newHandler reflects either a channel or a function into a handler. A function
|
||||
|
|
101
utils/ws/codec.go
Normal file
101
utils/ws/codec.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Codec holds the codec states for Websocket implementations to share with the
|
||||
// manager. It is used internally in the Websocket and the Connection
|
||||
// implementation.
|
||||
type Codec struct {
|
||||
Unmarshalers OpUnmarshalers
|
||||
Headers http.Header
|
||||
}
|
||||
|
||||
// NewCodec creates a new default Codec instance.
|
||||
func NewCodec(unmarshalers OpUnmarshalers) Codec {
|
||||
return Codec{
|
||||
Unmarshalers: unmarshalers,
|
||||
Headers: http.Header{
|
||||
"Accept-Encoding": {"zlib"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type codecOp struct {
|
||||
Op
|
||||
Data json.Raw `json:"d,omitempty"`
|
||||
}
|
||||
|
||||
const maxSharedBufferSize = 1 << 15 // 32KB
|
||||
|
||||
// DecodeBuffer boxes a byte slice to provide a shared and thread-unsafe buffer.
|
||||
// It is used internally and should only be handled around as an opaque thing.
|
||||
type DecodeBuffer struct {
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// NewDecodeBuffer creates a new preallocated DecodeBuffer.
|
||||
func NewDecodeBuffer(cap int) DecodeBuffer {
|
||||
if cap > maxSharedBufferSize {
|
||||
cap = maxSharedBufferSize
|
||||
}
|
||||
|
||||
return DecodeBuffer{
|
||||
buf: make([]byte, 0, cap),
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeFrom reads the given reader and decodes it into an Op.
|
||||
//
|
||||
// buf is optional.
|
||||
func (c Codec) DecodeFrom(r io.Reader, buf *DecodeBuffer) Op {
|
||||
var op codecOp
|
||||
op.Data = json.Raw(buf.buf)
|
||||
|
||||
if err := json.DecodeStream(r, &op); err != nil {
|
||||
return newErrOp(err, "cannot read JSON stream")
|
||||
}
|
||||
|
||||
// buf isn't grown from here out. Set it back right now. If Data hasn't been
|
||||
// grown, then this will just set buf back to what it was.
|
||||
if cap(op.Data) < maxSharedBufferSize {
|
||||
buf.buf = op.Data[:0]
|
||||
}
|
||||
|
||||
fn := c.Unmarshalers.Lookup(op.Code, op.Type)
|
||||
if fn == nil {
|
||||
err := UnknownEventError{
|
||||
Op: op.Code,
|
||||
Type: op.Type,
|
||||
}
|
||||
return newErrOp(err, "")
|
||||
}
|
||||
|
||||
op.Op.Data = fn()
|
||||
if err := op.Data.UnmarshalTo(op.Op.Data); err != nil {
|
||||
return newErrOp(err, "cannot unmarshal JSON data from gateway")
|
||||
}
|
||||
|
||||
return op.Op
|
||||
}
|
||||
|
||||
func newErrOp(err error, wrap string) Op {
|
||||
if wrap != "" {
|
||||
err = errors.Wrap(err, wrap)
|
||||
}
|
||||
|
||||
ev := &BackgroundErrorEvent{
|
||||
Err: err,
|
||||
}
|
||||
|
||||
return Op{
|
||||
Code: ev.Op(),
|
||||
Type: ev.EventType(),
|
||||
Data: ev,
|
||||
}
|
||||
}
|
278
utils/ws/conn.go
Normal file
278
utils/ws/conn.go
Normal file
|
@ -0,0 +1,278 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"compress/zlib"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const rwBufferSize = 1 << 15 // 32KB
|
||||
|
||||
// ErrWebsocketClosed is returned if the websocket is already closed.
|
||||
var ErrWebsocketClosed = errors.New("websocket is closed")
|
||||
|
||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||
// This connection expects the driver to handle compression by itself, including
|
||||
// modifying the connection URL. The implementation doesn't have to be safe for
|
||||
// concurrent use.
|
||||
type Connection interface {
|
||||
// Dial dials the address (string). Context needs to be passed in for
|
||||
// timeout. This method should also be re-usable after Close is called.
|
||||
Dial(context.Context, string) (<-chan Op, error)
|
||||
|
||||
// Send allows the caller to send bytes.
|
||||
Send(context.Context, []byte) error
|
||||
|
||||
// Close should close the websocket connection. The underlying connection
|
||||
// may be reused, but this Connection instance will be reused with Dial. The
|
||||
// Connection must still be reusable even if Close returns an error. If
|
||||
// gracefully is true, then the implementation must send a close frame
|
||||
// prior.
|
||||
Close(gracefully bool) error
|
||||
}
|
||||
|
||||
// Conn is the default Websocket connection. It tries to compresses all payloads
|
||||
// using zlib.
|
||||
type Conn struct {
|
||||
dialer websocket.Dialer
|
||||
codec Codec
|
||||
|
||||
// conn is used for synchronizing the conn instance itself. Any use of conn
|
||||
// must copy conn out.
|
||||
conn *connMutex
|
||||
// mut is used for synchronizing the conn field.
|
||||
mut sync.Mutex
|
||||
|
||||
// CloseTimeout is the timeout for graceful closing. It's defaulted to 5s.
|
||||
CloseTimeout time.Duration
|
||||
}
|
||||
|
||||
type connMutex struct {
|
||||
wrmut chan struct{}
|
||||
*websocket.Conn
|
||||
}
|
||||
|
||||
var _ Connection = (*Conn)(nil)
|
||||
|
||||
// NewConn creates a new default websocket connection with a default dialer.
|
||||
func NewConn(codec Codec) *Conn {
|
||||
return NewConnWithDialer(codec, websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
ReadBufferSize: rwBufferSize,
|
||||
WriteBufferSize: rwBufferSize,
|
||||
EnableCompression: true,
|
||||
})
|
||||
}
|
||||
|
||||
// NewConnWithDialer creates a new default websocket connection with a custom
|
||||
// dialer.
|
||||
func NewConnWithDialer(codec Codec, dialer websocket.Dialer) *Conn {
|
||||
return &Conn{
|
||||
dialer: dialer,
|
||||
codec: codec,
|
||||
CloseTimeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial starts a new connection and returns the listening channel for it. If the
|
||||
// websocket is already dialed, then the connection is closed first.
|
||||
func (c *Conn) Dial(ctx context.Context, addr string) (<-chan Op, error) {
|
||||
// BUG which prevents stream compression.
|
||||
// See https://github.com/golang/go/issues/31514.
|
||||
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
// Ensure that the connection is already closed.
|
||||
if c.conn != nil {
|
||||
c.conn.close(c.CloseTimeout, false)
|
||||
}
|
||||
|
||||
conn, _, err := c.dialer.DialContext(ctx, addr, c.codec.Headers)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to dial WS")
|
||||
}
|
||||
|
||||
events := make(chan Op, 1)
|
||||
go readLoop(conn, c.codec, events)
|
||||
|
||||
c.conn = &connMutex{
|
||||
wrmut: make(chan struct{}, 1),
|
||||
Conn: conn,
|
||||
}
|
||||
|
||||
return events, err
|
||||
}
|
||||
|
||||
// Close implements Connection.
|
||||
func (c *Conn) Close(gracefully bool) error {
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
return c.conn.close(c.CloseTimeout, gracefully)
|
||||
}
|
||||
|
||||
func (c *connMutex) close(timeout time.Duration, gracefully bool) error {
|
||||
if c == nil || c.Conn == nil {
|
||||
WSDebug("Conn: Close is called on already closed connection")
|
||||
return ErrWebsocketClosed
|
||||
}
|
||||
|
||||
WSDebug("Conn: Close is called; shutting down the Websocket connection.")
|
||||
|
||||
if gracefully {
|
||||
// Have a deadline before closing.
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case c.wrmut <- struct{}{}:
|
||||
// Lock acquired. We can now safely set the deadline and write.
|
||||
c.SetWriteDeadline(deadline)
|
||||
|
||||
WSDebug("Conn: Graceful closing requested, sending close frame.")
|
||||
|
||||
if err := c.WriteMessage(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
||||
); err != nil {
|
||||
WSError(err)
|
||||
}
|
||||
|
||||
// Release the lock.
|
||||
<-c.wrmut
|
||||
|
||||
case <-ctx.Done():
|
||||
// We couldn't acquire the lock. Resort to just closing the
|
||||
// connection directly.
|
||||
}
|
||||
}
|
||||
|
||||
// Close the WS.
|
||||
err := c.Conn.Close()
|
||||
|
||||
if err != nil {
|
||||
WSDebug("Conn: Websocket closed; error:", err)
|
||||
} else {
|
||||
WSDebug("Conn: Websocket closed successfully")
|
||||
}
|
||||
|
||||
c.Conn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// resetDeadline is used to reset the write deadline after using the context's.
|
||||
var resetDeadline = time.Time{}
|
||||
|
||||
// Send implements Connection.
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
c.mut.Lock()
|
||||
conn := c.conn
|
||||
c.mut.Unlock()
|
||||
|
||||
select {
|
||||
case conn.wrmut <- struct{}{}:
|
||||
defer func() { <-conn.wrmut }()
|
||||
|
||||
if ctx != context.Background() {
|
||||
d, ok := ctx.Deadline()
|
||||
if ok {
|
||||
conn.SetWriteDeadline(d)
|
||||
defer conn.SetWriteDeadline(resetDeadline)
|
||||
}
|
||||
}
|
||||
|
||||
return conn.WriteMessage(websocket.TextMessage, b)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// loopState is a thread-unsafe disposable state container for the read loop.
|
||||
// It's made to completely separate the read loop of any synchronization that
|
||||
// doesn't involve the websocket connection itself.
|
||||
type loopState struct {
|
||||
conn *websocket.Conn
|
||||
codec Codec
|
||||
zlib io.ReadCloser
|
||||
buf DecodeBuffer
|
||||
}
|
||||
|
||||
func readLoop(conn *websocket.Conn, codec Codec, opCh chan<- Op) {
|
||||
// Clean up the events channel in the end.
|
||||
defer close(opCh)
|
||||
|
||||
// Allocate the read loop its own private resources.
|
||||
state := loopState{
|
||||
conn: conn,
|
||||
codec: codec,
|
||||
buf: NewDecodeBuffer(1 << 14), // 16KB
|
||||
}
|
||||
|
||||
for {
|
||||
b, err := state.handle()
|
||||
if err != nil {
|
||||
WSDebug("Conn: fatal Conn error:", err)
|
||||
|
||||
closeEv := &CloseEvent{
|
||||
Err: err,
|
||||
Code: -1,
|
||||
}
|
||||
|
||||
var closeErr *websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeEv.Code = closeErr.Code
|
||||
closeEv.Err = fmt.Errorf("%d %s", closeErr.Code, closeErr.Text)
|
||||
}
|
||||
|
||||
opCh <- Op{
|
||||
Code: closeEv.Op(),
|
||||
Type: closeEv.EventType(),
|
||||
Data: closeEv,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
opCh <- b
|
||||
}
|
||||
}
|
||||
|
||||
func (state *loopState) handle() (Op, error) {
|
||||
// skip message type
|
||||
t, r, err := state.conn.NextReader()
|
||||
if err != nil {
|
||||
return Op{}, err
|
||||
}
|
||||
|
||||
if t == websocket.BinaryMessage {
|
||||
// Probably a zlib payload.
|
||||
|
||||
if state.zlib == nil {
|
||||
z, err := zlib.NewReader(r)
|
||||
if err != nil {
|
||||
return Op{}, errors.Wrap(err, "failed to create a zlib reader")
|
||||
}
|
||||
state.zlib = z
|
||||
} else {
|
||||
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
||||
return Op{}, errors.Wrap(err, "failed to reset zlib reader")
|
||||
}
|
||||
}
|
||||
|
||||
defer state.zlib.Close()
|
||||
r = state.zlib
|
||||
}
|
||||
|
||||
return state.codec.DecodeFrom(r, &state.buf), nil
|
||||
}
|
364
utils/ws/gateway.go
Normal file
364
utils/ws/gateway.go
Normal file
|
@ -0,0 +1,364 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/internal/lazytime"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ConnectionError is given to the user if the gateway fails to connect to the
|
||||
// gateway for any reason, including during an initial connection or a
|
||||
// reconnection. To check for this error, use the errors.As function.
|
||||
type ConnectionError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Unwrap unwraps the ConnectionError.
|
||||
func (err ConnectionError) Unwrap() error { return err.Err }
|
||||
|
||||
// Error formats the error.
|
||||
func (err ConnectionError) Error() string {
|
||||
return fmt.Sprintf("error reconnecting: %s", err.Err)
|
||||
}
|
||||
|
||||
// BackgroundErrorEvent describes an error that the gateway event loop might
|
||||
// stumble upon while it's running. See Gateway's documentation for possible
|
||||
// usages.
|
||||
type BackgroundErrorEvent struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
var _ Event = (*BackgroundErrorEvent)(nil)
|
||||
|
||||
// Unwrap returns err.Err.
|
||||
func (err *BackgroundErrorEvent) Unwrap() error { return err.Err }
|
||||
|
||||
// Error formats the BackgroundErrorEvent.
|
||||
func (err *BackgroundErrorEvent) Error() string {
|
||||
return "background gateway error: " + err.Err.Error()
|
||||
}
|
||||
|
||||
// Op implements Op. It returns -1.
|
||||
func (err *BackgroundErrorEvent) Op() OpCode { return -1 }
|
||||
|
||||
// EventType implements Op. It returns an opaque unique string.
|
||||
func (err *BackgroundErrorEvent) EventType() EventType {
|
||||
return "__ws.BackgroundErrorEvent"
|
||||
}
|
||||
|
||||
// GatewayOpts describes the gateway event loop options.
|
||||
type GatewayOpts struct {
|
||||
// ReconnectDelay determines the duration to idle after each failed retry.
|
||||
// This can be used to implement exponential backoff. The default is already
|
||||
// sane, so this field rarely needs to be changed.
|
||||
ReconnectDelay func(try int) time.Duration
|
||||
|
||||
// FatalCloseCodes is a list of close codes that will cause the gateway to
|
||||
// exit out if it stumbles on one of these. It is a copy of FatalCloseCodes
|
||||
// (the global variable) by default.
|
||||
FatalCloseCodes []int
|
||||
|
||||
// DialTimeout is the timeout to wait for each websocket dial before failing
|
||||
// it and retrying. Default is 0.
|
||||
DialTimeout time.Duration
|
||||
|
||||
// ReconnectAttempt is the maximum number of attempts made to Reconnect
|
||||
// before aborting the whole gateway. If this set to 0, unlimited attempts
|
||||
// will be made. Default is 0.
|
||||
ReconnectAttempt int
|
||||
|
||||
// AlwaysCloseGracefully, if true, will always make the Gateway close
|
||||
// gracefully once the context given to Open is cancelled. It governs the
|
||||
// Close behavior. The default is true.
|
||||
AlwaysCloseGracefully bool
|
||||
}
|
||||
|
||||
// DefaultGatewayOpts is the default event loop options.
|
||||
var DefaultGatewayOpts = GatewayOpts{
|
||||
ReconnectDelay: func(try int) time.Duration {
|
||||
// minimum 4 seconds
|
||||
return time.Duration(4+(2*try)) * time.Second
|
||||
},
|
||||
DialTimeout: 0,
|
||||
ReconnectAttempt: 0,
|
||||
AlwaysCloseGracefully: true,
|
||||
}
|
||||
|
||||
// Gateway describes an instance that handles the Discord gateway. It is
|
||||
// basically an abstracted concurrent event loop that the user could signal to
|
||||
// start connecting to the Discord gateway server.
|
||||
type Gateway struct {
|
||||
ws *Websocket
|
||||
|
||||
reconnect chan struct{}
|
||||
heart lazytime.Ticker
|
||||
srcOp <-chan Op // from WS
|
||||
outer outerState
|
||||
lastError error
|
||||
|
||||
opts GatewayOpts
|
||||
}
|
||||
|
||||
// outerState holds gateway state that the caller may change concurrently. As
|
||||
// such, it holds a mutex to allow that. The main purpose of this
|
||||
// synchronization is to allow the caller to use the gateway while the event
|
||||
// loop is still running without having the event loop muddle in without locking
|
||||
// properly. For example, opCh is given to the event loop as a copy; the event
|
||||
// loop must never access the outerState directly.
|
||||
type outerState struct {
|
||||
sync.Mutex
|
||||
ch chan Op
|
||||
started bool
|
||||
}
|
||||
|
||||
// Handler describes a gateway handler. It describes the core that governs the
|
||||
// behavior of the gateway event loop.
|
||||
type Handler interface {
|
||||
// OnOp is called by the gateway event loop on every new Op. If the returned
|
||||
// boolean is false, then the loop fatally exits.
|
||||
OnOp(context.Context, Op) (canContinue bool)
|
||||
// SendHeartbeat is called by the gateway event loop everytime a heartbeat
|
||||
// needs to be sent over.
|
||||
SendHeartbeat(context.Context)
|
||||
// Close closes the handler.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// NewGateway creates a new Gateway with a custom gateway URL and a pre-existing
|
||||
// Identifier. If opts is nil, then DefaultOpts is used.
|
||||
func NewGateway(ws *Websocket, opts *GatewayOpts) *Gateway {
|
||||
if opts == nil {
|
||||
opts = &DefaultGatewayOpts
|
||||
}
|
||||
|
||||
return &Gateway{
|
||||
ws: ws,
|
||||
opts: *opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Send is a function to send an Op payload to the Gateway.
|
||||
func (g *Gateway) Send(ctx context.Context, data Event) error {
|
||||
op := Op{
|
||||
Code: data.Op(),
|
||||
Type: data.EventType(),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
b, err := json.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to encode payload")
|
||||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return g.ws.Send(ctx, b)
|
||||
}
|
||||
|
||||
// HasStarted returns true if the gateway event loop is currently spinning.
|
||||
func (g *Gateway) HasStarted() bool {
|
||||
g.outer.Lock()
|
||||
defer g.outer.Unlock()
|
||||
|
||||
return g.outer.started
|
||||
}
|
||||
|
||||
// AssertIsNotRunning asserts that the gateway is currently not running. If the
|
||||
// gateway is running, the method will panic. Since a gateway cannot be started
|
||||
// back up, this method can be used to detect whether or not the caller in a
|
||||
// single goroutine can read the state safely.
|
||||
func (g *Gateway) AssertIsNotRunning() {
|
||||
g.outer.Lock()
|
||||
defer g.outer.Unlock()
|
||||
|
||||
if !g.outer.started {
|
||||
return
|
||||
}
|
||||
|
||||
// Hack to ensure that the event channel is closed.
|
||||
select {
|
||||
case _, ok := <-g.outer.ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// The panic behavior is a must, because if this branch is hit, then
|
||||
// we've actually stolen an event from the channel unexpectedly, putting
|
||||
// the event loop under a weird state.
|
||||
//
|
||||
// An alternative solution to this bug would be to mutex-guard the error
|
||||
// field, but the purpose of this method isn't to be called before the
|
||||
// gateway has been stopped.
|
||||
panic("ws: Error called while Gateway is still running")
|
||||
default:
|
||||
panic("ws: Error called while Gateway is still running")
|
||||
}
|
||||
}
|
||||
|
||||
// Connect starts the background goroutine that tries its best to maintain a
|
||||
// stable connection to the Websocket gateway. To the user, the gateway should
|
||||
// appear to be working seamlessly.
|
||||
//
|
||||
// For more documentation, refer to (*gateway.Gateway).Connect.
|
||||
func (g *Gateway) Connect(ctx context.Context, h Handler) <-chan Op {
|
||||
g.outer.Lock()
|
||||
defer g.outer.Unlock()
|
||||
|
||||
if !g.outer.started {
|
||||
g.outer.started = true
|
||||
g.outer.ch = make(chan Op, 1)
|
||||
go g.spin(ctx, h)
|
||||
}
|
||||
|
||||
return g.outer.ch
|
||||
}
|
||||
|
||||
// LastError returns the last error that the gateway has received.
|
||||
func (g *Gateway) LastError() error {
|
||||
g.AssertIsNotRunning()
|
||||
return g.lastError
|
||||
}
|
||||
|
||||
// finalize closes the gateway permanently.
|
||||
func (g *Gateway) finalize(h Handler) {
|
||||
var err error
|
||||
|
||||
if g.opts.AlwaysCloseGracefully {
|
||||
err = g.ws.CloseGracefully()
|
||||
} else {
|
||||
err = g.ws.Close()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
g.SendErrorWrap(err, "failed to finalize websocket")
|
||||
}
|
||||
|
||||
if err := h.Close(); err != nil {
|
||||
g.SendError(err)
|
||||
}
|
||||
|
||||
g.outer.Lock()
|
||||
close(g.outer.ch)
|
||||
g.outer.started = false
|
||||
g.outer.Unlock()
|
||||
}
|
||||
|
||||
// QueueReconnect queues a reconnection in the gateway loop. This method should
|
||||
// only be called in the event loop ONCE; calling more than once will deadlock
|
||||
// the loop.
|
||||
func (g *Gateway) QueueReconnect() {
|
||||
select {
|
||||
case g.reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
g.heart.Stop()
|
||||
}
|
||||
|
||||
// ResetHeartbeat resets the heartbeat to be the given duration.
|
||||
func (g *Gateway) ResetHeartbeat(d time.Duration) {
|
||||
g.heart.Reset(d)
|
||||
}
|
||||
|
||||
// SendError sends the given error wrapped in a BackgroundErrorEvent into the
|
||||
// event channel.
|
||||
func (g *Gateway) SendError(err error) {
|
||||
event := &BackgroundErrorEvent{err}
|
||||
|
||||
g.outer.ch <- Op{
|
||||
Code: event.Op(),
|
||||
Type: event.EventType(),
|
||||
Data: event,
|
||||
}
|
||||
g.lastError = err
|
||||
}
|
||||
|
||||
// SendErrorWrap is a convenient function over SendError.
|
||||
func (g *Gateway) SendErrorWrap(err error, message string) {
|
||||
g.SendError(errors.Wrap(err, message))
|
||||
}
|
||||
|
||||
func (g *Gateway) spin(ctx context.Context, h Handler) {
|
||||
// Always close the event channel once we exit.
|
||||
defer g.finalize(h)
|
||||
|
||||
var retryTimer lazytime.Timer
|
||||
defer retryTimer.Stop()
|
||||
|
||||
g.reconnect = make(chan struct{}, 1)
|
||||
g.reconnect <- struct{}{}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case op, ok := <-g.srcOp:
|
||||
if !ok {
|
||||
// Skip zero-value Ops that may happen on gateway closure.
|
||||
continue
|
||||
}
|
||||
|
||||
switch data := op.Data.(type) {
|
||||
case *CloseEvent:
|
||||
for _, code := range g.opts.FatalCloseCodes {
|
||||
if code == data.Code {
|
||||
// Don't wrap the error, but instead, just pipe it as-is
|
||||
// through the channel.
|
||||
g.outer.ch <- op
|
||||
g.lastError = data
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ok = h.OnOp(ctx, op)
|
||||
g.outer.ch <- op
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
case <-g.heart.C:
|
||||
h.SendHeartbeat(ctx)
|
||||
|
||||
case <-g.reconnect:
|
||||
// Close the previous connection if it's not already. Ignore the
|
||||
// already closed error.
|
||||
if err := g.ws.Close(); err != nil && !errors.Is(err, ErrWebsocketClosed) {
|
||||
g.SendErrorWrap(err, "error closing before reconnecting")
|
||||
}
|
||||
|
||||
// Invalidate our srcOp.
|
||||
g.srcOp = nil
|
||||
|
||||
// Keep track of the last error for notifying.
|
||||
var err error
|
||||
|
||||
for try := 0; g.opts.ReconnectAttempt == 0 || try < g.opts.ReconnectAttempt; try++ {
|
||||
g.srcOp, err = g.ws.Dial(ctx)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Signal an error before retrying.
|
||||
g.SendError(ConnectionError{err})
|
||||
|
||||
retryTimer.Reset(g.opts.ReconnectDelay(try))
|
||||
if err := retryTimer.Wait(ctx); err != nil {
|
||||
g.SendError(ConnectionError{ctx.Err()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that we've reconnected successfully. Exit otherwise.
|
||||
if g.srcOp == nil {
|
||||
err = errors.Wrap(err, "failed to reconnect after max attempts")
|
||||
g.SendError(ConnectionError{err})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
212
utils/ws/op.go
Normal file
212
utils/ws/op.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// OpCode is the type for websocket Op codes. Op codes less than 0 are
|
||||
// internal Op codes and should usually be ignored.
|
||||
type OpCode int
|
||||
|
||||
// CloseEvent is an event that is given from wsutil when the websocket is
|
||||
// closed.
|
||||
type CloseEvent struct {
|
||||
// Err is the underlying error.
|
||||
Err error
|
||||
// Code is the websocket close code, if any.
|
||||
Code int
|
||||
}
|
||||
|
||||
// Unwrap returns err.Err.
|
||||
func (e *CloseEvent) Unwrap() error { return e.Err }
|
||||
|
||||
// Error formats the CloseEvent. A CloseEvent is also an error.
|
||||
func (e *CloseEvent) Error() string {
|
||||
return fmt.Sprintf("websocket closed, reason: %s", e.Err)
|
||||
}
|
||||
|
||||
// Op implements Event. It returns -1.
|
||||
func (e *CloseEvent) Op() OpCode { return -1 }
|
||||
|
||||
// EventType implements Event. It returns an emty string.
|
||||
func (e *CloseEvent) EventType() EventType { return "__ws.CloseEvent" }
|
||||
|
||||
// EventType is a type for event types, which is the "t" field in the payload.
|
||||
type EventType string
|
||||
|
||||
// Event describes an Event data that comes from a gateway Operation.
|
||||
type Event interface {
|
||||
Op() OpCode
|
||||
EventType() EventType
|
||||
}
|
||||
|
||||
// OpFunc is a constructor function for an Operation.
|
||||
type OpFunc func() Event
|
||||
|
||||
// OpUnmarshalers contains a map of event constructor function.
|
||||
type OpUnmarshalers struct {
|
||||
r map[opFuncID]OpFunc
|
||||
}
|
||||
|
||||
type opFuncID struct {
|
||||
Op OpCode `json:"op"`
|
||||
T EventType `json:"t"`
|
||||
}
|
||||
|
||||
// NewOpUnmarshalers creates a nwe OpUnmarshalers instance from the given
|
||||
// constructor functions.
|
||||
func NewOpUnmarshalers(funcs ...OpFunc) OpUnmarshalers {
|
||||
m := OpUnmarshalers{r: make(map[opFuncID]OpFunc)}
|
||||
m.Add(funcs...)
|
||||
return m
|
||||
}
|
||||
|
||||
// Each iterates over the marshaler map.
|
||||
func (m OpUnmarshalers) Each(f func(OpCode, EventType, OpFunc) (done bool)) {
|
||||
for id, fn := range m.r {
|
||||
if f(id.Op, id.T, fn) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds the given functions into the unmarshaler registry.
|
||||
func (m OpUnmarshalers) Add(funcs ...OpFunc) {
|
||||
for _, fn := range funcs {
|
||||
ev := fn()
|
||||
id := opFuncID{
|
||||
Op: ev.Op(),
|
||||
T: ev.EventType(),
|
||||
}
|
||||
|
||||
m.r[id] = fn
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup searches the OpMarshalers map for the given constructor function.
|
||||
func (m OpUnmarshalers) Lookup(op OpCode, t EventType) OpFunc {
|
||||
return m.r[opFuncID{op, t}]
|
||||
}
|
||||
|
||||
// Op is a gateway Operation.
|
||||
type Op struct {
|
||||
Code OpCode `json:"op"`
|
||||
Data Event `json:"d,omitempty"`
|
||||
|
||||
// Type is only for gateway dispatch events.
|
||||
Type EventType `json:"t,omitempty"`
|
||||
// Sequence is only for gateway dispatch events (Op 0).
|
||||
Sequence int64 `json:"s,omitempty"`
|
||||
}
|
||||
|
||||
// UnknownEventError is required by HandleOp if an event is encountered that is
|
||||
// not known. Internally, unknown events are logged and ignored. It is not a
|
||||
// fatal error.
|
||||
type UnknownEventError struct {
|
||||
Op OpCode
|
||||
Type EventType
|
||||
}
|
||||
|
||||
// Error formats the unknown event error to with the event name and payload
|
||||
func (err UnknownEventError) Error() string {
|
||||
return fmt.Sprintf("unknown op %d, event %s", err.Op, err.Type)
|
||||
}
|
||||
|
||||
// IsBrokenConnection returns true if the error is a broken connection error.
|
||||
func IsUnknownEvent(err error) bool {
|
||||
var uevent *UnknownEventError
|
||||
return errors.As(err, &uevent)
|
||||
}
|
||||
|
||||
// ReadOps reads maximum n Ops and accumulate them into a slice.
|
||||
func ReadOps(ctx context.Context, ch <-chan Op, n int) ([]Op, error) {
|
||||
ops := make([]Op, 0, n)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ops, ctx.Err()
|
||||
case op := <-ch:
|
||||
ops = append(ops, op)
|
||||
if len(ops) == n {
|
||||
return ops, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadOp reads a single Op.
|
||||
func ReadOp(ctx context.Context, ch <-chan Op) (Op, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return Op{}, ctx.Err()
|
||||
case op := <-ch:
|
||||
return op, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcaster is primarily used for debugging.
|
||||
type Broadcaster struct {
|
||||
src <-chan Op
|
||||
dst map[chan<- Op]struct{}
|
||||
mut sync.Mutex
|
||||
void bool
|
||||
}
|
||||
|
||||
// NewBroadcaster creates a new broadcaster.
|
||||
func NewBroadcaster(src <-chan Op) *Broadcaster {
|
||||
return &Broadcaster{
|
||||
src: src,
|
||||
dst: make(map[chan<- Op]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the broadcasting loop.
|
||||
func (b *Broadcaster) Start() {
|
||||
b.mut.Lock()
|
||||
if b.void {
|
||||
panic("Start called on voided Broadcaster")
|
||||
}
|
||||
b.mut.Unlock()
|
||||
|
||||
go func() {
|
||||
for op := range b.src {
|
||||
b.mut.Lock()
|
||||
|
||||
for ch := range b.dst {
|
||||
ch <- op
|
||||
}
|
||||
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
b.mut.Lock()
|
||||
b.void = true
|
||||
|
||||
for ch := range b.dst {
|
||||
close(ch)
|
||||
}
|
||||
|
||||
b.mut.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// Subscribe subscribes the given channel
|
||||
func (b *Broadcaster) Subscribe(ch chan<- Op) {
|
||||
b.mut.Lock()
|
||||
if b.void {
|
||||
panic("Subscribe called on voided Broadcaster")
|
||||
}
|
||||
b.dst[ch] = struct{}{}
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
// NewSubscribed creates a newly subscribed Op channel.
|
||||
func (b *Broadcaster) NewSubscribed() <-chan Op {
|
||||
ch := make(chan Op, 1)
|
||||
b.Subscribe(ch)
|
||||
return ch
|
||||
}
|
35
utils/ws/ophandler/ophandler.go
Normal file
35
utils/ws/ophandler/ophandler.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
// Package ophandler provides an Op channel reader that redistributes the events
|
||||
// into handlers.
|
||||
package ophandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
)
|
||||
|
||||
// Loop starts a background goroutine that starts reading from src and
|
||||
// distributes received events into the given handler. It's stopped once src is
|
||||
// closed. The returned channel will be closed once src is closed.
|
||||
func Loop(src <-chan ws.Op, dst *handler.Handler) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for op := range src {
|
||||
dst.Call(op.Data)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
// WaitForDone waits for the done channel returned by Loop until the channel is
|
||||
// closed or the context expires.
|
||||
func WaitForDone(ctx context.Context, done <-chan struct{}) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package wsutil
|
||||
package ws
|
||||
|
||||
import (
|
||||
"time"
|
118
utils/ws/ws.go
Normal file
118
utils/ws/ws.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
// Package wsutil provides abstractions around the Websocket, including rate
|
||||
// limits.
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var (
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||
// WSDebug is used for extra debug logging. This is expected to behave
|
||||
// similarly to log.Println().
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
)
|
||||
|
||||
// Websocket is a wrapper around a websocket Conn with thread safety and rate
|
||||
// limiting for sending and throttling.
|
||||
type Websocket struct {
|
||||
mutex sync.Mutex
|
||||
conn Connection
|
||||
addr string
|
||||
|
||||
// If you ever need access to these fields from outside the package, please
|
||||
// open an issue. It might be worth it to refactor these out for distributed
|
||||
// sharding.
|
||||
|
||||
sendLimiter *rate.Limiter
|
||||
dialLimiter *rate.Limiter
|
||||
}
|
||||
|
||||
// NewWebsocket creates a default Websocket with the given address.
|
||||
func NewWebsocket(c Codec, addr string) *Websocket {
|
||||
return NewCustomWebsocket(NewConn(c), addr)
|
||||
}
|
||||
|
||||
// NewCustomWebsocket creates a new undialed Websocket.
|
||||
func NewCustomWebsocket(conn Connection, addr string) *Websocket {
|
||||
return &Websocket{
|
||||
conn: conn,
|
||||
addr: addr,
|
||||
|
||||
sendLimiter: NewSendLimiter(),
|
||||
dialLimiter: NewDialLimiter(),
|
||||
}
|
||||
}
|
||||
|
||||
// Dial waits until the rate limiter allows then dials the websocket.
|
||||
func (ws *Websocket) Dial(ctx context.Context) (<-chan Op, error) {
|
||||
if err := ws.dialLimiter.Wait(ctx); err != nil {
|
||||
// Expired, fatal error
|
||||
return nil, errors.Wrap(err, "failed to wait for dial rate limiter")
|
||||
}
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
// Reset the send limiter.
|
||||
// TODO: see if each limit only applies to one connection or not.
|
||||
ws.sendLimiter = NewSendLimiter()
|
||||
|
||||
return ws.conn.Dial(ctx, ws.addr)
|
||||
}
|
||||
|
||||
// Send sends b over the Websocket with a deadline. It closes the internal
|
||||
// Websocket if the Send method errors out.
|
||||
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
|
||||
WSDebug("Acquiring the websoccket mutex for sending.")
|
||||
|
||||
ws.mutex.Lock()
|
||||
WSDebug("Mutex lock acquired.")
|
||||
sendLimiter := ws.sendLimiter
|
||||
conn := ws.conn
|
||||
ws.mutex.Unlock()
|
||||
|
||||
WSDebug("Waiting for the send rate limiter...")
|
||||
|
||||
if err := sendLimiter.Wait(ctx); err != nil {
|
||||
WSDebug("Send rate limiter timed out.")
|
||||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
|
||||
|
||||
return conn.Send(ctx, b)
|
||||
}
|
||||
|
||||
// Close closes the websocket connection. It assumes that the Websocket is
|
||||
// closed even when it returns an error. If the Websocket was already closed
|
||||
// before, ErrWebsocketClosed will be returned.
|
||||
func (ws *Websocket) Close() error {
|
||||
WSDebug("Conn: Acquiring mutex lock to close...")
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
WSDebug("Conn: Write mutex acquired")
|
||||
|
||||
return ws.conn.Close(false)
|
||||
}
|
||||
|
||||
// CloseGracefully is similar to Close, but a proper close frame is sent to
|
||||
// Discord, invalidating the internal session ID and voiding resumes.
|
||||
func (ws *Websocket) CloseGracefully() error {
|
||||
WSDebug("Conn: Acquiring mutex lock to close...")
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
WSDebug("Conn: Write mutex acquired")
|
||||
|
||||
return ws.conn.Close(true)
|
||||
}
|
|
@ -1,270 +0,0 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// CopyBufferSize is used for the initial size of the internal WS' buffer. Its
|
||||
// size is 4KB.
|
||||
var CopyBufferSize = 4096
|
||||
|
||||
// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
|
||||
// re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
|
||||
var MaxCapUntilReset = CopyBufferSize * 4
|
||||
|
||||
// CloseDeadline controls the deadline to wait for sending the Close frame.
|
||||
var CloseDeadline = time.Second
|
||||
|
||||
// ErrWebsocketClosed is returned if the websocket is already closed.
|
||||
var ErrWebsocketClosed = errors.New("websocket is closed")
|
||||
|
||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||
// This connection expects the driver to handle compression by itself, including
|
||||
// modifying the connection URL. The implementation doesn't have to be safe for
|
||||
// concurrent use.
|
||||
type Connection interface {
|
||||
// Dial dials the address (string). Context needs to be passed in for
|
||||
// timeout. This method should also be re-usable after Close is called.
|
||||
Dial(context.Context, string) error
|
||||
|
||||
// Listen returns an event channel that sends over events constantly. It can
|
||||
// return nil if there isn't an ongoing connection.
|
||||
Listen() <-chan Event
|
||||
|
||||
// Send allows the caller to send bytes. It does not need to clean itself
|
||||
// up on errors, as the Websocket wrapper will do that.
|
||||
//
|
||||
// If the data is nil, it should send a close frame
|
||||
Send(context.Context, []byte) error
|
||||
|
||||
// Close should close the websocket connection. The underlying connection
|
||||
// may be reused, but this Connection instance will be reused with Dial. The
|
||||
// Connection must still be reusable even if Close returns an error.
|
||||
Close() error
|
||||
// CloseGracefully sends a close frame and then closes the websocket
|
||||
// connection.
|
||||
CloseGracefully() error
|
||||
}
|
||||
|
||||
// Conn is the default Websocket connection. It tries to compresses all payloads
|
||||
// using zlib.
|
||||
type Conn struct {
|
||||
Dialer websocket.Dialer
|
||||
Header http.Header
|
||||
Conn *websocket.Conn
|
||||
events chan Event
|
||||
}
|
||||
|
||||
var _ Connection = (*Conn)(nil)
|
||||
|
||||
// NewConn creates a new default websocket connection with a default dialer.
|
||||
func NewConn() *Conn {
|
||||
return NewConnWithDialer(websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: WSTimeout,
|
||||
ReadBufferSize: CopyBufferSize,
|
||||
WriteBufferSize: CopyBufferSize,
|
||||
EnableCompression: true,
|
||||
})
|
||||
}
|
||||
|
||||
// NewConnWithDialer creates a new default websocket connection with a custom
|
||||
// dialer.
|
||||
func NewConnWithDialer(dialer websocket.Dialer) *Conn {
|
||||
return &Conn{
|
||||
Dialer: dialer,
|
||||
Header: http.Header{
|
||||
"Accept-Encoding": {"zlib"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
|
||||
// BUG which prevents stream compression.
|
||||
// See https://github.com/golang/go/issues/31514.
|
||||
|
||||
c.Conn, _, err = c.Dialer.DialContext(ctx, addr, c.Header)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to dial WS")
|
||||
}
|
||||
|
||||
// Reset the deadline.
|
||||
c.Conn.SetWriteDeadline(resetDeadline)
|
||||
|
||||
c.events = make(chan Event, WSBuffer)
|
||||
go startReadLoop(c.Conn, c.events)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Listen returns an event channel if there is a connection associated with it.
|
||||
// It returns nil if there is none.
|
||||
func (c *Conn) Listen() <-chan Event {
|
||||
return c.events
|
||||
}
|
||||
|
||||
// resetDeadline is used to reset the write deadline after using the context's.
|
||||
var resetDeadline = time.Time{}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
d, ok := ctx.Deadline()
|
||||
if ok {
|
||||
c.Conn.SetWriteDeadline(d)
|
||||
defer c.Conn.SetWriteDeadline(resetDeadline)
|
||||
}
|
||||
|
||||
if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
WSDebug("Conn: Close is called; shutting down the Websocket connection.")
|
||||
|
||||
// Have a deadline before closing.
|
||||
var deadline = time.Now().Add(5 * time.Second)
|
||||
c.Conn.SetWriteDeadline(deadline)
|
||||
|
||||
// Close the WS.
|
||||
err := c.Conn.Close()
|
||||
|
||||
c.Conn.SetWriteDeadline(resetDeadline)
|
||||
|
||||
WSDebug("Conn: Websocket closed; error:", err)
|
||||
WSDebug("Conn: Flushing events...")
|
||||
|
||||
// Flush all events before closing the channel. This will return as soon as
|
||||
// c.events is closed, or after closed.
|
||||
for range c.events {
|
||||
}
|
||||
|
||||
WSDebug("Flushed events.")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) CloseGracefully() error {
|
||||
WSDebug("Conn: CloseGracefully is called; sending close frame.")
|
||||
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(CloseDeadline))
|
||||
|
||||
err := c.Conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
WSError(err)
|
||||
}
|
||||
|
||||
WSDebug("Conn: Close frame sent; error:", err)
|
||||
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// loopState is a thread-unsafe disposable state container for the read loop.
|
||||
// It's made to completely separate the read loop of any synchronization that
|
||||
// doesn't involve the websocket connection itself.
|
||||
type loopState struct {
|
||||
conn *websocket.Conn
|
||||
zlib io.ReadCloser
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
|
||||
// Clean up the events channel in the end.
|
||||
defer close(eventCh)
|
||||
|
||||
// Allocate the read loop its own private resources.
|
||||
state := loopState{conn: conn}
|
||||
state.buf.Grow(CopyBufferSize)
|
||||
|
||||
for {
|
||||
b, err := state.handle()
|
||||
if err != nil {
|
||||
WSDebug("Conn: Read error:", err)
|
||||
|
||||
// Is the error an EOF?
|
||||
if errors.Is(err, io.EOF) {
|
||||
// Yes it is, exit.
|
||||
return
|
||||
}
|
||||
|
||||
// Is the error an intentional close call? Go 1.16 exposes
|
||||
// ErrClosing, but we have to do this for now.
|
||||
if strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
return
|
||||
}
|
||||
|
||||
// Unusual error; log and exit:
|
||||
eventCh <- Event{nil, errors.Wrap(err, "WS error")}
|
||||
return
|
||||
}
|
||||
|
||||
// If the payload length is 0, skip it.
|
||||
if len(b) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
eventCh <- Event{b, nil}
|
||||
}
|
||||
}
|
||||
|
||||
func (state *loopState) handle() ([]byte, error) {
|
||||
// skip message type
|
||||
t, r, err := state.conn.NextReader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t == websocket.BinaryMessage {
|
||||
// Probably a zlib payload.
|
||||
|
||||
if state.zlib == nil {
|
||||
z, err := zlib.NewReader(r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create a zlib reader")
|
||||
}
|
||||
state.zlib = z
|
||||
} else {
|
||||
if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to reset zlib reader")
|
||||
}
|
||||
}
|
||||
|
||||
defer state.zlib.Close()
|
||||
r = state.zlib
|
||||
}
|
||||
|
||||
return state.readAll(r)
|
||||
}
|
||||
|
||||
// readAll reads bytes into an existing buffer, copy it over, then wipe the old
|
||||
// buffer.
|
||||
func (state *loopState) readAll(r io.Reader) ([]byte, error) {
|
||||
defer state.buf.Reset()
|
||||
|
||||
if _, err := state.buf.ReadFrom(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy the bytes so we could empty the buffer for reuse.
|
||||
cpy := make([]byte, state.buf.Len())
|
||||
copy(cpy, state.buf.Bytes())
|
||||
|
||||
// If the buffer's capacity is over the limit, then re-allocate a new one.
|
||||
if state.buf.Cap() > MaxCapUntilReset {
|
||||
state.buf = bytes.Buffer{}
|
||||
state.buf.Grow(CopyBufferSize)
|
||||
}
|
||||
|
||||
return cpy, nil
|
||||
}
|
|
@ -1,151 +0,0 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/internal/heart"
|
||||
)
|
||||
|
||||
type brokenConnectionError struct {
|
||||
underneath error
|
||||
}
|
||||
|
||||
// Error formats the broken connection error with the message "explicit
|
||||
// connection break."
|
||||
func (err brokenConnectionError) Error() string {
|
||||
return "explicit connection break: " + err.underneath.Error()
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (err brokenConnectionError) Unwrap() error {
|
||||
return err.underneath
|
||||
}
|
||||
|
||||
// ErrBrokenConnection marks the given error as a broken connection error. This
|
||||
// error will cause the pacemaker loop to break and return the error. The error,
|
||||
// when stringified, will say "explicit connection break."
|
||||
func ErrBrokenConnection(err error) error {
|
||||
return brokenConnectionError{underneath: err}
|
||||
}
|
||||
|
||||
// IsBrokenConnection returns true if the error is a broken connection error.
|
||||
func IsBrokenConnection(err error) bool {
|
||||
var broken *brokenConnectionError
|
||||
return errors.As(err, &broken)
|
||||
}
|
||||
|
||||
// TODO API
|
||||
type EventLoopHandler interface {
|
||||
EventHandler
|
||||
HeartbeatCtx(context.Context) error
|
||||
}
|
||||
|
||||
// PacemakerLoop provides an event loop with a pacemaker. A zero-value instance
|
||||
// is a valid instance only when RunAsync is called first.
|
||||
type PacemakerLoop struct {
|
||||
heart.Pacemaker
|
||||
Extras ExtraHandlers
|
||||
ErrorLog func(error)
|
||||
|
||||
events <-chan Event
|
||||
control chan func()
|
||||
handler func(*OP) error
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) errorLog(err error) {
|
||||
if p.ErrorLog == nil {
|
||||
WSDebug("Uncaught error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.ErrorLog(err)
|
||||
}
|
||||
|
||||
// Pace calls the pacemaker's Pace function.
|
||||
func (p *PacemakerLoop) Pace(ctx context.Context) error {
|
||||
return p.Pacemaker.PaceCtx(ctx)
|
||||
}
|
||||
|
||||
// StartBeating asynchronously starts the pacemaker loop.
|
||||
func (p *PacemakerLoop) StartBeating(pace time.Duration, evl EventLoopHandler, exit func(error)) {
|
||||
WSDebug("Starting the pacemaker loop.")
|
||||
|
||||
p.Pacemaker = heart.NewPacemaker(pace, evl.HeartbeatCtx)
|
||||
p.control = make(chan func())
|
||||
p.handler = evl.HandleOP
|
||||
p.events = nil // block forever
|
||||
|
||||
go func() { exit(p.startLoop()) }()
|
||||
}
|
||||
|
||||
// Stop signals the pacemaker to stop. It does not wait for the pacer to stop.
|
||||
// The pacer will call the given callback with a nil error.
|
||||
func (p *PacemakerLoop) Stop() {
|
||||
close(p.control)
|
||||
}
|
||||
|
||||
// SetEventChannel sets the event channel inside the event loop. There is no
|
||||
// guarantee that the channel is set when the function returns. This function is
|
||||
// concurrently safe.
|
||||
func (p *PacemakerLoop) SetEventChannel(evCh <-chan Event) {
|
||||
p.control <- func() { p.events = evCh }
|
||||
}
|
||||
|
||||
// SetPace (re)sets the pace duration. As with SetEventChannel, there is no
|
||||
// guarantee that the pacer is reset when the function returns. This function is
|
||||
// concurrently safe.
|
||||
func (p *PacemakerLoop) SetPace(pace time.Duration) {
|
||||
p.control <- func() { p.Pacemaker.SetPace(pace) }
|
||||
}
|
||||
|
||||
func (p *PacemakerLoop) startLoop() error {
|
||||
defer WSDebug("Pacemaker loop has exited.")
|
||||
defer p.Pacemaker.StopTicker()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.Pacemaker.Ticks:
|
||||
if err := p.Pacemaker.Pace(); err != nil {
|
||||
return errors.Wrap(err, "pace failed, reconnecting")
|
||||
}
|
||||
|
||||
case fn, ok := <-p.control:
|
||||
if !ok { // Intentional stop at p.Close().
|
||||
WSDebug("Pacemaker intentionally stopped using p.control.")
|
||||
return nil
|
||||
}
|
||||
|
||||
fn()
|
||||
|
||||
case ev, ok := <-p.events:
|
||||
if !ok {
|
||||
WSDebug("Events channel closed, stopping pacemaker.")
|
||||
return nil
|
||||
}
|
||||
|
||||
if ev.Error != nil {
|
||||
return errors.Wrap(ev.Error, "event returned error")
|
||||
}
|
||||
|
||||
o, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to decode OP")
|
||||
}
|
||||
|
||||
// Check the events before handling.
|
||||
p.Extras.Check(o)
|
||||
|
||||
// Handle the event
|
||||
if err := p.handler(o); err != nil {
|
||||
if IsBrokenConnection(err) {
|
||||
return errors.Wrap(err, "handler failed")
|
||||
}
|
||||
|
||||
p.errorLog(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,202 +0,0 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
)
|
||||
|
||||
var ErrEmptyPayload = errors.New("empty payload")
|
||||
|
||||
// OPCode is a generic type for websocket OP codes.
|
||||
type OPCode uint8
|
||||
|
||||
type OP struct {
|
||||
Code OPCode `json:"op"`
|
||||
Data json.Raw `json:"d,omitempty"`
|
||||
|
||||
// Only for Gateway Dispatch (op 0)
|
||||
Sequence int64 `json:"s,omitempty"`
|
||||
EventName string `json:"t,omitempty"`
|
||||
}
|
||||
|
||||
func (op *OP) UnmarshalData(v interface{}) error {
|
||||
return json.Unmarshal(op.Data, v)
|
||||
}
|
||||
|
||||
func DecodeOP(ev Event) (*OP, error) {
|
||||
if ev.Error != nil {
|
||||
return nil, ev.Error
|
||||
}
|
||||
|
||||
if len(ev.Data) == 0 {
|
||||
return nil, ErrEmptyPayload
|
||||
}
|
||||
|
||||
var op *OP
|
||||
if err := json.Unmarshal(ev.Data, &op); err != nil {
|
||||
return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func AssertEvent(ev Event, code OPCode, v interface{}) (*OP, error) {
|
||||
op, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if op.Code != code {
|
||||
return op, fmt.Errorf(
|
||||
"unexpected OP Code: %d, expected %d (%s)",
|
||||
op.Code, code, op.Data,
|
||||
)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(op.Data, v); err != nil {
|
||||
return op, errors.Wrap(err, "failed to decode data")
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// UnknownEventError is required by HandleOP if an event is encountered that is
|
||||
// not known. Internally, unknown events are logged and ignored. It is not a
|
||||
// fatal error.
|
||||
type UnknownEventError struct {
|
||||
Name string
|
||||
Data json.Raw
|
||||
}
|
||||
|
||||
// Error formats the unknown event error to with the event name and payload
|
||||
func (err UnknownEventError) Error() string {
|
||||
return fmt.Sprintf("unknown event %s: %s", err.Name, string(err.Data))
|
||||
}
|
||||
|
||||
// IsBrokenConnection returns true if the error is a broken connection error.
|
||||
func IsUnknownEvent(err error) bool {
|
||||
var uevent *UnknownEventError
|
||||
return errors.As(err, &uevent)
|
||||
}
|
||||
|
||||
type EventHandler interface {
|
||||
HandleOP(op *OP) error
|
||||
}
|
||||
|
||||
func HandleEvent(h EventHandler, ev Event) error {
|
||||
o, err := DecodeOP(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.HandleOP(o)
|
||||
}
|
||||
|
||||
// WaitForEvent blocks until fn() returns true. All incoming events are handled
|
||||
// regardless.
|
||||
func WaitForEvent(ctx context.Context, h EventHandler, ch <-chan Event, fn func(*OP) bool) error {
|
||||
for {
|
||||
select {
|
||||
case e, ok := <-ch:
|
||||
if !ok {
|
||||
return errors.New("event not found and event channel is closed")
|
||||
}
|
||||
|
||||
o, err := DecodeOP(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle the *OP first, in case it's an Invalid Session. This should
|
||||
// also prevent a race condition with things that need Ready after
|
||||
// Open().
|
||||
if err := h.HandleOP(o); err != nil {
|
||||
// Explicitly ignore events we don't know.
|
||||
if IsUnknownEvent(err) {
|
||||
WSError(err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Are these events what we're looking for? If we've found the event,
|
||||
// return.
|
||||
if fn(o) {
|
||||
return nil
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ExtraHandlers struct {
|
||||
mutex sync.Mutex
|
||||
handlers map[uint32]*ExtraHandler
|
||||
serial uint32
|
||||
}
|
||||
|
||||
type ExtraHandler struct {
|
||||
Check func(*OP) bool
|
||||
send chan *OP
|
||||
|
||||
closed moreatomic.Bool
|
||||
}
|
||||
|
||||
func (ex *ExtraHandlers) Add(check func(*OP) bool) (<-chan *OP, func()) {
|
||||
handler := &ExtraHandler{
|
||||
Check: check,
|
||||
send: make(chan *OP),
|
||||
}
|
||||
|
||||
ex.mutex.Lock()
|
||||
defer ex.mutex.Unlock()
|
||||
|
||||
if ex.handlers == nil {
|
||||
ex.handlers = make(map[uint32]*ExtraHandler, 1)
|
||||
}
|
||||
|
||||
i := ex.serial
|
||||
ex.serial++
|
||||
|
||||
ex.handlers[i] = handler
|
||||
|
||||
return handler.send, func() {
|
||||
// Check the atomic bool before acquiring the mutex. Might help a bit in
|
||||
// performance.
|
||||
if handler.closed.Get() {
|
||||
return
|
||||
}
|
||||
|
||||
ex.mutex.Lock()
|
||||
defer ex.mutex.Unlock()
|
||||
|
||||
delete(ex.handlers, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Check runs and sends OP data. It is not thread-safe.
|
||||
func (ex *ExtraHandlers) Check(op *OP) {
|
||||
ex.mutex.Lock()
|
||||
defer ex.mutex.Unlock()
|
||||
|
||||
for i, handler := range ex.handlers {
|
||||
if handler.Check(op) {
|
||||
// Attempt to send.
|
||||
handler.send <- op
|
||||
|
||||
// Mark the handler as closed.
|
||||
handler.closed.Set(true)
|
||||
|
||||
// Delete the handler.
|
||||
delete(ex.handlers, i)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,200 +0,0 @@
|
|||
// Package wsutil provides abstractions around the Websocket, including rate
|
||||
// limits.
|
||||
package wsutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var (
|
||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||
// before Gateway cancels and fails.
|
||||
WSTimeout = 30 * time.Second
|
||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) { log.Println("Gateway error:", err) }
|
||||
// WSDebug is used for extra debug logging. This is expected to behave
|
||||
// similarly to log.Println().
|
||||
WSDebug = func(v ...interface{}) {}
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Data []byte
|
||||
|
||||
// Error is non-nil if Data is nil.
|
||||
Error error
|
||||
}
|
||||
|
||||
// Websocket is a wrapper around a websocket Conn with thread safety and rate
|
||||
// limiting for sending and throttling.
|
||||
type Websocket struct {
|
||||
mutex sync.Mutex
|
||||
conn Connection
|
||||
addr string
|
||||
closed bool
|
||||
|
||||
sendLimiter *rate.Limiter
|
||||
dialLimiter *rate.Limiter
|
||||
|
||||
// Timeout is the default timeout used if a context with no deadline is
|
||||
// given to Dial.
|
||||
//
|
||||
// It must not be changed after the Websocket is used once.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// New creates a default Websocket with the given address.
|
||||
func New(addr string) *Websocket {
|
||||
return NewCustom(NewConn(), addr)
|
||||
}
|
||||
|
||||
// NewCustom creates a new undialed Websocket.
|
||||
func NewCustom(conn Connection, addr string) *Websocket {
|
||||
return &Websocket{
|
||||
conn: conn,
|
||||
addr: addr,
|
||||
closed: true,
|
||||
|
||||
sendLimiter: NewSendLimiter(),
|
||||
dialLimiter: NewDialLimiter(),
|
||||
|
||||
Timeout: WSTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial waits until the rate limiter allows then dials the websocket.
|
||||
//
|
||||
// If the passed context has no deadline, Dial will wrap it in a
|
||||
// context.WithTimeout using ws.Timeout as timeout.
|
||||
func (ws *Websocket) Dial(ctx context.Context) error {
|
||||
if _, ok := ctx.Deadline(); !ok && ws.Timeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, ws.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if err := ws.dialLimiter.Wait(ctx); err != nil {
|
||||
// Expired, fatal error
|
||||
return errors.Wrap(err, "failed to wait")
|
||||
}
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
if !ws.closed {
|
||||
WSDebug("Old connection not yet closed while dialog; closing it.")
|
||||
ws.conn.Close()
|
||||
}
|
||||
|
||||
if err := ws.conn.Dial(ctx, ws.addr); err != nil {
|
||||
return errors.Wrap(err, "failed to dial")
|
||||
}
|
||||
|
||||
ws.closed = false
|
||||
|
||||
// Reset the send limiter.
|
||||
ws.sendLimiter = NewSendLimiter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Listen returns the inner event channel or nil if the Websocket connection is
|
||||
// not alive.
|
||||
func (ws *Websocket) Listen() <-chan Event {
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
if ws.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ws.conn.Listen()
|
||||
}
|
||||
|
||||
// Send sends b over the Websocket without a timeout.
|
||||
func (ws *Websocket) Send(b []byte) error {
|
||||
return ws.SendCtx(context.Background(), b)
|
||||
}
|
||||
|
||||
// SendCtx sends b over the Websocket with a deadline. It closes the internal
|
||||
// Websocket if the Send method errors out.
|
||||
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
|
||||
WSDebug("Waiting for the send rate limiter...")
|
||||
|
||||
if err := ws.sendLimiter.Wait(ctx); err != nil {
|
||||
WSDebug("Send rate limiter timed out.")
|
||||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
WSDebug("Mutex lock acquired.")
|
||||
|
||||
if ws.closed {
|
||||
return ErrWebsocketClosed
|
||||
}
|
||||
|
||||
if err := ws.conn.Send(ctx, b); err != nil {
|
||||
// We need to clean up ourselves if things are erroring out.
|
||||
WSDebug("Conn: Error while sending; closing the connection. Error:", err)
|
||||
ws.close(false)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the websocket connection. It assumes that the Websocket is
|
||||
// closed even when it returns an error. If the Websocket was already closed
|
||||
// before, ErrWebsocketClosed will be returned.
|
||||
func (ws *Websocket) Close() error {
|
||||
WSDebug("Conn: Acquiring mutex lock to close...")
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
WSDebug("Conn: Write mutex acquired")
|
||||
|
||||
return ws.close(false)
|
||||
}
|
||||
|
||||
func (ws *Websocket) CloseGracefully() error {
|
||||
WSDebug("Conn: Acquiring mutex lock to close...")
|
||||
|
||||
ws.mutex.Lock()
|
||||
defer ws.mutex.Unlock()
|
||||
|
||||
WSDebug("Conn: Write mutex acquired")
|
||||
|
||||
return ws.close(true)
|
||||
}
|
||||
|
||||
// close closes the Websocket without acquiring the mutex. Refer to Close for
|
||||
// more information.
|
||||
func (ws *Websocket) close(graceful bool) error {
|
||||
if ws.closed {
|
||||
WSDebug("Conn: Websocket is already closed.")
|
||||
return ErrWebsocketClosed
|
||||
}
|
||||
|
||||
ws.closed = true
|
||||
|
||||
if graceful {
|
||||
WSDebug("Conn: Closing gracefully")
|
||||
return ws.conn.CloseGracefully()
|
||||
}
|
||||
|
||||
WSDebug("Conn: Closing")
|
||||
return ws.conn.Close()
|
||||
}
|
498
voice/session.go
498
voice/session.go
|
@ -2,20 +2,22 @@ package voice
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
"github.com/diamondburned/arikawa/v3/utils/handler"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws/ophandler"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/gateway"
|
||||
"github.com/diamondburned/arikawa/v3/internal/handleloop"
|
||||
"github.com/diamondburned/arikawa/v3/internal/lazytime"
|
||||
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
|
||||
"github.com/diamondburned/arikawa/v3/session"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
"github.com/diamondburned/arikawa/v3/voice/udp"
|
||||
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
|
||||
)
|
||||
|
@ -32,24 +34,65 @@ var ErrCannotSend = errors.New("cannot send audio to closed channel")
|
|||
// WSTimeout is the duration to wait for a gateway operation including Session
|
||||
// to complete before erroring out. This only applies to functions that don't
|
||||
// take in a context already.
|
||||
var WSTimeout = 10 * time.Second
|
||||
const WSTimeout = 25 * time.Second
|
||||
|
||||
// ReconnectError is emitted into Session.Handler everytime the voice gateway
|
||||
// fails to be reconnected. It implements the error interface.
|
||||
type ReconnectError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error implements error.
|
||||
func (e ReconnectError) Error() string {
|
||||
return "voice reconnect error: " + e.Err.Error()
|
||||
}
|
||||
|
||||
// Unwrap returns e.Err.
|
||||
func (e ReconnectError) Unwrap() error { return e.Err }
|
||||
|
||||
// MainSession abstracts both session.Session and state.State.
|
||||
type MainSession interface {
|
||||
// AddHandler describes the method in handler.Handler.
|
||||
AddHandler(handler interface{}) (rm func())
|
||||
// Gateway returns the session's main Discord gateway.
|
||||
Gateway() *gateway.Gateway
|
||||
// Me returns the current user.
|
||||
Me() (*discord.User, error)
|
||||
// Channel queries for the channel with the given ID.
|
||||
Channel(discord.ChannelID) (*discord.Channel, error)
|
||||
}
|
||||
|
||||
var (
|
||||
_ MainSession = (*session.Session)(nil)
|
||||
_ MainSession = (*state.State)(nil)
|
||||
)
|
||||
|
||||
// UDPDialer is the UDP dialer function type. It's the function signature for
|
||||
// udp.DialConnection.
|
||||
type UDPDialer = func(ctx context.Context, addr string, ssrc uint32) (*udp.Connection, error)
|
||||
|
||||
// Session is a single voice session that wraps around the voice gateway and UDP
|
||||
// connection.
|
||||
type Session struct {
|
||||
*handler.Handler
|
||||
ErrorLog func(err error)
|
||||
session MainSession
|
||||
|
||||
session *session.Session
|
||||
looper *handleloop.Loop
|
||||
detach func()
|
||||
mut sync.RWMutex
|
||||
// connected is a non-nil blocking channel after Join is called and is
|
||||
// closed once Leave is called.
|
||||
disconnected chan struct{}
|
||||
|
||||
mut sync.RWMutex
|
||||
state voicegateway.State // guarded except UserID
|
||||
// TODO: expose getters mutex-guarded.
|
||||
gateway *voicegateway.Gateway
|
||||
|
||||
detachReconnect []func()
|
||||
|
||||
voiceUDP *udp.Connection
|
||||
// end of mutex
|
||||
gateway *voicegateway.Gateway
|
||||
gwCancel context.CancelFunc
|
||||
gwDone <-chan struct{}
|
||||
|
||||
// DialUDP is the custom function for dialing up a UDP connection.
|
||||
DialUDP UDPDialer
|
||||
|
||||
WSTimeout time.Duration // global WSTimeout
|
||||
WSMaxRetry int // 2
|
||||
|
@ -59,76 +102,102 @@ type Session struct {
|
|||
// joining determines the behavior of incoming event callbacks (Update).
|
||||
// If this is true, incoming events will just send into Updated channels. If
|
||||
// false, events will trigger a reconnection.
|
||||
joining moreatomic.Bool
|
||||
connected bool
|
||||
joining moreatomic.Bool
|
||||
// disconnectClosed is true if connected is already closed. It is only used
|
||||
// to keep track of closing connected.
|
||||
disconnectClosed bool
|
||||
}
|
||||
|
||||
// NewSession creates a new voice session for the current user.
|
||||
func NewSession(state *state.State) (*Session, error) {
|
||||
func NewSession(state MainSession) (*Session, error) {
|
||||
u, err := state.Me()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get me")
|
||||
}
|
||||
|
||||
return NewSessionCustom(state.Session, u.ID), nil
|
||||
return NewSessionCustom(state, u.ID), nil
|
||||
}
|
||||
|
||||
// NewSessionCustom creates a new voice session from the given session and user
|
||||
// ID.
|
||||
func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
|
||||
handler := handler.New()
|
||||
hlooper := handleloop.NewLoop(handler)
|
||||
func NewSessionCustom(ses MainSession, userID discord.UserID) *Session {
|
||||
closed := make(chan struct{})
|
||||
close(closed)
|
||||
|
||||
session := &Session{
|
||||
Handler: handler,
|
||||
looper: hlooper,
|
||||
Handler: handler.New(),
|
||||
session: ses,
|
||||
state: voicegateway.State{
|
||||
UserID: userID,
|
||||
},
|
||||
ErrorLog: func(err error) {},
|
||||
DialUDP: udp.DialConnection,
|
||||
WSTimeout: WSTimeout,
|
||||
WSMaxRetry: 2,
|
||||
WSRetryDelay: 2 * time.Second,
|
||||
WSWaitDuration: 5 * time.Second,
|
||||
|
||||
// Set this pair of value in so we never have to nil-check the channel.
|
||||
// We can just assume that it's either closed or connected.
|
||||
disconnected: closed,
|
||||
disconnectClosed: true,
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// updateServer is specifically used to monitor for reconnects.
|
||||
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||
func (s *Session) acquireUpdate(f func()) bool {
|
||||
if s.joining.Get() {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
// Ignore if we haven't connected yet or we're still joining.
|
||||
if !s.connected || s.state.GuildID != ev.GuildID {
|
||||
return
|
||||
select {
|
||||
case <-s.disconnected:
|
||||
return false
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
|
||||
// Reconnect.
|
||||
|
||||
s.state.Endpoint = ev.Endpoint
|
||||
s.state.Token = ev.Token
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.reconnectCtx(ctx); err != nil {
|
||||
s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update"))
|
||||
}
|
||||
f()
|
||||
return true
|
||||
}
|
||||
|
||||
// JoinChannel joins a voice channel with the default WS timeout. See
|
||||
// JoinChannelCtx for more information.
|
||||
func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
// updateServer is specifically used to monitor for reconnects.
|
||||
func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
|
||||
s.acquireUpdate(func() {
|
||||
if s.state.GuildID != ev.GuildID {
|
||||
return
|
||||
}
|
||||
|
||||
return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
|
||||
s.state.Endpoint = ev.Endpoint
|
||||
s.state.Token = ev.Token
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
s.reconnectCtx(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// updateState is specifically used after connecting to monitor when the bot is
|
||||
// forced across channels.
|
||||
func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
|
||||
s.acquireUpdate(func() {
|
||||
if s.state.GuildID != ev.GuildID || s.state.UserID != ev.UserID {
|
||||
return
|
||||
}
|
||||
|
||||
s.state.ChannelID = ev.ChannelID
|
||||
s.state.SessionID = ev.SessionID
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
s.reconnectCtx(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
type waitEventChs struct {
|
||||
|
@ -136,11 +205,17 @@ type waitEventChs struct {
|
|||
stateUpdate chan *gateway.VoiceStateUpdateEvent
|
||||
}
|
||||
|
||||
// JoinChannelCtx joins a voice channel. Callers shouldn't use this method
|
||||
// directly, but rather Voice's. This method shouldn't ever be called
|
||||
// concurrently.
|
||||
func (s *Session) JoinChannelCtx(
|
||||
ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
|
||||
// JoinChannel joins the given voice channel with the default timeout.
|
||||
func (s *Session) JoinChannel(ctx context.Context, chID discord.ChannelID, mute, deaf bool) error {
|
||||
var ch *discord.Channel
|
||||
|
||||
if chID.IsValid() {
|
||||
var err error
|
||||
ch, err = s.session.Channel(chID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "invalid channel ID")
|
||||
}
|
||||
}
|
||||
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
@ -155,14 +230,20 @@ func (s *Session) JoinChannelCtx(
|
|||
defer s.joining.Set(false)
|
||||
|
||||
// Set the state.
|
||||
s.state.ChannelID = cID
|
||||
s.state.GuildID = gID
|
||||
s.detach = s.session.AddHandler(s.updateServer)
|
||||
if ch != nil {
|
||||
s.state.ChannelID = ch.ID
|
||||
s.state.GuildID = ch.GuildID
|
||||
} else {
|
||||
s.state.GuildID = 0
|
||||
// Ensure that if `cID` is zero that it passes null to the update event.
|
||||
s.state.ChannelID = discord.NullChannelID
|
||||
}
|
||||
|
||||
// Ensure that if `cID` is zero that it passes null to the update event.
|
||||
channelID := discord.NullChannelID
|
||||
if cID.IsValid() {
|
||||
channelID = cID
|
||||
if s.detachReconnect == nil {
|
||||
s.detachReconnect = []func(){
|
||||
s.session.AddHandler(s.updateServer),
|
||||
s.session.AddHandler(s.updateState),
|
||||
}
|
||||
}
|
||||
|
||||
chs := waitEventChs{
|
||||
|
@ -185,15 +266,16 @@ func (s *Session) JoinChannelCtx(
|
|||
// Ensure gateway and voiceUDP are already closed.
|
||||
s.ensureClosed()
|
||||
|
||||
data := gateway.UpdateVoiceStateData{
|
||||
GuildID: gID,
|
||||
ChannelID: channelID,
|
||||
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
|
||||
// Send a Voice State Update event to the gateway.
|
||||
data := &gateway.UpdateVoiceStateCommand{
|
||||
GuildID: s.state.GuildID,
|
||||
ChannelID: s.state.ChannelID,
|
||||
SelfMute: mute,
|
||||
SelfDeaf: deaf,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
var timer *time.Timer
|
||||
|
||||
// Retry 3 times maximum.
|
||||
|
@ -230,17 +312,17 @@ func (s *Session) JoinChannelCtx(
|
|||
|
||||
// Mark the session as connected and move on. This allows one of the
|
||||
// connected handlers to reconnect on its own.
|
||||
s.connected = true
|
||||
s.disconnected = make(chan struct{})
|
||||
|
||||
return s.reconnectCtx(ctx)
|
||||
}
|
||||
|
||||
func (s *Session) askDiscord(
|
||||
ctx context.Context, data gateway.UpdateVoiceStateData, chs waitEventChs) error {
|
||||
ctx context.Context, data *gateway.UpdateVoiceStateCommand, chs waitEventChs) error {
|
||||
|
||||
// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
|
||||
// Send a Voice State Update event to the gateway.
|
||||
if err := s.session.Gateway.UpdateVoiceStateCtx(ctx, data); err != nil {
|
||||
if err := s.session.Gateway().Send(ctx, data); err != nil {
|
||||
return errors.Wrap(err, "failed to send Voice State Update event")
|
||||
}
|
||||
|
||||
|
@ -290,118 +372,183 @@ func (s *Session) waitForIncoming(ctx context.Context, chs waitEventChs) error {
|
|||
|
||||
// reconnect uses the current state to reconnect to a new gateway and UDP
|
||||
// connection.
|
||||
func (s *Session) reconnectCtx(ctx context.Context) (err error) {
|
||||
wsutil.WSDebug("Sending stop handle.")
|
||||
s.looper.Stop()
|
||||
func (s *Session) reconnectCtx(ctx context.Context) error {
|
||||
ws.WSDebug("Sending stop handle.")
|
||||
|
||||
wsutil.WSDebug("Start gateway.")
|
||||
s.ensureClosed()
|
||||
|
||||
ws.WSDebug("Start gateway.")
|
||||
s.gateway = voicegateway.New(s.state)
|
||||
|
||||
// Open the voice gateway. The function will block until Ready is received.
|
||||
if err := s.gateway.OpenCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to open voice gateway")
|
||||
gwctx, gwcancel := context.WithCancel(context.Background())
|
||||
s.gwCancel = gwcancel
|
||||
|
||||
gwch := s.gateway.Connect(gwctx)
|
||||
|
||||
if err := s.spinGateway(ctx, gwch); err != nil {
|
||||
// Early cancel the gateway.
|
||||
gwcancel()
|
||||
// Nil this so future reconnects don't use the invalid gwDone.
|
||||
s.gwCancel = nil
|
||||
// Emit the error. It's fine to do this here since this is the only
|
||||
// place that can error out.
|
||||
s.Handler.Call(&ReconnectError{err})
|
||||
return err
|
||||
}
|
||||
|
||||
// Start the handler dispatching
|
||||
s.looper.Start(s.gateway.Events)
|
||||
|
||||
// Get the Ready event.
|
||||
voiceReady := s.gateway.Ready()
|
||||
|
||||
// Prepare the UDP voice connection.
|
||||
s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to open voice UDP connection")
|
||||
}
|
||||
|
||||
// Get the session description from the voice gateway.
|
||||
d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{
|
||||
Protocol: "udp",
|
||||
Data: voicegateway.SelectProtocolData{
|
||||
Address: s.voiceUDP.GatewayIP,
|
||||
Port: s.voiceUDP.GatewayPort,
|
||||
Mode: Protocol,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to select protocol")
|
||||
}
|
||||
|
||||
s.voiceUDP.UseSecret(d.SecretKey)
|
||||
// Start dispatching.
|
||||
s.gwDone = ophandler.Loop(gwch, s.Handler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) spinGateway(ctx context.Context, gwch <-chan ws.Op) error {
|
||||
var err error
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ev, ok := <-gwch:
|
||||
if !ok {
|
||||
return s.gateway.LastError()
|
||||
}
|
||||
|
||||
switch data := ev.Data.(type) {
|
||||
case *ws.CloseEvent:
|
||||
return errors.Wrap(err, "voice gateway error")
|
||||
|
||||
case *voicegateway.ReadyEvent:
|
||||
// Prepare the UDP voice connection.
|
||||
s.voiceUDP, err = s.DialUDP(ctx, data.Addr(), data.SSRC)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to open voice UDP connection")
|
||||
}
|
||||
|
||||
if err := s.gateway.Send(ctx, &voicegateway.SelectProtocolCommand{
|
||||
Protocol: "udp",
|
||||
Data: voicegateway.SelectProtocolData{
|
||||
Address: s.voiceUDP.GatewayIP,
|
||||
Port: s.voiceUDP.GatewayPort,
|
||||
Mode: Protocol,
|
||||
},
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to send SelectProtocolCommand")
|
||||
}
|
||||
|
||||
case *voicegateway.SessionDescriptionEvent:
|
||||
// We're done.
|
||||
s.voiceUDP.UseSecret(data.SecretKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dispatch this event to the handler.
|
||||
s.Handler.Call(ev.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Speaking tells Discord we're speaking. This method should not be called
|
||||
// concurrently.
|
||||
func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
|
||||
s.mut.RLock()
|
||||
gateway := s.gateway
|
||||
s.mut.RUnlock()
|
||||
|
||||
return gateway.Speaking(flag)
|
||||
}
|
||||
|
||||
// UseContext tells the UDP voice connection to write with the given context.
|
||||
func (s *Session) UseContext(ctx context.Context) error {
|
||||
func (s *Session) Speaking(ctx context.Context, flag voicegateway.SpeakingFlag) error {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
gateway := s.gateway
|
||||
s.mut.Unlock()
|
||||
|
||||
if s.voiceUDP == nil {
|
||||
return ErrCannotSend
|
||||
return gateway.Speaking(ctx, flag)
|
||||
}
|
||||
|
||||
func (s *Session) useUDP(f func(c *udp.Connection) error) (err error) {
|
||||
const maxAttempts = 5
|
||||
const retryDelay = 250 * time.Millisecond // adds up to about 1.25s
|
||||
|
||||
var lazyWait lazytime.Timer
|
||||
|
||||
// Hack: loop until we no longer get an error closed or until the connection
|
||||
// is dead. This is a workaround for when the session is trying to reconnect
|
||||
// itself in the background, which would drop the UDP connection.
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
s.mut.RLock()
|
||||
voiceUDP := s.voiceUDP
|
||||
disconnected := s.disconnected
|
||||
s.mut.RUnlock()
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
if voiceUDP == nil {
|
||||
// Session is still connected, but our voice UDP connection is
|
||||
// nil, so we're probably in the process of reconnecting
|
||||
// already.
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
|
||||
if err = f(voiceUDP); err != nil && errors.Is(err, net.ErrClosed) {
|
||||
// Session is still connected, but our UDP connection is somehow
|
||||
// closed, so we're probably waiting for the server to ask us to
|
||||
// reconnect with a new session.
|
||||
goto retry
|
||||
}
|
||||
|
||||
// Unknown error or none at all; exit.
|
||||
return err
|
||||
|
||||
retry:
|
||||
// Wait a slight bit. We can probably make the caller wait a couple
|
||||
// milliseconds without a wait.
|
||||
lazyWait.Reset(retryDelay)
|
||||
select {
|
||||
case <-lazyWait.C:
|
||||
continue
|
||||
case <-disconnected:
|
||||
return net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
return s.voiceUDP.UseContext(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// VoiceUDPConn gets a voice UDP connection. The caller could use this method to
|
||||
// circumvent the rapid mutex-read-lock acquire inside Write.
|
||||
func (s *Session) VoiceUDPConn() *udp.Connection {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
|
||||
return s.voiceUDP
|
||||
}
|
||||
|
||||
// Write writes into the UDP voice connection WITHOUT a timeout. Refer to
|
||||
// WriteCtx for more information.
|
||||
// Write writes into the UDP voice connection. This method is thread safe as far
|
||||
// as calling other methods of Session goes; HOWEVER it is not thread safe to
|
||||
// call Write itself concurrently.
|
||||
func (s *Session) Write(b []byte) (int, error) {
|
||||
return s.WriteCtx(context.Background(), b)
|
||||
var n int
|
||||
err := s.useUDP(func(c *udp.Connection) (err error) {
|
||||
n, err = c.Write(b)
|
||||
return
|
||||
})
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteCtx writes into the UDP voice connection with a context for timeout.
|
||||
// This method is thread safe as far as calling other methods of Session goes;
|
||||
// HOWEVER it is not thread safe to call Write itself concurrently.
|
||||
func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
voiceUDP := s.VoiceUDPConn()
|
||||
|
||||
if voiceUDP == nil {
|
||||
return 0, ErrCannotSend
|
||||
}
|
||||
|
||||
return voiceUDP.WriteCtx(ctx, b)
|
||||
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
|
||||
// thread safe, and must be used very carefully. The backing buffer is always
|
||||
// reused.
|
||||
func (s *Session) ReadPacket() (*udp.Packet, error) {
|
||||
var p *udp.Packet
|
||||
err := s.useUDP(func(c *udp.Connection) (err error) {
|
||||
p, err = c.ReadPacket()
|
||||
return
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
// Leave disconnects the current voice session from the currently connected
|
||||
// channel.
|
||||
func (s *Session) Leave() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.LeaveCtx(ctx)
|
||||
}
|
||||
|
||||
// LeaveCtx disconencts with a context. Refer to Leave for more information.
|
||||
func (s *Session) LeaveCtx(ctx context.Context) error {
|
||||
func (s *Session) Leave(ctx context.Context) error {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
s.connected = false
|
||||
s.ensureClosed()
|
||||
|
||||
// Unbind the handlers.
|
||||
if s.detach != nil {
|
||||
s.detach()
|
||||
s.detach = nil
|
||||
if s.detachReconnect != nil {
|
||||
for _, detach := range s.detachReconnect {
|
||||
detach()
|
||||
}
|
||||
s.detachReconnect = nil
|
||||
}
|
||||
|
||||
// If we're already closed.
|
||||
|
@ -409,46 +556,59 @@ func (s *Session) LeaveCtx(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
s.looper.Stop()
|
||||
|
||||
// Notify Discord that we're leaving. This will send a
|
||||
// VoiceStateUpdateEvent, in which our handler will promptly remove the
|
||||
// session from the map.
|
||||
|
||||
err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
|
||||
// Notify Discord that we're leaving.
|
||||
err := s.session.Gateway().Send(ctx, &gateway.UpdateVoiceStateCommand{
|
||||
GuildID: s.state.GuildID,
|
||||
ChannelID: discord.ChannelID(discord.NullSnowflake),
|
||||
SelfMute: true,
|
||||
SelfDeaf: true,
|
||||
})
|
||||
|
||||
s.ensureClosed()
|
||||
// wrap returns nil if err is nil
|
||||
return errors.Wrap(err, "failed to update voice state")
|
||||
// Wait for the gateway to exit first before we tell the user of the gateway
|
||||
// send error.
|
||||
if err := s.cancelGateway(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to update voice state")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) cancelGateway(ctx context.Context) error {
|
||||
if s.gwCancel != nil {
|
||||
s.gwCancel()
|
||||
s.gwCancel = nil
|
||||
|
||||
// Wait for the previous gateway to finish closing up, but make sure to
|
||||
// bail if the context expires.
|
||||
if err := ophandler.WaitForDone(ctx, s.gwDone); err != nil {
|
||||
return errors.Wrap(err, "cannot wait for gateway to close")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// close ensures everything is closed. It does not acquire the mutex.
|
||||
func (s *Session) ensureClosed() {
|
||||
s.looper.Stop()
|
||||
|
||||
// Disconnect the UDP connection.
|
||||
if s.voiceUDP != nil {
|
||||
s.voiceUDP.Close()
|
||||
s.voiceUDP = nil
|
||||
}
|
||||
|
||||
// Disconnect the voice gateway, ignoring the error.
|
||||
if s.gateway != nil {
|
||||
if err := s.gateway.Close(); err != nil {
|
||||
wsutil.WSDebug("Uncaught voice gateway close error:", err)
|
||||
}
|
||||
s.gateway = nil
|
||||
if !s.disconnectClosed {
|
||||
close(s.disconnected)
|
||||
s.disconnectClosed = true
|
||||
}
|
||||
|
||||
if s.gwCancel != nil {
|
||||
s.gwCancel()
|
||||
// Don't actually clear this field, because we still want the caller to
|
||||
// be able to wait for the gateway to completely exit using
|
||||
// cancelGateway.
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPacket reads a single packet from the UDP connection. This is NOT at all
|
||||
// thread safe, and must be used very carefully. The backing buffer is always
|
||||
// reused.
|
||||
func (s *Session) ReadPacket() (*udp.Packet, error) {
|
||||
return s.VoiceUDPConn().ReadPacket()
|
||||
}
|
||||
|
|
|
@ -34,33 +34,25 @@ func TestNoop(t *testing.T) {
|
|||
}
|
||||
|
||||
func ExampleSession() {
|
||||
s, err := state.New("Bot " + token)
|
||||
if err != nil {
|
||||
log.Fatalln("failed to make state:", err)
|
||||
}
|
||||
s := state.New("Bot " + token)
|
||||
|
||||
// This is required for bots.
|
||||
voice.AddIntents(s.Gateway)
|
||||
voice.AddIntents(s)
|
||||
|
||||
if err := s.Open(context.TODO()); err != nil {
|
||||
log.Fatalln("failed to open gateway:", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
c, err := s.Channel(channelID)
|
||||
if err != nil {
|
||||
log.Fatalln("failed to get channel:", err)
|
||||
}
|
||||
|
||||
v, err := voice.NewSession(s)
|
||||
if err != nil {
|
||||
log.Fatalln("failed to create voice session:", err)
|
||||
}
|
||||
|
||||
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
|
||||
if err := v.JoinChannel(context.TODO(), channelID, false, false); err != nil {
|
||||
log.Fatalln("failed to join voice channel:", err)
|
||||
}
|
||||
defer v.Leave()
|
||||
defer v.Leave(context.TODO())
|
||||
|
||||
// Start writing Opus frames.
|
||||
for {
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/internal/testenv"
|
||||
"github.com/diamondburned/arikawa/v3/state"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
"github.com/diamondburned/arikawa/v3/voice/voicegateway"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
@ -23,17 +23,14 @@ import (
|
|||
func TestIntegration(t *testing.T) {
|
||||
config := testenv.Must(t)
|
||||
|
||||
wsutil.WSDebug = func(v ...interface{}) {
|
||||
ws.WSDebug = func(v ...interface{}) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
caller := file + ":" + strconv.Itoa(line)
|
||||
log.Println(append([]interface{}{caller}, v...)...)
|
||||
}
|
||||
|
||||
s, err := state.New("Bot " + config.BotToken)
|
||||
if err != nil {
|
||||
t.Fatal("failed to create a new state:", err)
|
||||
}
|
||||
AddIntents(s.Gateway)
|
||||
s := state.New("Bot " + config.BotToken)
|
||||
AddIntents(s)
|
||||
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
|
@ -70,7 +67,6 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
|
|||
if err != nil {
|
||||
t.Fatal("failed to create a new voice session:", err)
|
||||
}
|
||||
v.ErrorLog = func(err error) { t.Error(err) }
|
||||
|
||||
// Grab a timer to benchmark things.
|
||||
finish := timer()
|
||||
|
@ -80,7 +76,10 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
|
|||
finish("receiving voice speaking event")
|
||||
})
|
||||
|
||||
if err := v.JoinChannel(c.GuildID, c.ID, false, false); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
if err := v.JoinChannel(ctx, c.ID, false, false); err != nil {
|
||||
t.Fatal("failed to join voice:", err)
|
||||
}
|
||||
|
||||
|
@ -88,28 +87,19 @@ func testVoice(t *testing.T, s *state.State, c *discord.Channel) {
|
|||
log.Println("Leaving the voice channel concurrently.")
|
||||
|
||||
raceMe(t, "failed to leave voice channel", func() (interface{}, error) {
|
||||
return nil, v.Leave()
|
||||
return nil, v.Leave(ctx)
|
||||
})
|
||||
})
|
||||
|
||||
finish("joining the voice channel")
|
||||
|
||||
// Create a context and only cancel it AFTER we're done sending silence
|
||||
// frames.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Trigger speaking.
|
||||
if err := v.Speaking(voicegateway.Microphone); err != nil {
|
||||
if err := v.Speaking(ctx, voicegateway.Microphone); err != nil {
|
||||
t.Fatal("failed to start speaking:", err)
|
||||
}
|
||||
|
||||
finish("sending the speaking command")
|
||||
|
||||
if err := v.UseContext(ctx); err != nil {
|
||||
t.Fatal("failed to set ctx into vs:", err)
|
||||
}
|
||||
|
||||
f, err := os.Open("testdata/nico.dca")
|
||||
if err != nil {
|
||||
t.Fatal("failed to open nico.dca:", err)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/nacl/secretbox"
|
||||
)
|
||||
|
@ -39,13 +40,13 @@ type Connection struct {
|
|||
GatewayIP string
|
||||
GatewayPort uint16
|
||||
|
||||
context context.Context
|
||||
conn net.Conn
|
||||
ssrc uint32
|
||||
conn net.Conn
|
||||
ssrc uint32
|
||||
|
||||
// frequency rate.Limiter
|
||||
frequency *time.Ticker
|
||||
timeIncr uint32
|
||||
stopFreq chan struct{}
|
||||
|
||||
packet [12]byte
|
||||
secret [32]byte
|
||||
|
@ -59,9 +60,12 @@ type Connection struct {
|
|||
recvBuf []byte // len 1400
|
||||
recvOpus []byte // len 1400
|
||||
recvPacket *Packet // uses recvOpus' backing array
|
||||
|
||||
closed moreatomic.Bool
|
||||
}
|
||||
|
||||
func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
|
||||
// DialConnection dials a UDP connection.
|
||||
func DialConnection(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
|
||||
// Create a new UDP connection.
|
||||
conn, err := Dialer.DialContext(ctx, "udp", addr)
|
||||
if err != nil {
|
||||
|
@ -114,7 +118,7 @@ func DialConnectionCtx(ctx context.Context, addr string, ssrc uint32) (*Connecti
|
|||
GatewayPort: port,
|
||||
frequency: time.NewTicker(20 * time.Millisecond),
|
||||
timeIncr: 960,
|
||||
context: context.Background(),
|
||||
stopFreq: make(chan struct{}),
|
||||
packet: packet,
|
||||
ssrc: ssrc,
|
||||
conn: conn,
|
||||
|
@ -159,49 +163,29 @@ func (c *Connection) UseSecret(secret [32]byte) {
|
|||
c.secret = secret
|
||||
}
|
||||
|
||||
// UseContext lets the connection use the given context for its Write method.
|
||||
// WriteCtx will override this context.
|
||||
func (c *Connection) UseContext(ctx context.Context) error {
|
||||
return c.useContext(ctx)
|
||||
// SetWriteDeadline sets the UDP connection's write deadline.
|
||||
func (c *Connection) SetWriteDeadline(deadline time.Time) {
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
|
||||
func (c *Connection) useContext(ctx context.Context) error {
|
||||
if c.context == ctx {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.context = ctx
|
||||
|
||||
if deadline, ok := c.context.Deadline(); ok {
|
||||
return c.conn.SetWriteDeadline(deadline)
|
||||
} else {
|
||||
return c.conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
// SetReadDeadline sets the UDP connection's read deadline.
|
||||
func (c *Connection) SetReadDeadline(deadline time.Time) {
|
||||
c.conn.SetReadDeadline(deadline)
|
||||
}
|
||||
|
||||
func (c *Connection) Close() error {
|
||||
c.frequency.Stop()
|
||||
if c.closed.Acquire() {
|
||||
// Be sure to only run this ONCE.
|
||||
c.frequency.Stop()
|
||||
close(c.stopFreq)
|
||||
}
|
||||
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Write sends bytes into the voice UDP connection using the preset context.
|
||||
// Write sends a packet of audio into the voice UDP connection using the preset
|
||||
// context.
|
||||
func (c *Connection) Write(b []byte) (int, error) {
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
// WriteCtx sends bytes into the voice UDP connection with a timeout using the
|
||||
// given context. It ignores the context inside the connection, but will restore
|
||||
// the deadline after this call is done.
|
||||
func (c *Connection) WriteCtx(ctx context.Context, b []byte) (int, error) {
|
||||
oldCtx := c.context
|
||||
|
||||
c.useContext(ctx)
|
||||
defer c.useContext(oldCtx)
|
||||
|
||||
return c.write(b)
|
||||
}
|
||||
|
||||
func (c *Connection) write(b []byte) (int, error) {
|
||||
// Write a new sequence.
|
||||
binary.BigEndian.PutUint16(c.packet[2:4], c.sequence)
|
||||
c.sequence++
|
||||
|
@ -215,9 +199,9 @@ func (c *Connection) write(b []byte) (int, error) {
|
|||
|
||||
select {
|
||||
case <-c.frequency.C:
|
||||
|
||||
case <-c.context.Done():
|
||||
return 0, c.context.Err()
|
||||
// ok
|
||||
case <-c.stopFreq:
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
n, err := c.conn.Write(toSend)
|
||||
|
|
|
@ -8,7 +8,7 @@ import "github.com/diamondburned/arikawa/v3/gateway"
|
|||
|
||||
// AddIntents adds the needed voice intents into gw. Bots should always call
|
||||
// this before Open if voice is required.
|
||||
func AddIntents(gw *gateway.Gateway) {
|
||||
func AddIntents(gw interface{ AddIntents(gateway.Intents) }) {
|
||||
gw.AddIntents(gateway.IntentGuilds)
|
||||
gw.AddIntents(gateway.IntentGuildVoiceStates)
|
||||
}
|
||||
|
|
|
@ -1,167 +0,0 @@
|
|||
package voicegateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMissingForIdentify is an error when we are missing information to identify.
|
||||
ErrMissingForIdentify = errors.New("missing GuildID, UserID, SessionID, or Token for identify")
|
||||
|
||||
// ErrMissingForResume is an error when we are missing information to resume.
|
||||
ErrMissingForResume = errors.New("missing GuildID, SessionID, or Token for resuming")
|
||||
)
|
||||
|
||||
// OPCode 0
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload
|
||||
type IdentifyData struct {
|
||||
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
|
||||
UserID discord.UserID `json:"user_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// Identify sends an Identify operation (opcode 0) to the Gateway Gateway.
|
||||
func (c *Gateway) Identify() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.IdentifyCtx(ctx)
|
||||
}
|
||||
|
||||
// IdentifyCtx sends an Identify operation (opcode 0) to the Gateway Gateway.
|
||||
func (c *Gateway) IdentifyCtx(ctx context.Context) error {
|
||||
guildID := c.state.GuildID
|
||||
userID := c.state.UserID
|
||||
sessionID := c.state.SessionID
|
||||
token := c.state.Token
|
||||
|
||||
if guildID == 0 || userID == 0 || sessionID == "" || token == "" {
|
||||
return ErrMissingForIdentify
|
||||
}
|
||||
|
||||
return c.SendCtx(ctx, IdentifyOP, IdentifyData{
|
||||
GuildID: guildID,
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
Token: token,
|
||||
})
|
||||
}
|
||||
|
||||
// OPCode 1
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload
|
||||
type SelectProtocol struct {
|
||||
Protocol string `json:"protocol"`
|
||||
Data SelectProtocolData `json:"data"`
|
||||
}
|
||||
|
||||
type SelectProtocolData struct {
|
||||
Address string `json:"address"`
|
||||
Port uint16 `json:"port"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
// SelectProtocol sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
|
||||
func (c *Gateway) SelectProtocol(data SelectProtocol) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.SelectProtocolCtx(ctx, data)
|
||||
}
|
||||
|
||||
// SelectProtocolCtx sends a Select Protocol operation (opcode 1) to the Gateway Gateway.
|
||||
func (c *Gateway) SelectProtocolCtx(ctx context.Context, data SelectProtocol) error {
|
||||
return c.SendCtx(ctx, SelectProtocolOP, data)
|
||||
}
|
||||
|
||||
// OPCode 3
|
||||
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
|
||||
// type Heartbeat uint64
|
||||
|
||||
// Heartbeat sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
|
||||
func (c *Gateway) Heartbeat() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.HeartbeatCtx(ctx)
|
||||
}
|
||||
|
||||
// HeartbeatCtx sends a Heartbeat operation (opcode 3) to the Gateway Gateway.
|
||||
func (c *Gateway) HeartbeatCtx(ctx context.Context) error {
|
||||
return c.SendCtx(ctx, HeartbeatOP, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// https://discord.com/developers/docs/topics/voice-connections#speaking
|
||||
type SpeakingFlag uint64
|
||||
|
||||
const (
|
||||
NotSpeaking SpeakingFlag = 0
|
||||
Microphone SpeakingFlag = 1 << iota
|
||||
Soundshare
|
||||
Priority
|
||||
)
|
||||
|
||||
// OPCode 5
|
||||
// https://discord.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload
|
||||
type SpeakingData struct {
|
||||
Speaking SpeakingFlag `json:"speaking"`
|
||||
Delay int `json:"delay"`
|
||||
SSRC uint32 `json:"ssrc"`
|
||||
UserID discord.UserID `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
|
||||
func (c *Gateway) Speaking(flag SpeakingFlag) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.SpeakingCtx(ctx, flag)
|
||||
}
|
||||
|
||||
// SpeakingCtx sends a Speaking operation (opcode 5) to the Gateway Gateway.
|
||||
func (c *Gateway) SpeakingCtx(ctx context.Context, flag SpeakingFlag) error {
|
||||
// How do we allow a user to stop speaking?
|
||||
// Also: https://discordapp.com/developers/docs/topics/voice-connections#voice-data-interpolation
|
||||
|
||||
return c.SendCtx(ctx, SpeakingOP, SpeakingData{
|
||||
Speaking: flag,
|
||||
Delay: 0,
|
||||
SSRC: c.ready.SSRC,
|
||||
})
|
||||
}
|
||||
|
||||
// OPCode 7
|
||||
// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload
|
||||
type ResumeData struct {
|
||||
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
|
||||
SessionID string `json:"session_id"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// Resume sends a Resume operation (opcode 7) to the Gateway Gateway.
|
||||
func (c *Gateway) Resume() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
return c.ResumeCtx(ctx)
|
||||
}
|
||||
|
||||
// ResumeCtx sends a Resume operation (opcode 7) to the Gateway Gateway.
|
||||
func (c *Gateway) ResumeCtx(ctx context.Context) error {
|
||||
guildID := c.state.GuildID
|
||||
sessionID := c.state.SessionID
|
||||
token := c.state.Token
|
||||
|
||||
if !guildID.IsValid() || sessionID == "" || token == "" {
|
||||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
return c.SendCtx(ctx, ResumeOP, ResumeData{
|
||||
GuildID: guildID,
|
||||
SessionID: sessionID,
|
||||
Token: token,
|
||||
})
|
||||
}
|
94
voice/voicegateway/event_methods.go
Normal file
94
voice/voicegateway/event_methods.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
// Code generated by genevent. DO NOT EDIT.
|
||||
|
||||
package voicegateway
|
||||
|
||||
import "github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
|
||||
func init() {
|
||||
OpUnmarshalers.Add(
|
||||
func() ws.Event { return new(IdentifyCommand) },
|
||||
func() ws.Event { return new(SelectProtocolCommand) },
|
||||
func() ws.Event { return new(ReadyEvent) },
|
||||
func() ws.Event { return new(HeartbeatCommand) },
|
||||
func() ws.Event { return new(SessionDescriptionEvent) },
|
||||
func() ws.Event { return new(SpeakingEvent) },
|
||||
func() ws.Event { return new(HeartbeatAckEvent) },
|
||||
func() ws.Event { return new(ResumeCommand) },
|
||||
func() ws.Event { return new(HelloEvent) },
|
||||
func() ws.Event { return new(ResumedEvent) },
|
||||
func() ws.Event { return new(ClientConnectEvent) },
|
||||
func() ws.Event { return new(ClientDisconnectEvent) },
|
||||
)
|
||||
}
|
||||
|
||||
// Op implements Event. It always returns Op 0.
|
||||
func (*IdentifyCommand) Op() ws.OpCode { return 0 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*IdentifyCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 1.
|
||||
func (*SelectProtocolCommand) Op() ws.OpCode { return 1 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*SelectProtocolCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 2.
|
||||
func (*ReadyEvent) Op() ws.OpCode { return 2 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ReadyEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 3.
|
||||
func (*HeartbeatCommand) Op() ws.OpCode { return 3 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*HeartbeatCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 4.
|
||||
func (*SessionDescriptionEvent) Op() ws.OpCode { return 4 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*SessionDescriptionEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 5.
|
||||
func (*SpeakingEvent) Op() ws.OpCode { return 5 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*SpeakingEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 6.
|
||||
func (*HeartbeatAckEvent) Op() ws.OpCode { return 6 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*HeartbeatAckEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 7.
|
||||
func (*ResumeCommand) Op() ws.OpCode { return 7 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ResumeCommand) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 8.
|
||||
func (*HelloEvent) Op() ws.OpCode { return 8 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*HelloEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 9.
|
||||
func (*ResumedEvent) Op() ws.OpCode { return 9 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ResumedEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 12.
|
||||
func (*ClientConnectEvent) Op() ws.OpCode { return 12 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ClientConnectEvent) EventType() ws.EventType { return "" }
|
||||
|
||||
// Op implements Event. It always returns Op 13.
|
||||
func (*ClientDisconnectEvent) Op() ws.OpCode { return 13 }
|
||||
|
||||
// EventType implements Event.
|
||||
func (*ClientDisconnectEvent) EventType() ws.EventType { return "" }
|
|
@ -4,9 +4,41 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
)
|
||||
|
||||
// OPCode 2
|
||||
//go:generate go run ../../utils/cmd/genevent -p voicegateway -o event_methods.go
|
||||
|
||||
// OpUnmarshalers contains the Op unmarshalers for the voice gateway events.
|
||||
var OpUnmarshalers = ws.NewOpUnmarshalers()
|
||||
|
||||
// IdentifyCommand is a command for Op 0.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-identify-payload
|
||||
type IdentifyCommand struct {
|
||||
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
|
||||
UserID discord.UserID `json:"user_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// SelectProtocolCommand is a command for Op 1.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-select-protocol-payload
|
||||
type SelectProtocolCommand struct {
|
||||
Protocol string `json:"protocol"`
|
||||
Data SelectProtocolData `json:"data"`
|
||||
}
|
||||
|
||||
// SelectProtocolData is the data inside a SelectProtocolCommand.
|
||||
type SelectProtocolData struct {
|
||||
Address string `json:"address"`
|
||||
Port uint16 `json:"port"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
// ReadyEvent is an event for Op 2.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection-example-voice-ready-payload
|
||||
type ReadyEvent struct {
|
||||
SSRC uint32 `json:"ssrc"`
|
||||
|
@ -23,45 +55,79 @@ type ReadyEvent struct {
|
|||
// HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
|
||||
}
|
||||
|
||||
// Addr formats the URL inside Ready to be of format "host:port".
|
||||
func (r ReadyEvent) Addr() string {
|
||||
return r.IP + ":" + strconv.Itoa(r.Port)
|
||||
}
|
||||
|
||||
// OPCode 4
|
||||
// HeartbeatCommand is a command for Op 3.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-payload
|
||||
type HeartbeatCommand uint64
|
||||
|
||||
// SessionDescriptionEvent is an event for Op 4.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-example-session-description-payload
|
||||
type SessionDescriptionEvent struct {
|
||||
Mode string `json:"mode"`
|
||||
SecretKey [32]byte `json:"secret_key"`
|
||||
}
|
||||
|
||||
// OPCode 5
|
||||
type SpeakingEvent SpeakingData
|
||||
// https://discord.com/developers/docs/topics/voice-connections#speaking
|
||||
type SpeakingFlag uint64
|
||||
|
||||
// OPCode 6
|
||||
const (
|
||||
NotSpeaking SpeakingFlag = 0
|
||||
Microphone SpeakingFlag = 1 << iota
|
||||
Soundshare
|
||||
Priority
|
||||
)
|
||||
|
||||
// SpeakingEvent is an event for Op 5. It is also a command.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#speaking-example-speaking-payload
|
||||
type SpeakingEvent struct {
|
||||
Speaking SpeakingFlag `json:"speaking"`
|
||||
Delay int `json:"delay"`
|
||||
SSRC uint32 `json:"ssrc"`
|
||||
UserID discord.UserID `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// HeartbeatAckEvent is an event for Op 6.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-heartbeat-ack-payload
|
||||
type HeartbeatACKEvent uint64
|
||||
type HeartbeatAckEvent uint64
|
||||
|
||||
// OPCode 8
|
||||
// ResumeCommand is a command for Op 7.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resume-connection-payload
|
||||
type ResumeCommand struct {
|
||||
GuildID discord.GuildID `json:"server_id"` // yes, this should be "server_id"
|
||||
SessionID string `json:"session_id"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// HelloEvent is an event for Op 8.
|
||||
//
|
||||
// https://discord.com/developers/docs/topics/voice-connections#heartbeating-example-hello-payload-since-v3
|
||||
type HelloEvent struct {
|
||||
HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
|
||||
}
|
||||
|
||||
// OPCode 9
|
||||
// ResumedEvent is an event for Op 9.
|
||||
// https://discord.com/developers/docs/topics/voice-connections#resuming-voice-connection-example-resumed-payload
|
||||
type ResumedEvent struct{}
|
||||
|
||||
// OPCode 12
|
||||
// (undocumented)
|
||||
// ClientConnectEvent is an event for Op 12. It is undocumented.
|
||||
type ClientConnectEvent struct {
|
||||
UserID discord.UserID `json:"user_id"`
|
||||
AudioSSRC uint32 `json:"audio_ssrc"`
|
||||
VideoSSRC uint32 `json:"video_ssrc"`
|
||||
}
|
||||
|
||||
// OPCode 13
|
||||
// Undocumented, existence mentioned in below issue
|
||||
// https://github.com/discord/discord-api-docs/issues/510
|
||||
// ClientDisconnectEvent is an event for Op 13. It is undocumented, but its
|
||||
// existence is mentioned in this issue:
|
||||
// https://github.com/discord/discord-api-docs/issues/510.
|
||||
type ClientDisconnectEvent struct {
|
||||
UserID discord.UserID `json:"user_id"`
|
||||
}
|
||||
|
|
|
@ -18,9 +18,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/discord"
|
||||
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
"github.com/diamondburned/arikawa/v3/utils/ws"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -37,9 +35,9 @@ type Event = interface{}
|
|||
|
||||
// State contains state information of a voice gateway.
|
||||
type State struct {
|
||||
GuildID discord.GuildID
|
||||
UserID discord.UserID // constant
|
||||
GuildID discord.GuildID // constant
|
||||
ChannelID discord.ChannelID
|
||||
UserID discord.UserID
|
||||
|
||||
SessionID string
|
||||
Token string
|
||||
|
@ -48,282 +46,164 @@ type State struct {
|
|||
|
||||
// Gateway represents a Discord Gateway Gateway connection.
|
||||
type Gateway struct {
|
||||
state State // constant
|
||||
gateway *ws.Gateway
|
||||
state State // constant
|
||||
|
||||
mutex sync.RWMutex
|
||||
ready ReadyEvent
|
||||
|
||||
WS *wsutil.Websocket
|
||||
|
||||
Timeout time.Duration
|
||||
reconnect moreatomic.Bool
|
||||
|
||||
EventLoop wsutil.PacemakerLoop
|
||||
Events chan Event
|
||||
|
||||
// ErrorLog will be called when an error occurs (defaults to log.Println)
|
||||
ErrorLog func(err error)
|
||||
// AfterClose is called after each close. Error can be non-nil, as this is
|
||||
// called even when the Gateway is gracefully closed. It's used mainly for
|
||||
// reconnections or any type of connection interruptions. (defaults to noop)
|
||||
AfterClose func(err error)
|
||||
|
||||
// Filled by methods, internal use
|
||||
waitGroup *sync.WaitGroup
|
||||
ready *ReadyEvent
|
||||
}
|
||||
|
||||
// DefaultGatewayOpts contains the default options to be used for connecting to
|
||||
// the gateway.
|
||||
var DefaultGatewayOpts = ws.GatewayOpts{
|
||||
ReconnectDelay: func(try int) time.Duration {
|
||||
// minimum 4 seconds
|
||||
return time.Duration(4+(2*try)) * time.Second
|
||||
},
|
||||
// FatalCloseCodes contains the default gateway close codes that will cause
|
||||
// the gateway to exit. In other words, it's a list of unrecoverable close
|
||||
// codes.
|
||||
FatalCloseCodes: []int{
|
||||
4003, // not authenticated
|
||||
4004, // authentication failed
|
||||
4006, // session invalid
|
||||
4009, // session timed out
|
||||
4011, // server not found
|
||||
4012, // unknown protocol
|
||||
4014, // disconnected
|
||||
4016, // unknown encryption mode
|
||||
},
|
||||
DialTimeout: 0,
|
||||
ReconnectAttempt: 0,
|
||||
AlwaysCloseGracefully: true,
|
||||
}
|
||||
|
||||
// New creates a new voice gateway.
|
||||
func New(state State) *Gateway {
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
|
||||
var endpoint = "wss://" + strings.TrimSuffix(state.Endpoint, ":80") + "/?v=" + Version
|
||||
|
||||
gw := ws.NewGateway(
|
||||
ws.NewWebsocket(ws.NewCodec(OpUnmarshalers), endpoint),
|
||||
&DefaultGatewayOpts,
|
||||
)
|
||||
|
||||
return &Gateway{
|
||||
state: state,
|
||||
WS: wsutil.New(endpoint),
|
||||
Timeout: wsutil.WSTimeout,
|
||||
Events: make(chan Event, wsutil.WSBuffer),
|
||||
ErrorLog: wsutil.WSError,
|
||||
AfterClose: func(error) {},
|
||||
gateway: gw,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: get rid of
|
||||
func (c *Gateway) Ready() ReadyEvent {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
// Ready returns the ready event.
|
||||
func (g *Gateway) Ready() *ReadyEvent {
|
||||
g.mutex.RLock()
|
||||
defer g.mutex.RUnlock()
|
||||
|
||||
return c.ready
|
||||
return g.ready
|
||||
}
|
||||
|
||||
// OpenCtx shouldn't be used, but JoinServer instead.
|
||||
func (c *Gateway) OpenCtx(ctx context.Context) error {
|
||||
if c.state.Endpoint == "" {
|
||||
return errors.New("missing endpoint in state")
|
||||
}
|
||||
|
||||
// https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
|
||||
var endpoint = "wss://" + strings.TrimSuffix(c.state.Endpoint, ":80") + "/?v=" + Version
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Connecting to voice endpoint (endpoint=" + endpoint + ")")
|
||||
|
||||
// Create a new context with a timeout for the connection.
|
||||
ctx, cancel := context.WithTimeout(ctx, c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Connect to the Gateway Gateway.
|
||||
if err := c.WS.Dial(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to connect to voice gateway")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Trying to start...")
|
||||
|
||||
// Try to start or resume the connection.
|
||||
if err := c.start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
// LastError returns the last error that the gateway has received. It only
|
||||
// returns a valid error if the gateway's event loop as exited. If the event
|
||||
// loop hasn't been started AND stopped, the function will panic.
|
||||
func (g *Gateway) LastError() error {
|
||||
return g.gateway.LastError()
|
||||
}
|
||||
|
||||
// Start .
|
||||
func (c *Gateway) start(ctx context.Context) error {
|
||||
if err := c.__start(ctx); err != nil {
|
||||
wsutil.WSDebug("VoiceGateway: Start failed: ", err)
|
||||
|
||||
// Close can be called with the mutex still acquired here, as the
|
||||
// pacemaker hasn't started yet.
|
||||
if err := c.Close(); err != nil {
|
||||
wsutil.WSDebug("VoiceGateway: Failed to close after start fail: ", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
// Send is a function to send an Op payload to the Gateway.
|
||||
func (g *Gateway) Send(ctx context.Context, data ws.Event) error {
|
||||
return g.gateway.Send(ctx, data)
|
||||
}
|
||||
|
||||
// this function blocks until READY.
|
||||
func (c *Gateway) __start(ctx context.Context) error {
|
||||
// Make a new WaitGroup for use in background loops:
|
||||
c.waitGroup = new(sync.WaitGroup)
|
||||
// Speaking sends a Speaking operation (opcode 5) to the Gateway Gateway.
|
||||
func (g *Gateway) Speaking(ctx context.Context, flag SpeakingFlag) error {
|
||||
g.mutex.RLock()
|
||||
ssrc := g.ready.SSRC
|
||||
g.mutex.RUnlock()
|
||||
|
||||
ch := c.WS.Listen()
|
||||
return g.gateway.Send(ctx, &SpeakingEvent{
|
||||
Speaking: flag,
|
||||
Delay: 0,
|
||||
SSRC: ssrc,
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for hello.
|
||||
wsutil.WSDebug("VoiceGateway: Waiting for Hello..")
|
||||
func (g *Gateway) Connect(ctx context.Context) <-chan ws.Op {
|
||||
return g.gateway.Connect(ctx, (*gatewayImpl)(g))
|
||||
}
|
||||
|
||||
var hello *HelloEvent
|
||||
// Wait for the Hello event; return if it times out.
|
||||
select {
|
||||
case e, ok := <-ch:
|
||||
if !ok {
|
||||
return errors.New("unexpected ws close while waiting for Hello")
|
||||
}
|
||||
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "error at Hello")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
|
||||
var (
|
||||
// ErrMissingForIdentify is an error when we are missing information to
|
||||
// identify.
|
||||
ErrMissingForIdentify = errors.New("missing GuildID, UserID, SessionID, or Token for identify")
|
||||
// ErrMissingForResume is an error when we are missing information to
|
||||
// resume.
|
||||
ErrMissingForResume = errors.New("missing GuildID, SessionID, or Token for resuming")
|
||||
)
|
||||
|
||||
type gatewayImpl Gateway
|
||||
|
||||
func (g *gatewayImpl) sendIdentify(ctx context.Context) error {
|
||||
id := IdentifyCommand{
|
||||
GuildID: g.state.GuildID,
|
||||
UserID: g.state.UserID,
|
||||
SessionID: g.state.SessionID,
|
||||
Token: g.state.Token,
|
||||
}
|
||||
if !id.GuildID.IsValid() || id == (IdentifyCommand{}) {
|
||||
return ErrMissingForIdentify
|
||||
}
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Received Hello")
|
||||
return g.gateway.Send(ctx, &id)
|
||||
}
|
||||
|
||||
// Start the event handler, which also handles the pacemaker death signal.
|
||||
c.waitGroup.Add(1)
|
||||
func (g *gatewayImpl) sendResume(ctx context.Context) error {
|
||||
if !g.state.GuildID.IsValid() || g.state.SessionID == "" || g.state.Token == "" {
|
||||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
c.EventLoop.StartBeating(hello.HeartbeatInterval.Duration(), c, func(err error) {
|
||||
c.waitGroup.Done() // mark so Close() can exit.
|
||||
wsutil.WSDebug("VoiceGateway: Event loop stopped.")
|
||||
return g.gateway.Send(ctx, &ResumeCommand{
|
||||
GuildID: g.state.GuildID,
|
||||
SessionID: g.state.SessionID,
|
||||
Token: g.state.Token,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.ErrorLog(err)
|
||||
func (g *gatewayImpl) OnOp(ctx context.Context, op ws.Op) bool {
|
||||
switch data := op.Data.(type) {
|
||||
case *HelloEvent:
|
||||
g.gateway.ResetHeartbeat(data.HeartbeatInterval.Duration())
|
||||
|
||||
if err := c.Reconnect(); err != nil {
|
||||
c.ErrorLog(errors.Wrap(err, "failed to reconnect voice"))
|
||||
// Send Discord either the Identify packet (if it's a fresh
|
||||
// connection), or a Resume packet (if it's a dead connection).
|
||||
if g.ready == nil {
|
||||
// SessionID is empty, so this is a completely new session.
|
||||
if err := g.sendIdentify(ctx); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "failed to send identify")
|
||||
g.gateway.QueueReconnect()
|
||||
}
|
||||
} else {
|
||||
if err := g.sendResume(ctx); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "failed to send resume")
|
||||
g.gateway.QueueReconnect()
|
||||
}
|
||||
|
||||
// Reconnect should spawn another eventLoop in its Start function.
|
||||
}
|
||||
})
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/voice-connections#establishing-a-voice-websocket-connection
|
||||
// Turns out Hello is sent right away on connection start.
|
||||
if !c.reconnect.Get() {
|
||||
if err := c.IdentifyCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to identify")
|
||||
}
|
||||
} else {
|
||||
if err := c.ResumeCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to resume")
|
||||
}
|
||||
}
|
||||
// This bool is because we should only try and Resume once.
|
||||
c.reconnect.Set(false)
|
||||
|
||||
// Wait for either Ready or Resumed.
|
||||
err := wsutil.WaitForEvent(ctx, c, ch, func(op *wsutil.OP) bool {
|
||||
return op.Code == ReadyOP || op.Code == ResumedOP
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to wait for Ready or Resumed")
|
||||
case *ReadyEvent:
|
||||
g.mutex.Lock()
|
||||
g.ready = data
|
||||
g.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Bind the event channel away.
|
||||
c.EventLoop.SetEventChannel(ch)
|
||||
return true
|
||||
}
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Started successfully.")
|
||||
func (g *gatewayImpl) SendHeartbeat(ctx context.Context) {
|
||||
heartbeat := HeartbeatCommand(time.Now().UnixNano())
|
||||
if err := g.gateway.Send(ctx, &heartbeat); err != nil {
|
||||
g.gateway.SendErrorWrap(err, "heartbeat error")
|
||||
g.gateway.QueueReconnect()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gatewayImpl) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
wsutil.WSDebug("VoiceGateway: Trying to close. Pacemaker check skipped.")
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Closing the Websocket...")
|
||||
err := g.WS.Close()
|
||||
|
||||
if errors.Is(err, wsutil.ErrWebsocketClosed) {
|
||||
wsutil.WSDebug("VoiceGateway: Websocket already closed.")
|
||||
return nil
|
||||
}
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Websocket closed; error:", err)
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Waiting for the Pacemaker loop to exit.")
|
||||
g.waitGroup.Wait()
|
||||
wsutil.WSDebug("VoiceGateway: Pacemaker loop exited.")
|
||||
|
||||
g.AfterClose(err)
|
||||
wsutil.WSDebug("VoiceGateway: AfterClose callback finished.")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Gateway) Reconnect() error {
|
||||
return c.ReconnectCtx(context.Background())
|
||||
}
|
||||
|
||||
func (c *Gateway) ReconnectCtx(ctx context.Context) error {
|
||||
wsutil.WSDebug("VoiceGateway: Reconnecting...")
|
||||
|
||||
// TODO: implement a reconnect loop
|
||||
|
||||
// Guarantee the gateway is already closed. Ignore its error, as we're
|
||||
// redialing anyway.
|
||||
c.Close()
|
||||
|
||||
c.reconnect.Set(true)
|
||||
|
||||
// Condition: err == ErrInvalidSession:
|
||||
// If the connection is rate limited (documented behavior):
|
||||
// https://discord.com/developers/docs/topics/gateway#rate-limiting
|
||||
|
||||
if err := c.OpenCtx(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to reopen gateway")
|
||||
}
|
||||
|
||||
wsutil.WSDebug("VoiceGateway: Reconnected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Gateway) SessionDescriptionCtx(
|
||||
ctx context.Context, sp SelectProtocol) (*SessionDescriptionEvent, error) {
|
||||
|
||||
// Add the handler first.
|
||||
ch, cancel := c.EventLoop.Extras.Add(func(op *wsutil.OP) bool {
|
||||
return op.Code == SessionDescriptionOP
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
if err := c.SelectProtocolCtx(ctx, sp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sesdesc *SessionDescriptionEvent
|
||||
|
||||
// Wait for SessionDescriptionOP packet.
|
||||
select {
|
||||
case e, ok := <-ch:
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected close waiting for session description")
|
||||
}
|
||||
if err := e.UnmarshalData(&sesdesc); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal session description")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, errors.Wrap(ctx.Err(), "failed to wait for session description")
|
||||
}
|
||||
|
||||
return sesdesc, nil
|
||||
}
|
||||
|
||||
// Send sends a payload to the Gateway with the default timeout.
|
||||
func (c *Gateway) Send(code OPCode, v interface{}) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.SendCtx(ctx, code, v)
|
||||
}
|
||||
|
||||
func (c *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
|
||||
var op = wsutil.OP{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
if v != nil {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to encode v")
|
||||
}
|
||||
|
||||
op.Data = b
|
||||
}
|
||||
|
||||
b, err := json.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to encode payload")
|
||||
}
|
||||
|
||||
// WS should already be thread-safe.
|
||||
return c.WS.SendCtx(ctx, b)
|
||||
}
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
package voicegateway
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/v3/utils/json"
|
||||
"github.com/diamondburned/arikawa/v3/utils/wsutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// OPCode represents a Discord Gateway Gateway operation code.
|
||||
type OPCode = wsutil.OPCode
|
||||
|
||||
const (
|
||||
IdentifyOP OPCode = 0 // send
|
||||
SelectProtocolOP OPCode = 1 // send
|
||||
ReadyOP OPCode = 2 // receive
|
||||
HeartbeatOP OPCode = 3 // send
|
||||
SessionDescriptionOP OPCode = 4 // receive
|
||||
SpeakingOP OPCode = 5 // send/receive
|
||||
HeartbeatAckOP OPCode = 6 // receive
|
||||
ResumeOP OPCode = 7 // send
|
||||
HelloOP OPCode = 8 // receive
|
||||
ResumedOP OPCode = 9 // receive
|
||||
ClientConnectOP OPCode = 12 // receive
|
||||
ClientDisconnectOP OPCode = 13 // receive
|
||||
)
|
||||
|
||||
func (c *Gateway) HandleOP(op *wsutil.OP) error {
|
||||
wsutil.WSDebug("Handle OP", op.Code)
|
||||
switch op.Code {
|
||||
// Gives information required to make a UDP connection
|
||||
case ReadyOP:
|
||||
if err := unmarshalMutex(op.Data, &c.ready, &c.mutex); err != nil {
|
||||
return errors.Wrap(err, "failed to parse READY event")
|
||||
}
|
||||
|
||||
c.Events <- &c.ready
|
||||
|
||||
// Gives information about the encryption mode and secret key for sending voice packets
|
||||
case SessionDescriptionOP:
|
||||
// ?
|
||||
// Already handled by Session.
|
||||
|
||||
// Someone started or stopped speaking.
|
||||
case SpeakingOP:
|
||||
ev := new(SpeakingEvent)
|
||||
|
||||
if err := json.Unmarshal(op.Data, ev); err != nil {
|
||||
return errors.Wrap(err, "failed to parse Speaking event")
|
||||
}
|
||||
|
||||
c.Events <- ev
|
||||
|
||||
// Heartbeat response from the server
|
||||
case HeartbeatAckOP:
|
||||
c.EventLoop.Echo()
|
||||
|
||||
// Hello server, we hear you! :)
|
||||
case HelloOP:
|
||||
// ?
|
||||
// Already handled on initial connection.
|
||||
|
||||
// Server is saying the connection was resumed, no data here.
|
||||
case ResumedOP:
|
||||
wsutil.WSDebug("Gateway connection has been resumed.")
|
||||
|
||||
case ClientConnectOP:
|
||||
ev := new(ClientConnectEvent)
|
||||
|
||||
if err := json.Unmarshal(op.Data, ev); err != nil {
|
||||
return errors.Wrap(err, "failed to parse Speaking event")
|
||||
}
|
||||
|
||||
c.Events <- ev
|
||||
|
||||
case ClientDisconnectOP:
|
||||
ev := new(ClientDisconnectEvent)
|
||||
|
||||
if err := json.Unmarshal(op.Data, ev); err != nil {
|
||||
return errors.Wrap(err, "failed to parse Speaking event")
|
||||
}
|
||||
|
||||
c.Events <- ev
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown OP code %d", op.Code)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalMutex(d []byte, v interface{}, m *sync.RWMutex) error {
|
||||
m.Lock()
|
||||
err := json.Unmarshal(d, v)
|
||||
m.Unlock()
|
||||
return err
|
||||
}
|
Loading…
Reference in a new issue