API: Added WithContext API, closes #15

This commit is contained in:
diamondburned (Forefront) 2020-05-03 14:02:03 -07:00
parent 892c88d808
commit a0bccd9c35
7 changed files with 141 additions and 80 deletions

View File

@ -3,6 +3,7 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"github.com/diamondburned/arikawa/api/rate" "github.com/diamondburned/arikawa/api/rate"
@ -24,10 +25,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)"
type Client struct { type Client struct {
*httputil.Client *httputil.Client
Limiter *rate.Limiter Session
Token string
UserAgent string
} }
func NewClient(token string) *Client { func NewClient(token string) *Client {
@ -35,28 +33,48 @@ func NewClient(token string) *Client {
} }
func NewCustomClient(token string, httpClient *httputil.Client) *Client { func NewCustomClient(token string, httpClient *httputil.Client) *Client {
cli := &Client{ ses := Session{
Client: httpClient,
Limiter: rate.NewLimiter(APIPath), Limiter: rate.NewLimiter(APIPath),
Token: token, Token: token,
UserAgent: UserAgent, UserAgent: UserAgent,
} }
cli.DefaultOptions = []httputil.RequestOption{ hcl := httpClient.Copy()
func(r httpdriver.Request) error { hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest)
r.AddHeader(http.Header{ hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
"Authorization": {cli.Token},
"User-Agent": {cli.UserAgent},
"X-RateLimit-Precision": {"millisecond"},
})
// Rate limit stuff return &Client{
return cli.Limiter.Acquire(r.GetContext(), r.GetPath()) Client: hcl,
}, Session: ses,
} }
cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error { }
return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
} func (c *Client) WithContext(ctx context.Context) *Client {
return &Client{
return cli Client: c.Client.WithContext(ctx),
Session: c.Session,
}
}
// Session keeps a single session. This is typically wrapped around Client.
type Session struct {
Limiter *rate.Limiter
Token string
UserAgent string
}
func (s *Session) InjectRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {s.Token},
"User-Agent": {s.UserAgent},
"X-RateLimit-Precision": {"millisecond"},
})
// Rate limit stuff
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
}
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
} }

20
api/api_test.go Normal file
View File

@ -0,0 +1,20 @@
package api
import (
"context"
"errors"
"testing"
)
func TestContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // lol
client := NewClient("no. 3-chan").WithContext(ctx)
// This should fail.
_, err := client.Me()
if err == nil || !errors.Is(err, context.Canceled) {
t.Fatal("Unexpected error:", err)
}
}

View File

@ -3,6 +3,7 @@
package state package state
import ( import (
"context"
"sync" "sync"
"github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/discord"
@ -46,19 +47,7 @@ type State struct {
// List of channels with few messages, so it doesn't bother hitting the API // List of channels with few messages, so it doesn't bother hitting the API
// again. // again.
fewMessages map[discord.Snowflake]struct{} fewMessages map[discord.Snowflake]struct{}
fewMutex sync.Mutex fewMutex *sync.Mutex
}
func NewFromSession(s *session.Session, store Store) (*State, error) {
state := &State{
Session: s,
Store: store,
Handler: handler.New(),
StateLog: func(err error) {},
fewMessages: map[discord.Snowflake]struct{}{},
}
return state, state.hookSession()
} }
func New(token string) (*State, error) { func New(token string) (*State, error) {
@ -74,9 +63,24 @@ func NewWithStore(token string, store Store) (*State, error) {
return NewFromSession(s, store) return NewFromSession(s, store)
} }
// Unhook removes all state handlers from the session handlers. func NewFromSession(s *session.Session, store Store) (*State, error) {
func (s *State) Unhook() { state := &State{
s.unhooker() Session: s,
Store: store,
Handler: handler.New(),
StateLog: func(err error) {},
fewMessages: map[discord.Snowflake]struct{}{},
fewMutex: new(sync.Mutex),
}
return state, state.hookSession()
}
func (s *State) WithContext(ctx context.Context) *State {
copied := *s
copied.Client = copied.Client.WithContext(ctx)
return &copied
} }
//// Helper methods //// Helper methods

View File

@ -22,20 +22,17 @@ type Client struct {
json.Driver json.Driver
SchemaEncoder SchemaEncoder
// DefaultOptions, if not nil, will be copied and prefixed on each Request. // OnRequest, if not nil, will be copied and prefixed on each Request.
DefaultOptions []RequestOption OnRequest []RequestOption
// OnResponse is called after every Do() call. Response might be nil if Do() // OnResponse is called after every Do() call. Response might be nil if Do()
// errors out. The error returned will override Do's if it's not nil. // errors out. The error returned will override Do's if it's not nil.
OnResponse func(httpdriver.Request, httpdriver.Response) error OnResponse []ResponseFunc
// Default to the global Retries variable (5). // Default to the global Retries variable (5).
Retries uint Retries uint
}
// ResponseNoop is used for (*Client).OnResponse. context context.Context
func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
return nil
} }
func NewClient() *Client { func NewClient() *Client {
@ -44,12 +41,32 @@ func NewClient() *Client {
Driver: json.Default, Driver: json.Default,
SchemaEncoder: &DefaultSchema{}, SchemaEncoder: &DefaultSchema{},
Retries: Retries, Retries: Retries,
OnResponse: ResponseNoop, context: context.Background(),
} }
} }
// Copy returns a shallow copy of the client.
func (c *Client) Copy() *Client {
cl := new(Client)
*cl = *c
return cl
}
// WithContext returns a client copy of the client with the given context.
func (c *Client) WithContext(ctx context.Context) *Client {
c = c.Copy()
c.context = ctx
return c
}
// Context is a shared context for all future calls. It's Background by
// default.
func (c *Client) Context() context.Context {
return c.context
}
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error { func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
for _, opt := range c.DefaultOptions { for _, opt := range c.OnRequest {
if err := opt(r); err != nil { if err := opt(r); err != nil {
return err return err
} }
@ -67,8 +84,8 @@ func (c *Client) MeanwhileMultipart(
writer func(*multipart.Writer) error, writer func(*multipart.Writer) error,
method, url string, opts ...RequestOption) (httpdriver.Response, error) { method, url string, opts ...RequestOption) (httpdriver.Response, error) {
// We want to cancel the request if our bodyWriter fails // We want to cancel the request if our bodyWriter fails.
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(c.context)
defer cancel() defer cancel()
r, w := io.Pipe() r, w := io.Pipe()
@ -93,7 +110,8 @@ func (c *Client) MeanwhileMultipart(
WithContentType(body.FormDataContentType()), WithContentType(body.FormDataContentType()),
) )
resp, err := c.RequestCtx(ctx, method, url, opts...) // Request with the current client and our own context:
resp, err := c.WithContext(ctx).Request(method, url, opts...)
if err != nil && bgErr != nil { if err != nil && bgErr != nil {
return nil, bgErr return nil, bgErr
} }
@ -109,13 +127,10 @@ func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
return r.GetBody().Close() return r.GetBody().Close()
} }
func (c *Client) RequestCtxJSON( func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
ctx context.Context,
to interface{}, method, url string, opts ...RequestOption) error {
opts = PrependOptions(opts, JSONRequest) opts = PrependOptions(opts, JSONRequest)
r, err := c.RequestCtx(ctx, method, url, opts...) r, err := c.Request(method, url, opts...)
if err != nil { if err != nil {
return err return err
} }
@ -135,11 +150,8 @@ func (c *Client) RequestCtxJSON(
return nil return nil
} }
func (c *Client) RequestCtx( func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
ctx context.Context, req, err := c.Client.NewRequest(c.context, method, url)
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
req, err := c.Client.NewRequest(ctx, method, url)
if err != nil { if err != nil {
return nil, RequestError{err} return nil, RequestError{err}
} }
@ -165,8 +177,10 @@ func (c *Client) RequestCtx(
} }
// Call OnResponse() even if the request failed. // Call OnResponse() even if the request failed.
if err := c.OnResponse(req, r); err != nil { for _, fn := range c.OnResponse {
return nil, err if err := fn(req, r); err != nil {
return nil, err
}
} }
// If all retries failed: // If all retries failed:
@ -197,11 +211,3 @@ func (c *Client) RequestCtx(
return r, nil return r, nil
} }
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
return c.RequestCtx(context.Background(), method, url, opts...)
}
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
return c.RequestCtxJSON(context.Background(), to, method, url, opts...)
}

View File

@ -6,11 +6,27 @@ import (
) )
type JSONError struct { type JSONError struct {
error err error
}
func (j JSONError) Error() string {
return "JSON decoding failed: " + j.err.Error()
}
func (j JSONError) Unwrap() error {
return j.err
} }
type RequestError struct { type RequestError struct {
error err error
}
func (r RequestError) Error() string {
return "Request failed: " + r.err.Error()
}
func (r RequestError) Unwrap() error {
return r.err
} }
type HTTPError struct { type HTTPError struct {

View File

@ -9,18 +9,14 @@ import (
) )
// DefaultClient implements Client and wraps around the stdlib Client. // DefaultClient implements Client and wraps around the stdlib Client.
type DefaultClient struct { type DefaultClient http.Client
Client http.Client
}
var _ Client = (*DefaultClient)(nil) var _ Client = (*DefaultClient)(nil)
// WrapClient wraps around the standard library's http.Client and returns an // WrapClient wraps around the standard library's http.Client and returns an
// implementation that's compatible with the Client driver interface. // implementation that's compatible with the Client driver interface.
func WrapClient(client http.Client) Client { func WrapClient(client http.Client) Client {
return &DefaultClient{ return DefaultClient(client)
Client: client,
}
} }
// NewClient creates a new client around the standard library's http.Client. The // NewClient creates a new client around the standard library's http.Client. The
@ -31,7 +27,7 @@ func NewClient() Client {
}) })
} }
func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) { func (d DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
r, err := http.NewRequestWithContext(ctx, method, url, nil) r, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -39,11 +35,11 @@ func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Req
return (*DefaultRequest)(r), nil return (*DefaultRequest)(r), nil
} }
func (d *DefaultClient) Do(req Request) (Response, error) { func (d DefaultClient) Do(req Request) (Response, error) {
// Implementations can safely assert this. // Implementations can safely assert this.
request := req.(*DefaultRequest) request := req.(*DefaultRequest)
r, err := d.Client.Do((*http.Request)(request)) r, err := (*http.Client)(&d).Do((*http.Request)(request))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,6 +9,7 @@ import (
) )
type RequestOption func(httpdriver.Request) error type RequestOption func(httpdriver.Request) error
type ResponseFunc func(httpdriver.Request, httpdriver.Response) error
func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption { func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption {
if len(opts) == 0 { if len(opts) == 0 {