mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-11-13 15:55:50 +00:00
API: Fixed major rate limiters not working
This commit is contained in:
parent
165ef71cb5
commit
f33dc2ee75
10
api/api.go
10
api/api.go
|
|
@ -10,10 +10,11 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
BaseEndpoint = "https://discordapp.com/api"
|
||||
BaseEndpoint = "https://discordapp.com"
|
||||
APIVersion = "6"
|
||||
APIPath = "/api/v" + APIVersion
|
||||
|
||||
Endpoint = BaseEndpoint + "/v" + APIVersion + "/"
|
||||
Endpoint = BaseEndpoint + APIPath + "/"
|
||||
EndpointGateway = Endpoint + "gateway"
|
||||
EndpointGatewayBot = EndpointGateway + "/bot"
|
||||
)
|
||||
|
|
@ -30,7 +31,7 @@ type Client struct {
|
|||
func NewClient(token string) *Client {
|
||||
cli := &Client{
|
||||
Client: httputil.DefaultClient,
|
||||
Limiter: rate.NewLimiter(),
|
||||
Limiter: rate.NewLimiter(APIPath),
|
||||
Token: token,
|
||||
}
|
||||
|
||||
|
|
@ -46,6 +47,9 @@ func NewClient(token string) *Client {
|
|||
// Rate limit stuff
|
||||
return cli.Limiter.Acquire(r.Context(), r.URL.Path)
|
||||
}
|
||||
tw.Cancel = func(r *http.Request, err error) {
|
||||
cli.Limiter.Cancel(r.URL.Path)
|
||||
}
|
||||
tw.Post = func(r *http.Response) error {
|
||||
return cli.Limiter.Release(r.Request.URL.Path, r.Header)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
csync "github.com/sasha-s/go-csync"
|
||||
"github.com/sasha-s/go-csync"
|
||||
)
|
||||
|
||||
// ExtraDelay because Discord is trash. I've seen this in both litcord and
|
||||
|
|
@ -25,17 +25,22 @@ type Limiter struct {
|
|||
// Only 1 per bucket
|
||||
CustomLimits []*CustomRateLimit
|
||||
|
||||
// These callbacks will only be called for valid buckets. They will also be
|
||||
// called right before locking. Returning false will not rate limit.
|
||||
OnAcquire func(path string) bool
|
||||
OnCancel func(path string) bool
|
||||
OnRelease func(path string) bool // false means not unlocking
|
||||
|
||||
Prefix string
|
||||
|
||||
global *int64 // atomic guarded, unixnano
|
||||
buckets sync.Map
|
||||
globalRate time.Duration
|
||||
}
|
||||
|
||||
type CustomRateLimit struct {
|
||||
// This string will match on a Printf format string.
|
||||
// e.g. /guilds/%s/channels/%s/...
|
||||
Contains string
|
||||
|
||||
Reset time.Duration
|
||||
Reset time.Duration
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
|
|
@ -49,16 +54,25 @@ type bucket struct {
|
|||
lastReset time.Time // only for custom
|
||||
}
|
||||
|
||||
func NewLimiter() *Limiter {
|
||||
func returnTrue(string) bool {
|
||||
// time.Sleep(time.Nanosecond)
|
||||
return true
|
||||
}
|
||||
|
||||
func NewLimiter(prefix string) *Limiter {
|
||||
return &Limiter{
|
||||
Prefix: prefix,
|
||||
global: new(int64),
|
||||
buckets: sync.Map{},
|
||||
CustomLimits: []*CustomRateLimit{},
|
||||
OnAcquire: returnTrue,
|
||||
OnCancel: returnTrue,
|
||||
OnRelease: returnTrue,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) getBucket(path string, store bool) *bucket {
|
||||
path = ParseBucketKey(path)
|
||||
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
|
||||
|
||||
bc, ok := l.buckets.Load(path)
|
||||
if !ok && !store {
|
||||
|
|
@ -87,6 +101,10 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
|
|||
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||
b := l.getBucket(path, true)
|
||||
|
||||
if !l.OnAcquire(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Acquire lock with a timeout
|
||||
if err := b.lock.CLock(ctx); err != nil {
|
||||
return err
|
||||
|
|
@ -111,6 +129,7 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
|||
if sleep > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.lock.Unlock()
|
||||
return ctx.Err()
|
||||
case <-time.After(sleep):
|
||||
}
|
||||
|
|
@ -123,6 +142,22 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (l *Limiter) Cancel(path string) error {
|
||||
b := l.getBucket(path, false)
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
if !l.OnCancel(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TryLock would either not lock because it's already locked, or lock
|
||||
// because it isn't.
|
||||
b.lock.TryLock()
|
||||
b.lock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Release releases the URL from the locks. This doesn't need a context for
|
||||
// timing out, it doesn't block that much.
|
||||
func (l *Limiter) Release(path string, headers http.Header) error {
|
||||
|
|
@ -131,7 +166,11 @@ func (l *Limiter) Release(path string, headers http.Header) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
defer b.lock.Unlock()
|
||||
defer func() {
|
||||
if l.OnRelease(path) {
|
||||
b.lock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Check custom limiter
|
||||
if b.custom != nil {
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -6,7 +6,7 @@ require (
|
|||
github.com/gorilla/schema v1.1.0
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa // indirect
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 // indirect
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
|
||||
nhooyr.io/websocket v1.7.4
|
||||
)
|
||||
|
|
|
|||
6
go.sum
6
go.sum
|
|
@ -33,11 +33,9 @@ go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU=
|
|||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
|
||||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
type TransportWrapper struct {
|
||||
Default http.RoundTripper
|
||||
Pre func(*http.Request) error
|
||||
Cancel func(*http.Request, error)
|
||||
Post func(*http.Response) error
|
||||
}
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ func NewTransportWrapper() *TransportWrapper {
|
|||
return &TransportWrapper{
|
||||
Default: http.DefaultTransport,
|
||||
Pre: func(*http.Request) error { return nil },
|
||||
Cancel: func(*http.Request, error) {},
|
||||
Post: func(*http.Response) error { return nil },
|
||||
}
|
||||
}
|
||||
|
|
@ -27,6 +29,7 @@ func (c *TransportWrapper) RoundTrip(req *http.Request) (*http.Response, error)
|
|||
|
||||
r, err := c.Default.RoundTrip(req)
|
||||
if err != nil {
|
||||
c.Cancel(req, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue