mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-02-01 09:27:18 +00:00
API: Fixed unlock of unlocked mutex bug
This commit is contained in:
parent
04a33c3ee4
commit
14b9d8f43a
|
@ -47,11 +47,11 @@ func NewClient(token string) *Client {
|
||||||
// Rate limit stuff
|
// Rate limit stuff
|
||||||
return cli.Limiter.Acquire(r.Context(), r.URL.Path)
|
return cli.Limiter.Acquire(r.Context(), r.URL.Path)
|
||||||
}
|
}
|
||||||
tw.Cancel = func(r *http.Request, err error) {
|
tw.Post = func(r *http.Request, resp *http.Response) error {
|
||||||
cli.Limiter.Cancel(r.URL.Path)
|
if resp == nil {
|
||||||
|
return cli.Limiter.Release(r.URL.Path, nil)
|
||||||
}
|
}
|
||||||
tw.Post = func(r *http.Response) error {
|
return cli.Limiter.Release(r.URL.Path, resp.Header)
|
||||||
return cli.Limiter.Release(r.Request.URL.Path, r.Header)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cli.Client.Transport = tw
|
cli.Client.Transport = tw
|
||||||
|
|
|
@ -25,12 +25,6 @@ type Limiter struct {
|
||||||
// Only 1 per bucket
|
// Only 1 per bucket
|
||||||
CustomLimits []*CustomRateLimit
|
CustomLimits []*CustomRateLimit
|
||||||
|
|
||||||
// These callbacks will only be called for valid buckets. They will also be
|
|
||||||
// called right before locking. Returning false will not rate limit.
|
|
||||||
OnAcquire func(path string) bool
|
|
||||||
OnCancel func(path string) bool
|
|
||||||
OnRelease func(path string) bool // false means not unlocking
|
|
||||||
|
|
||||||
Prefix string
|
Prefix string
|
||||||
|
|
||||||
global *int64 // atomic guarded, unixnano
|
global *int64 // atomic guarded, unixnano
|
||||||
|
@ -54,20 +48,12 @@ type bucket struct {
|
||||||
lastReset time.Time // only for custom
|
lastReset time.Time // only for custom
|
||||||
}
|
}
|
||||||
|
|
||||||
func returnTrue(string) bool {
|
|
||||||
// time.Sleep(time.Nanosecond)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLimiter(prefix string) *Limiter {
|
func NewLimiter(prefix string) *Limiter {
|
||||||
return &Limiter{
|
return &Limiter{
|
||||||
Prefix: prefix,
|
Prefix: prefix,
|
||||||
global: new(int64),
|
global: new(int64),
|
||||||
buckets: sync.Map{},
|
buckets: sync.Map{},
|
||||||
CustomLimits: []*CustomRateLimit{},
|
CustomLimits: []*CustomRateLimit{},
|
||||||
OnAcquire: returnTrue,
|
|
||||||
OnCancel: returnTrue,
|
|
||||||
OnRelease: returnTrue,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,10 +87,6 @@ func (l *Limiter) getBucket(path string, store bool) *bucket {
|
||||||
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||||
b := l.getBucket(path, true)
|
b := l.getBucket(path, true)
|
||||||
|
|
||||||
if !l.OnAcquire(path) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire lock with a timeout
|
// Acquire lock with a timeout
|
||||||
if err := b.lock.CLock(ctx); err != nil {
|
if err := b.lock.CLock(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -142,22 +124,6 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Limiter) Cancel(path string) error {
|
|
||||||
b := l.getBucket(path, false)
|
|
||||||
if b == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !l.OnCancel(path) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryLock would either not lock because it's already locked, or lock
|
|
||||||
// because it isn't.
|
|
||||||
b.lock.TryLock()
|
|
||||||
b.lock.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release releases the URL from the locks. This doesn't need a context for
|
// Release releases the URL from the locks. This doesn't need a context for
|
||||||
// timing out, it doesn't block that much.
|
// timing out, it doesn't block that much.
|
||||||
func (l *Limiter) Release(path string, headers http.Header) error {
|
func (l *Limiter) Release(path string, headers http.Header) error {
|
||||||
|
@ -167,9 +133,9 @@ func (l *Limiter) Release(path string, headers http.Header) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if l.OnRelease(path) {
|
// Try and lock the bucket, to prevent unlocking an unlocked lock:
|
||||||
|
b.lock.TryLock()
|
||||||
b.lock.Unlock()
|
b.lock.Unlock()
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Check custom limiter
|
// Check custom limiter
|
||||||
|
@ -184,6 +150,11 @@ func (l *Limiter) Release(path string, headers http.Header) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if headers is nil or not:
|
||||||
|
if headers == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// boolean
|
// boolean
|
||||||
global = headers.Get("X-RateLimit-Global")
|
global = headers.Get("X-RateLimit-Global")
|
||||||
|
|
|
@ -7,8 +7,7 @@ import (
|
||||||
type TransportWrapper struct {
|
type TransportWrapper struct {
|
||||||
Default http.RoundTripper
|
Default http.RoundTripper
|
||||||
Pre func(*http.Request) error
|
Pre func(*http.Request) error
|
||||||
Cancel func(*http.Request, error)
|
Post func(*http.Request, *http.Response) error
|
||||||
Post func(*http.Response) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ http.RoundTripper = (*TransportWrapper)(nil)
|
var _ http.RoundTripper = (*TransportWrapper)(nil)
|
||||||
|
@ -17,25 +16,21 @@ func NewTransportWrapper() *TransportWrapper {
|
||||||
return &TransportWrapper{
|
return &TransportWrapper{
|
||||||
Default: http.DefaultTransport,
|
Default: http.DefaultTransport,
|
||||||
Pre: func(*http.Request) error { return nil },
|
Pre: func(*http.Request) error { return nil },
|
||||||
Cancel: func(*http.Request, error) {},
|
Post: func(*http.Request, *http.Response) error { return nil },
|
||||||
Post: func(*http.Response) error { return nil },
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (c *TransportWrapper) RoundTrip(req *http.Request) (r *http.Response, err error) {
|
||||||
if err := c.Pre(req); err != nil {
|
if err := c.Pre(req); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := c.Default.RoundTrip(req)
|
r, err = c.Default.RoundTrip(req)
|
||||||
if err != nil {
|
|
||||||
c.Cancel(req, err)
|
// Call Post regardless of error:
|
||||||
return nil, err
|
if postErr := c.Post(req, r); postErr != nil {
|
||||||
|
return r, postErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Post(r); err != nil {
|
return r, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return r, nil
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue