API: Fixed unlock of unlocked mutex bug

This commit is contained in:
diamondburned (Forefront) 2020-03-01 14:25:54 -08:00
parent 04a33c3ee4
commit 14b9d8f43a
3 changed files with 22 additions and 56 deletions

View File

@ -47,11 +47,11 @@ func NewClient(token string) *Client {
// Rate limit stuff // Rate limit stuff
return cli.Limiter.Acquire(r.Context(), r.URL.Path) return cli.Limiter.Acquire(r.Context(), r.URL.Path)
} }
tw.Cancel = func(r *http.Request, err error) { tw.Post = func(r *http.Request, resp *http.Response) error {
cli.Limiter.Cancel(r.URL.Path) if resp == nil {
} return cli.Limiter.Release(r.URL.Path, nil)
tw.Post = func(r *http.Response) error { }
return cli.Limiter.Release(r.Request.URL.Path, r.Header) return cli.Limiter.Release(r.URL.Path, resp.Header)
} }
cli.Client.Transport = tw cli.Client.Transport = tw

View File

@ -25,12 +25,6 @@ type Limiter struct {
// Only 1 per bucket // Only 1 per bucket
CustomLimits []*CustomRateLimit CustomLimits []*CustomRateLimit
// These callbacks will only be called for valid buckets. They will also be
// called right before locking. Returning false will not rate limit.
OnAcquire func(path string) bool
OnCancel func(path string) bool
OnRelease func(path string) bool // false means not unlocking
Prefix string Prefix string
global *int64 // atomic guarded, unixnano global *int64 // atomic guarded, unixnano
@ -54,20 +48,12 @@ type bucket struct {
lastReset time.Time // only for custom lastReset time.Time // only for custom
} }
func returnTrue(string) bool {
// time.Sleep(time.Nanosecond)
return true
}
func NewLimiter(prefix string) *Limiter { func NewLimiter(prefix string) *Limiter {
return &Limiter{ return &Limiter{
Prefix: prefix, Prefix: prefix,
global: new(int64), global: new(int64),
buckets: sync.Map{}, buckets: sync.Map{},
CustomLimits: []*CustomRateLimit{}, CustomLimits: []*CustomRateLimit{},
OnAcquire: returnTrue,
OnCancel: returnTrue,
OnRelease: returnTrue,
} }
} }
@ -101,10 +87,6 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
func (l *Limiter) Acquire(ctx context.Context, path string) error { func (l *Limiter) Acquire(ctx context.Context, path string) error {
b := l.getBucket(path, true) b := l.getBucket(path, true)
if !l.OnAcquire(path) {
return nil
}
// Acquire lock with a timeout // Acquire lock with a timeout
if err := b.lock.CLock(ctx); err != nil { if err := b.lock.CLock(ctx); err != nil {
return err return err
@ -142,22 +124,6 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
return nil return nil
} }
func (l *Limiter) Cancel(path string) error {
b := l.getBucket(path, false)
if b == nil {
return nil
}
if !l.OnCancel(path) {
return nil
}
// TryLock would either not lock because it's already locked, or lock
// because it isn't.
b.lock.TryLock()
b.lock.Unlock()
return nil
}
// Release releases the URL from the locks. This doesn't need a context for // Release releases the URL from the locks. This doesn't need a context for
// timing out, it doesn't block that much. // timing out, it doesn't block that much.
func (l *Limiter) Release(path string, headers http.Header) error { func (l *Limiter) Release(path string, headers http.Header) error {
@ -167,9 +133,9 @@ func (l *Limiter) Release(path string, headers http.Header) error {
} }
defer func() { defer func() {
if l.OnRelease(path) { // Try and lock the bucket, to prevent unlocking an unlocked lock:
b.lock.Unlock() b.lock.TryLock()
} b.lock.Unlock()
}() }()
// Check custom limiter // Check custom limiter
@ -184,6 +150,11 @@ func (l *Limiter) Release(path string, headers http.Header) error {
return nil return nil
} }
// Check if headers is nil or not:
if headers == nil {
return nil
}
var ( var (
// boolean // boolean
global = headers.Get("X-RateLimit-Global") global = headers.Get("X-RateLimit-Global")

View File

@ -7,8 +7,7 @@ import (
type TransportWrapper struct { type TransportWrapper struct {
Default http.RoundTripper Default http.RoundTripper
Pre func(*http.Request) error Pre func(*http.Request) error
Cancel func(*http.Request, error) Post func(*http.Request, *http.Response) error
Post func(*http.Response) error
} }
var _ http.RoundTripper = (*TransportWrapper)(nil) var _ http.RoundTripper = (*TransportWrapper)(nil)
@ -17,25 +16,21 @@ func NewTransportWrapper() *TransportWrapper {
return &TransportWrapper{ return &TransportWrapper{
Default: http.DefaultTransport, Default: http.DefaultTransport,
Pre: func(*http.Request) error { return nil }, Pre: func(*http.Request) error { return nil },
Cancel: func(*http.Request, error) {}, Post: func(*http.Request, *http.Response) error { return nil },
Post: func(*http.Response) error { return nil },
} }
} }
func (c *TransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) { func (c *TransportWrapper) RoundTrip(req *http.Request) (r *http.Response, err error) {
if err := c.Pre(req); err != nil { if err := c.Pre(req); err != nil {
return nil, err return nil, err
} }
r, err := c.Default.RoundTrip(req) r, err = c.Default.RoundTrip(req)
if err != nil {
c.Cancel(req, err) // Call Post regardless of error:
return nil, err if postErr := c.Post(req, r); postErr != nil {
return r, postErr
} }
if err := c.Post(r); err != nil { return r, err
return nil, err
}
return r, nil
} }