diff --git a/api/rate/rate.go b/api/rate/rate.go index a379b10..9a358e4 100644 --- a/api/rate/rate.go +++ b/api/rate/rate.go @@ -31,8 +31,10 @@ type Limiter struct { Prefix string - global *int64 // atomic guarded, unixnano - buckets sync.Map + global int64 // atomic guarded, unixnano + + bucketMu sync.Mutex + buckets map[string]*bucket } type CustomRateLimit struct { @@ -60,8 +62,7 @@ func newBucket() *bucket { func NewLimiter(prefix string) *Limiter { return &Limiter{ Prefix: prefix, - global: new(int64), - buckets: sync.Map{}, + buckets: map[string]*bucket{}, CustomLimits: []*CustomRateLimit{}, } } @@ -69,7 +70,10 @@ func NewLimiter(prefix string) *Limiter { func (l *Limiter) getBucket(path string, store bool) *bucket { path = ParseBucketKey(strings.TrimPrefix(path, l.Prefix)) - bc, ok := l.buckets.Load(path) + l.bucketMu.Lock() + defer l.bucketMu.Unlock() + + bc, ok := l.buckets[path] if !ok && !store { return nil } @@ -84,11 +88,11 @@ func (l *Limiter) getBucket(path string, store bool) *bucket { } } - l.buckets.Store(path, bc) + l.buckets[path] = bc return bc } - return bc.(*bucket) + return bc } func (l *Limiter) Acquire(ctx context.Context, path string) error { @@ -107,7 +111,7 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error { until = b.reset } else { // maybe global rate limit has it - until = time.Unix(0, atomic.LoadInt64(l.global)) + until = time.Unix(0, atomic.LoadInt64(&l.global)) } if until.After(now) { @@ -178,7 +182,7 @@ func (l *Limiter) Release(path string, headers http.Header) error { at := time.Now().Add(time.Duration(i) * time.Second) if global != "" { // probably true - atomic.StoreInt64(l.global, at.UnixNano()) + atomic.StoreInt64(&l.global, at.UnixNano()) } else { b.reset = at } diff --git a/internal/testenv/testenv.go b/internal/testenv/testenv.go index 4ce5044..00058ee 100644 --- a/internal/testenv/testenv.go +++ b/internal/testenv/testenv.go @@ -18,9 +18,9 @@ type Env struct { } var ( - env Env - err error - once sync.Once + globalEnv Env + globalErr error + once sync.Once ) func Must(t *testing.T) Env { @@ -33,41 +33,41 @@ func Must(t *testing.T) Env { func GetEnv() (Env, error) { once.Do(getEnv) - return env, err + return globalEnv, globalErr } func getEnv() { var token = os.Getenv("BOT_TOKEN") if token == "" { - err = errors.New("missing $BOT_TOKEN") + globalErr = errors.New("missing $BOT_TOKEN") return } var cid = os.Getenv("CHANNEL_ID") if cid == "" { - err = errors.New("missing $CHANNEL_ID") + globalErr = errors.New("missing $CHANNEL_ID") return } id, err := discord.ParseSnowflake(cid) if err != nil { - err = errors.Wrap(err, "invalid $CHANNEL_ID") + globalErr = errors.Wrap(err, "invalid $CHANNEL_ID") return } var sid = os.Getenv("VOICE_ID") if sid == "" { - err = errors.New("missing $VOICE_ID") + globalErr = errors.New("missing $VOICE_ID") return } vid, err := discord.ParseSnowflake(sid) if err != nil { - err = errors.Wrap(err, "invalid $VOICE_ID") + globalErr = errors.Wrap(err, "invalid $VOICE_ID") return } - env = Env{ + globalEnv = Env{ BotToken: token, ChannelID: discord.ChannelID(id), VoiceChID: discord.ChannelID(vid), diff --git a/utils/httputil/client.go b/utils/httputil/client.go index 46c849b..b070490 100644 --- a/utils/httputil/client.go +++ b/utils/httputil/client.go @@ -144,7 +144,10 @@ func (c *Client) RequestJSON(to interface{}, method, url string, opts ...Request } func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) { + // Error for the actual Do method. var doErr error + // Error that represents the latest error in the chain. + var onRespErr error var r httpdriver.Response var status int @@ -178,9 +181,6 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver. r, doErr = c.Client.Do(q) - // Error that represents the latest error in the chain. - var onRespErr error - // Call OnResponse() even if the request failed. for _, fn := range c.OnResponse { // Be sure to call ALL OnResponse handlers. @@ -189,12 +189,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver. } } - if onRespErr != nil { - return nil, errors.Wrap(err, "OnResponse handler failed") - } - - // Retry if the request failed. - if doErr != nil { + if onRespErr != nil || doErr != nil { continue } @@ -205,6 +200,10 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver. break } + if onRespErr != nil { + return nil, errors.Wrap(onRespErr, "OnResponse handler failed") + } + // If all retries failed: if doErr != nil { return nil, RequestError{doErr} diff --git a/utils/httputil/errors.go b/utils/httputil/errors.go index 5e1b5cd..19fe015 100644 --- a/utils/httputil/errors.go +++ b/utils/httputil/errors.go @@ -40,7 +40,7 @@ type HTTPError struct { func (err HTTPError) Error() string { switch { case err.Message != "": - return "Discord error: " + err.Message + return fmt.Sprintf("Discord %d error: %s", err.Status, err.Message) case err.Code > 0: return fmt.Sprintf("Discord returned status %d error code %d",