1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-02 18:26:41 +00:00

Added an untested Gateway package

This commit is contained in:
diamondburned 2020-01-14 20:43:34 -08:00 committed by GitHub
parent 67450c4872
commit 8fc7229c3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 1190 additions and 45 deletions

View file

@ -28,21 +28,15 @@ type Client struct {
func NewClient(token string) *Client {
cli := &Client{
Client: httputil.NewClient(),
Client: httputil.DefaultClient,
Limiter: rate.NewLimiter(),
Token: token,
}
tw := httputil.NewTransportWrapper()
tw.Pre = func(r *http.Request) error {
if r.Header.Get("Authorization") == "" {
r.Header.Set("Authorization", cli.Token)
}
if r.UserAgent() == "" {
r.Header.Set("User-Agent", UserAgent)
}
r.Header.Set("Authorization", cli.Token)
r.Header.Set("User-Agent", UserAgent)
r.Header.Set("X-RateLimit-Precision", "millisecond")
// Rate limit stuff

View file

@ -128,4 +128,4 @@ func (c *Client) CreatePrivateChannel(
// shitty SDK, don't care, PR welcomed
// func (c *Client) CreateGroup(tokens []string, nicks map[])
func (c *Client) UserConnections() ([]discord.Connection, error) {}
// func (c *Client) UserConnections() ([]discord.Connection, error) {}

View file

@ -33,6 +33,10 @@ func (s Snowflake) String() string {
return strconv.FormatUint(uint64(s), 10)
}
func (s Snowflake) Valid() bool {
return uint64(s) < 1
}
func (s Snowflake) Time() time.Time {
return time.Unix(0, int64(s)>>22*1000000+DiscordEpoch)
}

View file

@ -28,7 +28,33 @@ func (t Timestamp) MarshalJSON() ([]byte, error) {
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
}
type Seconds uint
//
type UnixTimestamp int64
func (t UnixTimestamp) String() string {
return t.Time().String()
}
func (t UnixTimestamp) Time() time.Time {
return time.Unix(int64(t), 0)
}
//
type UnixMsTimestamp int64
func (t UnixMsTimestamp) String() string {
return t.Time().String()
}
func (t UnixMsTimestamp) Time() time.Time {
return time.Unix(0, int64(t)*int64(time.Millisecond))
}
//
type Seconds int
func DurationToSeconds(dura time.Duration) Seconds {
return Seconds(dura.Seconds())
@ -41,3 +67,19 @@ func (s Seconds) String() string {
func (s Seconds) Duration() time.Duration {
return time.Duration(s) * time.Second
}
//
type Milliseconds int
func DurationToMilliseconds(dura time.Duration) Milliseconds {
return Milliseconds(dura.Milliseconds())
}
func (ms Milliseconds) String() string {
return ms.Duration().String()
}
func (ms Milliseconds) Duration() time.Duration {
return time.Duration(ms) * time.Millisecond
}

78
gateway/activity.go Normal file
View file

@ -0,0 +1,78 @@
package gateway
import "github.com/diamondburned/arikawa/discord"
type Status string
const (
UnknownStatus Status = ""
OnlineStatus Status = "online"
DoNotDisturbStatus Status = "dnd"
IdleStatus Status = "idle"
InvisibleStatus Status = "invisible"
OfflineStatus Status = "offline"
)
type Activity struct {
Name string `json:"name"`
Type ActivityType `json:"type"`
URL discord.URL `json:"url"`
// User only
CreatedAt discord.UnixTimestamp `json:"created_at"`
Timestamps struct {
Start discord.UnixMsTimestamp `json:"start,omitempty"`
End discord.UnixMsTimestamp `json:"end,omitempty"`
} `json:"timestamps,omitempty"`
ApplicationID discord.Snowflake `json:"application_id,omitempty"`
Details string `json:"details,omitempty"`
State string `json:"state,omitempty"` // party status
Emoji discord.Emoji `json:"emoji,omitempty"`
Party struct {
ID string `json:"id,omitempty"`
Size [2]int `json:"size,omitempty"` // [ current, max ]
} `json:"party,omitempty"`
Assets struct {
LargeImage string `json:"large_image,omitempty"` // id
LargeText string `json:"large_text,omitempty"`
SmallImage string `json:"small_image,omitempty"` // id
SmallText string `json:"small_text,omitempty"`
} `json:"assets,omitempty"`
Secrets struct {
Join string `json:"join,omitempty"`
Spectate string `json:"spectate,omitempty"`
Match string `json:"match,omitempty"`
} `json:"secrets,omitempty"`
Instance bool `json:"instance,omitempty"`
Flags ActivityFlags `json:"flags,omitempty"`
}
type ActivityType uint8
const (
// Playing $name
GameActivity ActivityType = iota
// Streaming $details
StreamingActivity
// Listening to $name
ListeningActivity
// $emoji $name
CustomActivity
)
type ActivityFlags uint8
const (
InstanceActivity ActivityFlags = 1 << iota
JoinActivity
SpectateActivity
JoinRequestActivity
SyncActivity
PlayActivity
)

75
gateway/commands.go Normal file
View file

@ -0,0 +1,75 @@
package gateway
import (
"github.com/diamondburned/arikawa/discord"
)
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
type IdentifyData struct {
Token string `json:"token"`
Properties IdentifyProperties `json:"properties"`
Compress bool `json:"compress,omitempty"` // true
LargeThreshold uint `json:"large_threshold,omitempty"` // 50
GuildSubscription bool `json:"guild_subscriptions"` // true
Shard [2]int `json:"shard"` // [ shard_id, num_shards ]
Presence UpdateStatusData `json:"presence,omitempty"`
}
type IdentifyProperties struct {
// Required
OS string `json:"os"` // GOOS
Browser string `json:"browser"` // Arikawa
Device string `json:"device"` // Arikawa
// Optional
BrowserUserAgent string `json:"browser_user_agent,omitempty"`
BrowserVersion string `json:"browser_version,omitempty"`
OsVersion string `json:"os_version,omitempty"`
Referrer string `json:"referrer,omitempty"`
ReferringDomain string `json:"referring_domain,omitempty"`
}
func (g *Gateway) Identify() error {
return g.Send(IdentifyOP, g.Identity)
}
type ResumeData struct {
Token string `json:"token"`
SessionID string `json:"session_id"`
Sequence int64 `json:"seq"`
}
// HeartbeatData is the last sequence number to be sent.
type HeartbeatData int
func (g *Gateway) Heartbeat() error {
return g.Send(HeartbeatOP, g.Sequence.Get())
}
type RequestGuildMembersData struct {
GuildID []discord.Snowflake `json:"guild_id"`
UserIDs []discord.Snowflake `json:"user_id,omitempty"`
Query string `json:"query,omitempty"`
Limit uint `json:"limit"`
Presences bool `json:"presences,omitempty"`
}
type UpdateVoiceStateData struct {
GuildID discord.Snowflake `json:"guild_id"`
ChannelID discord.Snowflake `json:"channel_id"`
SelfMute bool `json:"self_mute"`
SelfDeaf bool `json:"self_deaf"`
}
type UpdateStatusData struct {
Since discord.Milliseconds `json:"since,omitempty"` // 0 if not idle
Game *Activity `json:"game,omitempty"` // nullable
Status Status `json:"status"`
AFK bool `json:"afk"`
}

199
gateway/events.go Normal file
View file

@ -0,0 +1,199 @@
package gateway
import "github.com/diamondburned/arikawa/discord"
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
// https://discordapp.com/developers/docs/topics/gateway#connecting-and-resuming
type (
HelloEvent struct {
HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
}
ReadyEvent struct {
Version int `json:"version"`
User discord.User `json:"user"`
SessionID string `json:"session_id"`
PrivateChannels []discord.Channel `json:"private_channels"`
Guilds []discord.Guild `json:"guilds"`
Shard [2]int `json:"shard"` // [ shard_id num_shards ]
}
ResumedEvent struct{}
// InvalidSessionEvent indicates if the event is resumable.
InvalidSessionEvent bool
)
// https://discordapp.com/developers/docs/topics/gateway#channels
type (
ChannelCreateEvent discord.Channel
ChannelUpdateEvent discord.Channel
ChannelDeleteEvent discord.Channel
ChannelPinsUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id,omitempty"`
ChannelID discord.Snowflake `json:"channel_id,omitempty"`
LastPin discord.Timestamp `json:"timestamp,omitempty"`
}
)
// https://discordapp.com/developers/docs/topics/gateway#guilds
type (
GuildCreateEvent discord.Guild
GuildUpdateEvent discord.Guild
GuildDeleteEvent struct {
ID discord.Snowflake `json:"id"`
// Unavailable if false == removed
Unavailable bool `json:"unavailable"`
}
GuildBanAddEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
User discord.User `json:"user"`
}
GuildBanRemoveEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
User discord.User `json:"user"`
}
GuildEmojisUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Emojis []discord.Emoji `json:"emoji"`
}
GuildIntegrationsUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
}
GuildMemberAddEvent struct {
discord.Member
GuildID discord.Snowflake `json:"guild_id"`
}
GuildMemberRemoveEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
User discord.User `json:"user"`
}
GuildMemberUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Roles []discord.Snowflake `json:"roles"`
User discord.User `json:"user"`
Nick string `json:"nick"`
}
// GuildMembersChunkEvent is sent when Guild Request Members is called.
GuildMembersChunkEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Members []discord.Member `json:"members"`
// Whatever's not found goes here
NotFound []string `json:"not_found,omitempty"`
// Only filled if requested
Presences []discord.Presence `json:"presences,omitempty"`
}
GuildRoleCreateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Role discord.Role `json:"role"`
}
GuildRoleUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Role discord.Role `json:"role"`
}
GuildRoleDeleteEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
RoleID discord.Snowflake `json:"role_id"`
}
)
// https://discordapp.com/developers/docs/topics/gateway#messages
type (
MessageCreateEvent discord.Message
MessageUpdateEvent discord.Message
MessageDeleteEvent struct {
ID discord.Snowflake `json:"id"`
ChannelID discord.Snowflake `json:"channel_id"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
}
MessageDeleteBulkEvent struct {
IDs []discord.Snowflake `json:"ids"`
ChannelID discord.Snowflake `json:"channel_id"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
}
MessageReactionAddEvent struct {
UserID discord.Snowflake `json:"user_id"`
ChannelID discord.Snowflake `json:"channel_id"`
MessageID discord.Snowflake `json:"message_id"`
Emoji discord.Emoji `json:"emoji,omitempty"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`
}
MessageReactionRemoveEvent struct {
UserID discord.Snowflake `json:"user_id"`
ChannelID discord.Snowflake `json:"channel_id"`
MessageID discord.Snowflake `json:"message_id"`
Emoji discord.Emoji `json:"emoji"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
}
MessageReactionRemoveAllEvent struct {
ChannelID discord.Snowflake `json:"channel_id"`
}
)
// https://discordapp.com/developers/docs/topics/gateway#presence
type (
// Clients may only update their game status 5 times per 20 seconds.
PresenceUpdateEvent struct {
User discord.User `json:"user"`
Nick string `json:"nick"`
Roles []discord.Snowflake `json:"roles"`
GuildID discord.Snowflake `json:"guild_id"`
PremiumSince discord.Timestamp `json:"premium_since,omitempty"`
Game *Activity `json:"game"`
Activities []Activity `json:"activities"`
Status Status `json:"status"`
ClientStatus struct {
Desktop Status `json:"status,omitempty"`
Mobile Status `json:"mobile,omitempty"`
Web Status `json:"web,omitempty"`
} `json:"client_status"`
}
TypingStartEvent struct {
ChannelID discord.Snowflake `json:"channel_id"`
UserID discord.Snowflake `json:"user_id"`
Timestamp discord.Timestamp `json:"timestamp"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`
}
UserUpdateEvent discord.User
)
// https://discordapp.com/developers/docs/topics/gateway#voice
type (
VoiceStateUpdateEvent discord.VoiceState
VoiceServerUpdateEvent struct {
Token string `json:"token"`
GuildID discord.Snowflake `json:"guild_id"`
Endpoint string `json:"endpoint"`
}
)
// https://discordapp.com/developers/docs/topics/gateway#webhooks
type (
WebhooksUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
ChannelID discord.Snowflake `json:"channel_id"`
}
)

59
gateway/events_map.go Normal file
View file

@ -0,0 +1,59 @@
package gateway
// Event is any event struct. They have an "Event" suffixed to them.
type Event = interface{}
var EventCreator = map[string]func() Event{
"HELLO": func() Event { return new(HelloEvent) },
"READY": func() Event { return new(ReadyEvent) },
"RESUMED": func() Event { return new(ResumedEvent) },
"INVALID_SESSION": func() Event { return new(InvalidSessionEvent) },
"CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) },
"CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) },
"CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) },
"CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) },
"GUILD_CREATE": func() Event { return new(GuildCreateEvent) },
"GUILD_UPDATE": func() Event { return new(GuildUpdateEvent) },
"GUILD_DELETE": func() Event { return new(GuildDeleteEvent) },
"GUILD_BAN_ADD": func() Event { return new(GuildBanAddEvent) },
"GUILD_BAN_REMOVE": func() Event { return new(GuildBanRemoveEvent) },
"GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) },
"GUILD_INTEGRATIONS_UPDATE": func() Event {
return new(GuildIntegrationsUpdateEvent)
},
"GUILD_MEMBER_ADD": func() Event { return new(GuildMemberAddEvent) },
"GUILD_MEMBER_REMOVE": func() Event { return new(GuildMemberRemoveEvent) },
"GUILD_MEMBER_UPDATE": func() Event { return new(GuildMemberUpdateEvent) },
"GUILD_MEMBERS_CHUNK": func() Event { return new(GuildMembersChunkEvent) },
"GUILD_ROLE_CREATE": func() Event { return new(GuildRoleCreateEvent) },
"GUILD_ROLE_UPDATE": func() Event { return new(GuildRoleUpdateEvent) },
"GUILD_ROLE_DELETE": func() Event { return new(GuildRoleDeleteEvent) },
"MESSAGE_CREATE": func() Event { return new(MessageCreateEvent) },
"MESSAGE_UPDATE": func() Event { return new(MessageUpdateEvent) },
"MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) },
"MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) },
"MESSAGE_REACTION_ADD": func() Event {
return new(MessageReactionAddEvent)
},
"MESSAGE_REACTION_REMOVE": func() Event {
return new(MessageReactionRemoveEvent)
},
"MESSAGE_REACTION_REMOVE_ALL": func() Event {
return new(MessageReactionRemoveAllEvent)
},
"PRESENCE_UPDATE": func() Event { return new(PresenceUpdateEvent) },
"TYPING_START": func() Event { return new(TypingStartEvent) },
"USER_UPDATE": func() Event { return new(UserUpdateEvent) },
"VOICE_STATE_UPDATE": func() Event { return new(VoiceStateUpdateEvent) },
"VOICE_SERVER_UPDATE": func() Event { return new(VoiceServerUpdateEvent) },
}

319
gateway/gateway.go Normal file
View file

@ -0,0 +1,319 @@
package gateway
import (
"context"
"net/url"
"runtime"
"time"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/httputil"
"github.com/diamondburned/arikawa/json"
"github.com/diamondburned/arikawa/wsutil"
"github.com/pkg/errors"
)
const (
EndpointGateway = api.Endpoint + "gateway"
EndpointGatewayBot = api.EndpointGateway + "/bot"
Version = "6"
Encoding = "json"
)
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
WSTimeout = wsutil.DefaultTimeout
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
)
var (
ErrMissingForResume = errors.New(
"missing session ID or sequence for resuming")
ErrWSMaxTries = errors.New("max tries reached")
)
func GatewayURL() (string, error) {
var Gateway struct {
URL string `json:"url"`
}
return Gateway.URL, httputil.DefaultClient.RequestJSON(
&Gateway, "GET", EndpointGateway)
}
// Identity is used as the default identity when initializing a new Gateway.
var Identity = IdentifyProperties{
OS: runtime.GOOS,
Browser: "Arikawa",
Device: "Arikawa",
}
type Gateway struct {
WS *wsutil.Websocket
json.Driver
// Timeout for connecting and writing to the Websocket, uses default
// WSTimeout (global).
WSTimeout time.Duration
// Retries on connect and reconnect.
WSRetries uint // 3
Events chan Event
SessionID string
URL string // URL
Identity *IdentifyData
Pacemaker *Pacemaker
Sequence Sequence
ErrorLog func(err error) // default to log.Println
// Only use for debugging
// If this channel is non-nil, all incoming OP packets will also be sent
// here.
OP chan Event
// Filled by methods, internal use
paceDeath <-chan error
handler chan struct{}
}
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
// information, refer to NewGatewayWithDriver.
func NewGateway(token string) (*Gateway, error) {
return NewGatewayWithDriver(token, json.Default{})
}
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
URL, err := GatewayURL()
if err != nil {
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
}
g := &Gateway{
URL: URL,
Driver: driver,
WSTimeout: WSTimeout,
Events: make(chan Event),
Identity: &IdentifyData{
Token: token,
Properties: Identity,
Compress: true,
LargeThreshold: 50,
GuildSubscription: true,
},
Sequence: NewSequence(),
}
// Parameters for the gateway
param := url.Values{}
param.Set("v", Version)
param.Set("encoding", Encoding)
// Append the form to the URL
URL += "?" + param.Encode()
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Create a new undialed Websocket.
ws, err := wsutil.NewCustom(ctx, wsutil.NewConn(driver), URL)
if err != nil {
return nil, errors.Wrap(err, "Failed to connect to Gateway "+URL)
}
g.WS = ws
// Try and dial it
return g, g.connect()
}
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
return g.WS.Close(nil)
}
// Reconnects and resumes.
func (g *Gateway) Reconnect() error {
// Close, but we don't care about the error (I think)
g.Close()
// Actually a reconnect at this point.
return g.connect()
}
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) Resume() error {
var (
ses = g.SessionID
seq = g.Sequence.Get()
)
if ses == "" || seq == 0 {
return ErrMissingForResume
}
return g.Send(ResumeOP, ResumeData{
Token: g.Identity.Token,
SessionID: ses,
Sequence: seq,
})
}
// Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block. To block until a fatal error, use
// Wait().
func (g *Gateway) Start() error {
// This is where we'll get our events
ch := g.WS.Listen()
// Wait for an OP 10 Hello
var hello HelloEvent
if err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
return errors.Wrap(err, "Error at Hello")
}
// Send Discord either the Identify packet (if it's a fresh connection), or
// a Resume packet (if it's a dead connection).
if g.SessionID == "" {
// SessionID is empty, so this is a completely new session.
if err := g.Identify(); err != nil {
return errors.Wrap(err, "Failed to identify")
}
// We should now expect a Ready event.
var ready ReadyEvent
if err := AssertEvent(g, <-ch, DispatchOP, &ready); err != nil {
return errors.Wrap(err, "Error at Ready")
}
} else {
if err := g.Resume(); err != nil {
return errors.Wrap(err, "Failed to resume")
}
// We should now expect a Resumed event.
var resumed ResumedEvent
if err := AssertEvent(g, <-ch, DispatchOP, &resumed); err != nil {
return errors.Wrap(err, "Error at Resumed")
}
}
// Start the event handler
g.handler = make(chan struct{})
go g.handleWS(g.handler)
// Start the pacemaker with the heartrate received from Hello
g.Pacemaker = &Pacemaker{
Heartrate: hello.HeartbeatInterval.Duration(),
Pace: g.Heartbeat,
OnDead: g.Reconnect,
}
// Pacemaker dies here, only when it's fatal.
g.paceDeath = g.Pacemaker.StartAsync()
return nil
}
func (g *Gateway) Wait() error {
defer close(g.handler)
return <-g.paceDeath
}
// handleWS uses the Websocket and parses them into g.Events.
func (g *Gateway) handleWS(stop <-chan struct{}) {
ch := g.WS.Listen()
for {
select {
case <-stop:
return
case ev := <-ch:
// Check for error
if ev.Error != nil {
g.ErrorLog(errors.Wrap(ev.Error, "WS error"))
continue
}
// Handle the event
if err := HandleEvent(g, ev.Data); err != nil {
g.ErrorLog(errors.Wrap(ev.Error, "WS handler error"))
}
}
}
}
func (g *Gateway) Send(code OPCode, v interface{}) error {
var op = OP{
Code: code,
}
if v != nil {
b, err := g.Marshal(v)
if err != nil {
return errors.Wrap(err, "Failed to encode v")
}
op.Data = b
}
b, err := g.Marshal(op)
if err != nil {
return errors.Wrap(err, "Failed to encode payload")
}
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.WS.Send(ctx, b)
}
func (g *Gateway) connect() error {
// Reconnect timeout
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
var Lerr error
for i := uint(0); i < g.WSRetries; i++ {
// Check if context is expired
if err := ctx.Err(); err != nil {
// Don't bother if it's expired
return err
}
// Reconnect to the Gateway
if err := g.WS.Redial(ctx); err != nil {
// Save the error, retry again
Lerr = errors.Wrap(err, "Failed to reconnect")
continue
}
// Try to resume the connection
if err := g.Start(); err != nil {
// If the connection is rate limited (documented behavior):
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
if err == ErrInvalidSession {
continue // retry
}
// Else, fatal
return errors.Wrap(err, "Failed to start gateway")
}
// Started successfully, return
return nil
}
// Check if any earlier errors are fatal
if Lerr != nil {
return Lerr
}
// We tried.
return ErrWSMaxTries
}

153
gateway/op.go Normal file
View file

@ -0,0 +1,153 @@
package gateway
import (
"fmt"
"github.com/diamondburned/arikawa/json"
"github.com/diamondburned/arikawa/wsutil"
"github.com/pkg/errors"
)
type OPCode uint8
const (
DispatchOP OPCode = iota // recv
HeartbeatOP // send/recv
IdentifyOP // send...
StatusUpdateOP //
VoiceStateUpdateOP //
VoiceServerPingOP //
ResumeOP //
ReconnectOP // recv
RequestGuildMembersOP // send
InvalidSessionOP // recv...
HelloOP
HeartbeatAckOP
_
CallConnectOP
GuildSubscriptionsOP
)
type OP struct {
Code OPCode `json:"op"`
Data json.Raw `json:"d,omitempty"`
// Only for Dispatch (op 0)
Sequence int `json:"s,omitempty"`
EventName string `json:"t,omitempty"`
}
var ErrInvalidSession = errors.New("Invalid session")
func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
if ev.Error != nil {
return nil, ev.Error
}
var op *OP
if err := driver.Unmarshal(ev.Data, &op); err != nil {
return nil, errors.Wrap(err, "Failed to decode payload")
}
if op.Code == InvalidSessionOP {
return op, ErrInvalidSession
}
return op, nil
}
func DecodeEvent(driver json.Driver,
ev wsutil.Event, v interface{}) (OPCode, error) {
op, err := DecodeOP(driver, ev)
if err != nil {
return 0, err
}
if err := driver.Unmarshal(op.Data, v); err != nil {
return 0, errors.Wrap(err, "Failed to decode data")
}
return op.Code, nil
}
func AssertEvent(driver json.Driver,
ev wsutil.Event, code OPCode, v interface{}) error {
op, err := DecodeOP(driver, ev)
if err != nil {
return err
}
if op.Code != code {
return fmt.Errorf(
"Unexpected OP Code: %d, expected %d (%s)",
op.Code, code, op.Data,
)
}
if err := driver.Unmarshal(op.Data, v); err != nil {
return errors.Wrap(err, "Failed to decode data")
}
return nil
}
func HandleEvent(g *Gateway, data []byte) error {
// Parse the raw data into an OP struct
var op *OP
if err := g.Driver.Unmarshal(data, &op); err != nil {
return errors.Wrap(err, "OP error")
}
if g.OP != nil {
g.OP <- op
}
switch op.Code {
case HeartbeatAckOP:
// Heartbeat from the server?
g.Pacemaker.Echo()
case HeartbeatOP:
// Server requesting a heartbeat.
return g.Pacemaker.Pace()
case ReconnectOP:
// Server requests to reconnect, die and retry.
return g.Reconnect()
case InvalidSessionOP:
// Invalid session, respond with Identify.
return g.Identify()
case HelloOP:
// What is this OP doing here???
return nil
case DispatchOP:
// Check if we know the event
fn, ok := EventCreator[op.EventName]
if !ok {
return errors.New("Unknown event: " + op.EventName)
}
// Make a new pointer to the event
var ev = fn()
// Try and parse the event
if err := g.Driver.Unmarshal(op.Data, ev); err != nil {
return errors.Wrap(err, "Failed to parse event "+op.EventName)
}
// Throw the event into a channel, it's valid now.
g.Events <- ev
return nil
default:
return fmt.Errorf(
"Unknown OP code %d (event %s)", op.Code, op.EventName)
}
return nil
}

80
gateway/pacemaker.go Normal file
View file

@ -0,0 +1,80 @@
package gateway
import (
"time"
"github.com/pkg/errors"
)
var ErrDead = errors.New("no heartbeat replied")
type Pacemaker struct {
// Heartrate is the received duration between heartbeats.
Heartrate time.Duration
// LastBeat logs the received heartbeats, with the newest one
// first.
LastBeat [2]time.Time
// Any callback that returns an error will stop the pacer.
Pace func() error
// Event
OnDead func() error
stop chan<- struct{}
}
func (p *Pacemaker) Echo() {
// Swap our received heartbeats
p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0]
}
// Dead, if true, will have Pace return an ErrDead.
func (p *Pacemaker) Dead() bool {
if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() {
return false
}
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
}
func (p *Pacemaker) Stop() {
close(p.stop)
}
// Start beats until it's dead.
func (p *Pacemaker) Start() error {
tick := time.NewTicker(p.Heartrate)
defer tick.Stop()
stop := make(chan struct{})
p.stop = stop
for {
if err := p.Pace(); err != nil {
return err
}
if !p.Dead() {
continue
}
if err := p.OnDead(); err != nil {
return err
}
select {
case <-stop:
return nil
case <-tick.C:
continue
}
}
}
func (p *Pacemaker) StartAsync() (death <-chan error) {
var ch = make(chan error)
go func() {
ch <- p.Start()
}()
return ch
}

14
gateway/sequence.go Normal file
View file

@ -0,0 +1,14 @@
package gateway
import "sync/atomic"
type Sequence struct {
seq *int64
}
func NewSequence() Sequence {
return Sequence{new(int64)}
}
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(s.seq, seq) }
func (s *Sequence) Get() int64 { return atomic.LoadInt64(s.seq) }

1
go.mod
View file

@ -8,6 +8,7 @@ require (
github.com/gorilla/websocket v1.4.1
github.com/pkg/errors v0.8.1
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
go.uber.org/atomic v1.4.0
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
nhooyr.io/websocket v1.7.4

View file

@ -16,6 +16,8 @@ type Client struct {
SchemaEncoder
}
var DefaultClient = NewClient()
func NewClient() Client {
return Client{
Client: http.Client{

View file

@ -59,6 +59,7 @@ func WithJSONBody(json json.Driver, v interface{}) RequestOption {
return err
}
r.Header.Set("Content-Type", "application/json")
r.Body = ioutil.NopCloser(&buf)
return nil
}

View file

@ -38,6 +38,27 @@ func getBool(Bool bool) OptionBool {
return &Bool
}
//
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
// from encoding/json.
type Raw []byte
// Raw returns m as the JSON encoding of m.
func (m Raw) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
func (m Raw) String() string {
return string(m)
}
//
type Driver interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error

83
session/session.go Normal file
View file

@ -0,0 +1,83 @@
package session
import (
"log"
"sync"
"time"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/json"
)
/*
TODO:
and Session's supposed to handle callbacks too kec
might move all these to Gateway, dunno
could have a lock on Listen()
I can actually see people using gateway channels to handle things
themselves without any callback abstractions, so this is probably the way to go
welp shit
rewrite imminent
*/
type Session struct {
API *api.Client
Gateway *gateway.Conn
gatewayOnce sync.Once
ErrorLog func(err error) // default to log.Println
// Heartrate is the received duration between heartbeats.
Heartrate time.Duration
// LastBeat logs the received heartbeats, with the newest one
// first.
LastBeat [2]time.Time
// Used for Close()
stoppers []chan<- struct{}
closers []func() error
}
func New(token string) (*Session, error) {
// Initialize the session and the API interface
s := &Session{}
s.API = api.NewClient(token)
// Default logger
s.ErrorLog = func(err error) {
log.Println("Arikawa/session error:", err)
}
// Connect to the Gateway
c, err := gateway.NewConn(json.Default{})
if err != nil {
return nil, err
}
s.Gateway = c
return s, nil
}
func (s *Session) Close() error {
for _, stop := range s.stoppers {
close(stop)
}
var err error
for _, closer := range s.closers {
if cerr := closer(); cerr != nil {
err = cerr
}
}
return err
}

View file

@ -13,11 +13,15 @@ import (
)
var WSBuffer = 12
var WSReadLimit = 512 * 1024 // 512KiB
var WSReadLimit = 4096 // 4096 bytes
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself.
type Connection interface {
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) error
// Listen sends over events constantly. Error will be non-nil if Data is
// nil, so check for Error first.
Listen() <-chan Event
@ -67,10 +71,6 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
}
func (c *Conn) Listen() <-chan Event {
if c.events != nil {
return c.events
}
c.events = make(chan Event, WSBuffer)
go func() { c.readLoop(c.events) }()
return c.events
@ -139,6 +139,9 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
}
func (c *Conn) Close(err error) error {
// Close the event channels
defer close(c.events)
if err == nil {
return c.Conn.Close(websocket.StatusNormalClosure, "")
}

View file

@ -9,3 +9,7 @@ import (
func NewSendLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Every(time.Minute), 120)
}
func NewDialLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Every(5*time.Second), 1)
}

View file

@ -19,49 +19,62 @@ type Event struct {
}
type Websocket struct {
conn Connection
Conn Connection
Addr string
WriteTimeout time.Duration
SendLimiter *rate.Limiter
SendLimiter *rate.Limiter
DialLimiter *rate.Limiter
listener <-chan Event
}
func New(ctx context.Context,
driver json.Driver, addr string) (*Websocket, error) {
if driver == nil {
driver = json.Default{}
}
c := NewConn(driver)
if err := c.Dial(ctx, addr); err != nil {
return nil, errors.Wrap(err, "Failed to dial")
}
return NewWithConn(c), nil
func New(ctx context.Context, addr string) (*Websocket, error) {
return NewCustom(ctx, NewConn(json.Default{}), addr)
}
// NewWithConn uses an already-dialed connection for Websocket.
func NewWithConn(conn Connection) *Websocket {
return &Websocket{
conn: conn,
// NewCustom creates a new undialed Websocket.
func NewCustom(
ctx context.Context, conn Connection, addr string) (*Websocket, error) {
WriteTimeout: DefaultTimeout,
SendLimiter: NewSendLimiter(),
ws := &Websocket{
Conn: conn,
Addr: addr,
SendLimiter: NewSendLimiter(),
DialLimiter: NewDialLimiter(),
}
return ws, nil
}
func (ws *Websocket) Redial(ctx context.Context) error {
if err := ws.DialLimiter.Wait(ctx); err != nil {
// Expired, fatal error
return errors.Wrap(err, "Failed to wait")
}
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
return errors.Wrap(err, "Failed to dial")
}
return nil
}
func (ws *Websocket) Listen() <-chan Event {
return ws.conn.Listen()
if ws.listener == nil {
ws.listener = ws.Conn.Listen()
}
return ws.listener
}
func (ws *Websocket) Send(b []byte) error {
ctx, cancel := context.WithTimeout(
context.Background(), ws.WriteTimeout)
defer cancel()
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
if err := ws.SendLimiter.Wait(ctx); err != nil {
return errors.Wrap(err, "SendLimiter failed")
}
return ws.conn.Send(ctx, b)
return ws.Conn.Send(ctx, b)
}
func (ws *Websocket) Close(err error) error {
return ws.Conn.Close(err)
}