API: Added WithContext API, closes #15

This commit is contained in:
diamondburned (Forefront) 2020-05-03 14:02:03 -07:00
parent 892c88d808
commit a0bccd9c35
7 changed files with 141 additions and 80 deletions

View File

@ -3,6 +3,7 @@
package api
import (
"context"
"net/http"
"github.com/diamondburned/arikawa/api/rate"
@ -24,10 +25,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)"
type Client struct {
*httputil.Client
Limiter *rate.Limiter
Token string
UserAgent string
Session
}
func NewClient(token string) *Client {
@ -35,28 +33,48 @@ func NewClient(token string) *Client {
}
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
cli := &Client{
Client: httpClient,
ses := Session{
Limiter: rate.NewLimiter(APIPath),
Token: token,
UserAgent: UserAgent,
}
cli.DefaultOptions = []httputil.RequestOption{
func(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {cli.Token},
"User-Agent": {cli.UserAgent},
"X-RateLimit-Precision": {"millisecond"},
})
hcl := httpClient.Copy()
hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest)
hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
// Rate limit stuff
return cli.Limiter.Acquire(r.GetContext(), r.GetPath())
},
return &Client{
Client: hcl,
Session: ses,
}
cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error {
return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
}
return cli
}
func (c *Client) WithContext(ctx context.Context) *Client {
return &Client{
Client: c.Client.WithContext(ctx),
Session: c.Session,
}
}
// Session keeps a single session. This is typically wrapped around Client.
type Session struct {
Limiter *rate.Limiter
Token string
UserAgent string
}
func (s *Session) InjectRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {s.Token},
"User-Agent": {s.UserAgent},
"X-RateLimit-Precision": {"millisecond"},
})
// Rate limit stuff
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
}
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
}

20
api/api_test.go Normal file
View File

@ -0,0 +1,20 @@
package api
import (
"context"
"errors"
"testing"
)
func TestContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // lol
client := NewClient("no. 3-chan").WithContext(ctx)
// This should fail.
_, err := client.Me()
if err == nil || !errors.Is(err, context.Canceled) {
t.Fatal("Unexpected error:", err)
}
}

View File

@ -3,6 +3,7 @@
package state
import (
"context"
"sync"
"github.com/diamondburned/arikawa/discord"
@ -46,19 +47,7 @@ type State struct {
// List of channels with few messages, so it doesn't bother hitting the API
// again.
fewMessages map[discord.Snowflake]struct{}
fewMutex sync.Mutex
}
func NewFromSession(s *session.Session, store Store) (*State, error) {
state := &State{
Session: s,
Store: store,
Handler: handler.New(),
StateLog: func(err error) {},
fewMessages: map[discord.Snowflake]struct{}{},
}
return state, state.hookSession()
fewMutex *sync.Mutex
}
func New(token string) (*State, error) {
@ -74,9 +63,24 @@ func NewWithStore(token string, store Store) (*State, error) {
return NewFromSession(s, store)
}
// Unhook removes all state handlers from the session handlers.
func (s *State) Unhook() {
s.unhooker()
func NewFromSession(s *session.Session, store Store) (*State, error) {
state := &State{
Session: s,
Store: store,
Handler: handler.New(),
StateLog: func(err error) {},
fewMessages: map[discord.Snowflake]struct{}{},
fewMutex: new(sync.Mutex),
}
return state, state.hookSession()
}
func (s *State) WithContext(ctx context.Context) *State {
copied := *s
copied.Client = copied.Client.WithContext(ctx)
return &copied
}
//// Helper methods

View File

@ -22,20 +22,17 @@ type Client struct {
json.Driver
SchemaEncoder
// DefaultOptions, if not nil, will be copied and prefixed on each Request.
DefaultOptions []RequestOption
// OnRequest, if not nil, will be copied and prefixed on each Request.
OnRequest []RequestOption
// OnResponse is called after every Do() call. Response might be nil if Do()
// errors out. The error returned will override Do's if it's not nil.
OnResponse func(httpdriver.Request, httpdriver.Response) error
OnResponse []ResponseFunc
// Default to the global Retries variable (5).
Retries uint
}
// ResponseNoop is used for (*Client).OnResponse.
func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
return nil
context context.Context
}
func NewClient() *Client {
@ -44,12 +41,32 @@ func NewClient() *Client {
Driver: json.Default,
SchemaEncoder: &DefaultSchema{},
Retries: Retries,
OnResponse: ResponseNoop,
context: context.Background(),
}
}
// Copy returns a shallow copy of the client.
func (c *Client) Copy() *Client {
cl := new(Client)
*cl = *c
return cl
}
// WithContext returns a client copy of the client with the given context.
func (c *Client) WithContext(ctx context.Context) *Client {
c = c.Copy()
c.context = ctx
return c
}
// Context is a shared context for all future calls. It's Background by
// default.
func (c *Client) Context() context.Context {
return c.context
}
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
for _, opt := range c.DefaultOptions {
for _, opt := range c.OnRequest {
if err := opt(r); err != nil {
return err
}
@ -67,8 +84,8 @@ func (c *Client) MeanwhileMultipart(
writer func(*multipart.Writer) error,
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
// We want to cancel the request if our bodyWriter fails
ctx, cancel := context.WithCancel(context.Background())
// We want to cancel the request if our bodyWriter fails.
ctx, cancel := context.WithCancel(c.context)
defer cancel()
r, w := io.Pipe()
@ -93,7 +110,8 @@ func (c *Client) MeanwhileMultipart(
WithContentType(body.FormDataContentType()),
)
resp, err := c.RequestCtx(ctx, method, url, opts...)
// Request with the current client and our own context:
resp, err := c.WithContext(ctx).Request(method, url, opts...)
if err != nil && bgErr != nil {
return nil, bgErr
}
@ -109,13 +127,10 @@ func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
return r.GetBody().Close()
}
func (c *Client) RequestCtxJSON(
ctx context.Context,
to interface{}, method, url string, opts ...RequestOption) error {
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
opts = PrependOptions(opts, JSONRequest)
r, err := c.RequestCtx(ctx, method, url, opts...)
r, err := c.Request(method, url, opts...)
if err != nil {
return err
}
@ -135,11 +150,8 @@ func (c *Client) RequestCtxJSON(
return nil
}
func (c *Client) RequestCtx(
ctx context.Context,
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
req, err := c.Client.NewRequest(ctx, method, url)
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
req, err := c.Client.NewRequest(c.context, method, url)
if err != nil {
return nil, RequestError{err}
}
@ -165,8 +177,10 @@ func (c *Client) RequestCtx(
}
// Call OnResponse() even if the request failed.
if err := c.OnResponse(req, r); err != nil {
return nil, err
for _, fn := range c.OnResponse {
if err := fn(req, r); err != nil {
return nil, err
}
}
// If all retries failed:
@ -197,11 +211,3 @@ func (c *Client) RequestCtx(
return r, nil
}
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
return c.RequestCtx(context.Background(), method, url, opts...)
}
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
return c.RequestCtxJSON(context.Background(), to, method, url, opts...)
}

View File

@ -6,11 +6,27 @@ import (
)
type JSONError struct {
error
err error
}
func (j JSONError) Error() string {
return "JSON decoding failed: " + j.err.Error()
}
func (j JSONError) Unwrap() error {
return j.err
}
type RequestError struct {
error
err error
}
func (r RequestError) Error() string {
return "Request failed: " + r.err.Error()
}
func (r RequestError) Unwrap() error {
return r.err
}
type HTTPError struct {

View File

@ -9,18 +9,14 @@ import (
)
// DefaultClient implements Client and wraps around the stdlib Client.
type DefaultClient struct {
Client http.Client
}
type DefaultClient http.Client
var _ Client = (*DefaultClient)(nil)
// WrapClient wraps around the standard library's http.Client and returns an
// implementation that's compatible with the Client driver interface.
func WrapClient(client http.Client) Client {
return &DefaultClient{
Client: client,
}
return DefaultClient(client)
}
// NewClient creates a new client around the standard library's http.Client. The
@ -31,7 +27,7 @@ func NewClient() Client {
})
}
func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
func (d DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
r, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
@ -39,11 +35,11 @@ func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Req
return (*DefaultRequest)(r), nil
}
func (d *DefaultClient) Do(req Request) (Response, error) {
func (d DefaultClient) Do(req Request) (Response, error) {
// Implementations can safely assert this.
request := req.(*DefaultRequest)
r, err := d.Client.Do((*http.Request)(request))
r, err := (*http.Client)(&d).Do((*http.Request)(request))
if err != nil {
return nil, err
}

View File

@ -9,6 +9,7 @@ import (
)
type RequestOption func(httpdriver.Request) error
type ResponseFunc func(httpdriver.Request, httpdriver.Response) error
func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption {
if len(opts) == 0 {