mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-19 21:32:49 +00:00
API: Potential rate limit fix for reactions
This commit is contained in:
parent
43e5eeafde
commit
d65807ce15
|
@ -31,8 +31,10 @@ type Limiter struct {
|
||||||
|
|
||||||
Prefix string
|
Prefix string
|
||||||
|
|
||||||
global *int64 // atomic guarded, unixnano
|
global int64 // atomic guarded, unixnano
|
||||||
buckets sync.Map
|
|
||||||
|
bucketMu sync.Mutex
|
||||||
|
buckets map[string]*bucket
|
||||||
}
|
}
|
||||||
|
|
||||||
type CustomRateLimit struct {
|
type CustomRateLimit struct {
|
||||||
|
@ -60,8 +62,7 @@ func newBucket() *bucket {
|
||||||
func NewLimiter(prefix string) *Limiter {
|
func NewLimiter(prefix string) *Limiter {
|
||||||
return &Limiter{
|
return &Limiter{
|
||||||
Prefix: prefix,
|
Prefix: prefix,
|
||||||
global: new(int64),
|
buckets: map[string]*bucket{},
|
||||||
buckets: sync.Map{},
|
|
||||||
CustomLimits: []*CustomRateLimit{},
|
CustomLimits: []*CustomRateLimit{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -69,7 +70,10 @@ func NewLimiter(prefix string) *Limiter {
|
||||||
func (l *Limiter) getBucket(path string, store bool) *bucket {
|
func (l *Limiter) getBucket(path string, store bool) *bucket {
|
||||||
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
|
path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix))
|
||||||
|
|
||||||
bc, ok := l.buckets.Load(path)
|
l.bucketMu.Lock()
|
||||||
|
defer l.bucketMu.Unlock()
|
||||||
|
|
||||||
|
bc, ok := l.buckets[path]
|
||||||
if !ok && !store {
|
if !ok && !store {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -84,11 +88,11 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
l.buckets.Store(path, bc)
|
l.buckets[path] = bc
|
||||||
return bc
|
return bc
|
||||||
}
|
}
|
||||||
|
|
||||||
return bc.(*bucket)
|
return bc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||||
|
@ -107,7 +111,7 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||||
until = b.reset
|
until = b.reset
|
||||||
} else {
|
} else {
|
||||||
// maybe global rate limit has it
|
// maybe global rate limit has it
|
||||||
until = time.Unix(0, atomic.LoadInt64(l.global))
|
until = time.Unix(0, atomic.LoadInt64(&l.global))
|
||||||
}
|
}
|
||||||
|
|
||||||
if until.After(now) {
|
if until.After(now) {
|
||||||
|
@ -178,7 +182,7 @@ func (l *Limiter) Release(path string, headers http.Header) error {
|
||||||
at := time.Now().Add(time.Duration(i) * time.Second)
|
at := time.Now().Add(time.Duration(i) * time.Second)
|
||||||
|
|
||||||
if global != "" { // probably true
|
if global != "" { // probably true
|
||||||
atomic.StoreInt64(l.global, at.UnixNano())
|
atomic.StoreInt64(&l.global, at.UnixNano())
|
||||||
} else {
|
} else {
|
||||||
b.reset = at
|
b.reset = at
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,8 +18,8 @@ type Env struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
env Env
|
globalEnv Env
|
||||||
err error
|
globalErr error
|
||||||
once sync.Once
|
once sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,41 +33,41 @@ func Must(t *testing.T) Env {
|
||||||
|
|
||||||
func GetEnv() (Env, error) {
|
func GetEnv() (Env, error) {
|
||||||
once.Do(getEnv)
|
once.Do(getEnv)
|
||||||
return env, err
|
return globalEnv, globalErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnv() {
|
func getEnv() {
|
||||||
var token = os.Getenv("BOT_TOKEN")
|
var token = os.Getenv("BOT_TOKEN")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
err = errors.New("missing $BOT_TOKEN")
|
globalErr = errors.New("missing $BOT_TOKEN")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var cid = os.Getenv("CHANNEL_ID")
|
var cid = os.Getenv("CHANNEL_ID")
|
||||||
if cid == "" {
|
if cid == "" {
|
||||||
err = errors.New("missing $CHANNEL_ID")
|
globalErr = errors.New("missing $CHANNEL_ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := discord.ParseSnowflake(cid)
|
id, err := discord.ParseSnowflake(cid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.Wrap(err, "invalid $CHANNEL_ID")
|
globalErr = errors.Wrap(err, "invalid $CHANNEL_ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var sid = os.Getenv("VOICE_ID")
|
var sid = os.Getenv("VOICE_ID")
|
||||||
if sid == "" {
|
if sid == "" {
|
||||||
err = errors.New("missing $VOICE_ID")
|
globalErr = errors.New("missing $VOICE_ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vid, err := discord.ParseSnowflake(sid)
|
vid, err := discord.ParseSnowflake(sid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.Wrap(err, "invalid $VOICE_ID")
|
globalErr = errors.Wrap(err, "invalid $VOICE_ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
env = Env{
|
globalEnv = Env{
|
||||||
BotToken: token,
|
BotToken: token,
|
||||||
ChannelID: discord.ChannelID(id),
|
ChannelID: discord.ChannelID(id),
|
||||||
VoiceChID: discord.ChannelID(vid),
|
VoiceChID: discord.ChannelID(vid),
|
||||||
|
|
|
@ -144,7 +144,10 @@ func (c *Client) RequestJSON(to interface{}, method, url string, opts ...Request
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||||
|
// Error for the actual Do method.
|
||||||
var doErr error
|
var doErr error
|
||||||
|
// Error that represents the latest error in the chain.
|
||||||
|
var onRespErr error
|
||||||
|
|
||||||
var r httpdriver.Response
|
var r httpdriver.Response
|
||||||
var status int
|
var status int
|
||||||
|
@ -178,9 +181,6 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
|
||||||
|
|
||||||
r, doErr = c.Client.Do(q)
|
r, doErr = c.Client.Do(q)
|
||||||
|
|
||||||
// Error that represents the latest error in the chain.
|
|
||||||
var onRespErr error
|
|
||||||
|
|
||||||
// Call OnResponse() even if the request failed.
|
// Call OnResponse() even if the request failed.
|
||||||
for _, fn := range c.OnResponse {
|
for _, fn := range c.OnResponse {
|
||||||
// Be sure to call ALL OnResponse handlers.
|
// Be sure to call ALL OnResponse handlers.
|
||||||
|
@ -189,12 +189,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if onRespErr != nil {
|
if onRespErr != nil || doErr != nil {
|
||||||
return nil, errors.Wrap(err, "OnResponse handler failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retry if the request failed.
|
|
||||||
if doErr != nil {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,6 +200,10 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if onRespErr != nil {
|
||||||
|
return nil, errors.Wrap(onRespErr, "OnResponse handler failed")
|
||||||
|
}
|
||||||
|
|
||||||
// If all retries failed:
|
// If all retries failed:
|
||||||
if doErr != nil {
|
if doErr != nil {
|
||||||
return nil, RequestError{doErr}
|
return nil, RequestError{doErr}
|
||||||
|
|
|
@ -40,7 +40,7 @@ type HTTPError struct {
|
||||||
func (err HTTPError) Error() string {
|
func (err HTTPError) Error() string {
|
||||||
switch {
|
switch {
|
||||||
case err.Message != "":
|
case err.Message != "":
|
||||||
return "Discord error: " + err.Message
|
return fmt.Sprintf("Discord %d error: %s", err.Status, err.Message)
|
||||||
|
|
||||||
case err.Code > 0:
|
case err.Code > 0:
|
||||||
return fmt.Sprintf("Discord returned status %d error code %d",
|
return fmt.Sprintf("Discord returned status %d error code %d",
|
||||||
|
|
Loading…
Reference in a new issue