mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-21 03:57:26 +00:00
API: Fixed unlock of unlocked mutex bug
This commit is contained in:
parent
04a33c3ee4
commit
14b9d8f43a
|
@ -47,11 +47,11 @@ func NewClient(token string) *Client {
|
|||
// Rate limit stuff
|
||||
return cli.Limiter.Acquire(r.Context(), r.URL.Path)
|
||||
}
|
||||
tw.Cancel = func(r *http.Request, err error) {
|
||||
cli.Limiter.Cancel(r.URL.Path)
|
||||
tw.Post = func(r *http.Request, resp *http.Response) error {
|
||||
if resp == nil {
|
||||
return cli.Limiter.Release(r.URL.Path, nil)
|
||||
}
|
||||
tw.Post = func(r *http.Response) error {
|
||||
return cli.Limiter.Release(r.Request.URL.Path, r.Header)
|
||||
return cli.Limiter.Release(r.URL.Path, resp.Header)
|
||||
}
|
||||
|
||||
cli.Client.Transport = tw
|
||||
|
|
|
@ -25,12 +25,6 @@ type Limiter struct {
|
|||
// Only 1 per bucket
|
||||
CustomLimits []*CustomRateLimit
|
||||
|
||||
// These callbacks will only be called for valid buckets. They will also be
|
||||
// called right before locking. Returning false will not rate limit.
|
||||
OnAcquire func(path string) bool
|
||||
OnCancel func(path string) bool
|
||||
OnRelease func(path string) bool // false means not unlocking
|
||||
|
||||
Prefix string
|
||||
|
||||
global *int64 // atomic guarded, unixnano
|
||||
|
@ -54,20 +48,12 @@ type bucket struct {
|
|||
lastReset time.Time // only for custom
|
||||
}
|
||||
|
||||
func returnTrue(string) bool {
|
||||
// time.Sleep(time.Nanosecond)
|
||||
return true
|
||||
}
|
||||
|
||||
func NewLimiter(prefix string) *Limiter {
|
||||
return &Limiter{
|
||||
Prefix: prefix,
|
||||
global: new(int64),
|
||||
buckets: sync.Map{},
|
||||
CustomLimits: []*CustomRateLimit{},
|
||||
OnAcquire: returnTrue,
|
||||
OnCancel: returnTrue,
|
||||
OnRelease: returnTrue,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,10 +87,6 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
|
|||
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||
b := l.getBucket(path, true)
|
||||
|
||||
if !l.OnAcquire(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Acquire lock with a timeout
|
||||
if err := b.lock.CLock(ctx); err != nil {
|
||||
return err
|
||||
|
@ -142,22 +124,6 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (l *Limiter) Cancel(path string) error {
|
||||
b := l.getBucket(path, false)
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
if !l.OnCancel(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TryLock would either not lock because it's already locked, or lock
|
||||
// because it isn't.
|
||||
b.lock.TryLock()
|
||||
b.lock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Release releases the URL from the locks. This doesn't need a context for
|
||||
// timing out, it doesn't block that much.
|
||||
func (l *Limiter) Release(path string, headers http.Header) error {
|
||||
|
@ -167,9 +133,9 @@ func (l *Limiter) Release(path string, headers http.Header) error {
|
|||
}
|
||||
|
||||
defer func() {
|
||||
if l.OnRelease(path) {
|
||||
// Try and lock the bucket, to prevent unlocking an unlocked lock:
|
||||
b.lock.TryLock()
|
||||
b.lock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Check custom limiter
|
||||
|
@ -184,6 +150,11 @@ func (l *Limiter) Release(path string, headers http.Header) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Check if headers is nil or not:
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
// boolean
|
||||
global = headers.Get("X-RateLimit-Global")
|
||||
|
|
|
@ -7,8 +7,7 @@ import (
|
|||
type TransportWrapper struct {
|
||||
Default http.RoundTripper
|
||||
Pre func(*http.Request) error
|
||||
Cancel func(*http.Request, error)
|
||||
Post func(*http.Response) error
|
||||
Post func(*http.Request, *http.Response) error
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = (*TransportWrapper)(nil)
|
||||
|
@ -17,25 +16,21 @@ func NewTransportWrapper() *TransportWrapper {
|
|||
return &TransportWrapper{
|
||||
Default: http.DefaultTransport,
|
||||
Pre: func(*http.Request) error { return nil },
|
||||
Cancel: func(*http.Request, error) {},
|
||||
Post: func(*http.Response) error { return nil },
|
||||
Post: func(*http.Request, *http.Response) error { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (c *TransportWrapper) RoundTrip(req *http.Request) (r *http.Response, err error) {
|
||||
if err := c.Pre(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := c.Default.RoundTrip(req)
|
||||
if err != nil {
|
||||
c.Cancel(req, err)
|
||||
return nil, err
|
||||
r, err = c.Default.RoundTrip(req)
|
||||
|
||||
// Call Post regardless of error:
|
||||
if postErr := c.Post(req, r); postErr != nil {
|
||||
return r, postErr
|
||||
}
|
||||
|
||||
if err := c.Post(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
return r, err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue