1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-07-23 13:20:51 +00:00

Compare commits

...

3 commits

Author SHA1 Message Date
diamondburned 9925461a25 Gateway: Potential fix for URL() 2021-04-05 12:20:56 -07:00
diamondburned c9a7ec8122 Gateway: Add URL test 2021-04-05 12:20:56 -07:00
diamondburned 2dadb0701d Gateway: Add automatic BotData connection
This commit modifies Gateway constructors to allow the user to easily
feed existing Identifier instances as well as updating those instances
to adhere to the Discord-returned gateway rate limits.

These changes should make it easier for typical bot sharding, although
automatic sharding is not implemented.
2021-04-05 12:20:56 -07:00
2 changed files with 77 additions and 15 deletions

View file

@ -11,6 +11,7 @@ import (
"context" "context"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"sync" "sync"
"time" "time"
@ -47,19 +48,22 @@ type BotData struct {
// SessionStartLimit is the information on the current session start limit. It's // SessionStartLimit is the information on the current session start limit. It's
// used in BotData. // used in BotData.
type SessionStartLimit struct { type SessionStartLimit struct {
Total int `json:"total"` Total int `json:"total"`
Remaining int `json:"remaining"` Remaining int `json:"remaining"`
ResetAfter discord.Milliseconds `json:"reset_after"` ResetAfter discord.Milliseconds `json:"reset_after"`
MaxConcurrency int `json:"max_concurrency"`
} }
// URL asks Discord for a Websocket URL to the Gateway. // URL asks Discord for a Websocket URL to the Gateway.
func URL() (string, error) { func URL() (string, error) {
var g BotData var g BotData
return g.URL, httputil.NewClient().RequestJSON( c := httputil.NewClient()
&g, "GET", if err := c.RequestJSON(&g, "GET", EndpointGateway); err != nil {
EndpointGateway, return "", err
) }
return g.URL, nil
} }
// BotURL fetches the Gateway URL along with extra metadata. The token // BotURL fetches the Gateway URL along with extra metadata. The token
@ -139,12 +143,31 @@ func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
return g, nil return g, nil
} }
// NewGateway creates a new Gateway with the default stdlib JSON driver. For // NewGateway creates a new Gateway to the default Discord server.
// more information, refer to NewGatewayWithDriver.
func NewGateway(token string) (*Gateway, error) { func NewGateway(token string) (*Gateway, error) {
URL, err := URL() return NewIdentifiedGateway(DefaultIdentifier(token))
if err != nil { }
return nil, errors.Wrap(err, "failed to get gateway endpoint")
// NewIdentifiedGateway creates a new Gateway with the given gateway identifier
// and the default everything. Sharded bots should prefer this function for the
// shared identifier.
func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
var gatewayURL string
var botData *BotData
var err error
if strings.HasPrefix(id.Token, "Bot ") {
botData, err = BotURL(id.Token)
if err != nil {
return nil, errors.Wrap(err, "failed to get bot data")
}
gatewayURL = botData.URL
} else {
gatewayURL, err = URL()
if err != nil {
return nil, errors.Wrap(err, "failed to get gateway endpoint")
}
} }
// Parameters for the gateway // Parameters for the gateway
@ -154,18 +177,42 @@ func NewGateway(token string) (*Gateway, error) {
} }
// Append the form to the URL // Append the form to the URL
URL += "?" + param.Encode() gatewayURL += "?" + param.Encode()
gateway := NewCustomIdentifiedGateway(gatewayURL, id)
return NewCustomGateway(URL, token), nil // Use the supplied connect rate limit, if any.
if botData != nil && botData.StartLimit != nil {
resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
limiter := gateway.Identifier.IdentifyGlobalLimit
// Update the burst to be the current given time and reset it back to
// the default when the given time is reached.
limiter.SetBurst(botData.StartLimit.Remaining)
limiter.SetBurstAt(resetAt, botData.StartLimit.Total)
// Update the maximum number of identify requests allowed per 5s.
gateway.Identifier.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
}
return gateway, nil
} }
// NewCustomGateway creates a new Gateway with a custom gateway URL and a new
// Identifier. Most bots connecting to the official server should not use these
// custom functions.
func NewCustomGateway(gatewayURL, token string) *Gateway { func NewCustomGateway(gatewayURL, token string) *Gateway {
return NewCustomIdentifiedGateway(gatewayURL, DefaultIdentifier(token))
}
// NewCustomIdentifiedGateway creates a new Gateway with a custom gateway URL
// and a pre-existing Identifier. Refer to NewCustomGateway.
func NewCustomIdentifiedGateway(gatewayURL string, id *Identifier) *Gateway {
return &Gateway{ return &Gateway{
WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL), WS: wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
WSTimeout: wsutil.WSTimeout, WSTimeout: wsutil.WSTimeout,
Events: make(chan Event, wsutil.WSBuffer), Events: make(chan Event, wsutil.WSBuffer),
Identifier: DefaultIdentifier(token), Identifier: id,
Sequence: moreatomic.NewInt64(0), Sequence: moreatomic.NewInt64(0),
ErrorLog: wsutil.WSError, ErrorLog: wsutil.WSError,

View file

@ -21,6 +21,21 @@ func init() {
} }
} }
func TestURL(t *testing.T) {
u, err := URL()
if err != nil {
t.Fatal("failed to get gateway URL:", err)
}
if u == "" {
t.Fatal("gateway URL is empty")
}
if !strings.HasPrefix(u, "wss://") {
t.Fatal("gatewayURL is invalid:", u)
}
}
func TestInvalidToken(t *testing.T) { func TestInvalidToken(t *testing.T) {
g, err := NewGateway("bad token") g, err := NewGateway("bad token")
if err != nil { if err != nil {