1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-30 18:53:30 +00:00

rate: Add rate.AcquireOptions

This commit is contained in:
Maximilian von Lindern 2021-05-29 22:28:45 +02:00 committed by diamondburned
parent 8a7c6c48a7
commit 428ef4ac70
2 changed files with 61 additions and 30 deletions

View file

@ -26,6 +26,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa/v3)"
type Client struct { type Client struct {
*httputil.Client *httputil.Client
*Session *Session
AcquireOptions rate.AcquireOptions
} }
func NewClient(token string) *Client { func NewClient(token string) *Client {
@ -33,31 +34,45 @@ func NewClient(token string) *Client {
} }
func NewCustomClient(token string, httpClient *httputil.Client) *Client { func NewCustomClient(token string, httpClient *httputil.Client) *Client {
ses := Session{ c := &Client{
Limiter: rate.NewLimiter(Path), Session: &Session{
Token: token, Limiter: rate.NewLimiter(Path),
UserAgent: UserAgent, Token: token,
UserAgent: UserAgent,
},
Client: httpClient.Copy(),
} }
hcl := httpClient.Copy() c.Client.OnRequest = append(c.Client.OnRequest, c.InjectRequest)
hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest) c.Client.OnResponse = append(c.Client.OnResponse, c.OnResponse)
hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
return &Client{ return c
Client: hcl,
Session: &ses,
}
} }
// WithContext returns a shallow copy of Client with the given context. It's // WithContext returns a shallow copy of Client with the given context. It's
// used for method timeouts and such. This method is thread-safe. // used for method timeouts and such. This method is thread-safe.
func (c *Client) WithContext(ctx context.Context) *Client { func (c *Client) WithContext(ctx context.Context) *Client {
return &Client{ return &Client{
Client: c.Client.WithContext(ctx), Client: c.Client.WithContext(ctx),
Session: c.Session, Session: c.Session,
AcquireOptions: c.AcquireOptions,
} }
} }
func (c *Client) InjectRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {c.Session.Token},
"User-Agent": {c.Session.UserAgent},
})
ctx := c.AcquireOptions.Context(r.GetContext())
return c.Session.Limiter.Acquire(ctx, r.GetPath())
}
func (c *Client) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
return c.Session.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
}
// Session keeps a single session. This is typically wrapped around Client. // Session keeps a single session. This is typically wrapped around Client.
type Session struct { type Session struct {
Limiter *rate.Limiter Limiter *rate.Limiter
@ -65,17 +80,3 @@ type Session struct {
Token string Token string
UserAgent string UserAgent string
} }
func (s *Session) InjectRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {s.Token},
"User-Agent": {s.UserAgent},
})
// Rate limit stuff
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
}
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
}

View file

@ -19,8 +19,10 @@ import (
const ExtraDelay = 250 * time.Millisecond const ExtraDelay = 250 * time.Millisecond
// ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit // ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
// exceeds the deadline of the context.Context. // exceeds the deadline of the context.Context or api.AcquireOptions.DontWait
var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline") // is set to true
var ErrTimedOutEarly = errors.New(
"rate: rate limit exceeds context deadline or is blocked acquire options")
// This makes me suicidal. // This makes me suicidal.
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go // https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
@ -43,6 +45,25 @@ type CustomRateLimit struct {
Reset time.Duration Reset time.Duration
} }
type contextKey uint8
const (
// AcquireOptionsKey is the key used to store the AcquireOptions in the
// context.
acquireOptionsKey contextKey = iota
)
type AcquireOptions struct {
// DontWait prevents rate.Limiters from waiting for a rate limit. Instead
// they will return an rate.ErrTimedOutEarly.
DontWait bool
}
// Context wraps the given ctx to have the AcquireOptions.
func (opts AcquireOptions) Context(ctx context.Context) context.Context {
return context.WithValue(ctx, acquireOptionsKey, opts)
}
type bucket struct { type bucket struct {
lock moreatomic.CtxMutex lock moreatomic.CtxMutex
custom *CustomRateLimit custom *CustomRateLimit
@ -99,6 +120,13 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
// Acquire acquires the rate limiter for the given URL bucket. // Acquire acquires the rate limiter for the given URL bucket.
func (l *Limiter) Acquire(ctx context.Context, path string) error { func (l *Limiter) Acquire(ctx context.Context, path string) error {
var options AcquireOptions
if untypedOptions := ctx.Value(acquireOptionsKey); untypedOptions != nil {
// Zero value are default anyways, so we can ignore ok.
options, _ = untypedOptions.(AcquireOptions)
}
b := l.getBucket(path, true) b := l.getBucket(path, true)
if err := b.lock.Lock(ctx); err != nil { if err := b.lock.Lock(ctx); err != nil {
@ -118,7 +146,9 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
} }
if until.After(now) { if until.After(now) {
if deadline, ok := ctx.Deadline(); ok && until.After(deadline) { if options.DontWait {
return ErrTimedOutEarly
} else if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
return ErrTimedOutEarly return ErrTimedOutEarly
} }