mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-30 18:53:30 +00:00
rate: Add rate.AcquireOptions
This commit is contained in:
parent
8a7c6c48a7
commit
428ef4ac70
45
api/api.go
45
api/api.go
|
@ -26,6 +26,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa/v3)"
|
|||
type Client struct {
|
||||
*httputil.Client
|
||||
*Session
|
||||
AcquireOptions rate.AcquireOptions
|
||||
}
|
||||
|
||||
func NewClient(token string) *Client {
|
||||
|
@ -33,20 +34,19 @@ func NewClient(token string) *Client {
|
|||
}
|
||||
|
||||
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
|
||||
ses := Session{
|
||||
c := &Client{
|
||||
Session: &Session{
|
||||
Limiter: rate.NewLimiter(Path),
|
||||
Token: token,
|
||||
UserAgent: UserAgent,
|
||||
},
|
||||
Client: httpClient.Copy(),
|
||||
}
|
||||
|
||||
hcl := httpClient.Copy()
|
||||
hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest)
|
||||
hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
|
||||
c.Client.OnRequest = append(c.Client.OnRequest, c.InjectRequest)
|
||||
c.Client.OnResponse = append(c.Client.OnResponse, c.OnResponse)
|
||||
|
||||
return &Client{
|
||||
Client: hcl,
|
||||
Session: &ses,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// WithContext returns a shallow copy of Client with the given context. It's
|
||||
|
@ -55,9 +55,24 @@ func (c *Client) WithContext(ctx context.Context) *Client {
|
|||
return &Client{
|
||||
Client: c.Client.WithContext(ctx),
|
||||
Session: c.Session,
|
||||
AcquireOptions: c.AcquireOptions,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) InjectRequest(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Authorization": {c.Session.Token},
|
||||
"User-Agent": {c.Session.UserAgent},
|
||||
})
|
||||
|
||||
ctx := c.AcquireOptions.Context(r.GetContext())
|
||||
return c.Session.Limiter.Acquire(ctx, r.GetPath())
|
||||
}
|
||||
|
||||
func (c *Client) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
|
||||
return c.Session.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
||||
}
|
||||
|
||||
// Session keeps a single session. This is typically wrapped around Client.
|
||||
type Session struct {
|
||||
Limiter *rate.Limiter
|
||||
|
@ -65,17 +80,3 @@ type Session struct {
|
|||
Token string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
func (s *Session) InjectRequest(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Authorization": {s.Token},
|
||||
"User-Agent": {s.UserAgent},
|
||||
})
|
||||
|
||||
// Rate limit stuff
|
||||
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
|
||||
}
|
||||
|
||||
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
|
||||
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
||||
}
|
||||
|
|
|
@ -19,8 +19,10 @@ import (
|
|||
const ExtraDelay = 250 * time.Millisecond
|
||||
|
||||
// ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
|
||||
// exceeds the deadline of the context.Context.
|
||||
var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline")
|
||||
// exceeds the deadline of the context.Context or api.AcquireOptions.DontWait
|
||||
// is set to true
|
||||
var ErrTimedOutEarly = errors.New(
|
||||
"rate: rate limit exceeds context deadline or is blocked acquire options")
|
||||
|
||||
// This makes me suicidal.
|
||||
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
|
||||
|
@ -43,6 +45,25 @@ type CustomRateLimit struct {
|
|||
Reset time.Duration
|
||||
}
|
||||
|
||||
type contextKey uint8
|
||||
|
||||
const (
|
||||
// AcquireOptionsKey is the key used to store the AcquireOptions in the
|
||||
// context.
|
||||
acquireOptionsKey contextKey = iota
|
||||
)
|
||||
|
||||
type AcquireOptions struct {
|
||||
// DontWait prevents rate.Limiters from waiting for a rate limit. Instead
|
||||
// they will return an rate.ErrTimedOutEarly.
|
||||
DontWait bool
|
||||
}
|
||||
|
||||
// Context wraps the given ctx to have the AcquireOptions.
|
||||
func (opts AcquireOptions) Context(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, acquireOptionsKey, opts)
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
lock moreatomic.CtxMutex
|
||||
custom *CustomRateLimit
|
||||
|
@ -99,6 +120,13 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
|
|||
|
||||
// Acquire acquires the rate limiter for the given URL bucket.
|
||||
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||
var options AcquireOptions
|
||||
|
||||
if untypedOptions := ctx.Value(acquireOptionsKey); untypedOptions != nil {
|
||||
// Zero value are default anyways, so we can ignore ok.
|
||||
options, _ = untypedOptions.(AcquireOptions)
|
||||
}
|
||||
|
||||
b := l.getBucket(path, true)
|
||||
|
||||
if err := b.lock.Lock(ctx); err != nil {
|
||||
|
@ -118,7 +146,9 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
|||
}
|
||||
|
||||
if until.After(now) {
|
||||
if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
|
||||
if options.DontWait {
|
||||
return ErrTimedOutEarly
|
||||
} else if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
|
||||
return ErrTimedOutEarly
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue