From b9384042bb702105b15c02663c883586a64689c0 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sun, 19 Apr 2020 09:17:04 -0700 Subject: [PATCH] Gateway: Added GatewayBot --- gateway/gateway.go | 79 +++++++++++++++++++++++++++------------ utils/httputil/options.go | 9 +++++ 2 files changed, 64 insertions(+), 24 deletions(-) diff --git a/gateway/gateway.go b/gateway/gateway.go index 110115f..23a738c 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -10,11 +10,13 @@ package gateway import ( "context" "log" + "net/http" "net/url" "sync" "time" "github.com/diamondburned/arikawa/api" + "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/utils/httputil" "github.com/diamondburned/arikawa/utils/json" "github.com/diamondburned/arikawa/utils/wsutil" @@ -48,18 +50,48 @@ var ( ) var ( - ErrMissingForResume = errors.New( - "missing session ID or sequence for resuming") - ErrWSMaxTries = errors.New("max tries reached") + ErrMissingForResume = errors.New("missing session ID or sequence for resuming") + ErrWSMaxTries = errors.New("max tries reached") ) -func GatewayURL() (string, error) { - var Gateway struct { - URL string `json:"url"` - } +// GatewayBotData contains the GatewayURL as well as extra metadata on how to +// shard bots. +type GatewayBotData struct { + URL string `json:"url"` + Shards int `json:"shards,omitempty"` + StartLimit *SessionStartLimit `json:"session_start_limit"` +} - return Gateway.URL, httputil.DefaultClient.RequestJSON( - &Gateway, "GET", EndpointGateway) +// SessionStartLimit is the information on the current session start limit. It's +// used in GatewayBotData. +type SessionStartLimit struct { + Total int `json:"total"` + Remaining int `json:"remaining"` + ResetAfter discord.Milliseconds `json:"reset_after"` +} + +// GatewayURL asks Discord for a Websocket URL to the Gateway. +func GatewayURL() (string, error) { + var g GatewayBotData + + return g.URL, httputil.DefaultClient.RequestJSON( + &g, "GET", + EndpointGateway, + ) +} + +// GatewayBot fetches the Gateway URL along with extra metadata. The token +// passed in will NOT be prefixed with Bot. +func GatewayBot(token string) (*GatewayBotData, error) { + var g *GatewayBotData + + return g, httputil.DefaultClient.RequestJSON( + &g, "GET", + EndpointGatewayBot, + httputil.WithHeaders(http.Header{ + "Authorization": {token}, + }), + ) } type Gateway struct { @@ -117,16 +149,6 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { return nil, errors.Wrap(err, "Failed to get gateway endpoint") } - g := &Gateway{ - Driver: driver, - WSTimeout: WSTimeout, - Events: make(chan Event, WSBuffer), - Identifier: DefaultIdentifier(token), - Sequence: NewSequence(), - ErrorLog: WSError, - AfterClose: func(error) {}, - } - // Parameters for the gateway param := url.Values{ "v": {Version}, @@ -136,11 +158,20 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { // Append the form to the URL URL += "?" + param.Encode() - // Create a new undialed Websocket. - g.WS = wsutil.NewCustom(wsutil.NewConn(driver), URL) + return NewCustomGateway(URL, token, driver), nil +} - // Try and dial it - return g, nil +func NewCustomGateway(gatewayURL, token string, driver json.Driver) *Gateway { + return &Gateway{ + WS: wsutil.NewCustom(wsutil.NewConn(driver), gatewayURL), + Driver: driver, + WSTimeout: WSTimeout, + Events: make(chan Event, WSBuffer), + Identifier: DefaultIdentifier(token), + Sequence: NewSequence(), + ErrorLog: WSError, + AfterClose: func(error) {}, + } } // Close closes the underlying Websocket connection. @@ -282,7 +313,7 @@ func (g *Gateway) start() error { // Expect either READY or RESUMED before continuing. WSDebug("Waiting for either READY or RESUMED.") - // WaitForEvent should + // WaitForEvent should err := WaitForEvent(g, ch, func(op *OP) bool { switch op.EventName { case "READY": diff --git a/utils/httputil/options.go b/utils/httputil/options.go index 7d93b0c..f1910e6 100644 --- a/utils/httputil/options.go +++ b/utils/httputil/options.go @@ -19,6 +19,15 @@ func MultipartRequest(r *http.Request) error { return nil } +func WithHeaders(headers http.Header) RequestOption { + return func(r *http.Request) error { + for key, values := range headers { + r.Header[key] = append(r.Header[key], values...) + } + return nil + } +} + func WithContentType(ctype string) RequestOption { return func(r *http.Request) error { r.Header.Set("Content-Type", ctype)