diff --git a/api/api.go b/api/api.go index d7c899c..20946b4 100644 --- a/api/api.go +++ b/api/api.go @@ -26,6 +26,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa/v3)" type Client struct { *httputil.Client *Session + AcquireOptions rate.AcquireOptions } func NewClient(token string) *Client { @@ -33,31 +34,45 @@ func NewClient(token string) *Client { } func NewCustomClient(token string, httpClient *httputil.Client) *Client { - ses := Session{ - Limiter: rate.NewLimiter(Path), - Token: token, - UserAgent: UserAgent, + c := &Client{ + Session: &Session{ + Limiter: rate.NewLimiter(Path), + Token: token, + UserAgent: UserAgent, + }, + Client: httpClient.Copy(), } - hcl := httpClient.Copy() - hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest) - hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse) + c.Client.OnRequest = append(c.Client.OnRequest, c.InjectRequest) + c.Client.OnResponse = append(c.Client.OnResponse, c.OnResponse) - return &Client{ - Client: hcl, - Session: &ses, - } + return c } // WithContext returns a shallow copy of Client with the given context. It's // used for method timeouts and such. This method is thread-safe. func (c *Client) WithContext(ctx context.Context) *Client { return &Client{ - Client: c.Client.WithContext(ctx), - Session: c.Session, + Client: c.Client.WithContext(ctx), + Session: c.Session, + AcquireOptions: c.AcquireOptions, } } +func (c *Client) InjectRequest(r httpdriver.Request) error { + r.AddHeader(http.Header{ + "Authorization": {c.Session.Token}, + "User-Agent": {c.Session.UserAgent}, + }) + + ctx := c.AcquireOptions.Context(r.GetContext()) + return c.Session.Limiter.Acquire(ctx, r.GetPath()) +} + +func (c *Client) OnResponse(r httpdriver.Request, resp httpdriver.Response) error { + return c.Session.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp)) +} + // Session keeps a single session. This is typically wrapped around Client. type Session struct { Limiter *rate.Limiter @@ -65,17 +80,3 @@ type Session struct { Token string UserAgent string } - -func (s *Session) InjectRequest(r httpdriver.Request) error { - r.AddHeader(http.Header{ - "Authorization": {s.Token}, - "User-Agent": {s.UserAgent}, - }) - - // Rate limit stuff - return s.Limiter.Acquire(r.GetContext(), r.GetPath()) -} - -func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error { - return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp)) -} diff --git a/api/rate/rate.go b/api/rate/rate.go index 3e8508b..20440ac 100644 --- a/api/rate/rate.go +++ b/api/rate/rate.go @@ -19,8 +19,10 @@ import ( const ExtraDelay = 250 * time.Millisecond // ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit -// exceeds the deadline of the context.Context. -var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline") +// exceeds the deadline of the context.Context or api.AcquireOptions.DontWait +// is set to true +var ErrTimedOutEarly = errors.New( + "rate: rate limit exceeds context deadline or is blocked acquire options") // This makes me suicidal. // https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go @@ -43,6 +45,25 @@ type CustomRateLimit struct { Reset time.Duration } +type contextKey uint8 + +const ( + // AcquireOptionsKey is the key used to store the AcquireOptions in the + // context. + acquireOptionsKey contextKey = iota +) + +type AcquireOptions struct { + // DontWait prevents rate.Limiters from waiting for a rate limit. Instead + // they will return an rate.ErrTimedOutEarly. + DontWait bool +} + +// Context wraps the given ctx to have the AcquireOptions. +func (opts AcquireOptions) Context(ctx context.Context) context.Context { + return context.WithValue(ctx, acquireOptionsKey, opts) +} + type bucket struct { lock moreatomic.CtxMutex custom *CustomRateLimit @@ -99,6 +120,13 @@ func (l *Limiter) getBucket(path string, store bool) *bucket { // Acquire acquires the rate limiter for the given URL bucket. func (l *Limiter) Acquire(ctx context.Context, path string) error { + var options AcquireOptions + + if untypedOptions := ctx.Value(acquireOptionsKey); untypedOptions != nil { + // Zero value are default anyways, so we can ignore ok. + options, _ = untypedOptions.(AcquireOptions) + } + b := l.getBucket(path, true) if err := b.lock.Lock(ctx); err != nil { @@ -118,7 +146,9 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error { } if until.After(now) { - if deadline, ok := ctx.Deadline(); ok && until.After(deadline) { + if options.DontWait { + return ErrTimedOutEarly + } else if deadline, ok := ctx.Deadline(); ok && until.After(deadline) { return ErrTimedOutEarly }