1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-10-02 07:18:49 +00:00

API: Potential rate limit fix for reactions

This commit is contained in:
diamondburned 2020-12-11 17:58:52 -08:00
parent 43e5eeafde
commit d65807ce15
4 changed files with 32 additions and 29 deletions

View file

@ -31,8 +31,10 @@ type Limiter struct {
Prefix string
global *int64 // atomic guarded, unixnano
buckets sync.Map
global int64 // atomic guarded, unixnano
bucketMu sync.Mutex
buckets map[string]*bucket
}
type CustomRateLimit struct {
@ -60,8 +62,7 @@ func newBucket() *bucket {
func NewLimiter(prefix string) *Limiter {
return &Limiter{
Prefix: prefix,
global: new(int64),
buckets: sync.Map{},
buckets: map[string]*bucket{},
CustomLimits: []*CustomRateLimit{},
}
}
@ -69,7 +70,10 @@ func NewLimiter(prefix string) *Limiter {
func (l *Limiter) getBucket(path string, store bool) *bucket {
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
bc, ok := l.buckets.Load(path)
l.bucketMu.Lock()
defer l.bucketMu.Unlock()
bc, ok := l.buckets[path]
if !ok && !store {
return nil
}
@ -84,11 +88,11 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
}
}
l.buckets.Store(path, bc)
l.buckets[path] = bc
return bc
}
return bc.(*bucket)
return bc
}
func (l *Limiter) Acquire(ctx context.Context, path string) error {
@ -107,7 +111,7 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
until = b.reset
} else {
// maybe global rate limit has it
until = time.Unix(0, atomic.LoadInt64(l.global))
until = time.Unix(0, atomic.LoadInt64(&l.global))
}
if until.After(now) {
@ -178,7 +182,7 @@ func (l *Limiter) Release(path string, headers http.Header) error {
at := time.Now().Add(time.Duration(i) * time.Second)
if global != "" { // probably true
atomic.StoreInt64(l.global, at.UnixNano())
atomic.StoreInt64(&l.global, at.UnixNano())
} else {
b.reset = at
}

View file

@ -18,9 +18,9 @@ type Env struct {
}
var (
env Env
err error
once sync.Once
globalEnv Env
globalErr error
once sync.Once
)
func Must(t *testing.T) Env {
@ -33,41 +33,41 @@ func Must(t *testing.T) Env {
func GetEnv() (Env, error) {
once.Do(getEnv)
return env, err
return globalEnv, globalErr
}
func getEnv() {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
err = errors.New("missing $BOT_TOKEN")
globalErr = errors.New("missing $BOT_TOKEN")
return
}
var cid = os.Getenv("CHANNEL_ID")
if cid == "" {
err = errors.New("missing $CHANNEL_ID")
globalErr = errors.New("missing $CHANNEL_ID")
return
}
id, err := discord.ParseSnowflake(cid)
if err != nil {
err = errors.Wrap(err, "invalid $CHANNEL_ID")
globalErr = errors.Wrap(err, "invalid $CHANNEL_ID")
return
}
var sid = os.Getenv("VOICE_ID")
if sid == "" {
err = errors.New("missing $VOICE_ID")
globalErr = errors.New("missing $VOICE_ID")
return
}
vid, err := discord.ParseSnowflake(sid)
if err != nil {
err = errors.Wrap(err, "invalid $VOICE_ID")
globalErr = errors.Wrap(err, "invalid $VOICE_ID")
return
}
env = Env{
globalEnv = Env{
BotToken: token,
ChannelID: discord.ChannelID(id),
VoiceChID: discord.ChannelID(vid),

View file

@ -144,7 +144,10 @@ func (c *Client) RequestJSON(to interface{}, method, url string, opts ...Request
}
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
// Error for the actual Do method.
var doErr error
// Error that represents the latest error in the chain.
var onRespErr error
var r httpdriver.Response
var status int
@ -178,9 +181,6 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
r, doErr = c.Client.Do(q)
// Error that represents the latest error in the chain.
var onRespErr error
// Call OnResponse() even if the request failed.
for _, fn := range c.OnResponse {
// Be sure to call ALL OnResponse handlers.
@ -189,12 +189,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
}
}
if onRespErr != nil {
return nil, errors.Wrap(err, "OnResponse handler failed")
}
// Retry if the request failed.
if doErr != nil {
if onRespErr != nil || doErr != nil {
continue
}
@ -205,6 +200,10 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
break
}
if onRespErr != nil {
return nil, errors.Wrap(onRespErr, "OnResponse handler failed")
}
// If all retries failed:
if doErr != nil {
return nil, RequestError{doErr}

View file

@ -40,7 +40,7 @@ type HTTPError struct {
func (err HTTPError) Error() string {
switch {
case err.Message != "":
return "Discord error: " + err.Message
return fmt.Sprintf("Discord %d error: %s", err.Status, err.Message)
case err.Code > 0:
return fmt.Sprintf("Discord returned status %d error code %d",