mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-19 21:32:49 +00:00
API: Potential rate limit fix for reactions
This commit is contained in:
parent
43e5eeafde
commit
d65807ce15
|
@ -31,8 +31,10 @@ type Limiter struct {
|
|||
|
||||
Prefix string
|
||||
|
||||
global *int64 // atomic guarded, unixnano
|
||||
buckets sync.Map
|
||||
global int64 // atomic guarded, unixnano
|
||||
|
||||
bucketMu sync.Mutex
|
||||
buckets map[string]*bucket
|
||||
}
|
||||
|
||||
type CustomRateLimit struct {
|
||||
|
@ -60,8 +62,7 @@ func newBucket() *bucket {
|
|||
func NewLimiter(prefix string) *Limiter {
|
||||
return &Limiter{
|
||||
Prefix: prefix,
|
||||
global: new(int64),
|
||||
buckets: sync.Map{},
|
||||
buckets: map[string]*bucket{},
|
||||
CustomLimits: []*CustomRateLimit{},
|
||||
}
|
||||
}
|
||||
|
@ -69,7 +70,10 @@ func NewLimiter(prefix string) *Limiter {
|
|||
func (l *Limiter) getBucket(path string, store bool) *bucket {
|
||||
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
|
||||
|
||||
bc, ok := l.buckets.Load(path)
|
||||
l.bucketMu.Lock()
|
||||
defer l.bucketMu.Unlock()
|
||||
|
||||
bc, ok := l.buckets[path]
|
||||
if !ok && !store {
|
||||
return nil
|
||||
}
|
||||
|
@ -84,11 +88,11 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
|
|||
}
|
||||
}
|
||||
|
||||
l.buckets.Store(path, bc)
|
||||
l.buckets[path] = bc
|
||||
return bc
|
||||
}
|
||||
|
||||
return bc.(*bucket)
|
||||
return bc
|
||||
}
|
||||
|
||||
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||
|
@ -107,7 +111,7 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
|||
until = b.reset
|
||||
} else {
|
||||
// maybe global rate limit has it
|
||||
until = time.Unix(0, atomic.LoadInt64(l.global))
|
||||
until = time.Unix(0, atomic.LoadInt64(&l.global))
|
||||
}
|
||||
|
||||
if until.After(now) {
|
||||
|
@ -178,7 +182,7 @@ func (l *Limiter) Release(path string, headers http.Header) error {
|
|||
at := time.Now().Add(time.Duration(i) * time.Second)
|
||||
|
||||
if global != "" { // probably true
|
||||
atomic.StoreInt64(l.global, at.UnixNano())
|
||||
atomic.StoreInt64(&l.global, at.UnixNano())
|
||||
} else {
|
||||
b.reset = at
|
||||
}
|
||||
|
|
|
@ -18,9 +18,9 @@ type Env struct {
|
|||
}
|
||||
|
||||
var (
|
||||
env Env
|
||||
err error
|
||||
once sync.Once
|
||||
globalEnv Env
|
||||
globalErr error
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func Must(t *testing.T) Env {
|
||||
|
@ -33,41 +33,41 @@ func Must(t *testing.T) Env {
|
|||
|
||||
func GetEnv() (Env, error) {
|
||||
once.Do(getEnv)
|
||||
return env, err
|
||||
return globalEnv, globalErr
|
||||
}
|
||||
|
||||
func getEnv() {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
err = errors.New("missing $BOT_TOKEN")
|
||||
globalErr = errors.New("missing $BOT_TOKEN")
|
||||
return
|
||||
}
|
||||
|
||||
var cid = os.Getenv("CHANNEL_ID")
|
||||
if cid == "" {
|
||||
err = errors.New("missing $CHANNEL_ID")
|
||||
globalErr = errors.New("missing $CHANNEL_ID")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := discord.ParseSnowflake(cid)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "invalid $CHANNEL_ID")
|
||||
globalErr = errors.Wrap(err, "invalid $CHANNEL_ID")
|
||||
return
|
||||
}
|
||||
|
||||
var sid = os.Getenv("VOICE_ID")
|
||||
if sid == "" {
|
||||
err = errors.New("missing $VOICE_ID")
|
||||
globalErr = errors.New("missing $VOICE_ID")
|
||||
return
|
||||
}
|
||||
|
||||
vid, err := discord.ParseSnowflake(sid)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "invalid $VOICE_ID")
|
||||
globalErr = errors.Wrap(err, "invalid $VOICE_ID")
|
||||
return
|
||||
}
|
||||
|
||||
env = Env{
|
||||
globalEnv = Env{
|
||||
BotToken: token,
|
||||
ChannelID: discord.ChannelID(id),
|
||||
VoiceChID: discord.ChannelID(vid),
|
||||
|
|
|
@ -144,7 +144,10 @@ func (c *Client) RequestJSON(to interface{}, method, url string, opts ...Request
|
|||
}
|
||||
|
||||
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
// Error for the actual Do method.
|
||||
var doErr error
|
||||
// Error that represents the latest error in the chain.
|
||||
var onRespErr error
|
||||
|
||||
var r httpdriver.Response
|
||||
var status int
|
||||
|
@ -178,9 +181,6 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
|
|||
|
||||
r, doErr = c.Client.Do(q)
|
||||
|
||||
// Error that represents the latest error in the chain.
|
||||
var onRespErr error
|
||||
|
||||
// Call OnResponse() even if the request failed.
|
||||
for _, fn := range c.OnResponse {
|
||||
// Be sure to call ALL OnResponse handlers.
|
||||
|
@ -189,12 +189,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
|
|||
}
|
||||
}
|
||||
|
||||
if onRespErr != nil {
|
||||
return nil, errors.Wrap(err, "OnResponse handler failed")
|
||||
}
|
||||
|
||||
// Retry if the request failed.
|
||||
if doErr != nil {
|
||||
if onRespErr != nil || doErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -205,6 +200,10 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
|
|||
break
|
||||
}
|
||||
|
||||
if onRespErr != nil {
|
||||
return nil, errors.Wrap(onRespErr, "OnResponse handler failed")
|
||||
}
|
||||
|
||||
// If all retries failed:
|
||||
if doErr != nil {
|
||||
return nil, RequestError{doErr}
|
||||
|
|
|
@ -40,7 +40,7 @@ type HTTPError struct {
|
|||
func (err HTTPError) Error() string {
|
||||
switch {
|
||||
case err.Message != "":
|
||||
return "Discord error: " + err.Message
|
||||
return fmt.Sprintf("Discord %d error: %s", err.Status, err.Message)
|
||||
|
||||
case err.Code > 0:
|
||||
return fmt.Sprintf("Discord returned status %d error code %d",
|
||||
|
|
Loading…
Reference in a new issue