1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-27 17:23:00 +00:00

rate: Add rate.AcquireOptions

This commit is contained in:
Maximilian von Lindern 2021-05-29 22:28:45 +02:00 committed by diamondburned
parent 8a7c6c48a7
commit 428ef4ac70
2 changed files with 61 additions and 30 deletions

View file

@ -26,6 +26,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa/v3)"
type Client struct {
*httputil.Client
*Session
AcquireOptions rate.AcquireOptions
}
func NewClient(token string) *Client {
@ -33,31 +34,45 @@ func NewClient(token string) *Client {
}
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
ses := Session{
Limiter: rate.NewLimiter(Path),
Token: token,
UserAgent: UserAgent,
c := &Client{
Session: &Session{
Limiter: rate.NewLimiter(Path),
Token: token,
UserAgent: UserAgent,
},
Client: httpClient.Copy(),
}
hcl := httpClient.Copy()
hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest)
hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
c.Client.OnRequest = append(c.Client.OnRequest, c.InjectRequest)
c.Client.OnResponse = append(c.Client.OnResponse, c.OnResponse)
return &Client{
Client: hcl,
Session: &ses,
}
return c
}
// WithContext returns a shallow copy of Client with the given context. It's
// used for method timeouts and such. This method is thread-safe.
func (c *Client) WithContext(ctx context.Context) *Client {
return &Client{
Client: c.Client.WithContext(ctx),
Session: c.Session,
Client: c.Client.WithContext(ctx),
Session: c.Session,
AcquireOptions: c.AcquireOptions,
}
}
func (c *Client) InjectRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {c.Session.Token},
"User-Agent": {c.Session.UserAgent},
})
ctx := c.AcquireOptions.Context(r.GetContext())
return c.Session.Limiter.Acquire(ctx, r.GetPath())
}
func (c *Client) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
return c.Session.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
}
// Session keeps a single session. This is typically wrapped around Client.
type Session struct {
Limiter *rate.Limiter
@ -65,17 +80,3 @@ type Session struct {
Token string
UserAgent string
}
func (s *Session) InjectRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {s.Token},
"User-Agent": {s.UserAgent},
})
// Rate limit stuff
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
}
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
}

View file

@ -19,8 +19,10 @@ import (
const ExtraDelay = 250 * time.Millisecond
// ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
// exceeds the deadline of the context.Context.
var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline")
// exceeds the deadline of the context.Context or api.AcquireOptions.DontWait
// is set to true
var ErrTimedOutEarly = errors.New(
"rate: rate limit exceeds context deadline or is blocked acquire options")
// This makes me suicidal.
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
@ -43,6 +45,25 @@ type CustomRateLimit struct {
Reset time.Duration
}
type contextKey uint8
const (
// AcquireOptionsKey is the key used to store the AcquireOptions in the
// context.
acquireOptionsKey contextKey = iota
)
type AcquireOptions struct {
// DontWait prevents rate.Limiters from waiting for a rate limit. Instead
// they will return an rate.ErrTimedOutEarly.
DontWait bool
}
// Context wraps the given ctx to have the AcquireOptions.
func (opts AcquireOptions) Context(ctx context.Context) context.Context {
return context.WithValue(ctx, acquireOptionsKey, opts)
}
type bucket struct {
lock moreatomic.CtxMutex
custom *CustomRateLimit
@ -99,6 +120,13 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
// Acquire acquires the rate limiter for the given URL bucket.
func (l *Limiter) Acquire(ctx context.Context, path string) error {
var options AcquireOptions
if untypedOptions := ctx.Value(acquireOptionsKey); untypedOptions != nil {
// Zero value are default anyways, so we can ignore ok.
options, _ = untypedOptions.(AcquireOptions)
}
b := l.getBucket(path, true)
if err := b.lock.Lock(ctx); err != nil {
@ -118,7 +146,9 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
}
if until.After(now) {
if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
if options.DontWait {
return ErrTimedOutEarly
} else if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
return ErrTimedOutEarly
}