diff --git a/api/api.go b/api/api.go index 11bd37a..301a2c5 100644 --- a/api/api.go +++ b/api/api.go @@ -28,21 +28,15 @@ type Client struct { func NewClient(token string) *Client { cli := &Client{ - Client: httputil.NewClient(), + Client: httputil.DefaultClient, Limiter: rate.NewLimiter(), Token: token, } tw := httputil.NewTransportWrapper() tw.Pre = func(r *http.Request) error { - if r.Header.Get("Authorization") == "" { - r.Header.Set("Authorization", cli.Token) - } - - if r.UserAgent() == "" { - r.Header.Set("User-Agent", UserAgent) - } - + r.Header.Set("Authorization", cli.Token) + r.Header.Set("User-Agent", UserAgent) r.Header.Set("X-RateLimit-Precision", "millisecond") // Rate limit stuff diff --git a/api/user.go b/api/user.go index 22d0827..dff944e 100644 --- a/api/user.go +++ b/api/user.go @@ -128,4 +128,4 @@ func (c *Client) CreatePrivateChannel( // shitty SDK, don't care, PR welcomed // func (c *Client) CreateGroup(tokens []string, nicks map[]) -func (c *Client) UserConnections() ([]discord.Connection, error) {} +// func (c *Client) UserConnections() ([]discord.Connection, error) {} diff --git a/discord/snowflake.go b/discord/snowflake.go index d64bd0e..e0c8113 100644 --- a/discord/snowflake.go +++ b/discord/snowflake.go @@ -33,6 +33,10 @@ func (s Snowflake) String() string { return strconv.FormatUint(uint64(s), 10) } +func (s Snowflake) Valid() bool { + return uint64(s) < 1 +} + func (s Snowflake) Time() time.Time { return time.Unix(0, int64(s)>>22*1000000+DiscordEpoch) } diff --git a/discord/time.go b/discord/time.go index 3f463ac..bc0a366 100644 --- a/discord/time.go +++ b/discord/time.go @@ -28,7 +28,33 @@ func (t Timestamp) MarshalJSON() ([]byte, error) { return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil } -type Seconds uint +// + +type UnixTimestamp int64 + +func (t UnixTimestamp) String() string { + return t.Time().String() +} + +func (t UnixTimestamp) Time() time.Time { + return time.Unix(int64(t), 0) +} + +// + +type UnixMsTimestamp int64 + +func (t UnixMsTimestamp) String() string { + return t.Time().String() +} + +func (t UnixMsTimestamp) Time() time.Time { + return time.Unix(0, int64(t)*int64(time.Millisecond)) +} + +// + +type Seconds int func DurationToSeconds(dura time.Duration) Seconds { return Seconds(dura.Seconds()) @@ -41,3 +67,19 @@ func (s Seconds) String() string { func (s Seconds) Duration() time.Duration { return time.Duration(s) * time.Second } + +// + +type Milliseconds int + +func DurationToMilliseconds(dura time.Duration) Milliseconds { + return Milliseconds(dura.Milliseconds()) +} + +func (ms Milliseconds) String() string { + return ms.Duration().String() +} + +func (ms Milliseconds) Duration() time.Duration { + return time.Duration(ms) * time.Millisecond +} diff --git a/gateway/activity.go b/gateway/activity.go new file mode 100644 index 0000000..e5fcbfe --- /dev/null +++ b/gateway/activity.go @@ -0,0 +1,78 @@ +package gateway + +import "github.com/diamondburned/arikawa/discord" + +type Status string + +const ( + UnknownStatus Status = "" + OnlineStatus Status = "online" + DoNotDisturbStatus Status = "dnd" + IdleStatus Status = "idle" + InvisibleStatus Status = "invisible" + OfflineStatus Status = "offline" +) + +type Activity struct { + Name string `json:"name"` + Type ActivityType `json:"type"` + URL discord.URL `json:"url"` + + // User only + + CreatedAt discord.UnixTimestamp `json:"created_at"` + Timestamps struct { + Start discord.UnixMsTimestamp `json:"start,omitempty"` + End discord.UnixMsTimestamp `json:"end,omitempty"` + } `json:"timestamps,omitempty"` + + ApplicationID discord.Snowflake `json:"application_id,omitempty"` + Details string `json:"details,omitempty"` + State string `json:"state,omitempty"` // party status + Emoji discord.Emoji `json:"emoji,omitempty"` + + Party struct { + ID string `json:"id,omitempty"` + Size [2]int `json:"size,omitempty"` // [ current, max ] + } `json:"party,omitempty"` + + Assets struct { + LargeImage string `json:"large_image,omitempty"` // id + LargeText string `json:"large_text,omitempty"` + SmallImage string `json:"small_image,omitempty"` // id + SmallText string `json:"small_text,omitempty"` + } `json:"assets,omitempty"` + + Secrets struct { + Join string `json:"join,omitempty"` + Spectate string `json:"spectate,omitempty"` + Match string `json:"match,omitempty"` + } `json:"secrets,omitempty"` + + Instance bool `json:"instance,omitempty"` + Flags ActivityFlags `json:"flags,omitempty"` +} + +type ActivityType uint8 + +const ( + // Playing $name + GameActivity ActivityType = iota + // Streaming $details + StreamingActivity + // Listening to $name + ListeningActivity + // $emoji $name + CustomActivity +) + +type ActivityFlags uint8 + +const ( + InstanceActivity ActivityFlags = 1 << iota + JoinActivity + SpectateActivity + JoinRequestActivity + SyncActivity + PlayActivity +) diff --git a/gateway/commands.go b/gateway/commands.go new file mode 100644 index 0000000..9d24121 --- /dev/null +++ b/gateway/commands.go @@ -0,0 +1,75 @@ +package gateway + +import ( + "github.com/diamondburned/arikawa/discord" +) + +// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent + +type IdentifyData struct { + Token string `json:"token"` + Properties IdentifyProperties `json:"properties"` + + Compress bool `json:"compress,omitempty"` // true + LargeThreshold uint `json:"large_threshold,omitempty"` // 50 + GuildSubscription bool `json:"guild_subscriptions"` // true + + Shard [2]int `json:"shard"` // [ shard_id, num_shards ] + + Presence UpdateStatusData `json:"presence,omitempty"` +} + +type IdentifyProperties struct { + // Required + OS string `json:"os"` // GOOS + Browser string `json:"browser"` // Arikawa + Device string `json:"device"` // Arikawa + + // Optional + BrowserUserAgent string `json:"browser_user_agent,omitempty"` + BrowserVersion string `json:"browser_version,omitempty"` + OsVersion string `json:"os_version,omitempty"` + Referrer string `json:"referrer,omitempty"` + ReferringDomain string `json:"referring_domain,omitempty"` +} + +func (g *Gateway) Identify() error { + return g.Send(IdentifyOP, g.Identity) +} + +type ResumeData struct { + Token string `json:"token"` + SessionID string `json:"session_id"` + Sequence int64 `json:"seq"` +} + +// HeartbeatData is the last sequence number to be sent. +type HeartbeatData int + +func (g *Gateway) Heartbeat() error { + return g.Send(HeartbeatOP, g.Sequence.Get()) +} + +type RequestGuildMembersData struct { + GuildID []discord.Snowflake `json:"guild_id"` + UserIDs []discord.Snowflake `json:"user_id,omitempty"` + + Query string `json:"query,omitempty"` + Limit uint `json:"limit"` + Presences bool `json:"presences,omitempty"` +} + +type UpdateVoiceStateData struct { + GuildID discord.Snowflake `json:"guild_id"` + ChannelID discord.Snowflake `json:"channel_id"` + SelfMute bool `json:"self_mute"` + SelfDeaf bool `json:"self_deaf"` +} + +type UpdateStatusData struct { + Since discord.Milliseconds `json:"since,omitempty"` // 0 if not idle + Game *Activity `json:"game,omitempty"` // nullable + + Status Status `json:"status"` + AFK bool `json:"afk"` +} diff --git a/gateway/events.go b/gateway/events.go new file mode 100644 index 0000000..4539ac8 --- /dev/null +++ b/gateway/events.go @@ -0,0 +1,199 @@ +package gateway + +import "github.com/diamondburned/arikawa/discord" + +// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent + +// https://discordapp.com/developers/docs/topics/gateway#connecting-and-resuming +type ( + HelloEvent struct { + HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"` + } + + ReadyEvent struct { + Version int `json:"version"` + + User discord.User `json:"user"` + SessionID string `json:"session_id"` + + PrivateChannels []discord.Channel `json:"private_channels"` + Guilds []discord.Guild `json:"guilds"` + + Shard [2]int `json:"shard"` // [ shard_id num_shards ] + } + + ResumedEvent struct{} + + // InvalidSessionEvent indicates if the event is resumable. + InvalidSessionEvent bool +) + +// https://discordapp.com/developers/docs/topics/gateway#channels +type ( + ChannelCreateEvent discord.Channel + ChannelUpdateEvent discord.Channel + ChannelDeleteEvent discord.Channel + ChannelPinsUpdateEvent struct { + GuildID discord.Snowflake `json:"guild_id,omitempty"` + ChannelID discord.Snowflake `json:"channel_id,omitempty"` + LastPin discord.Timestamp `json:"timestamp,omitempty"` + } +) + +// https://discordapp.com/developers/docs/topics/gateway#guilds +type ( + GuildCreateEvent discord.Guild + GuildUpdateEvent discord.Guild + GuildDeleteEvent struct { + ID discord.Snowflake `json:"id"` + // Unavailable if false == removed + Unavailable bool `json:"unavailable"` + } + + GuildBanAddEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + User discord.User `json:"user"` + } + GuildBanRemoveEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + User discord.User `json:"user"` + } + + GuildEmojisUpdateEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + Emojis []discord.Emoji `json:"emoji"` + } + + GuildIntegrationsUpdateEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + } + + GuildMemberAddEvent struct { + discord.Member + GuildID discord.Snowflake `json:"guild_id"` + } + GuildMemberRemoveEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + User discord.User `json:"user"` + } + GuildMemberUpdateEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + Roles []discord.Snowflake `json:"roles"` + User discord.User `json:"user"` + Nick string `json:"nick"` + } + + // GuildMembersChunkEvent is sent when Guild Request Members is called. + GuildMembersChunkEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + Members []discord.Member `json:"members"` + + // Whatever's not found goes here + NotFound []string `json:"not_found,omitempty"` + + // Only filled if requested + Presences []discord.Presence `json:"presences,omitempty"` + } + + GuildRoleCreateEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + Role discord.Role `json:"role"` + } + GuildRoleUpdateEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + Role discord.Role `json:"role"` + } + GuildRoleDeleteEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + RoleID discord.Snowflake `json:"role_id"` + } +) + +// https://discordapp.com/developers/docs/topics/gateway#messages +type ( + MessageCreateEvent discord.Message + MessageUpdateEvent discord.Message + MessageDeleteEvent struct { + ID discord.Snowflake `json:"id"` + ChannelID discord.Snowflake `json:"channel_id"` + GuildID discord.Snowflake `json:"guild_id,omitempty"` + } + MessageDeleteBulkEvent struct { + IDs []discord.Snowflake `json:"ids"` + ChannelID discord.Snowflake `json:"channel_id"` + GuildID discord.Snowflake `json:"guild_id,omitempty"` + } + + MessageReactionAddEvent struct { + UserID discord.Snowflake `json:"user_id"` + ChannelID discord.Snowflake `json:"channel_id"` + MessageID discord.Snowflake `json:"message_id"` + + Emoji discord.Emoji `json:"emoji,omitempty"` + + GuildID discord.Snowflake `json:"guild_id,omitempty"` + Member *discord.Member `json:"member,omitempty"` + } + MessageReactionRemoveEvent struct { + UserID discord.Snowflake `json:"user_id"` + ChannelID discord.Snowflake `json:"channel_id"` + MessageID discord.Snowflake `json:"message_id"` + + Emoji discord.Emoji `json:"emoji"` + + GuildID discord.Snowflake `json:"guild_id,omitempty"` + } + MessageReactionRemoveAllEvent struct { + ChannelID discord.Snowflake `json:"channel_id"` + } +) + +// https://discordapp.com/developers/docs/topics/gateway#presence +type ( + // Clients may only update their game status 5 times per 20 seconds. + PresenceUpdateEvent struct { + User discord.User `json:"user"` + Nick string `json:"nick"` + Roles []discord.Snowflake `json:"roles"` + GuildID discord.Snowflake `json:"guild_id"` + + PremiumSince discord.Timestamp `json:"premium_since,omitempty"` + + Game *Activity `json:"game"` + Activities []Activity `json:"activities"` + + Status Status `json:"status"` + ClientStatus struct { + Desktop Status `json:"status,omitempty"` + Mobile Status `json:"mobile,omitempty"` + Web Status `json:"web,omitempty"` + } `json:"client_status"` + } + TypingStartEvent struct { + ChannelID discord.Snowflake `json:"channel_id"` + UserID discord.Snowflake `json:"user_id"` + Timestamp discord.Timestamp `json:"timestamp"` + + GuildID discord.Snowflake `json:"guild_id,omitempty"` + Member *discord.Member `json:"member,omitempty"` + } + UserUpdateEvent discord.User +) + +// https://discordapp.com/developers/docs/topics/gateway#voice +type ( + VoiceStateUpdateEvent discord.VoiceState + VoiceServerUpdateEvent struct { + Token string `json:"token"` + GuildID discord.Snowflake `json:"guild_id"` + Endpoint string `json:"endpoint"` + } +) + +// https://discordapp.com/developers/docs/topics/gateway#webhooks +type ( + WebhooksUpdateEvent struct { + GuildID discord.Snowflake `json:"guild_id"` + ChannelID discord.Snowflake `json:"channel_id"` + } +) diff --git a/gateway/events_map.go b/gateway/events_map.go new file mode 100644 index 0000000..1ec0b2e --- /dev/null +++ b/gateway/events_map.go @@ -0,0 +1,59 @@ +package gateway + +// Event is any event struct. They have an "Event" suffixed to them. +type Event = interface{} + +var EventCreator = map[string]func() Event{ + "HELLO": func() Event { return new(HelloEvent) }, + "READY": func() Event { return new(ReadyEvent) }, + "RESUMED": func() Event { return new(ResumedEvent) }, + "INVALID_SESSION": func() Event { return new(InvalidSessionEvent) }, + + "CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) }, + "CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) }, + "CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) }, + "CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) }, + + "GUILD_CREATE": func() Event { return new(GuildCreateEvent) }, + "GUILD_UPDATE": func() Event { return new(GuildUpdateEvent) }, + "GUILD_DELETE": func() Event { return new(GuildDeleteEvent) }, + + "GUILD_BAN_ADD": func() Event { return new(GuildBanAddEvent) }, + "GUILD_BAN_REMOVE": func() Event { return new(GuildBanRemoveEvent) }, + + "GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) }, + "GUILD_INTEGRATIONS_UPDATE": func() Event { + return new(GuildIntegrationsUpdateEvent) + }, + + "GUILD_MEMBER_ADD": func() Event { return new(GuildMemberAddEvent) }, + "GUILD_MEMBER_REMOVE": func() Event { return new(GuildMemberRemoveEvent) }, + "GUILD_MEMBER_UPDATE": func() Event { return new(GuildMemberUpdateEvent) }, + "GUILD_MEMBERS_CHUNK": func() Event { return new(GuildMembersChunkEvent) }, + + "GUILD_ROLE_CREATE": func() Event { return new(GuildRoleCreateEvent) }, + "GUILD_ROLE_UPDATE": func() Event { return new(GuildRoleUpdateEvent) }, + "GUILD_ROLE_DELETE": func() Event { return new(GuildRoleDeleteEvent) }, + + "MESSAGE_CREATE": func() Event { return new(MessageCreateEvent) }, + "MESSAGE_UPDATE": func() Event { return new(MessageUpdateEvent) }, + "MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) }, + "MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) }, + + "MESSAGE_REACTION_ADD": func() Event { + return new(MessageReactionAddEvent) + }, + "MESSAGE_REACTION_REMOVE": func() Event { + return new(MessageReactionRemoveEvent) + }, + "MESSAGE_REACTION_REMOVE_ALL": func() Event { + return new(MessageReactionRemoveAllEvent) + }, + + "PRESENCE_UPDATE": func() Event { return new(PresenceUpdateEvent) }, + "TYPING_START": func() Event { return new(TypingStartEvent) }, + "USER_UPDATE": func() Event { return new(UserUpdateEvent) }, + + "VOICE_STATE_UPDATE": func() Event { return new(VoiceStateUpdateEvent) }, + "VOICE_SERVER_UPDATE": func() Event { return new(VoiceServerUpdateEvent) }, +} diff --git a/gateway/gateway.go b/gateway/gateway.go new file mode 100644 index 0000000..0e9b40a --- /dev/null +++ b/gateway/gateway.go @@ -0,0 +1,319 @@ +package gateway + +import ( + "context" + "net/url" + "runtime" + "time" + + "github.com/diamondburned/arikawa/api" + "github.com/diamondburned/arikawa/httputil" + "github.com/diamondburned/arikawa/json" + "github.com/diamondburned/arikawa/wsutil" + "github.com/pkg/errors" +) + +const ( + EndpointGateway = api.Endpoint + "gateway" + EndpointGatewayBot = api.EndpointGateway + "/bot" + + Version = "6" + Encoding = "json" +) + +var ( + // WSTimeout is the timeout for connecting and writing to the Websocket, + // before Gateway cancels and fails. + WSTimeout = wsutil.DefaultTimeout + // WSBuffer is the size of the Event channel. This has to be at least 1 to + // make space for the first Event: Ready or Resumed. + WSBuffer = 10 +) + +var ( + ErrMissingForResume = errors.New( + "missing session ID or sequence for resuming") + ErrWSMaxTries = errors.New("max tries reached") +) + +func GatewayURL() (string, error) { + var Gateway struct { + URL string `json:"url"` + } + + return Gateway.URL, httputil.DefaultClient.RequestJSON( + &Gateway, "GET", EndpointGateway) +} + +// Identity is used as the default identity when initializing a new Gateway. +var Identity = IdentifyProperties{ + OS: runtime.GOOS, + Browser: "Arikawa", + Device: "Arikawa", +} + +type Gateway struct { + WS *wsutil.Websocket + json.Driver + + // Timeout for connecting and writing to the Websocket, uses default + // WSTimeout (global). + WSTimeout time.Duration + // Retries on connect and reconnect. + WSRetries uint // 3 + + Events chan Event + SessionID string + + URL string // URL + + Identity *IdentifyData + Pacemaker *Pacemaker + Sequence Sequence + + ErrorLog func(err error) // default to log.Println + + // Only use for debugging + + // If this channel is non-nil, all incoming OP packets will also be sent + // here. + OP chan Event + + // Filled by methods, internal use + paceDeath <-chan error + handler chan struct{} +} + +// NewGateway starts a new Gateway with the default stdlib JSON driver. For more +// information, refer to NewGatewayWithDriver. +func NewGateway(token string) (*Gateway, error) { + return NewGatewayWithDriver(token, json.Default{}) +} + +// NewGatewayWithDriver connects to the Gateway and authenticates automatically. +func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { + URL, err := GatewayURL() + if err != nil { + return nil, errors.Wrap(err, "Failed to get gateway endpoint") + } + + g := &Gateway{ + URL: URL, + Driver: driver, + WSTimeout: WSTimeout, + Events: make(chan Event), + Identity: &IdentifyData{ + Token: token, + Properties: Identity, + Compress: true, + LargeThreshold: 50, + GuildSubscription: true, + }, + Sequence: NewSequence(), + } + + // Parameters for the gateway + param := url.Values{} + param.Set("v", Version) + param.Set("encoding", Encoding) + // Append the form to the URL + URL += "?" + param.Encode() + + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + // Create a new undialed Websocket. + ws, err := wsutil.NewCustom(ctx, wsutil.NewConn(driver), URL) + if err != nil { + return nil, errors.Wrap(err, "Failed to connect to Gateway "+URL) + } + g.WS = ws + + // Try and dial it + return g, g.connect() +} + +// Close closes the underlying Websocket connection. +func (g *Gateway) Close() error { + return g.WS.Close(nil) +} + +// Reconnects and resumes. +func (g *Gateway) Reconnect() error { + // Close, but we don't care about the error (I think) + g.Close() + // Actually a reconnect at this point. + return g.connect() +} + +// Resume sends to the Websocket a Resume OP, but it doesn't actually resume +// from a dead connection. Start() resumes from a dead connection. +func (g *Gateway) Resume() error { + var ( + ses = g.SessionID + seq = g.Sequence.Get() + ) + + if ses == "" || seq == 0 { + return ErrMissingForResume + } + + return g.Send(ResumeOP, ResumeData{ + Token: g.Identity.Token, + SessionID: ses, + Sequence: seq, + }) +} + +// Start authenticates with the websocket, or resume from a dead Websocket +// connection. This function doesn't block. To block until a fatal error, use +// Wait(). +func (g *Gateway) Start() error { + // This is where we'll get our events + ch := g.WS.Listen() + + // Wait for an OP 10 Hello + var hello HelloEvent + if err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil { + return errors.Wrap(err, "Error at Hello") + } + + // Send Discord either the Identify packet (if it's a fresh connection), or + // a Resume packet (if it's a dead connection). + if g.SessionID == "" { + // SessionID is empty, so this is a completely new session. + if err := g.Identify(); err != nil { + return errors.Wrap(err, "Failed to identify") + } + + // We should now expect a Ready event. + var ready ReadyEvent + if err := AssertEvent(g, <-ch, DispatchOP, &ready); err != nil { + return errors.Wrap(err, "Error at Ready") + } + } else { + if err := g.Resume(); err != nil { + return errors.Wrap(err, "Failed to resume") + } + + // We should now expect a Resumed event. + var resumed ResumedEvent + if err := AssertEvent(g, <-ch, DispatchOP, &resumed); err != nil { + return errors.Wrap(err, "Error at Resumed") + } + } + + // Start the event handler + g.handler = make(chan struct{}) + go g.handleWS(g.handler) + + // Start the pacemaker with the heartrate received from Hello + g.Pacemaker = &Pacemaker{ + Heartrate: hello.HeartbeatInterval.Duration(), + Pace: g.Heartbeat, + OnDead: g.Reconnect, + } + // Pacemaker dies here, only when it's fatal. + g.paceDeath = g.Pacemaker.StartAsync() + + return nil +} + +func (g *Gateway) Wait() error { + defer close(g.handler) + return <-g.paceDeath +} + +// handleWS uses the Websocket and parses them into g.Events. +func (g *Gateway) handleWS(stop <-chan struct{}) { + ch := g.WS.Listen() + + for { + select { + case <-stop: + return + case ev := <-ch: + // Check for error + if ev.Error != nil { + g.ErrorLog(errors.Wrap(ev.Error, "WS error")) + continue + } + + // Handle the event + if err := HandleEvent(g, ev.Data); err != nil { + g.ErrorLog(errors.Wrap(ev.Error, "WS handler error")) + } + } + } +} + +func (g *Gateway) Send(code OPCode, v interface{}) error { + var op = OP{ + Code: code, + } + + if v != nil { + b, err := g.Marshal(v) + if err != nil { + return errors.Wrap(err, "Failed to encode v") + } + + op.Data = b + } + + b, err := g.Marshal(op) + if err != nil { + return errors.Wrap(err, "Failed to encode payload") + } + + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + return g.WS.Send(ctx, b) +} + +func (g *Gateway) connect() error { + // Reconnect timeout + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + var Lerr error + + for i := uint(0); i < g.WSRetries; i++ { + // Check if context is expired + if err := ctx.Err(); err != nil { + // Don't bother if it's expired + return err + } + + // Reconnect to the Gateway + if err := g.WS.Redial(ctx); err != nil { + // Save the error, retry again + Lerr = errors.Wrap(err, "Failed to reconnect") + continue + } + + // Try to resume the connection + if err := g.Start(); err != nil { + // If the connection is rate limited (documented behavior): + // https://discordapp.com/developers/docs/topics/gateway#rate-limiting + if err == ErrInvalidSession { + continue // retry + } + + // Else, fatal + return errors.Wrap(err, "Failed to start gateway") + } + + // Started successfully, return + return nil + } + + // Check if any earlier errors are fatal + if Lerr != nil { + return Lerr + } + + // We tried. + return ErrWSMaxTries +} diff --git a/gateway/op.go b/gateway/op.go new file mode 100644 index 0000000..7732440 --- /dev/null +++ b/gateway/op.go @@ -0,0 +1,153 @@ +package gateway + +import ( + "fmt" + + "github.com/diamondburned/arikawa/json" + "github.com/diamondburned/arikawa/wsutil" + "github.com/pkg/errors" +) + +type OPCode uint8 + +const ( + DispatchOP OPCode = iota // recv + HeartbeatOP // send/recv + IdentifyOP // send... + StatusUpdateOP // + VoiceStateUpdateOP // + VoiceServerPingOP // + ResumeOP // + ReconnectOP // recv + RequestGuildMembersOP // send + InvalidSessionOP // recv... + HelloOP + HeartbeatAckOP + _ + CallConnectOP + GuildSubscriptionsOP +) + +type OP struct { + Code OPCode `json:"op"` + Data json.Raw `json:"d,omitempty"` + + // Only for Dispatch (op 0) + Sequence int `json:"s,omitempty"` + EventName string `json:"t,omitempty"` +} + +var ErrInvalidSession = errors.New("Invalid session") + +func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { + if ev.Error != nil { + return nil, ev.Error + } + + var op *OP + if err := driver.Unmarshal(ev.Data, &op); err != nil { + return nil, errors.Wrap(err, "Failed to decode payload") + } + + if op.Code == InvalidSessionOP { + return op, ErrInvalidSession + } + + return op, nil +} + +func DecodeEvent(driver json.Driver, + ev wsutil.Event, v interface{}) (OPCode, error) { + + op, err := DecodeOP(driver, ev) + if err != nil { + return 0, err + } + + if err := driver.Unmarshal(op.Data, v); err != nil { + return 0, errors.Wrap(err, "Failed to decode data") + } + + return op.Code, nil +} + +func AssertEvent(driver json.Driver, + ev wsutil.Event, code OPCode, v interface{}) error { + + op, err := DecodeOP(driver, ev) + if err != nil { + return err + } + + if op.Code != code { + return fmt.Errorf( + "Unexpected OP Code: %d, expected %d (%s)", + op.Code, code, op.Data, + ) + } + + if err := driver.Unmarshal(op.Data, v); err != nil { + return errors.Wrap(err, "Failed to decode data") + } + + return nil +} + +func HandleEvent(g *Gateway, data []byte) error { + // Parse the raw data into an OP struct + var op *OP + if err := g.Driver.Unmarshal(data, &op); err != nil { + return errors.Wrap(err, "OP error") + } + + if g.OP != nil { + g.OP <- op + } + + switch op.Code { + case HeartbeatAckOP: + // Heartbeat from the server? + g.Pacemaker.Echo() + + case HeartbeatOP: + // Server requesting a heartbeat. + return g.Pacemaker.Pace() + + case ReconnectOP: + // Server requests to reconnect, die and retry. + return g.Reconnect() + + case InvalidSessionOP: + // Invalid session, respond with Identify. + return g.Identify() + + case HelloOP: + // What is this OP doing here??? + return nil + + case DispatchOP: + // Check if we know the event + fn, ok := EventCreator[op.EventName] + if !ok { + return errors.New("Unknown event: " + op.EventName) + } + + // Make a new pointer to the event + var ev = fn() + + // Try and parse the event + if err := g.Driver.Unmarshal(op.Data, ev); err != nil { + return errors.Wrap(err, "Failed to parse event "+op.EventName) + } + + // Throw the event into a channel, it's valid now. + g.Events <- ev + return nil + + default: + return fmt.Errorf( + "Unknown OP code %d (event %s)", op.Code, op.EventName) + } + + return nil +} diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go new file mode 100644 index 0000000..cef57c8 --- /dev/null +++ b/gateway/pacemaker.go @@ -0,0 +1,80 @@ +package gateway + +import ( + "time" + + "github.com/pkg/errors" +) + +var ErrDead = errors.New("no heartbeat replied") + +type Pacemaker struct { + // Heartrate is the received duration between heartbeats. + Heartrate time.Duration + + // LastBeat logs the received heartbeats, with the newest one + // first. + LastBeat [2]time.Time + + // Any callback that returns an error will stop the pacer. + Pace func() error + // Event + OnDead func() error + + stop chan<- struct{} +} + +func (p *Pacemaker) Echo() { + // Swap our received heartbeats + p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0] +} + +// Dead, if true, will have Pace return an ErrDead. +func (p *Pacemaker) Dead() bool { + if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() { + return false + } + + return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2 +} + +func (p *Pacemaker) Stop() { + close(p.stop) +} + +// Start beats until it's dead. +func (p *Pacemaker) Start() error { + tick := time.NewTicker(p.Heartrate) + defer tick.Stop() + + stop := make(chan struct{}) + p.stop = stop + + for { + if err := p.Pace(); err != nil { + return err + } + + if !p.Dead() { + continue + } + if err := p.OnDead(); err != nil { + return err + } + + select { + case <-stop: + return nil + case <-tick.C: + continue + } + } +} + +func (p *Pacemaker) StartAsync() (death <-chan error) { + var ch = make(chan error) + go func() { + ch <- p.Start() + }() + return ch +} diff --git a/gateway/sequence.go b/gateway/sequence.go new file mode 100644 index 0000000..7f1a9a3 --- /dev/null +++ b/gateway/sequence.go @@ -0,0 +1,14 @@ +package gateway + +import "sync/atomic" + +type Sequence struct { + seq *int64 +} + +func NewSequence() Sequence { + return Sequence{new(int64)} +} + +func (s *Sequence) Set(seq int64) { atomic.StoreInt64(s.seq, seq) } +func (s *Sequence) Get() int64 { return atomic.LoadInt64(s.seq) } diff --git a/go.mod b/go.mod index 2b109ec..beab7fa 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/gorilla/websocket v1.4.1 github.com/pkg/errors v0.8.1 github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa + go.uber.org/atomic v1.4.0 golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 nhooyr.io/websocket v1.7.4 diff --git a/httputil/client.go b/httputil/client.go index 4538e07..d88dc80 100644 --- a/httputil/client.go +++ b/httputil/client.go @@ -16,6 +16,8 @@ type Client struct { SchemaEncoder } +var DefaultClient = NewClient() + func NewClient() Client { return Client{ Client: http.Client{ diff --git a/httputil/options.go b/httputil/options.go index 7960da8..212e3af 100644 --- a/httputil/options.go +++ b/httputil/options.go @@ -59,6 +59,7 @@ func WithJSONBody(json json.Driver, v interface{}) RequestOption { return err } + r.Header.Set("Content-Type", "application/json") r.Body = ioutil.NopCloser(&buf) return nil } diff --git a/json/json.go b/json/json.go index 89d4c13..a97623a 100644 --- a/json/json.go +++ b/json/json.go @@ -38,6 +38,27 @@ func getBool(Bool bool) OptionBool { return &Bool } +// + +// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and +// can be used to delay JSON decoding or precompute a JSON encoding. It's taken +// from encoding/json. +type Raw []byte + +// Raw returns m as the JSON encoding of m. +func (m Raw) MarshalJSON() ([]byte, error) { + if m == nil { + return []byte("null"), nil + } + return m, nil +} + +func (m Raw) String() string { + return string(m) +} + +// + type Driver interface { Marshal(v interface{}) ([]byte, error) Unmarshal(data []byte, v interface{}) error diff --git a/session/session.go b/session/session.go new file mode 100644 index 0000000..8a8a75b --- /dev/null +++ b/session/session.go @@ -0,0 +1,83 @@ +package session + +import ( + "log" + "sync" + "time" + + "github.com/diamondburned/arikawa/api" + "github.com/diamondburned/arikawa/gateway" + "github.com/diamondburned/arikawa/json" +) + +/* + TODO: + + and Session's supposed to handle callbacks too kec + + might move all these to Gateway, dunno + + could have a lock on Listen() + + I can actually see people using gateway channels to handle things + themselves without any callback abstractions, so this is probably the way to go + + welp shit + + rewrite imminent +*/ + +type Session struct { + API *api.Client + Gateway *gateway.Conn + gatewayOnce sync.Once + + ErrorLog func(err error) // default to log.Println + + // Heartrate is the received duration between heartbeats. + Heartrate time.Duration + + // LastBeat logs the received heartbeats, with the newest one + // first. + LastBeat [2]time.Time + + // Used for Close() + stoppers []chan<- struct{} + closers []func() error +} + +func New(token string) (*Session, error) { + // Initialize the session and the API interface + s := &Session{} + s.API = api.NewClient(token) + + // Default logger + s.ErrorLog = func(err error) { + log.Println("Arikawa/session error:", err) + } + + // Connect to the Gateway + c, err := gateway.NewConn(json.Default{}) + if err != nil { + return nil, err + } + s.Gateway = c + + return s, nil +} + +func (s *Session) Close() error { + for _, stop := range s.stoppers { + close(stop) + } + + var err error + + for _, closer := range s.closers { + if cerr := closer(); cerr != nil { + err = cerr + } + } + + return err +} diff --git a/wsutil/conn.go b/wsutil/conn.go index 1aa5c52..35d7499 100644 --- a/wsutil/conn.go +++ b/wsutil/conn.go @@ -13,11 +13,15 @@ import ( ) var WSBuffer = 12 -var WSReadLimit = 512 * 1024 // 512KiB +var WSReadLimit = 4096 // 4096 bytes // Connection is an interface that abstracts around a generic Websocket driver. // This connection expects the driver to handle compression by itself. type Connection interface { + // Dial dials the address (string). Context needs to be passed in for + // timeout. This method should also be re-usable after Close is called. + Dial(context.Context, string) error + // Listen sends over events constantly. Error will be non-nil if Data is // nil, so check for Error first. Listen() <-chan Event @@ -67,10 +71,6 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { } func (c *Conn) Listen() <-chan Event { - if c.events != nil { - return c.events - } - c.events = make(chan Event, WSBuffer) go func() { c.readLoop(c.events) }() return c.events @@ -139,6 +139,9 @@ func (c *Conn) Send(ctx context.Context, b []byte) error { } func (c *Conn) Close(err error) error { + // Close the event channels + defer close(c.events) + if err == nil { return c.Conn.Close(websocket.StatusNormalClosure, "") } diff --git a/wsutil/throttler.go b/wsutil/throttler.go index 63e6e37..f586b5f 100644 --- a/wsutil/throttler.go +++ b/wsutil/throttler.go @@ -9,3 +9,7 @@ import ( func NewSendLimiter() *rate.Limiter { return rate.NewLimiter(rate.Every(time.Minute), 120) } + +func NewDialLimiter() *rate.Limiter { + return rate.NewLimiter(rate.Every(5*time.Second), 1) +} diff --git a/wsutil/ws.go b/wsutil/ws.go index e823839..31c0da0 100644 --- a/wsutil/ws.go +++ b/wsutil/ws.go @@ -19,49 +19,62 @@ type Event struct { } type Websocket struct { - conn Connection + Conn Connection + Addr string - WriteTimeout time.Duration - SendLimiter *rate.Limiter + SendLimiter *rate.Limiter + DialLimiter *rate.Limiter + + listener <-chan Event } -func New(ctx context.Context, - driver json.Driver, addr string) (*Websocket, error) { - - if driver == nil { - driver = json.Default{} - } - - c := NewConn(driver) - if err := c.Dial(ctx, addr); err != nil { - return nil, errors.Wrap(err, "Failed to dial") - } - - return NewWithConn(c), nil +func New(ctx context.Context, addr string) (*Websocket, error) { + return NewCustom(ctx, NewConn(json.Default{}), addr) } -// NewWithConn uses an already-dialed connection for Websocket. -func NewWithConn(conn Connection) *Websocket { - return &Websocket{ - conn: conn, +// NewCustom creates a new undialed Websocket. +func NewCustom( + ctx context.Context, conn Connection, addr string) (*Websocket, error) { - WriteTimeout: DefaultTimeout, - SendLimiter: NewSendLimiter(), + ws := &Websocket{ + Conn: conn, + Addr: addr, + + SendLimiter: NewSendLimiter(), + DialLimiter: NewDialLimiter(), } + + return ws, nil +} + +func (ws *Websocket) Redial(ctx context.Context) error { + if err := ws.DialLimiter.Wait(ctx); err != nil { + // Expired, fatal error + return errors.Wrap(err, "Failed to wait") + } + + if err := ws.Conn.Dial(ctx, ws.Addr); err != nil { + return errors.Wrap(err, "Failed to dial") + } + + return nil } func (ws *Websocket) Listen() <-chan Event { - return ws.conn.Listen() + if ws.listener == nil { + ws.listener = ws.Conn.Listen() + } + return ws.listener } -func (ws *Websocket) Send(b []byte) error { - ctx, cancel := context.WithTimeout( - context.Background(), ws.WriteTimeout) - defer cancel() - +func (ws *Websocket) Send(ctx context.Context, b []byte) error { if err := ws.SendLimiter.Wait(ctx); err != nil { return errors.Wrap(err, "SendLimiter failed") } - return ws.conn.Send(ctx, b) + return ws.Conn.Send(ctx, b) +} + +func (ws *Websocket) Close(err error) error { + return ws.Conn.Close(err) }