mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-02 18:26:41 +00:00
Added an untested Gateway package
This commit is contained in:
parent
67450c4872
commit
8fc7229c3f
12
api/api.go
12
api/api.go
|
@ -28,21 +28,15 @@ type Client struct {
|
|||
|
||||
func NewClient(token string) *Client {
|
||||
cli := &Client{
|
||||
Client: httputil.NewClient(),
|
||||
Client: httputil.DefaultClient,
|
||||
Limiter: rate.NewLimiter(),
|
||||
Token: token,
|
||||
}
|
||||
|
||||
tw := httputil.NewTransportWrapper()
|
||||
tw.Pre = func(r *http.Request) error {
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
r.Header.Set("Authorization", cli.Token)
|
||||
}
|
||||
|
||||
if r.UserAgent() == "" {
|
||||
r.Header.Set("User-Agent", UserAgent)
|
||||
}
|
||||
|
||||
r.Header.Set("Authorization", cli.Token)
|
||||
r.Header.Set("User-Agent", UserAgent)
|
||||
r.Header.Set("X-RateLimit-Precision", "millisecond")
|
||||
|
||||
// Rate limit stuff
|
||||
|
|
|
@ -128,4 +128,4 @@ func (c *Client) CreatePrivateChannel(
|
|||
// shitty SDK, don't care, PR welcomed
|
||||
// func (c *Client) CreateGroup(tokens []string, nicks map[])
|
||||
|
||||
func (c *Client) UserConnections() ([]discord.Connection, error) {}
|
||||
// func (c *Client) UserConnections() ([]discord.Connection, error) {}
|
||||
|
|
|
@ -33,6 +33,10 @@ func (s Snowflake) String() string {
|
|||
return strconv.FormatUint(uint64(s), 10)
|
||||
}
|
||||
|
||||
func (s Snowflake) Valid() bool {
|
||||
return uint64(s) < 1
|
||||
}
|
||||
|
||||
func (s Snowflake) Time() time.Time {
|
||||
return time.Unix(0, int64(s)>>22*1000000+DiscordEpoch)
|
||||
}
|
||||
|
|
|
@ -28,7 +28,33 @@ func (t Timestamp) MarshalJSON() ([]byte, error) {
|
|||
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
|
||||
}
|
||||
|
||||
type Seconds uint
|
||||
//
|
||||
|
||||
type UnixTimestamp int64
|
||||
|
||||
func (t UnixTimestamp) String() string {
|
||||
return t.Time().String()
|
||||
}
|
||||
|
||||
func (t UnixTimestamp) Time() time.Time {
|
||||
return time.Unix(int64(t), 0)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type UnixMsTimestamp int64
|
||||
|
||||
func (t UnixMsTimestamp) String() string {
|
||||
return t.Time().String()
|
||||
}
|
||||
|
||||
func (t UnixMsTimestamp) Time() time.Time {
|
||||
return time.Unix(0, int64(t)*int64(time.Millisecond))
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type Seconds int
|
||||
|
||||
func DurationToSeconds(dura time.Duration) Seconds {
|
||||
return Seconds(dura.Seconds())
|
||||
|
@ -41,3 +67,19 @@ func (s Seconds) String() string {
|
|||
func (s Seconds) Duration() time.Duration {
|
||||
return time.Duration(s) * time.Second
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type Milliseconds int
|
||||
|
||||
func DurationToMilliseconds(dura time.Duration) Milliseconds {
|
||||
return Milliseconds(dura.Milliseconds())
|
||||
}
|
||||
|
||||
func (ms Milliseconds) String() string {
|
||||
return ms.Duration().String()
|
||||
}
|
||||
|
||||
func (ms Milliseconds) Duration() time.Duration {
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
|
78
gateway/activity.go
Normal file
78
gateway/activity.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package gateway
|
||||
|
||||
import "github.com/diamondburned/arikawa/discord"
|
||||
|
||||
type Status string
|
||||
|
||||
const (
|
||||
UnknownStatus Status = ""
|
||||
OnlineStatus Status = "online"
|
||||
DoNotDisturbStatus Status = "dnd"
|
||||
IdleStatus Status = "idle"
|
||||
InvisibleStatus Status = "invisible"
|
||||
OfflineStatus Status = "offline"
|
||||
)
|
||||
|
||||
type Activity struct {
|
||||
Name string `json:"name"`
|
||||
Type ActivityType `json:"type"`
|
||||
URL discord.URL `json:"url"`
|
||||
|
||||
// User only
|
||||
|
||||
CreatedAt discord.UnixTimestamp `json:"created_at"`
|
||||
Timestamps struct {
|
||||
Start discord.UnixMsTimestamp `json:"start,omitempty"`
|
||||
End discord.UnixMsTimestamp `json:"end,omitempty"`
|
||||
} `json:"timestamps,omitempty"`
|
||||
|
||||
ApplicationID discord.Snowflake `json:"application_id,omitempty"`
|
||||
Details string `json:"details,omitempty"`
|
||||
State string `json:"state,omitempty"` // party status
|
||||
Emoji discord.Emoji `json:"emoji,omitempty"`
|
||||
|
||||
Party struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Size [2]int `json:"size,omitempty"` // [ current, max ]
|
||||
} `json:"party,omitempty"`
|
||||
|
||||
Assets struct {
|
||||
LargeImage string `json:"large_image,omitempty"` // id
|
||||
LargeText string `json:"large_text,omitempty"`
|
||||
SmallImage string `json:"small_image,omitempty"` // id
|
||||
SmallText string `json:"small_text,omitempty"`
|
||||
} `json:"assets,omitempty"`
|
||||
|
||||
Secrets struct {
|
||||
Join string `json:"join,omitempty"`
|
||||
Spectate string `json:"spectate,omitempty"`
|
||||
Match string `json:"match,omitempty"`
|
||||
} `json:"secrets,omitempty"`
|
||||
|
||||
Instance bool `json:"instance,omitempty"`
|
||||
Flags ActivityFlags `json:"flags,omitempty"`
|
||||
}
|
||||
|
||||
type ActivityType uint8
|
||||
|
||||
const (
|
||||
// Playing $name
|
||||
GameActivity ActivityType = iota
|
||||
// Streaming $details
|
||||
StreamingActivity
|
||||
// Listening to $name
|
||||
ListeningActivity
|
||||
// $emoji $name
|
||||
CustomActivity
|
||||
)
|
||||
|
||||
type ActivityFlags uint8
|
||||
|
||||
const (
|
||||
InstanceActivity ActivityFlags = 1 << iota
|
||||
JoinActivity
|
||||
SpectateActivity
|
||||
JoinRequestActivity
|
||||
SyncActivity
|
||||
PlayActivity
|
||||
)
|
75
gateway/commands.go
Normal file
75
gateway/commands.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
)
|
||||
|
||||
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
|
||||
|
||||
type IdentifyData struct {
|
||||
Token string `json:"token"`
|
||||
Properties IdentifyProperties `json:"properties"`
|
||||
|
||||
Compress bool `json:"compress,omitempty"` // true
|
||||
LargeThreshold uint `json:"large_threshold,omitempty"` // 50
|
||||
GuildSubscription bool `json:"guild_subscriptions"` // true
|
||||
|
||||
Shard [2]int `json:"shard"` // [ shard_id, num_shards ]
|
||||
|
||||
Presence UpdateStatusData `json:"presence,omitempty"`
|
||||
}
|
||||
|
||||
type IdentifyProperties struct {
|
||||
// Required
|
||||
OS string `json:"os"` // GOOS
|
||||
Browser string `json:"browser"` // Arikawa
|
||||
Device string `json:"device"` // Arikawa
|
||||
|
||||
// Optional
|
||||
BrowserUserAgent string `json:"browser_user_agent,omitempty"`
|
||||
BrowserVersion string `json:"browser_version,omitempty"`
|
||||
OsVersion string `json:"os_version,omitempty"`
|
||||
Referrer string `json:"referrer,omitempty"`
|
||||
ReferringDomain string `json:"referring_domain,omitempty"`
|
||||
}
|
||||
|
||||
func (g *Gateway) Identify() error {
|
||||
return g.Send(IdentifyOP, g.Identity)
|
||||
}
|
||||
|
||||
type ResumeData struct {
|
||||
Token string `json:"token"`
|
||||
SessionID string `json:"session_id"`
|
||||
Sequence int64 `json:"seq"`
|
||||
}
|
||||
|
||||
// HeartbeatData is the last sequence number to be sent.
|
||||
type HeartbeatData int
|
||||
|
||||
func (g *Gateway) Heartbeat() error {
|
||||
return g.Send(HeartbeatOP, g.Sequence.Get())
|
||||
}
|
||||
|
||||
type RequestGuildMembersData struct {
|
||||
GuildID []discord.Snowflake `json:"guild_id"`
|
||||
UserIDs []discord.Snowflake `json:"user_id,omitempty"`
|
||||
|
||||
Query string `json:"query,omitempty"`
|
||||
Limit uint `json:"limit"`
|
||||
Presences bool `json:"presences,omitempty"`
|
||||
}
|
||||
|
||||
type UpdateVoiceStateData struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
SelfMute bool `json:"self_mute"`
|
||||
SelfDeaf bool `json:"self_deaf"`
|
||||
}
|
||||
|
||||
type UpdateStatusData struct {
|
||||
Since discord.Milliseconds `json:"since,omitempty"` // 0 if not idle
|
||||
Game *Activity `json:"game,omitempty"` // nullable
|
||||
|
||||
Status Status `json:"status"`
|
||||
AFK bool `json:"afk"`
|
||||
}
|
199
gateway/events.go
Normal file
199
gateway/events.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
package gateway
|
||||
|
||||
import "github.com/diamondburned/arikawa/discord"
|
||||
|
||||
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/gateway#connecting-and-resuming
|
||||
type (
|
||||
HelloEvent struct {
|
||||
HeartbeatInterval discord.Milliseconds `json:"heartbeat_interval"`
|
||||
}
|
||||
|
||||
ReadyEvent struct {
|
||||
Version int `json:"version"`
|
||||
|
||||
User discord.User `json:"user"`
|
||||
SessionID string `json:"session_id"`
|
||||
|
||||
PrivateChannels []discord.Channel `json:"private_channels"`
|
||||
Guilds []discord.Guild `json:"guilds"`
|
||||
|
||||
Shard [2]int `json:"shard"` // [ shard_id num_shards ]
|
||||
}
|
||||
|
||||
ResumedEvent struct{}
|
||||
|
||||
// InvalidSessionEvent indicates if the event is resumable.
|
||||
InvalidSessionEvent bool
|
||||
)
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/gateway#channels
|
||||
type (
|
||||
ChannelCreateEvent discord.Channel
|
||||
ChannelUpdateEvent discord.Channel
|
||||
ChannelDeleteEvent discord.Channel
|
||||
ChannelPinsUpdateEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||
ChannelID discord.Snowflake `json:"channel_id,omitempty"`
|
||||
LastPin discord.Timestamp `json:"timestamp,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/gateway#guilds
|
||||
type (
|
||||
GuildCreateEvent discord.Guild
|
||||
GuildUpdateEvent discord.Guild
|
||||
GuildDeleteEvent struct {
|
||||
ID discord.Snowflake `json:"id"`
|
||||
// Unavailable if false == removed
|
||||
Unavailable bool `json:"unavailable"`
|
||||
}
|
||||
|
||||
GuildBanAddEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
User discord.User `json:"user"`
|
||||
}
|
||||
GuildBanRemoveEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
User discord.User `json:"user"`
|
||||
}
|
||||
|
||||
GuildEmojisUpdateEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
Emojis []discord.Emoji `json:"emoji"`
|
||||
}
|
||||
|
||||
GuildIntegrationsUpdateEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
}
|
||||
|
||||
GuildMemberAddEvent struct {
|
||||
discord.Member
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
}
|
||||
GuildMemberRemoveEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
User discord.User `json:"user"`
|
||||
}
|
||||
GuildMemberUpdateEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
Roles []discord.Snowflake `json:"roles"`
|
||||
User discord.User `json:"user"`
|
||||
Nick string `json:"nick"`
|
||||
}
|
||||
|
||||
// GuildMembersChunkEvent is sent when Guild Request Members is called.
|
||||
GuildMembersChunkEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
Members []discord.Member `json:"members"`
|
||||
|
||||
// Whatever's not found goes here
|
||||
NotFound []string `json:"not_found,omitempty"`
|
||||
|
||||
// Only filled if requested
|
||||
Presences []discord.Presence `json:"presences,omitempty"`
|
||||
}
|
||||
|
||||
GuildRoleCreateEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
Role discord.Role `json:"role"`
|
||||
}
|
||||
GuildRoleUpdateEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
Role discord.Role `json:"role"`
|
||||
}
|
||||
GuildRoleDeleteEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
RoleID discord.Snowflake `json:"role_id"`
|
||||
}
|
||||
)
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/gateway#messages
|
||||
type (
|
||||
MessageCreateEvent discord.Message
|
||||
MessageUpdateEvent discord.Message
|
||||
MessageDeleteEvent struct {
|
||||
ID discord.Snowflake `json:"id"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||
}
|
||||
MessageDeleteBulkEvent struct {
|
||||
IDs []discord.Snowflake `json:"ids"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||
}
|
||||
|
||||
MessageReactionAddEvent struct {
|
||||
UserID discord.Snowflake `json:"user_id"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
MessageID discord.Snowflake `json:"message_id"`
|
||||
|
||||
Emoji discord.Emoji `json:"emoji,omitempty"`
|
||||
|
||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||
Member *discord.Member `json:"member,omitempty"`
|
||||
}
|
||||
MessageReactionRemoveEvent struct {
|
||||
UserID discord.Snowflake `json:"user_id"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
MessageID discord.Snowflake `json:"message_id"`
|
||||
|
||||
Emoji discord.Emoji `json:"emoji"`
|
||||
|
||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||
}
|
||||
MessageReactionRemoveAllEvent struct {
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
}
|
||||
)
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/gateway#presence
|
||||
type (
|
||||
// Clients may only update their game status 5 times per 20 seconds.
|
||||
PresenceUpdateEvent struct {
|
||||
User discord.User `json:"user"`
|
||||
Nick string `json:"nick"`
|
||||
Roles []discord.Snowflake `json:"roles"`
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
|
||||
PremiumSince discord.Timestamp `json:"premium_since,omitempty"`
|
||||
|
||||
Game *Activity `json:"game"`
|
||||
Activities []Activity `json:"activities"`
|
||||
|
||||
Status Status `json:"status"`
|
||||
ClientStatus struct {
|
||||
Desktop Status `json:"status,omitempty"`
|
||||
Mobile Status `json:"mobile,omitempty"`
|
||||
Web Status `json:"web,omitempty"`
|
||||
} `json:"client_status"`
|
||||
}
|
||||
TypingStartEvent struct {
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
UserID discord.Snowflake `json:"user_id"`
|
||||
Timestamp discord.Timestamp `json:"timestamp"`
|
||||
|
||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||
Member *discord.Member `json:"member,omitempty"`
|
||||
}
|
||||
UserUpdateEvent discord.User
|
||||
)
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/gateway#voice
|
||||
type (
|
||||
VoiceStateUpdateEvent discord.VoiceState
|
||||
VoiceServerUpdateEvent struct {
|
||||
Token string `json:"token"`
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
)
|
||||
|
||||
// https://discordapp.com/developers/docs/topics/gateway#webhooks
|
||||
type (
|
||||
WebhooksUpdateEvent struct {
|
||||
GuildID discord.Snowflake `json:"guild_id"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
}
|
||||
)
|
59
gateway/events_map.go
Normal file
59
gateway/events_map.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package gateway
|
||||
|
||||
// Event is any event struct. They have an "Event" suffixed to them.
|
||||
type Event = interface{}
|
||||
|
||||
var EventCreator = map[string]func() Event{
|
||||
"HELLO": func() Event { return new(HelloEvent) },
|
||||
"READY": func() Event { return new(ReadyEvent) },
|
||||
"RESUMED": func() Event { return new(ResumedEvent) },
|
||||
"INVALID_SESSION": func() Event { return new(InvalidSessionEvent) },
|
||||
|
||||
"CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) },
|
||||
"CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) },
|
||||
"CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) },
|
||||
"CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) },
|
||||
|
||||
"GUILD_CREATE": func() Event { return new(GuildCreateEvent) },
|
||||
"GUILD_UPDATE": func() Event { return new(GuildUpdateEvent) },
|
||||
"GUILD_DELETE": func() Event { return new(GuildDeleteEvent) },
|
||||
|
||||
"GUILD_BAN_ADD": func() Event { return new(GuildBanAddEvent) },
|
||||
"GUILD_BAN_REMOVE": func() Event { return new(GuildBanRemoveEvent) },
|
||||
|
||||
"GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) },
|
||||
"GUILD_INTEGRATIONS_UPDATE": func() Event {
|
||||
return new(GuildIntegrationsUpdateEvent)
|
||||
},
|
||||
|
||||
"GUILD_MEMBER_ADD": func() Event { return new(GuildMemberAddEvent) },
|
||||
"GUILD_MEMBER_REMOVE": func() Event { return new(GuildMemberRemoveEvent) },
|
||||
"GUILD_MEMBER_UPDATE": func() Event { return new(GuildMemberUpdateEvent) },
|
||||
"GUILD_MEMBERS_CHUNK": func() Event { return new(GuildMembersChunkEvent) },
|
||||
|
||||
"GUILD_ROLE_CREATE": func() Event { return new(GuildRoleCreateEvent) },
|
||||
"GUILD_ROLE_UPDATE": func() Event { return new(GuildRoleUpdateEvent) },
|
||||
"GUILD_ROLE_DELETE": func() Event { return new(GuildRoleDeleteEvent) },
|
||||
|
||||
"MESSAGE_CREATE": func() Event { return new(MessageCreateEvent) },
|
||||
"MESSAGE_UPDATE": func() Event { return new(MessageUpdateEvent) },
|
||||
"MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) },
|
||||
"MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) },
|
||||
|
||||
"MESSAGE_REACTION_ADD": func() Event {
|
||||
return new(MessageReactionAddEvent)
|
||||
},
|
||||
"MESSAGE_REACTION_REMOVE": func() Event {
|
||||
return new(MessageReactionRemoveEvent)
|
||||
},
|
||||
"MESSAGE_REACTION_REMOVE_ALL": func() Event {
|
||||
return new(MessageReactionRemoveAllEvent)
|
||||
},
|
||||
|
||||
"PRESENCE_UPDATE": func() Event { return new(PresenceUpdateEvent) },
|
||||
"TYPING_START": func() Event { return new(TypingStartEvent) },
|
||||
"USER_UPDATE": func() Event { return new(UserUpdateEvent) },
|
||||
|
||||
"VOICE_STATE_UPDATE": func() Event { return new(VoiceStateUpdateEvent) },
|
||||
"VOICE_SERVER_UPDATE": func() Event { return new(VoiceServerUpdateEvent) },
|
||||
}
|
319
gateway/gateway.go
Normal file
319
gateway/gateway.go
Normal file
|
@ -0,0 +1,319 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/api"
|
||||
"github.com/diamondburned/arikawa/httputil"
|
||||
"github.com/diamondburned/arikawa/json"
|
||||
"github.com/diamondburned/arikawa/wsutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
EndpointGateway = api.Endpoint + "gateway"
|
||||
EndpointGatewayBot = api.EndpointGateway + "/bot"
|
||||
|
||||
Version = "6"
|
||||
Encoding = "json"
|
||||
)
|
||||
|
||||
var (
|
||||
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
||||
// before Gateway cancels and fails.
|
||||
WSTimeout = wsutil.DefaultTimeout
|
||||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingForResume = errors.New(
|
||||
"missing session ID or sequence for resuming")
|
||||
ErrWSMaxTries = errors.New("max tries reached")
|
||||
)
|
||||
|
||||
func GatewayURL() (string, error) {
|
||||
var Gateway struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
return Gateway.URL, httputil.DefaultClient.RequestJSON(
|
||||
&Gateway, "GET", EndpointGateway)
|
||||
}
|
||||
|
||||
// Identity is used as the default identity when initializing a new Gateway.
|
||||
var Identity = IdentifyProperties{
|
||||
OS: runtime.GOOS,
|
||||
Browser: "Arikawa",
|
||||
Device: "Arikawa",
|
||||
}
|
||||
|
||||
type Gateway struct {
|
||||
WS *wsutil.Websocket
|
||||
json.Driver
|
||||
|
||||
// Timeout for connecting and writing to the Websocket, uses default
|
||||
// WSTimeout (global).
|
||||
WSTimeout time.Duration
|
||||
// Retries on connect and reconnect.
|
||||
WSRetries uint // 3
|
||||
|
||||
Events chan Event
|
||||
SessionID string
|
||||
|
||||
URL string // URL
|
||||
|
||||
Identity *IdentifyData
|
||||
Pacemaker *Pacemaker
|
||||
Sequence Sequence
|
||||
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
|
||||
// Only use for debugging
|
||||
|
||||
// If this channel is non-nil, all incoming OP packets will also be sent
|
||||
// here.
|
||||
OP chan Event
|
||||
|
||||
// Filled by methods, internal use
|
||||
paceDeath <-chan error
|
||||
handler chan struct{}
|
||||
}
|
||||
|
||||
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
|
||||
// information, refer to NewGatewayWithDriver.
|
||||
func NewGateway(token string) (*Gateway, error) {
|
||||
return NewGatewayWithDriver(token, json.Default{})
|
||||
}
|
||||
|
||||
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
|
||||
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
||||
URL, err := GatewayURL()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
|
||||
}
|
||||
|
||||
g := &Gateway{
|
||||
URL: URL,
|
||||
Driver: driver,
|
||||
WSTimeout: WSTimeout,
|
||||
Events: make(chan Event),
|
||||
Identity: &IdentifyData{
|
||||
Token: token,
|
||||
Properties: Identity,
|
||||
Compress: true,
|
||||
LargeThreshold: 50,
|
||||
GuildSubscription: true,
|
||||
},
|
||||
Sequence: NewSequence(),
|
||||
}
|
||||
|
||||
// Parameters for the gateway
|
||||
param := url.Values{}
|
||||
param.Set("v", Version)
|
||||
param.Set("encoding", Encoding)
|
||||
// Append the form to the URL
|
||||
URL += "?" + param.Encode()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create a new undialed Websocket.
|
||||
ws, err := wsutil.NewCustom(ctx, wsutil.NewConn(driver), URL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to connect to Gateway "+URL)
|
||||
}
|
||||
g.WS = ws
|
||||
|
||||
// Try and dial it
|
||||
return g, g.connect()
|
||||
}
|
||||
|
||||
// Close closes the underlying Websocket connection.
|
||||
func (g *Gateway) Close() error {
|
||||
return g.WS.Close(nil)
|
||||
}
|
||||
|
||||
// Reconnects and resumes.
|
||||
func (g *Gateway) Reconnect() error {
|
||||
// Close, but we don't care about the error (I think)
|
||||
g.Close()
|
||||
// Actually a reconnect at this point.
|
||||
return g.connect()
|
||||
}
|
||||
|
||||
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
|
||||
// from a dead connection. Start() resumes from a dead connection.
|
||||
func (g *Gateway) Resume() error {
|
||||
var (
|
||||
ses = g.SessionID
|
||||
seq = g.Sequence.Get()
|
||||
)
|
||||
|
||||
if ses == "" || seq == 0 {
|
||||
return ErrMissingForResume
|
||||
}
|
||||
|
||||
return g.Send(ResumeOP, ResumeData{
|
||||
Token: g.Identity.Token,
|
||||
SessionID: ses,
|
||||
Sequence: seq,
|
||||
})
|
||||
}
|
||||
|
||||
// Start authenticates with the websocket, or resume from a dead Websocket
|
||||
// connection. This function doesn't block. To block until a fatal error, use
|
||||
// Wait().
|
||||
func (g *Gateway) Start() error {
|
||||
// This is where we'll get our events
|
||||
ch := g.WS.Listen()
|
||||
|
||||
// Wait for an OP 10 Hello
|
||||
var hello HelloEvent
|
||||
if err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "Error at Hello")
|
||||
}
|
||||
|
||||
// Send Discord either the Identify packet (if it's a fresh connection), or
|
||||
// a Resume packet (if it's a dead connection).
|
||||
if g.SessionID == "" {
|
||||
// SessionID is empty, so this is a completely new session.
|
||||
if err := g.Identify(); err != nil {
|
||||
return errors.Wrap(err, "Failed to identify")
|
||||
}
|
||||
|
||||
// We should now expect a Ready event.
|
||||
var ready ReadyEvent
|
||||
if err := AssertEvent(g, <-ch, DispatchOP, &ready); err != nil {
|
||||
return errors.Wrap(err, "Error at Ready")
|
||||
}
|
||||
} else {
|
||||
if err := g.Resume(); err != nil {
|
||||
return errors.Wrap(err, "Failed to resume")
|
||||
}
|
||||
|
||||
// We should now expect a Resumed event.
|
||||
var resumed ResumedEvent
|
||||
if err := AssertEvent(g, <-ch, DispatchOP, &resumed); err != nil {
|
||||
return errors.Wrap(err, "Error at Resumed")
|
||||
}
|
||||
}
|
||||
|
||||
// Start the event handler
|
||||
g.handler = make(chan struct{})
|
||||
go g.handleWS(g.handler)
|
||||
|
||||
// Start the pacemaker with the heartrate received from Hello
|
||||
g.Pacemaker = &Pacemaker{
|
||||
Heartrate: hello.HeartbeatInterval.Duration(),
|
||||
Pace: g.Heartbeat,
|
||||
OnDead: g.Reconnect,
|
||||
}
|
||||
// Pacemaker dies here, only when it's fatal.
|
||||
g.paceDeath = g.Pacemaker.StartAsync()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) Wait() error {
|
||||
defer close(g.handler)
|
||||
return <-g.paceDeath
|
||||
}
|
||||
|
||||
// handleWS uses the Websocket and parses them into g.Events.
|
||||
func (g *Gateway) handleWS(stop <-chan struct{}) {
|
||||
ch := g.WS.Listen()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case ev := <-ch:
|
||||
// Check for error
|
||||
if ev.Error != nil {
|
||||
g.ErrorLog(errors.Wrap(ev.Error, "WS error"))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle the event
|
||||
if err := HandleEvent(g, ev.Data); err != nil {
|
||||
g.ErrorLog(errors.Wrap(ev.Error, "WS handler error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) Send(code OPCode, v interface{}) error {
|
||||
var op = OP{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
if v != nil {
|
||||
b, err := g.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode v")
|
||||
}
|
||||
|
||||
op.Data = b
|
||||
}
|
||||
|
||||
b, err := g.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode payload")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
return g.WS.Send(ctx, b)
|
||||
}
|
||||
|
||||
func (g *Gateway) connect() error {
|
||||
// Reconnect timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
var Lerr error
|
||||
|
||||
for i := uint(0); i < g.WSRetries; i++ {
|
||||
// Check if context is expired
|
||||
if err := ctx.Err(); err != nil {
|
||||
// Don't bother if it's expired
|
||||
return err
|
||||
}
|
||||
|
||||
// Reconnect to the Gateway
|
||||
if err := g.WS.Redial(ctx); err != nil {
|
||||
// Save the error, retry again
|
||||
Lerr = errors.Wrap(err, "Failed to reconnect")
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to resume the connection
|
||||
if err := g.Start(); err != nil {
|
||||
// If the connection is rate limited (documented behavior):
|
||||
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
|
||||
if err == ErrInvalidSession {
|
||||
continue // retry
|
||||
}
|
||||
|
||||
// Else, fatal
|
||||
return errors.Wrap(err, "Failed to start gateway")
|
||||
}
|
||||
|
||||
// Started successfully, return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if any earlier errors are fatal
|
||||
if Lerr != nil {
|
||||
return Lerr
|
||||
}
|
||||
|
||||
// We tried.
|
||||
return ErrWSMaxTries
|
||||
}
|
153
gateway/op.go
Normal file
153
gateway/op.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/diamondburned/arikawa/json"
|
||||
"github.com/diamondburned/arikawa/wsutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OPCode uint8
|
||||
|
||||
const (
|
||||
DispatchOP OPCode = iota // recv
|
||||
HeartbeatOP // send/recv
|
||||
IdentifyOP // send...
|
||||
StatusUpdateOP //
|
||||
VoiceStateUpdateOP //
|
||||
VoiceServerPingOP //
|
||||
ResumeOP //
|
||||
ReconnectOP // recv
|
||||
RequestGuildMembersOP // send
|
||||
InvalidSessionOP // recv...
|
||||
HelloOP
|
||||
HeartbeatAckOP
|
||||
_
|
||||
CallConnectOP
|
||||
GuildSubscriptionsOP
|
||||
)
|
||||
|
||||
type OP struct {
|
||||
Code OPCode `json:"op"`
|
||||
Data json.Raw `json:"d,omitempty"`
|
||||
|
||||
// Only for Dispatch (op 0)
|
||||
Sequence int `json:"s,omitempty"`
|
||||
EventName string `json:"t,omitempty"`
|
||||
}
|
||||
|
||||
var ErrInvalidSession = errors.New("Invalid session")
|
||||
|
||||
func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
||||
if ev.Error != nil {
|
||||
return nil, ev.Error
|
||||
}
|
||||
|
||||
var op *OP
|
||||
if err := driver.Unmarshal(ev.Data, &op); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to decode payload")
|
||||
}
|
||||
|
||||
if op.Code == InvalidSessionOP {
|
||||
return op, ErrInvalidSession
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func DecodeEvent(driver json.Driver,
|
||||
ev wsutil.Event, v interface{}) (OPCode, error) {
|
||||
|
||||
op, err := DecodeOP(driver, ev)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
||||
return 0, errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
return op.Code, nil
|
||||
}
|
||||
|
||||
func AssertEvent(driver json.Driver,
|
||||
ev wsutil.Event, code OPCode, v interface{}) error {
|
||||
|
||||
op, err := DecodeOP(driver, ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if op.Code != code {
|
||||
return fmt.Errorf(
|
||||
"Unexpected OP Code: %d, expected %d (%s)",
|
||||
op.Code, code, op.Data,
|
||||
)
|
||||
}
|
||||
|
||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
||||
return errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleEvent(g *Gateway, data []byte) error {
|
||||
// Parse the raw data into an OP struct
|
||||
var op *OP
|
||||
if err := g.Driver.Unmarshal(data, &op); err != nil {
|
||||
return errors.Wrap(err, "OP error")
|
||||
}
|
||||
|
||||
if g.OP != nil {
|
||||
g.OP <- op
|
||||
}
|
||||
|
||||
switch op.Code {
|
||||
case HeartbeatAckOP:
|
||||
// Heartbeat from the server?
|
||||
g.Pacemaker.Echo()
|
||||
|
||||
case HeartbeatOP:
|
||||
// Server requesting a heartbeat.
|
||||
return g.Pacemaker.Pace()
|
||||
|
||||
case ReconnectOP:
|
||||
// Server requests to reconnect, die and retry.
|
||||
return g.Reconnect()
|
||||
|
||||
case InvalidSessionOP:
|
||||
// Invalid session, respond with Identify.
|
||||
return g.Identify()
|
||||
|
||||
case HelloOP:
|
||||
// What is this OP doing here???
|
||||
return nil
|
||||
|
||||
case DispatchOP:
|
||||
// Check if we know the event
|
||||
fn, ok := EventCreator[op.EventName]
|
||||
if !ok {
|
||||
return errors.New("Unknown event: " + op.EventName)
|
||||
}
|
||||
|
||||
// Make a new pointer to the event
|
||||
var ev = fn()
|
||||
|
||||
// Try and parse the event
|
||||
if err := g.Driver.Unmarshal(op.Data, ev); err != nil {
|
||||
return errors.Wrap(err, "Failed to parse event "+op.EventName)
|
||||
}
|
||||
|
||||
// Throw the event into a channel, it's valid now.
|
||||
g.Events <- ev
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"Unknown OP code %d (event %s)", op.Code, op.EventName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
80
gateway/pacemaker.go
Normal file
80
gateway/pacemaker.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var ErrDead = errors.New("no heartbeat replied")
|
||||
|
||||
type Pacemaker struct {
|
||||
// Heartrate is the received duration between heartbeats.
|
||||
Heartrate time.Duration
|
||||
|
||||
// LastBeat logs the received heartbeats, with the newest one
|
||||
// first.
|
||||
LastBeat [2]time.Time
|
||||
|
||||
// Any callback that returns an error will stop the pacer.
|
||||
Pace func() error
|
||||
// Event
|
||||
OnDead func() error
|
||||
|
||||
stop chan<- struct{}
|
||||
}
|
||||
|
||||
func (p *Pacemaker) Echo() {
|
||||
// Swap our received heartbeats
|
||||
p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0]
|
||||
}
|
||||
|
||||
// Dead, if true, will have Pace return an ErrDead.
|
||||
func (p *Pacemaker) Dead() bool {
|
||||
if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
|
||||
}
|
||||
|
||||
func (p *Pacemaker) Stop() {
|
||||
close(p.stop)
|
||||
}
|
||||
|
||||
// Start beats until it's dead.
|
||||
func (p *Pacemaker) Start() error {
|
||||
tick := time.NewTicker(p.Heartrate)
|
||||
defer tick.Stop()
|
||||
|
||||
stop := make(chan struct{})
|
||||
p.stop = stop
|
||||
|
||||
for {
|
||||
if err := p.Pace(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !p.Dead() {
|
||||
continue
|
||||
}
|
||||
if err := p.OnDead(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stop:
|
||||
return nil
|
||||
case <-tick.C:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pacemaker) StartAsync() (death <-chan error) {
|
||||
var ch = make(chan error)
|
||||
go func() {
|
||||
ch <- p.Start()
|
||||
}()
|
||||
return ch
|
||||
}
|
14
gateway/sequence.go
Normal file
14
gateway/sequence.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package gateway
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type Sequence struct {
|
||||
seq *int64
|
||||
}
|
||||
|
||||
func NewSequence() Sequence {
|
||||
return Sequence{new(int64)}
|
||||
}
|
||||
|
||||
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(s.seq, seq) }
|
||||
func (s *Sequence) Get() int64 { return atomic.LoadInt64(s.seq) }
|
1
go.mod
1
go.mod
|
@ -8,6 +8,7 @@ require (
|
|||
github.com/gorilla/websocket v1.4.1
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
||||
go.uber.org/atomic v1.4.0
|
||||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
|
||||
nhooyr.io/websocket v1.7.4
|
||||
|
|
|
@ -16,6 +16,8 @@ type Client struct {
|
|||
SchemaEncoder
|
||||
}
|
||||
|
||||
var DefaultClient = NewClient()
|
||||
|
||||
func NewClient() Client {
|
||||
return Client{
|
||||
Client: http.Client{
|
||||
|
|
|
@ -59,6 +59,7 @@ func WithJSONBody(json json.Driver, v interface{}) RequestOption {
|
|||
return err
|
||||
}
|
||||
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Body = ioutil.NopCloser(&buf)
|
||||
return nil
|
||||
}
|
||||
|
|
21
json/json.go
21
json/json.go
|
@ -38,6 +38,27 @@ func getBool(Bool bool) OptionBool {
|
|||
return &Bool
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
// Raw is a raw encoded JSON value. It implements Marshaler and Unmarshaler and
|
||||
// can be used to delay JSON decoding or precompute a JSON encoding. It's taken
|
||||
// from encoding/json.
|
||||
type Raw []byte
|
||||
|
||||
// Raw returns m as the JSON encoding of m.
|
||||
func (m Raw) MarshalJSON() ([]byte, error) {
|
||||
if m == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Raw) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type Driver interface {
|
||||
Marshal(v interface{}) ([]byte, error)
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
|
|
83
session/session.go
Normal file
83
session/session.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/api"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/json"
|
||||
)
|
||||
|
||||
/*
|
||||
TODO:
|
||||
|
||||
and Session's supposed to handle callbacks too kec
|
||||
|
||||
might move all these to Gateway, dunno
|
||||
|
||||
could have a lock on Listen()
|
||||
|
||||
I can actually see people using gateway channels to handle things
|
||||
themselves without any callback abstractions, so this is probably the way to go
|
||||
|
||||
welp shit
|
||||
|
||||
rewrite imminent
|
||||
*/
|
||||
|
||||
type Session struct {
|
||||
API *api.Client
|
||||
Gateway *gateway.Conn
|
||||
gatewayOnce sync.Once
|
||||
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
|
||||
// Heartrate is the received duration between heartbeats.
|
||||
Heartrate time.Duration
|
||||
|
||||
// LastBeat logs the received heartbeats, with the newest one
|
||||
// first.
|
||||
LastBeat [2]time.Time
|
||||
|
||||
// Used for Close()
|
||||
stoppers []chan<- struct{}
|
||||
closers []func() error
|
||||
}
|
||||
|
||||
func New(token string) (*Session, error) {
|
||||
// Initialize the session and the API interface
|
||||
s := &Session{}
|
||||
s.API = api.NewClient(token)
|
||||
|
||||
// Default logger
|
||||
s.ErrorLog = func(err error) {
|
||||
log.Println("Arikawa/session error:", err)
|
||||
}
|
||||
|
||||
// Connect to the Gateway
|
||||
c, err := gateway.NewConn(json.Default{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Gateway = c
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Session) Close() error {
|
||||
for _, stop := range s.stoppers {
|
||||
close(stop)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
for _, closer := range s.closers {
|
||||
if cerr := closer(); cerr != nil {
|
||||
err = cerr
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -13,11 +13,15 @@ import (
|
|||
)
|
||||
|
||||
var WSBuffer = 12
|
||||
var WSReadLimit = 512 * 1024 // 512KiB
|
||||
var WSReadLimit = 4096 // 4096 bytes
|
||||
|
||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||
// This connection expects the driver to handle compression by itself.
|
||||
type Connection interface {
|
||||
// Dial dials the address (string). Context needs to be passed in for
|
||||
// timeout. This method should also be re-usable after Close is called.
|
||||
Dial(context.Context, string) error
|
||||
|
||||
// Listen sends over events constantly. Error will be non-nil if Data is
|
||||
// nil, so check for Error first.
|
||||
Listen() <-chan Event
|
||||
|
@ -67,10 +71,6 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
|||
}
|
||||
|
||||
func (c *Conn) Listen() <-chan Event {
|
||||
if c.events != nil {
|
||||
return c.events
|
||||
}
|
||||
|
||||
c.events = make(chan Event, WSBuffer)
|
||||
go func() { c.readLoop(c.events) }()
|
||||
return c.events
|
||||
|
@ -139,6 +139,9 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
|
|||
}
|
||||
|
||||
func (c *Conn) Close(err error) error {
|
||||
// Close the event channels
|
||||
defer close(c.events)
|
||||
|
||||
if err == nil {
|
||||
return c.Conn.Close(websocket.StatusNormalClosure, "")
|
||||
}
|
||||
|
|
|
@ -9,3 +9,7 @@ import (
|
|||
func NewSendLimiter() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Every(time.Minute), 120)
|
||||
}
|
||||
|
||||
func NewDialLimiter() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Every(5*time.Second), 1)
|
||||
}
|
||||
|
|
71
wsutil/ws.go
71
wsutil/ws.go
|
@ -19,49 +19,62 @@ type Event struct {
|
|||
}
|
||||
|
||||
type Websocket struct {
|
||||
conn Connection
|
||||
Conn Connection
|
||||
Addr string
|
||||
|
||||
WriteTimeout time.Duration
|
||||
SendLimiter *rate.Limiter
|
||||
SendLimiter *rate.Limiter
|
||||
DialLimiter *rate.Limiter
|
||||
|
||||
listener <-chan Event
|
||||
}
|
||||
|
||||
func New(ctx context.Context,
|
||||
driver json.Driver, addr string) (*Websocket, error) {
|
||||
|
||||
if driver == nil {
|
||||
driver = json.Default{}
|
||||
}
|
||||
|
||||
c := NewConn(driver)
|
||||
if err := c.Dial(ctx, addr); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to dial")
|
||||
}
|
||||
|
||||
return NewWithConn(c), nil
|
||||
func New(ctx context.Context, addr string) (*Websocket, error) {
|
||||
return NewCustom(ctx, NewConn(json.Default{}), addr)
|
||||
}
|
||||
|
||||
// NewWithConn uses an already-dialed connection for Websocket.
|
||||
func NewWithConn(conn Connection) *Websocket {
|
||||
return &Websocket{
|
||||
conn: conn,
|
||||
// NewCustom creates a new undialed Websocket.
|
||||
func NewCustom(
|
||||
ctx context.Context, conn Connection, addr string) (*Websocket, error) {
|
||||
|
||||
WriteTimeout: DefaultTimeout,
|
||||
SendLimiter: NewSendLimiter(),
|
||||
ws := &Websocket{
|
||||
Conn: conn,
|
||||
Addr: addr,
|
||||
|
||||
SendLimiter: NewSendLimiter(),
|
||||
DialLimiter: NewDialLimiter(),
|
||||
}
|
||||
|
||||
return ws, nil
|
||||
}
|
||||
|
||||
func (ws *Websocket) Redial(ctx context.Context) error {
|
||||
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
||||
// Expired, fatal error
|
||||
return errors.Wrap(err, "Failed to wait")
|
||||
}
|
||||
|
||||
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
|
||||
return errors.Wrap(err, "Failed to dial")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *Websocket) Listen() <-chan Event {
|
||||
return ws.conn.Listen()
|
||||
if ws.listener == nil {
|
||||
ws.listener = ws.Conn.Listen()
|
||||
}
|
||||
return ws.listener
|
||||
}
|
||||
|
||||
func (ws *Websocket) Send(b []byte) error {
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(), ws.WriteTimeout)
|
||||
defer cancel()
|
||||
|
||||
func (ws *Websocket) Send(ctx context.Context, b []byte) error {
|
||||
if err := ws.SendLimiter.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "SendLimiter failed")
|
||||
}
|
||||
|
||||
return ws.conn.Send(ctx, b)
|
||||
return ws.Conn.Send(ctx, b)
|
||||
}
|
||||
|
||||
func (ws *Websocket) Close(err error) error {
|
||||
return ws.Conn.Close(err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue