1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-01 12:34:28 +00:00
arikawa/api/rate/rate.go

247 lines
5.2 KiB
Go
Raw Normal View History

2020-01-08 07:10:37 +00:00
package rate
import (
"context"
"errors"
"fmt"
2020-01-08 07:10:37 +00:00
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
2021-06-02 02:53:19 +00:00
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
2020-01-08 07:10:37 +00:00
)
// ExtraDelay because Discord is trash. I've seen this in both litcord and
// discordgo, with dgo claiming from experiments.
2020-01-08 07:10:37 +00:00
// RE: Those who want others to fix it for them: release the source code then.
const ExtraDelay = 250 * time.Millisecond
// ErrTimedOutEarly is the error returned by Limiter.Acquire, if a rate limit
2021-05-29 20:28:45 +00:00
// exceeds the deadline of the context.Context or api.AcquireOptions.DontWait
// is set to true
var ErrTimedOutEarly = errors.New(
"rate: rate limit exceeds context deadline or is blocked acquire options")
2020-01-08 07:10:37 +00:00
// This makes me suicidal.
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
type Limiter struct {
// Only 1 per bucket
CustomLimits []*CustomRateLimit
Prefix string
2020-12-31 18:24:51 +00:00
// global is a pointer to prevent ARM-compatibility alignment.
global *int64 // atomic guarded, unixnano
bucketMu sync.Mutex
buckets map[string]*bucket
2020-01-08 07:10:37 +00:00
}
type CustomRateLimit struct {
Contains string
Reset time.Duration
2020-01-08 07:10:37 +00:00
}
2021-05-29 20:28:45 +00:00
type contextKey uint8
const (
// AcquireOptionsKey is the key used to store the AcquireOptions in the
// context.
acquireOptionsKey contextKey = iota
)
type AcquireOptions struct {
// DontWait prevents rate.Limiters from waiting for a rate limit. Instead
// they will return an rate.ErrTimedOutEarly.
DontWait bool
}
// Context wraps the given ctx to have the AcquireOptions.
func (opts AcquireOptions) Context(ctx context.Context) context.Context {
return context.WithValue(ctx, acquireOptionsKey, opts)
}
2020-01-08 07:10:37 +00:00
type bucket struct {
lock moreatomic.CtxMutex
2020-01-08 07:10:37 +00:00
custom *CustomRateLimit
remaining uint64
reset time.Time
lastReset time.Time // only for custom
}
func newBucket() *bucket {
return &bucket{
lock: *moreatomic.NewCtxMutex(),
remaining: 1,
}
}
func NewLimiter(prefix string) *Limiter {
2020-01-08 07:10:37 +00:00
return &Limiter{
Prefix: prefix,
2020-12-31 18:24:51 +00:00
global: new(int64),
buckets: map[string]*bucket{},
CustomLimits: []*CustomRateLimit{},
2020-01-08 07:10:37 +00:00
}
}
func (l *Limiter) getBucket(path string, store bool) *bucket {
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
2020-01-08 18:43:15 +00:00
l.bucketMu.Lock()
defer l.bucketMu.Unlock()
bc, ok := l.buckets[path]
2020-01-08 07:10:37 +00:00
if !ok && !store {
return nil
}
if !ok {
bc := newBucket()
2020-01-08 07:10:37 +00:00
for _, limit := range l.CustomLimits {
if strings.Contains(path, limit.Contains) {
bc.custom = limit
break
}
}
l.buckets[path] = bc
2020-01-08 07:10:37 +00:00
return bc
}
return bc
2020-01-08 07:10:37 +00:00
}
// Acquire acquires the rate limiter for the given URL bucket.
2020-01-08 07:10:37 +00:00
func (l *Limiter) Acquire(ctx context.Context, path string) error {
2021-05-29 20:28:45 +00:00
var options AcquireOptions
if untypedOptions := ctx.Value(acquireOptionsKey); untypedOptions != nil {
// Zero value are default anyways, so we can ignore ok.
options, _ = untypedOptions.(AcquireOptions)
}
2020-01-08 07:10:37 +00:00
b := l.getBucket(path, true)
if err := b.lock.Lock(ctx); err != nil {
2020-01-08 07:10:37 +00:00
return err
}
// Deadline until the limiter is released.
until := time.Time{}
now := time.Now()
2020-01-08 07:10:37 +00:00
if b.remaining == 0 && b.reset.After(now) {
2020-01-08 07:10:37 +00:00
// out of turns, gotta wait
until = b.reset
2020-01-08 07:10:37 +00:00
} else {
// maybe global rate limit has it
2020-12-31 18:24:51 +00:00
until = time.Unix(0, atomic.LoadInt64(l.global))
}
2020-01-08 07:10:37 +00:00
if until.After(now) {
2021-05-29 20:28:45 +00:00
if options.DontWait {
return ErrTimedOutEarly
} else if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
return ErrTimedOutEarly
2020-01-08 07:10:37 +00:00
}
select {
case <-ctx.Done():
b.lock.Unlock()
2020-01-08 07:10:37 +00:00
return ctx.Err()
case <-time.After(until.Sub(now)):
2020-01-08 07:10:37 +00:00
}
}
if b.remaining > 0 {
b.remaining--
}
return nil
}
// Release releases the URL from the locks. This doesn't need a context for
// timing out, since it doesn't block that much.
2020-01-08 07:10:37 +00:00
func (l *Limiter) Release(path string, headers http.Header) error {
b := l.getBucket(path, false)
if b == nil {
return nil
}
// TryUnlock because Release may be called when Acquire has not been.
defer b.lock.TryUnlock()
2020-01-08 07:10:37 +00:00
// Check custom limiter
if b.custom != nil {
now := time.Now()
if now.Sub(b.lastReset) >= b.custom.Reset {
b.lastReset = now
b.reset = now.Add(b.custom.Reset)
}
return nil
}
// Check if headers is nil or not:
if headers == nil {
return nil
}
2020-01-08 07:10:37 +00:00
var (
// boolean
global = headers.Get("X-RateLimit-Global")
// seconds
remaining = headers.Get("X-RateLimit-Remaining")
reset = headers.Get("X-RateLimit-Reset") // float
2020-01-08 07:10:37 +00:00
retryAfter = headers.Get("Retry-After")
)
switch {
case retryAfter != "":
i, err := strconv.Atoi(retryAfter)
if err != nil {
return fmt.Errorf("invalid retryAfter %q: %w", retryAfter, err)
2020-01-08 07:10:37 +00:00
}
at := time.Now().Add(time.Duration(i) * time.Second)
2020-01-08 07:10:37 +00:00
2020-12-31 18:24:51 +00:00
if global != "" { // probably "true"
atomic.StoreInt64(l.global, at.UnixNano())
2020-01-08 07:10:37 +00:00
} else {
b.reset = at
}
case reset != "":
unix, err := strconv.ParseFloat(reset, 64)
if err != nil {
return fmt.Errorf("invalid reset %q: %w", reset, err)
2020-01-08 07:10:37 +00:00
}
sec := int64(unix)
nsec := int64((unix - float64(sec)) * float64(time.Second))
b.reset = time.Unix(sec, nsec).Add(ExtraDelay)
2020-01-08 07:10:37 +00:00
}
if remaining != "" {
u, err := strconv.ParseUint(remaining, 10, 64)
if err != nil {
return fmt.Errorf("invalid remaining %q: %w", remaining, err)
2020-01-08 07:10:37 +00:00
}
b.remaining = u
}
return nil
}