From 14b9d8f43acd170a84754b28ee97bc76630678fa Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sun, 1 Mar 2020 14:25:54 -0800 Subject: [PATCH] API: Fixed unlock of unlocked mutex bug --- api/api.go | 10 ++++----- api/rate/rate.go | 45 +++++++-------------------------------- internal/httputil/http.go | 23 ++++++++------------ 3 files changed, 22 insertions(+), 56 deletions(-) diff --git a/api/api.go b/api/api.go index 54472c8..d7d7ac7 100644 --- a/api/api.go +++ b/api/api.go @@ -47,11 +47,11 @@ func NewClient(token string) *Client { // Rate limit stuff return cli.Limiter.Acquire(r.Context(), r.URL.Path) } - tw.Cancel = func(r *http.Request, err error) { - cli.Limiter.Cancel(r.URL.Path) - } - tw.Post = func(r *http.Response) error { - return cli.Limiter.Release(r.Request.URL.Path, r.Header) + tw.Post = func(r *http.Request, resp *http.Response) error { + if resp == nil { + return cli.Limiter.Release(r.URL.Path, nil) + } + return cli.Limiter.Release(r.URL.Path, resp.Header) } cli.Client.Transport = tw diff --git a/api/rate/rate.go b/api/rate/rate.go index d10ead0..0515208 100644 --- a/api/rate/rate.go +++ b/api/rate/rate.go @@ -25,12 +25,6 @@ type Limiter struct { // Only 1 per bucket CustomLimits []*CustomRateLimit - // These callbacks will only be called for valid buckets. They will also be - // called right before locking. Returning false will not rate limit. - OnAcquire func(path string) bool - OnCancel func(path string) bool - OnRelease func(path string) bool // false means not unlocking - Prefix string global *int64 // atomic guarded, unixnano @@ -54,20 +48,12 @@ type bucket struct { lastReset time.Time // only for custom } -func returnTrue(string) bool { - // time.Sleep(time.Nanosecond) - return true -} - func NewLimiter(prefix string) *Limiter { return &Limiter{ Prefix: prefix, global: new(int64), buckets: sync.Map{}, CustomLimits: []*CustomRateLimit{}, - OnAcquire: returnTrue, - OnCancel: returnTrue, - OnRelease: returnTrue, } } @@ -101,10 +87,6 @@ func (l *Limiter) getBucket(path string, store bool) *bucket { func (l *Limiter) Acquire(ctx context.Context, path string) error { b := l.getBucket(path, true) - if !l.OnAcquire(path) { - return nil - } - // Acquire lock with a timeout if err := b.lock.CLock(ctx); err != nil { return err @@ -142,22 +124,6 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error { return nil } -func (l *Limiter) Cancel(path string) error { - b := l.getBucket(path, false) - if b == nil { - return nil - } - if !l.OnCancel(path) { - return nil - } - - // TryLock would either not lock because it's already locked, or lock - // because it isn't. - b.lock.TryLock() - b.lock.Unlock() - return nil -} - // Release releases the URL from the locks. This doesn't need a context for // timing out, it doesn't block that much. func (l *Limiter) Release(path string, headers http.Header) error { @@ -167,9 +133,9 @@ func (l *Limiter) Release(path string, headers http.Header) error { } defer func() { - if l.OnRelease(path) { - b.lock.Unlock() - } + // Try and lock the bucket, to prevent unlocking an unlocked lock: + b.lock.TryLock() + b.lock.Unlock() }() // Check custom limiter @@ -184,6 +150,11 @@ func (l *Limiter) Release(path string, headers http.Header) error { return nil } + // Check if headers is nil or not: + if headers == nil { + return nil + } + var ( // boolean global = headers.Get("X-RateLimit-Global") diff --git a/internal/httputil/http.go b/internal/httputil/http.go index 8ba6f09..9ca71be 100644 --- a/internal/httputil/http.go +++ b/internal/httputil/http.go @@ -7,8 +7,7 @@ import ( type TransportWrapper struct { Default http.RoundTripper Pre func(*http.Request) error - Cancel func(*http.Request, error) - Post func(*http.Response) error + Post func(*http.Request, *http.Response) error } var _ http.RoundTripper = (*TransportWrapper)(nil) @@ -17,25 +16,21 @@ func NewTransportWrapper() *TransportWrapper { return &TransportWrapper{ Default: http.DefaultTransport, Pre: func(*http.Request) error { return nil }, - Cancel: func(*http.Request, error) {}, - Post: func(*http.Response) error { return nil }, + Post: func(*http.Request, *http.Response) error { return nil }, } } -func (c *TransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) { +func (c *TransportWrapper) RoundTrip(req *http.Request) (r *http.Response, err error) { if err := c.Pre(req); err != nil { return nil, err } - r, err := c.Default.RoundTrip(req) - if err != nil { - c.Cancel(req, err) - return nil, err + r, err = c.Default.RoundTrip(req) + + // Call Post regardless of error: + if postErr := c.Post(req, r); postErr != nil { + return r, postErr } - if err := c.Post(r); err != nil { - return nil, err - } - - return r, nil + return r, err }