1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-08 16:04:41 +00:00
arikawa/api/rate/rate.go

233 lines
4.4 KiB
Go
Raw Normal View History

2020-01-08 07:10:37 +00:00
package rate
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
"github.com/sasha-s/go-csync"
2020-01-08 07:10:37 +00:00
)
// ExtraDelay because Discord is trash. I've seen this in both litcord and
// discordgo, with dgo claiming from his experiments.
// RE: Those who want others to fix it for them: release the source code then.
const ExtraDelay = 250 * time.Millisecond
// This makes me suicidal.
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
type Limiter struct {
// Only 1 per bucket
CustomLimits []*CustomRateLimit
// These callbacks will only be called for valid buckets. They will also be
// called right before locking. Returning false will not rate limit.
OnAcquire func(path string) bool
OnCancel func(path string) bool
OnRelease func(path string) bool // false means not unlocking
Prefix string
2020-01-08 07:10:37 +00:00
global *int64 // atomic guarded, unixnano
buckets sync.Map
globalRate time.Duration
}
type CustomRateLimit struct {
Contains string
Reset time.Duration
2020-01-08 07:10:37 +00:00
}
type bucket struct {
lock csync.Mutex
custom *CustomRateLimit
remaining uint64
limit uint
reset time.Time
lastReset time.Time // only for custom
}
func returnTrue(string) bool {
// time.Sleep(time.Nanosecond)
return true
}
func NewLimiter(prefix string) *Limiter {
2020-01-08 07:10:37 +00:00
return &Limiter{
Prefix: prefix,
global: new(int64),
buckets: sync.Map{},
CustomLimits: []*CustomRateLimit{},
OnAcquire: returnTrue,
OnCancel: returnTrue,
OnRelease: returnTrue,
2020-01-08 07:10:37 +00:00
}
}
func (l *Limiter) getBucket(path string, store bool) *bucket {
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
2020-01-08 18:43:15 +00:00
2020-01-08 07:10:37 +00:00
bc, ok := l.buckets.Load(path)
if !ok && !store {
return nil
}
if !ok {
bc := &bucket{
remaining: 1,
}
for _, limit := range l.CustomLimits {
if strings.Contains(path, limit.Contains) {
bc.custom = limit
break
}
}
l.buckets.Store(path, bc)
return bc
}
return bc.(*bucket)
}
func (l *Limiter) Acquire(ctx context.Context, path string) error {
b := l.getBucket(path, true)
if !l.OnAcquire(path) {
return nil
}
2020-01-08 07:10:37 +00:00
// Acquire lock with a timeout
if err := b.lock.CLock(ctx); err != nil {
return err
}
// Time to sleep
var sleep time.Duration
if b.remaining == 0 && b.reset.After(time.Now()) {
// out of turns, gotta wait
sleep = b.reset.Sub(time.Now())
} else {
// maybe global rate limit has it
now := time.Now()
until := time.Unix(0, atomic.LoadInt64(l.global))
if until.After(now) {
sleep = until.Sub(now)
}
}
if sleep > 0 {
select {
case <-ctx.Done():
b.lock.Unlock()
2020-01-08 07:10:37 +00:00
return ctx.Err()
case <-time.After(sleep):
}
}
if b.remaining > 0 {
b.remaining--
}
return nil
}
func (l *Limiter) Cancel(path string) error {
b := l.getBucket(path, false)
if b == nil {
return nil
}
if !l.OnCancel(path) {
return nil
}
// TryLock would either not lock because it's already locked, or lock
// because it isn't.
b.lock.TryLock()
b.lock.Unlock()
return nil
}
2020-01-08 07:10:37 +00:00
// Release releases the URL from the locks. This doesn't need a context for
// timing out, it doesn't block that much.
func (l *Limiter) Release(path string, headers http.Header) error {
b := l.getBucket(path, false)
if b == nil {
return nil
}
defer func() {
if l.OnRelease(path) {
b.lock.Unlock()
}
}()
2020-01-08 07:10:37 +00:00
// Check custom limiter
if b.custom != nil {
now := time.Now()
if now.Sub(b.lastReset) >= b.custom.Reset {
b.lastReset = now
b.reset = now.Add(b.custom.Reset)
}
return nil
}
var (
// boolean
global = headers.Get("X-RateLimit-Global")
// seconds
remaining = headers.Get("X-RateLimit-Remaining")
reset = headers.Get("X-RateLimit-Reset")
retryAfter = headers.Get("Retry-After")
)
switch {
case retryAfter != "":
i, err := strconv.Atoi(retryAfter)
if err != nil {
return errors.Wrap(err, "Invalid retryAfter "+retryAfter)
}
at := time.Now().Add(time.Duration(i) * time.Millisecond)
if global != "" { // probably true
atomic.StoreInt64(l.global, at.UnixNano())
} else {
b.reset = at
}
case reset != "":
unix, err := strconv.ParseFloat(reset, 64)
if err != nil {
return errors.Wrap(err, "Invalid reset "+reset)
}
b.reset = time.Unix(0, int64(unix*float64(time.Second))).
Add(ExtraDelay)
2020-01-08 07:10:37 +00:00
}
if remaining != "" {
u, err := strconv.ParseUint(remaining, 10, 64)
if err != nil {
return errors.Wrap(err, "Invalid remaining "+remaining)
}
b.remaining = u
}
return nil
}