mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-30 18:53:30 +00:00
rate: Add rate.AcquireOptions
This commit is contained in:
parent
8a7c6c48a7
commit
428ef4ac70
55
api/api.go
55
api/api.go
|
@ -26,6 +26,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa/v3)"
|
||||||
type Client struct {
|
type Client struct {
|
||||||
*httputil.Client
|
*httputil.Client
|
||||||
*Session
|
*Session
|
||||||
|
AcquireOptions rate.AcquireOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(token string) *Client {
|
func NewClient(token string) *Client {
|
||||||
|
@ -33,31 +34,45 @@ func NewClient(token string) *Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
|
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
|
||||||
ses := Session{
|
c := &Client{
|
||||||
Limiter: rate.NewLimiter(Path),
|
Session: &Session{
|
||||||
Token: token,
|
Limiter: rate.NewLimiter(Path),
|
||||||
UserAgent: UserAgent,
|
Token: token,
|
||||||
|
UserAgent: UserAgent,
|
||||||
|
},
|
||||||
|
Client: httpClient.Copy(),
|
||||||
}
|
}
|
||||||
|
|
||||||
hcl := httpClient.Copy()
|
c.Client.OnRequest = append(c.Client.OnRequest, c.InjectRequest)
|
||||||
hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest)
|
c.Client.OnResponse = append(c.Client.OnResponse, c.OnResponse)
|
||||||
hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
|
|
||||||
|
|
||||||
return &Client{
|
return c
|
||||||
Client: hcl,
|
|
||||||
Session: &ses,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithContext returns a shallow copy of Client with the given context. It's
|
// WithContext returns a shallow copy of Client with the given context. It's
|
||||||
// used for method timeouts and such. This method is thread-safe.
|
// used for method timeouts and such. This method is thread-safe.
|
||||||
func (c *Client) WithContext(ctx context.Context) *Client {
|
func (c *Client) WithContext(ctx context.Context) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
Client: c.Client.WithContext(ctx),
|
Client: c.Client.WithContext(ctx),
|
||||||
Session: c.Session,
|
Session: c.Session,
|
||||||
|
AcquireOptions: c.AcquireOptions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) InjectRequest(r httpdriver.Request) error {
|
||||||
|
r.AddHeader(http.Header{
|
||||||
|
"Authorization": {c.Session.Token},
|
||||||
|
"User-Agent": {c.Session.UserAgent},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := c.AcquireOptions.Context(r.GetContext())
|
||||||
|
return c.Session.Limiter.Acquire(ctx, r.GetPath())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
|
||||||
|
return c.Session.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
||||||
|
}
|
||||||
|
|
||||||
// Session keeps a single session. This is typically wrapped around Client.
|
// Session keeps a single session. This is typically wrapped around Client.
|
||||||
type Session struct {
|
type Session struct {
|
||||||
Limiter *rate.Limiter
|
Limiter *rate.Limiter
|
||||||
|
@ -65,17 +80,3 @@ type Session struct {
|
||||||
Token string
|
Token string
|
||||||
UserAgent string
|
UserAgent string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) InjectRequest(r httpdriver.Request) error {
|
|
||||||
r.AddHeader(http.Header{
|
|
||||||
"Authorization": {s.Token},
|
|
||||||
"User-Agent": {s.UserAgent},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Rate limit stuff
|
|
||||||
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
|
|
||||||
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
|
||||||
}
|
|
||||||
|
|
|
@ -19,8 +19,10 @@ import (
|
||||||
const ExtraDelay = 250 * time.Millisecond
|
const ExtraDelay = 250 * time.Millisecond
|
||||||
|
|
||||||
// ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
|
// ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
|
||||||
// exceeds the deadline of the context.Context.
|
// exceeds the deadline of the context.Context or api.AcquireOptions.DontWait
|
||||||
var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline")
|
// is set to true
|
||||||
|
var ErrTimedOutEarly = errors.New(
|
||||||
|
"rate: rate limit exceeds context deadline or is blocked acquire options")
|
||||||
|
|
||||||
// This makes me suicidal.
|
// This makes me suicidal.
|
||||||
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
|
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
|
||||||
|
@ -43,6 +45,25 @@ type CustomRateLimit struct {
|
||||||
Reset time.Duration
|
Reset time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type contextKey uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// AcquireOptionsKey is the key used to store the AcquireOptions in the
|
||||||
|
// context.
|
||||||
|
acquireOptionsKey contextKey = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
type AcquireOptions struct {
|
||||||
|
// DontWait prevents rate.Limiters from waiting for a rate limit. Instead
|
||||||
|
// they will return an rate.ErrTimedOutEarly.
|
||||||
|
DontWait bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context wraps the given ctx to have the AcquireOptions.
|
||||||
|
func (opts AcquireOptions) Context(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, acquireOptionsKey, opts)
|
||||||
|
}
|
||||||
|
|
||||||
type bucket struct {
|
type bucket struct {
|
||||||
lock moreatomic.CtxMutex
|
lock moreatomic.CtxMutex
|
||||||
custom *CustomRateLimit
|
custom *CustomRateLimit
|
||||||
|
@ -99,6 +120,13 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
|
||||||
|
|
||||||
// Acquire acquires the rate limiter for the given URL bucket.
|
// Acquire acquires the rate limiter for the given URL bucket.
|
||||||
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||||
|
var options AcquireOptions
|
||||||
|
|
||||||
|
if untypedOptions := ctx.Value(acquireOptionsKey); untypedOptions != nil {
|
||||||
|
// Zero value are default anyways, so we can ignore ok.
|
||||||
|
options, _ = untypedOptions.(AcquireOptions)
|
||||||
|
}
|
||||||
|
|
||||||
b := l.getBucket(path, true)
|
b := l.getBucket(path, true)
|
||||||
|
|
||||||
if err := b.lock.Lock(ctx); err != nil {
|
if err := b.lock.Lock(ctx); err != nil {
|
||||||
|
@ -118,7 +146,9 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if until.After(now) {
|
if until.After(now) {
|
||||||
if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
|
if options.DontWait {
|
||||||
|
return ErrTimedOutEarly
|
||||||
|
} else if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
|
||||||
return ErrTimedOutEarly
|
return ErrTimedOutEarly
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue