1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-19 02:58:01 +00:00
arikawa/api/rate/rate.go
Maximilian von Lindern 0a8b24339b
API: Added timeout if deadline is after rate limit (#173)
* Rate: don't sleep if sleep exceeds context deadline

* Httputil: add Client.Timeout

* Bot: set default API timeout to 5 minutes

* Rate: reduce calls to time.Now in Acquire

* API: Optimize to use deadline instead of recalculating

Co-authored-by: diamondburned <datutbrus@gmail.com>
2020-11-25 12:08:42 -08:00

209 lines
4.1 KiB
Go

package rate
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/diamondburned/arikawa/v2/internal/moreatomic"
"github.com/pkg/errors"
)
// ExtraDelay because Discord is trash. I've seen this in both litcord and
// discordgo, with dgo claiming from experiments.
// RE: Those who want others to fix it for them: release the source code then.
const ExtraDelay = 250 * time.Millisecond
// ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
// exceeds the deadline of the context.Context.
var ErrTimedOutEarly = errors.New("rate: rate limit exceeds context deadline")
// This makes me suicidal.
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
type Limiter struct {
// Only 1 per bucket
CustomLimits []*CustomRateLimit
Prefix string
global *int64 // atomic guarded, unixnano
buckets sync.Map
}
type CustomRateLimit struct {
Contains string
Reset time.Duration
}
type bucket struct {
lock moreatomic.CtxMutex
custom *CustomRateLimit
remaining uint64
reset time.Time
lastReset time.Time // only for custom
}
func newBucket() *bucket {
return &bucket{
lock: *moreatomic.NewCtxMutex(),
remaining: 1,
}
}
func NewLimiter(prefix string) *Limiter {
return &Limiter{
Prefix: prefix,
global: new(int64),
buckets: sync.Map{},
CustomLimits: []*CustomRateLimit{},
}
}
func (l *Limiter) getBucket(path string, store bool) *bucket {
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
bc, ok := l.buckets.Load(path)
if !ok && !store {
return nil
}
if !ok {
bc := newBucket()
for _, limit := range l.CustomLimits {
if strings.Contains(path, limit.Contains) {
bc.custom = limit
break
}
}
l.buckets.Store(path, bc)
return bc
}
return bc.(*bucket)
}
func (l *Limiter) Acquire(ctx context.Context, path string) error {
b := l.getBucket(path, true)
if err := b.lock.Lock(ctx); err != nil {
return err
}
// Deadline until the limiter is released.
until := time.Time{}
now := time.Now()
if b.remaining == 0 && b.reset.After(now) {
// out of turns, gotta wait
until = b.reset
} else {
// maybe global rate limit has it
until = time.Unix(0, atomic.LoadInt64(l.global))
}
if until.After(now) {
if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
return ErrTimedOutEarly
}
select {
case <-ctx.Done():
b.lock.Unlock()
return ctx.Err()
case <-time.After(until.Sub(now)):
}
}
if b.remaining > 0 {
b.remaining--
}
return nil
}
// Release releases the URL from the locks. This doesn't need a context for
// timing out, since it doesn't block that much.
func (l *Limiter) Release(path string, headers http.Header) error {
b := l.getBucket(path, false)
if b == nil {
return nil
}
// TryUnlock because Release may be called when Acquire has not been.
defer b.lock.TryUnlock()
// Check custom limiter
if b.custom != nil {
now := time.Now()
if now.Sub(b.lastReset) >= b.custom.Reset {
b.lastReset = now
b.reset = now.Add(b.custom.Reset)
}
return nil
}
// Check if headers is nil or not:
if headers == nil {
return nil
}
var (
// boolean
global = headers.Get("X-RateLimit-Global")
// seconds
remaining = headers.Get("X-RateLimit-Remaining")
reset = headers.Get("X-RateLimit-Reset") // float
retryAfter = headers.Get("Retry-After")
)
switch {
case retryAfter != "":
i, err := strconv.Atoi(retryAfter)
if err != nil {
return errors.Wrapf(err, "invalid retryAfter %q", retryAfter)
}
at := time.Now().Add(time.Duration(i) * time.Second)
if global != "" { // probably true
atomic.StoreInt64(l.global, at.UnixNano())
} else {
b.reset = at
}
case reset != "":
unix, err := strconv.ParseFloat(reset, 64)
if err != nil {
return errors.Wrap(err, "invalid reset "+reset)
}
sec := int64(unix)
nsec := int64((unix - float64(sec)) * float64(time.Second))
b.reset = time.Unix(sec, nsec).Add(ExtraDelay)
}
if remaining != "" {
u, err := strconv.ParseUint(remaining, 10, 64)
if err != nil {
return errors.Wrap(err, "invalid remaining "+remaining)
}
b.remaining = u
}
return nil
}