From a0bccd9c350e41c7c40a8d2f7866269a89c2f3e0 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sun, 3 May 2020 14:02:03 -0700 Subject: [PATCH] API: Added WithContext API, closes #15 --- api/api.go | 60 +++++++++++++++--------- api/api_test.go | 20 ++++++++ state/state.go | 36 +++++++------- utils/httputil/client.go | 70 +++++++++++++++------------- utils/httputil/errors.go | 20 +++++++- utils/httputil/httpdriver/default.go | 14 ++---- utils/httputil/options.go | 1 + 7 files changed, 141 insertions(+), 80 deletions(-) create mode 100644 api/api_test.go diff --git a/api/api.go b/api/api.go index f7245b3..c2c9689 100644 --- a/api/api.go +++ b/api/api.go @@ -3,6 +3,7 @@ package api import ( + "context" "net/http" "github.com/diamondburned/arikawa/api/rate" @@ -24,10 +25,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)" type Client struct { *httputil.Client - Limiter *rate.Limiter - - Token string - UserAgent string + Session } func NewClient(token string) *Client { @@ -35,28 +33,48 @@ func NewClient(token string) *Client { } func NewCustomClient(token string, httpClient *httputil.Client) *Client { - cli := &Client{ - Client: httpClient, + ses := Session{ Limiter: rate.NewLimiter(APIPath), Token: token, UserAgent: UserAgent, } - cli.DefaultOptions = []httputil.RequestOption{ - func(r httpdriver.Request) error { - r.AddHeader(http.Header{ - "Authorization": {cli.Token}, - "User-Agent": {cli.UserAgent}, - "X-RateLimit-Precision": {"millisecond"}, - }) + hcl := httpClient.Copy() + hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest) + hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse) - // Rate limit stuff - return cli.Limiter.Acquire(r.GetContext(), r.GetPath()) - }, + return &Client{ + Client: hcl, + Session: ses, } - cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error { - return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp)) - } - - return cli +} + +func (c *Client) WithContext(ctx context.Context) *Client { + return &Client{ + Client: c.Client.WithContext(ctx), + Session: c.Session, + } +} + +// Session keeps a single session. This is typically wrapped around Client. +type Session struct { + Limiter *rate.Limiter + + Token string + UserAgent string +} + +func (s *Session) InjectRequest(r httpdriver.Request) error { + r.AddHeader(http.Header{ + "Authorization": {s.Token}, + "User-Agent": {s.UserAgent}, + "X-RateLimit-Precision": {"millisecond"}, + }) + + // Rate limit stuff + return s.Limiter.Acquire(r.GetContext(), r.GetPath()) +} + +func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error { + return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp)) } diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 0000000..5fcbb8a --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,20 @@ +package api + +import ( + "context" + "errors" + "testing" +) + +func TestContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // lol + + client := NewClient("no. 3-chan").WithContext(ctx) + + // This should fail. + _, err := client.Me() + if err == nil || !errors.Is(err, context.Canceled) { + t.Fatal("Unexpected error:", err) + } +} diff --git a/state/state.go b/state/state.go index fa1686c..6505422 100644 --- a/state/state.go +++ b/state/state.go @@ -3,6 +3,7 @@ package state import ( + "context" "sync" "github.com/diamondburned/arikawa/discord" @@ -46,19 +47,7 @@ type State struct { // List of channels with few messages, so it doesn't bother hitting the API // again. fewMessages map[discord.Snowflake]struct{} - fewMutex sync.Mutex -} - -func NewFromSession(s *session.Session, store Store) (*State, error) { - state := &State{ - Session: s, - Store: store, - Handler: handler.New(), - StateLog: func(err error) {}, - fewMessages: map[discord.Snowflake]struct{}{}, - } - - return state, state.hookSession() + fewMutex *sync.Mutex } func New(token string) (*State, error) { @@ -74,9 +63,24 @@ func NewWithStore(token string, store Store) (*State, error) { return NewFromSession(s, store) } -// Unhook removes all state handlers from the session handlers. -func (s *State) Unhook() { - s.unhooker() +func NewFromSession(s *session.Session, store Store) (*State, error) { + state := &State{ + Session: s, + Store: store, + Handler: handler.New(), + StateLog: func(err error) {}, + fewMessages: map[discord.Snowflake]struct{}{}, + fewMutex: new(sync.Mutex), + } + + return state, state.hookSession() +} + +func (s *State) WithContext(ctx context.Context) *State { + copied := *s + copied.Client = copied.Client.WithContext(ctx) + + return &copied } //// Helper methods diff --git a/utils/httputil/client.go b/utils/httputil/client.go index b5c461b..d37aae7 100644 --- a/utils/httputil/client.go +++ b/utils/httputil/client.go @@ -22,20 +22,17 @@ type Client struct { json.Driver SchemaEncoder - // DefaultOptions, if not nil, will be copied and prefixed on each Request. - DefaultOptions []RequestOption + // OnRequest, if not nil, will be copied and prefixed on each Request. + OnRequest []RequestOption // OnResponse is called after every Do() call. Response might be nil if Do() // errors out. The error returned will override Do's if it's not nil. - OnResponse func(httpdriver.Request, httpdriver.Response) error + OnResponse []ResponseFunc // Default to the global Retries variable (5). Retries uint -} -// ResponseNoop is used for (*Client).OnResponse. -func ResponseNoop(httpdriver.Request, httpdriver.Response) error { - return nil + context context.Context } func NewClient() *Client { @@ -44,12 +41,32 @@ func NewClient() *Client { Driver: json.Default, SchemaEncoder: &DefaultSchema{}, Retries: Retries, - OnResponse: ResponseNoop, + context: context.Background(), } } +// Copy returns a shallow copy of the client. +func (c *Client) Copy() *Client { + cl := new(Client) + *cl = *c + return cl +} + +// WithContext returns a client copy of the client with the given context. +func (c *Client) WithContext(ctx context.Context) *Client { + c = c.Copy() + c.context = ctx + return c +} + +// Context is a shared context for all future calls. It's Background by +// default. +func (c *Client) Context() context.Context { + return c.context +} + func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error { - for _, opt := range c.DefaultOptions { + for _, opt := range c.OnRequest { if err := opt(r); err != nil { return err } @@ -67,8 +84,8 @@ func (c *Client) MeanwhileMultipart( writer func(*multipart.Writer) error, method, url string, opts ...RequestOption) (httpdriver.Response, error) { - // We want to cancel the request if our bodyWriter fails - ctx, cancel := context.WithCancel(context.Background()) + // We want to cancel the request if our bodyWriter fails. + ctx, cancel := context.WithCancel(c.context) defer cancel() r, w := io.Pipe() @@ -93,7 +110,8 @@ func (c *Client) MeanwhileMultipart( WithContentType(body.FormDataContentType()), ) - resp, err := c.RequestCtx(ctx, method, url, opts...) + // Request with the current client and our own context: + resp, err := c.WithContext(ctx).Request(method, url, opts...) if err != nil && bgErr != nil { return nil, bgErr } @@ -109,13 +127,10 @@ func (c *Client) FastRequest(method, url string, opts ...RequestOption) error { return r.GetBody().Close() } -func (c *Client) RequestCtxJSON( - ctx context.Context, - to interface{}, method, url string, opts ...RequestOption) error { - +func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error { opts = PrependOptions(opts, JSONRequest) - r, err := c.RequestCtx(ctx, method, url, opts...) + r, err := c.Request(method, url, opts...) if err != nil { return err } @@ -135,11 +150,8 @@ func (c *Client) RequestCtxJSON( return nil } -func (c *Client) RequestCtx( - ctx context.Context, - method, url string, opts ...RequestOption) (httpdriver.Response, error) { - - req, err := c.Client.NewRequest(ctx, method, url) +func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) { + req, err := c.Client.NewRequest(c.context, method, url) if err != nil { return nil, RequestError{err} } @@ -165,8 +177,10 @@ func (c *Client) RequestCtx( } // Call OnResponse() even if the request failed. - if err := c.OnResponse(req, r); err != nil { - return nil, err + for _, fn := range c.OnResponse { + if err := fn(req, r); err != nil { + return nil, err + } } // If all retries failed: @@ -197,11 +211,3 @@ func (c *Client) RequestCtx( return r, nil } - -func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) { - return c.RequestCtx(context.Background(), method, url, opts...) -} - -func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error { - return c.RequestCtxJSON(context.Background(), to, method, url, opts...) -} diff --git a/utils/httputil/errors.go b/utils/httputil/errors.go index 0cb4967..9045b35 100644 --- a/utils/httputil/errors.go +++ b/utils/httputil/errors.go @@ -6,11 +6,27 @@ import ( ) type JSONError struct { - error + err error +} + +func (j JSONError) Error() string { + return "JSON decoding failed: " + j.err.Error() +} + +func (j JSONError) Unwrap() error { + return j.err } type RequestError struct { - error + err error +} + +func (r RequestError) Error() string { + return "Request failed: " + r.err.Error() +} + +func (r RequestError) Unwrap() error { + return r.err } type HTTPError struct { diff --git a/utils/httputil/httpdriver/default.go b/utils/httputil/httpdriver/default.go index 83a14bb..7d11938 100644 --- a/utils/httputil/httpdriver/default.go +++ b/utils/httputil/httpdriver/default.go @@ -9,18 +9,14 @@ import ( ) // DefaultClient implements Client and wraps around the stdlib Client. -type DefaultClient struct { - Client http.Client -} +type DefaultClient http.Client var _ Client = (*DefaultClient)(nil) // WrapClient wraps around the standard library's http.Client and returns an // implementation that's compatible with the Client driver interface. func WrapClient(client http.Client) Client { - return &DefaultClient{ - Client: client, - } + return DefaultClient(client) } // NewClient creates a new client around the standard library's http.Client. The @@ -31,7 +27,7 @@ func NewClient() Client { }) } -func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) { +func (d DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) { r, err := http.NewRequestWithContext(ctx, method, url, nil) if err != nil { return nil, err @@ -39,11 +35,11 @@ func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Req return (*DefaultRequest)(r), nil } -func (d *DefaultClient) Do(req Request) (Response, error) { +func (d DefaultClient) Do(req Request) (Response, error) { // Implementations can safely assert this. request := req.(*DefaultRequest) - r, err := d.Client.Do((*http.Request)(request)) + r, err := (*http.Client)(&d).Do((*http.Request)(request)) if err != nil { return nil, err } diff --git a/utils/httputil/options.go b/utils/httputil/options.go index 43c7d26..4094716 100644 --- a/utils/httputil/options.go +++ b/utils/httputil/options.go @@ -9,6 +9,7 @@ import ( ) type RequestOption func(httpdriver.Request) error +type ResponseFunc func(httpdriver.Request, httpdriver.Response) error func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption { if len(opts) == 0 {