API: Added WithContext API, closes #15
This commit is contained in:
parent
892c88d808
commit
a0bccd9c35
60
api/api.go
60
api/api.go
|
@ -3,6 +3,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/api/rate"
|
"github.com/diamondburned/arikawa/api/rate"
|
||||||
|
@ -24,10 +25,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)"
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
*httputil.Client
|
*httputil.Client
|
||||||
Limiter *rate.Limiter
|
Session
|
||||||
|
|
||||||
Token string
|
|
||||||
UserAgent string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(token string) *Client {
|
func NewClient(token string) *Client {
|
||||||
|
@ -35,28 +33,48 @@ func NewClient(token string) *Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
|
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
|
||||||
cli := &Client{
|
ses := Session{
|
||||||
Client: httpClient,
|
|
||||||
Limiter: rate.NewLimiter(APIPath),
|
Limiter: rate.NewLimiter(APIPath),
|
||||||
Token: token,
|
Token: token,
|
||||||
UserAgent: UserAgent,
|
UserAgent: UserAgent,
|
||||||
}
|
}
|
||||||
|
|
||||||
cli.DefaultOptions = []httputil.RequestOption{
|
hcl := httpClient.Copy()
|
||||||
func(r httpdriver.Request) error {
|
hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest)
|
||||||
r.AddHeader(http.Header{
|
hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
|
||||||
"Authorization": {cli.Token},
|
|
||||||
"User-Agent": {cli.UserAgent},
|
|
||||||
"X-RateLimit-Precision": {"millisecond"},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Rate limit stuff
|
return &Client{
|
||||||
return cli.Limiter.Acquire(r.GetContext(), r.GetPath())
|
Client: hcl,
|
||||||
},
|
Session: ses,
|
||||||
}
|
}
|
||||||
cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error {
|
}
|
||||||
return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
|
||||||
}
|
func (c *Client) WithContext(ctx context.Context) *Client {
|
||||||
|
return &Client{
|
||||||
return cli
|
Client: c.Client.WithContext(ctx),
|
||||||
|
Session: c.Session,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session keeps a single session. This is typically wrapped around Client.
|
||||||
|
type Session struct {
|
||||||
|
Limiter *rate.Limiter
|
||||||
|
|
||||||
|
Token string
|
||||||
|
UserAgent string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) InjectRequest(r httpdriver.Request) error {
|
||||||
|
r.AddHeader(http.Header{
|
||||||
|
"Authorization": {s.Token},
|
||||||
|
"User-Agent": {s.UserAgent},
|
||||||
|
"X-RateLimit-Precision": {"millisecond"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Rate limit stuff
|
||||||
|
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
|
||||||
|
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContext(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // lol
|
||||||
|
|
||||||
|
client := NewClient("no. 3-chan").WithContext(ctx)
|
||||||
|
|
||||||
|
// This should fail.
|
||||||
|
_, err := client.Me()
|
||||||
|
if err == nil || !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatal("Unexpected error:", err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,6 +3,7 @@
|
||||||
package state
|
package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/discord"
|
"github.com/diamondburned/arikawa/discord"
|
||||||
|
@ -46,19 +47,7 @@ type State struct {
|
||||||
// List of channels with few messages, so it doesn't bother hitting the API
|
// List of channels with few messages, so it doesn't bother hitting the API
|
||||||
// again.
|
// again.
|
||||||
fewMessages map[discord.Snowflake]struct{}
|
fewMessages map[discord.Snowflake]struct{}
|
||||||
fewMutex sync.Mutex
|
fewMutex *sync.Mutex
|
||||||
}
|
|
||||||
|
|
||||||
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
|
||||||
state := &State{
|
|
||||||
Session: s,
|
|
||||||
Store: store,
|
|
||||||
Handler: handler.New(),
|
|
||||||
StateLog: func(err error) {},
|
|
||||||
fewMessages: map[discord.Snowflake]struct{}{},
|
|
||||||
}
|
|
||||||
|
|
||||||
return state, state.hookSession()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(token string) (*State, error) {
|
func New(token string) (*State, error) {
|
||||||
|
@ -74,9 +63,24 @@ func NewWithStore(token string, store Store) (*State, error) {
|
||||||
return NewFromSession(s, store)
|
return NewFromSession(s, store)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unhook removes all state handlers from the session handlers.
|
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
||||||
func (s *State) Unhook() {
|
state := &State{
|
||||||
s.unhooker()
|
Session: s,
|
||||||
|
Store: store,
|
||||||
|
Handler: handler.New(),
|
||||||
|
StateLog: func(err error) {},
|
||||||
|
fewMessages: map[discord.Snowflake]struct{}{},
|
||||||
|
fewMutex: new(sync.Mutex),
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, state.hookSession()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) WithContext(ctx context.Context) *State {
|
||||||
|
copied := *s
|
||||||
|
copied.Client = copied.Client.WithContext(ctx)
|
||||||
|
|
||||||
|
return &copied
|
||||||
}
|
}
|
||||||
|
|
||||||
//// Helper methods
|
//// Helper methods
|
||||||
|
|
|
@ -22,20 +22,17 @@ type Client struct {
|
||||||
json.Driver
|
json.Driver
|
||||||
SchemaEncoder
|
SchemaEncoder
|
||||||
|
|
||||||
// DefaultOptions, if not nil, will be copied and prefixed on each Request.
|
// OnRequest, if not nil, will be copied and prefixed on each Request.
|
||||||
DefaultOptions []RequestOption
|
OnRequest []RequestOption
|
||||||
|
|
||||||
// OnResponse is called after every Do() call. Response might be nil if Do()
|
// OnResponse is called after every Do() call. Response might be nil if Do()
|
||||||
// errors out. The error returned will override Do's if it's not nil.
|
// errors out. The error returned will override Do's if it's not nil.
|
||||||
OnResponse func(httpdriver.Request, httpdriver.Response) error
|
OnResponse []ResponseFunc
|
||||||
|
|
||||||
// Default to the global Retries variable (5).
|
// Default to the global Retries variable (5).
|
||||||
Retries uint
|
Retries uint
|
||||||
}
|
|
||||||
|
|
||||||
// ResponseNoop is used for (*Client).OnResponse.
|
context context.Context
|
||||||
func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient() *Client {
|
func NewClient() *Client {
|
||||||
|
@ -44,12 +41,32 @@ func NewClient() *Client {
|
||||||
Driver: json.Default,
|
Driver: json.Default,
|
||||||
SchemaEncoder: &DefaultSchema{},
|
SchemaEncoder: &DefaultSchema{},
|
||||||
Retries: Retries,
|
Retries: Retries,
|
||||||
OnResponse: ResponseNoop,
|
context: context.Background(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy returns a shallow copy of the client.
|
||||||
|
func (c *Client) Copy() *Client {
|
||||||
|
cl := new(Client)
|
||||||
|
*cl = *c
|
||||||
|
return cl
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext returns a client copy of the client with the given context.
|
||||||
|
func (c *Client) WithContext(ctx context.Context) *Client {
|
||||||
|
c = c.Copy()
|
||||||
|
c.context = ctx
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context is a shared context for all future calls. It's Background by
|
||||||
|
// default.
|
||||||
|
func (c *Client) Context() context.Context {
|
||||||
|
return c.context
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
|
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
|
||||||
for _, opt := range c.DefaultOptions {
|
for _, opt := range c.OnRequest {
|
||||||
if err := opt(r); err != nil {
|
if err := opt(r); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -67,8 +84,8 @@ func (c *Client) MeanwhileMultipart(
|
||||||
writer func(*multipart.Writer) error,
|
writer func(*multipart.Writer) error,
|
||||||
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||||
|
|
||||||
// We want to cancel the request if our bodyWriter fails
|
// We want to cancel the request if our bodyWriter fails.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(c.context)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
r, w := io.Pipe()
|
r, w := io.Pipe()
|
||||||
|
@ -93,7 +110,8 @@ func (c *Client) MeanwhileMultipart(
|
||||||
WithContentType(body.FormDataContentType()),
|
WithContentType(body.FormDataContentType()),
|
||||||
)
|
)
|
||||||
|
|
||||||
resp, err := c.RequestCtx(ctx, method, url, opts...)
|
// Request with the current client and our own context:
|
||||||
|
resp, err := c.WithContext(ctx).Request(method, url, opts...)
|
||||||
if err != nil && bgErr != nil {
|
if err != nil && bgErr != nil {
|
||||||
return nil, bgErr
|
return nil, bgErr
|
||||||
}
|
}
|
||||||
|
@ -109,13 +127,10 @@ func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
|
||||||
return r.GetBody().Close()
|
return r.GetBody().Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RequestCtxJSON(
|
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
|
||||||
ctx context.Context,
|
|
||||||
to interface{}, method, url string, opts ...RequestOption) error {
|
|
||||||
|
|
||||||
opts = PrependOptions(opts, JSONRequest)
|
opts = PrependOptions(opts, JSONRequest)
|
||||||
|
|
||||||
r, err := c.RequestCtx(ctx, method, url, opts...)
|
r, err := c.Request(method, url, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -135,11 +150,8 @@ func (c *Client) RequestCtxJSON(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RequestCtx(
|
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||||
ctx context.Context,
|
req, err := c.Client.NewRequest(c.context, method, url)
|
||||||
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
|
||||||
|
|
||||||
req, err := c.Client.NewRequest(ctx, method, url)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RequestError{err}
|
return nil, RequestError{err}
|
||||||
}
|
}
|
||||||
|
@ -165,8 +177,10 @@ func (c *Client) RequestCtx(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call OnResponse() even if the request failed.
|
// Call OnResponse() even if the request failed.
|
||||||
if err := c.OnResponse(req, r); err != nil {
|
for _, fn := range c.OnResponse {
|
||||||
return nil, err
|
if err := fn(req, r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If all retries failed:
|
// If all retries failed:
|
||||||
|
@ -197,11 +211,3 @@ func (c *Client) RequestCtx(
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
|
||||||
return c.RequestCtx(context.Background(), method, url, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
|
|
||||||
return c.RequestCtxJSON(context.Background(), to, method, url, opts...)
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,11 +6,27 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type JSONError struct {
|
type JSONError struct {
|
||||||
error
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j JSONError) Error() string {
|
||||||
|
return "JSON decoding failed: " + j.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j JSONError) Unwrap() error {
|
||||||
|
return j.err
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestError struct {
|
type RequestError struct {
|
||||||
error
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RequestError) Error() string {
|
||||||
|
return "Request failed: " + r.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RequestError) Unwrap() error {
|
||||||
|
return r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
type HTTPError struct {
|
type HTTPError struct {
|
||||||
|
|
|
@ -9,18 +9,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultClient implements Client and wraps around the stdlib Client.
|
// DefaultClient implements Client and wraps around the stdlib Client.
|
||||||
type DefaultClient struct {
|
type DefaultClient http.Client
|
||||||
Client http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Client = (*DefaultClient)(nil)
|
var _ Client = (*DefaultClient)(nil)
|
||||||
|
|
||||||
// WrapClient wraps around the standard library's http.Client and returns an
|
// WrapClient wraps around the standard library's http.Client and returns an
|
||||||
// implementation that's compatible with the Client driver interface.
|
// implementation that's compatible with the Client driver interface.
|
||||||
func WrapClient(client http.Client) Client {
|
func WrapClient(client http.Client) Client {
|
||||||
return &DefaultClient{
|
return DefaultClient(client)
|
||||||
Client: client,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new client around the standard library's http.Client. The
|
// NewClient creates a new client around the standard library's http.Client. The
|
||||||
|
@ -31,7 +27,7 @@ func NewClient() Client {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
|
func (d DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
|
||||||
r, err := http.NewRequestWithContext(ctx, method, url, nil)
|
r, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -39,11 +35,11 @@ func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Req
|
||||||
return (*DefaultRequest)(r), nil
|
return (*DefaultRequest)(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultClient) Do(req Request) (Response, error) {
|
func (d DefaultClient) Do(req Request) (Response, error) {
|
||||||
// Implementations can safely assert this.
|
// Implementations can safely assert this.
|
||||||
request := req.(*DefaultRequest)
|
request := req.(*DefaultRequest)
|
||||||
|
|
||||||
r, err := d.Client.Do((*http.Request)(request))
|
r, err := (*http.Client)(&d).Do((*http.Request)(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestOption func(httpdriver.Request) error
|
type RequestOption func(httpdriver.Request) error
|
||||||
|
type ResponseFunc func(httpdriver.Request, httpdriver.Response) error
|
||||||
|
|
||||||
func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption {
|
func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption {
|
||||||
if len(opts) == 0 {
|
if len(opts) == 0 {
|
||||||
|
|
Loading…
Reference in New Issue