1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-19 21:32:49 +00:00

API: Potential rate limit fix for reactions

This commit is contained in:
diamondburned 2020-12-11 17:58:52 -08:00
parent 43e5eeafde
commit d65807ce15
4 changed files with 32 additions and 29 deletions

View file

@ -31,8 +31,10 @@ type Limiter struct {
Prefix string Prefix string
global *int64 // atomic guarded, unixnano global int64 // atomic guarded, unixnano
buckets sync.Map
bucketMu sync.Mutex
buckets map[string]*bucket
} }
type CustomRateLimit struct { type CustomRateLimit struct {
@ -60,8 +62,7 @@ func newBucket() *bucket {
func NewLimiter(prefix string) *Limiter { func NewLimiter(prefix string) *Limiter {
return &Limiter{ return &Limiter{
Prefix: prefix, Prefix: prefix,
global: new(int64), buckets: map[string]*bucket{},
buckets: sync.Map{},
CustomLimits: []*CustomRateLimit{}, CustomLimits: []*CustomRateLimit{},
} }
} }
@ -69,7 +70,10 @@ func NewLimiter(prefix string) *Limiter {
func (l *Limiter) getBucket(path string, store bool) *bucket { func (l *Limiter) getBucket(path string, store bool) *bucket {
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix)) path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
bc, ok := l.buckets.Load(path) l.bucketMu.Lock()
defer l.bucketMu.Unlock()
bc, ok := l.buckets[path]
if !ok && !store { if !ok && !store {
return nil return nil
} }
@ -84,11 +88,11 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
} }
} }
l.buckets.Store(path, bc) l.buckets[path] = bc
return bc return bc
} }
return bc.(*bucket) return bc
} }
func (l *Limiter) Acquire(ctx context.Context, path string) error { func (l *Limiter) Acquire(ctx context.Context, path string) error {
@ -107,7 +111,7 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
until = b.reset until = b.reset
} else { } else {
// maybe global rate limit has it // maybe global rate limit has it
until = time.Unix(0, atomic.LoadInt64(l.global)) until = time.Unix(0, atomic.LoadInt64(&l.global))
} }
if until.After(now) { if until.After(now) {
@ -178,7 +182,7 @@ func (l *Limiter) Release(path string, headers http.Header) error {
at := time.Now().Add(time.Duration(i) * time.Second) at := time.Now().Add(time.Duration(i) * time.Second)
if global != "" { // probably true if global != "" { // probably true
atomic.StoreInt64(l.global, at.UnixNano()) atomic.StoreInt64(&l.global, at.UnixNano())
} else { } else {
b.reset = at b.reset = at
} }

View file

@ -18,8 +18,8 @@ type Env struct {
} }
var ( var (
env Env globalEnv Env
err error globalErr error
once sync.Once once sync.Once
) )
@ -33,41 +33,41 @@ func Must(t *testing.T) Env {
func GetEnv() (Env, error) { func GetEnv() (Env, error) {
once.Do(getEnv) once.Do(getEnv)
return env, err return globalEnv, globalErr
} }
func getEnv() { func getEnv() {
var token = os.Getenv("BOT_TOKEN") var token = os.Getenv("BOT_TOKEN")
if token == "" { if token == "" {
err = errors.New("missing $BOT_TOKEN") globalErr = errors.New("missing $BOT_TOKEN")
return return
} }
var cid = os.Getenv("CHANNEL_ID") var cid = os.Getenv("CHANNEL_ID")
if cid == "" { if cid == "" {
err = errors.New("missing $CHANNEL_ID") globalErr = errors.New("missing $CHANNEL_ID")
return return
} }
id, err := discord.ParseSnowflake(cid) id, err := discord.ParseSnowflake(cid)
if err != nil { if err != nil {
err = errors.Wrap(err, "invalid $CHANNEL_ID") globalErr = errors.Wrap(err, "invalid $CHANNEL_ID")
return return
} }
var sid = os.Getenv("VOICE_ID") var sid = os.Getenv("VOICE_ID")
if sid == "" { if sid == "" {
err = errors.New("missing $VOICE_ID") globalErr = errors.New("missing $VOICE_ID")
return return
} }
vid, err := discord.ParseSnowflake(sid) vid, err := discord.ParseSnowflake(sid)
if err != nil { if err != nil {
err = errors.Wrap(err, "invalid $VOICE_ID") globalErr = errors.Wrap(err, "invalid $VOICE_ID")
return return
} }
env = Env{ globalEnv = Env{
BotToken: token, BotToken: token,
ChannelID: discord.ChannelID(id), ChannelID: discord.ChannelID(id),
VoiceChID: discord.ChannelID(vid), VoiceChID: discord.ChannelID(vid),

View file

@ -144,7 +144,10 @@ func (c *Client) RequestJSON(to interface{}, method, url string, opts ...Request
} }
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) { func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
// Error for the actual Do method.
var doErr error var doErr error
// Error that represents the latest error in the chain.
var onRespErr error
var r httpdriver.Response var r httpdriver.Response
var status int var status int
@ -178,9 +181,6 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
r, doErr = c.Client.Do(q) r, doErr = c.Client.Do(q)
// Error that represents the latest error in the chain.
var onRespErr error
// Call OnResponse() even if the request failed. // Call OnResponse() even if the request failed.
for _, fn := range c.OnResponse { for _, fn := range c.OnResponse {
// Be sure to call ALL OnResponse handlers. // Be sure to call ALL OnResponse handlers.
@ -189,12 +189,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
} }
} }
if onRespErr != nil { if onRespErr != nil || doErr != nil {
return nil, errors.Wrap(err, "OnResponse handler failed")
}
// Retry if the request failed.
if doErr != nil {
continue continue
} }
@ -205,6 +200,10 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
break break
} }
if onRespErr != nil {
return nil, errors.Wrap(onRespErr, "OnResponse handler failed")
}
// If all retries failed: // If all retries failed:
if doErr != nil { if doErr != nil {
return nil, RequestError{doErr} return nil, RequestError{doErr}

View file

@ -40,7 +40,7 @@ type HTTPError struct {
func (err HTTPError) Error() string { func (err HTTPError) Error() string {
switch { switch {
case err.Message != "": case err.Message != "":
return "Discord error: " + err.Message return fmt.Sprintf("Discord %d error: %s", err.Status, err.Message)
case err.Code > 0: case err.Code > 0:
return fmt.Sprintf("Discord returned status %d error code %d", return fmt.Sprintf("Discord returned status %d error code %d",