mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-03-25 11:29:22 +00:00
API: Added WithContext API, closes #15
This commit is contained in:
parent
892c88d808
commit
a0bccd9c35
60
api/api.go
60
api/api.go
|
@ -3,6 +3,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/diamondburned/arikawa/api/rate"
|
||||
|
@ -24,10 +25,7 @@ var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)"
|
|||
|
||||
type Client struct {
|
||||
*httputil.Client
|
||||
Limiter *rate.Limiter
|
||||
|
||||
Token string
|
||||
UserAgent string
|
||||
Session
|
||||
}
|
||||
|
||||
func NewClient(token string) *Client {
|
||||
|
@ -35,28 +33,48 @@ func NewClient(token string) *Client {
|
|||
}
|
||||
|
||||
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
|
||||
cli := &Client{
|
||||
Client: httpClient,
|
||||
ses := Session{
|
||||
Limiter: rate.NewLimiter(APIPath),
|
||||
Token: token,
|
||||
UserAgent: UserAgent,
|
||||
}
|
||||
|
||||
cli.DefaultOptions = []httputil.RequestOption{
|
||||
func(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Authorization": {cli.Token},
|
||||
"User-Agent": {cli.UserAgent},
|
||||
"X-RateLimit-Precision": {"millisecond"},
|
||||
})
|
||||
hcl := httpClient.Copy()
|
||||
hcl.OnRequest = append(hcl.OnRequest, ses.InjectRequest)
|
||||
hcl.OnResponse = append(hcl.OnResponse, ses.OnResponse)
|
||||
|
||||
// Rate limit stuff
|
||||
return cli.Limiter.Acquire(r.GetContext(), r.GetPath())
|
||||
},
|
||||
return &Client{
|
||||
Client: hcl,
|
||||
Session: ses,
|
||||
}
|
||||
cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error {
|
||||
return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
||||
}
|
||||
|
||||
return cli
|
||||
}
|
||||
|
||||
func (c *Client) WithContext(ctx context.Context) *Client {
|
||||
return &Client{
|
||||
Client: c.Client.WithContext(ctx),
|
||||
Session: c.Session,
|
||||
}
|
||||
}
|
||||
|
||||
// Session keeps a single session. This is typically wrapped around Client.
|
||||
type Session struct {
|
||||
Limiter *rate.Limiter
|
||||
|
||||
Token string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
func (s *Session) InjectRequest(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Authorization": {s.Token},
|
||||
"User-Agent": {s.UserAgent},
|
||||
"X-RateLimit-Precision": {"millisecond"},
|
||||
})
|
||||
|
||||
// Rate limit stuff
|
||||
return s.Limiter.Acquire(r.GetContext(), r.GetPath())
|
||||
}
|
||||
|
||||
func (s *Session) OnResponse(r httpdriver.Request, resp httpdriver.Response) error {
|
||||
return s.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
||||
}
|
||||
|
|
20
api/api_test.go
Normal file
20
api/api_test.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // lol
|
||||
|
||||
client := NewClient("no. 3-chan").WithContext(ctx)
|
||||
|
||||
// This should fail.
|
||||
_, err := client.Me()
|
||||
if err == nil || !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("Unexpected error:", err)
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
|
@ -46,19 +47,7 @@ type State struct {
|
|||
// List of channels with few messages, so it doesn't bother hitting the API
|
||||
// again.
|
||||
fewMessages map[discord.Snowflake]struct{}
|
||||
fewMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
||||
state := &State{
|
||||
Session: s,
|
||||
Store: store,
|
||||
Handler: handler.New(),
|
||||
StateLog: func(err error) {},
|
||||
fewMessages: map[discord.Snowflake]struct{}{},
|
||||
}
|
||||
|
||||
return state, state.hookSession()
|
||||
fewMutex *sync.Mutex
|
||||
}
|
||||
|
||||
func New(token string) (*State, error) {
|
||||
|
@ -74,9 +63,24 @@ func NewWithStore(token string, store Store) (*State, error) {
|
|||
return NewFromSession(s, store)
|
||||
}
|
||||
|
||||
// Unhook removes all state handlers from the session handlers.
|
||||
func (s *State) Unhook() {
|
||||
s.unhooker()
|
||||
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
||||
state := &State{
|
||||
Session: s,
|
||||
Store: store,
|
||||
Handler: handler.New(),
|
||||
StateLog: func(err error) {},
|
||||
fewMessages: map[discord.Snowflake]struct{}{},
|
||||
fewMutex: new(sync.Mutex),
|
||||
}
|
||||
|
||||
return state, state.hookSession()
|
||||
}
|
||||
|
||||
func (s *State) WithContext(ctx context.Context) *State {
|
||||
copied := *s
|
||||
copied.Client = copied.Client.WithContext(ctx)
|
||||
|
||||
return &copied
|
||||
}
|
||||
|
||||
//// Helper methods
|
||||
|
|
|
@ -22,20 +22,17 @@ type Client struct {
|
|||
json.Driver
|
||||
SchemaEncoder
|
||||
|
||||
// DefaultOptions, if not nil, will be copied and prefixed on each Request.
|
||||
DefaultOptions []RequestOption
|
||||
// OnRequest, if not nil, will be copied and prefixed on each Request.
|
||||
OnRequest []RequestOption
|
||||
|
||||
// OnResponse is called after every Do() call. Response might be nil if Do()
|
||||
// errors out. The error returned will override Do's if it's not nil.
|
||||
OnResponse func(httpdriver.Request, httpdriver.Response) error
|
||||
OnResponse []ResponseFunc
|
||||
|
||||
// Default to the global Retries variable (5).
|
||||
Retries uint
|
||||
}
|
||||
|
||||
// ResponseNoop is used for (*Client).OnResponse.
|
||||
func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
|
||||
return nil
|
||||
context context.Context
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
|
@ -44,12 +41,32 @@ func NewClient() *Client {
|
|||
Driver: json.Default,
|
||||
SchemaEncoder: &DefaultSchema{},
|
||||
Retries: Retries,
|
||||
OnResponse: ResponseNoop,
|
||||
context: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a shallow copy of the client.
|
||||
func (c *Client) Copy() *Client {
|
||||
cl := new(Client)
|
||||
*cl = *c
|
||||
return cl
|
||||
}
|
||||
|
||||
// WithContext returns a client copy of the client with the given context.
|
||||
func (c *Client) WithContext(ctx context.Context) *Client {
|
||||
c = c.Copy()
|
||||
c.context = ctx
|
||||
return c
|
||||
}
|
||||
|
||||
// Context is a shared context for all future calls. It's Background by
|
||||
// default.
|
||||
func (c *Client) Context() context.Context {
|
||||
return c.context
|
||||
}
|
||||
|
||||
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
|
||||
for _, opt := range c.DefaultOptions {
|
||||
for _, opt := range c.OnRequest {
|
||||
if err := opt(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -67,8 +84,8 @@ func (c *Client) MeanwhileMultipart(
|
|||
writer func(*multipart.Writer) error,
|
||||
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
|
||||
// We want to cancel the request if our bodyWriter fails
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// We want to cancel the request if our bodyWriter fails.
|
||||
ctx, cancel := context.WithCancel(c.context)
|
||||
defer cancel()
|
||||
|
||||
r, w := io.Pipe()
|
||||
|
@ -93,7 +110,8 @@ func (c *Client) MeanwhileMultipart(
|
|||
WithContentType(body.FormDataContentType()),
|
||||
)
|
||||
|
||||
resp, err := c.RequestCtx(ctx, method, url, opts...)
|
||||
// Request with the current client and our own context:
|
||||
resp, err := c.WithContext(ctx).Request(method, url, opts...)
|
||||
if err != nil && bgErr != nil {
|
||||
return nil, bgErr
|
||||
}
|
||||
|
@ -109,13 +127,10 @@ func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
|
|||
return r.GetBody().Close()
|
||||
}
|
||||
|
||||
func (c *Client) RequestCtxJSON(
|
||||
ctx context.Context,
|
||||
to interface{}, method, url string, opts ...RequestOption) error {
|
||||
|
||||
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
|
||||
opts = PrependOptions(opts, JSONRequest)
|
||||
|
||||
r, err := c.RequestCtx(ctx, method, url, opts...)
|
||||
r, err := c.Request(method, url, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -135,11 +150,8 @@ func (c *Client) RequestCtxJSON(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) RequestCtx(
|
||||
ctx context.Context,
|
||||
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
|
||||
req, err := c.Client.NewRequest(ctx, method, url)
|
||||
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
req, err := c.Client.NewRequest(c.context, method, url)
|
||||
if err != nil {
|
||||
return nil, RequestError{err}
|
||||
}
|
||||
|
@ -165,8 +177,10 @@ func (c *Client) RequestCtx(
|
|||
}
|
||||
|
||||
// Call OnResponse() even if the request failed.
|
||||
if err := c.OnResponse(req, r); err != nil {
|
||||
return nil, err
|
||||
for _, fn := range c.OnResponse {
|
||||
if err := fn(req, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If all retries failed:
|
||||
|
@ -197,11 +211,3 @@ func (c *Client) RequestCtx(
|
|||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
return c.RequestCtx(context.Background(), method, url, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
|
||||
return c.RequestCtxJSON(context.Background(), to, method, url, opts...)
|
||||
}
|
||||
|
|
|
@ -6,11 +6,27 @@ import (
|
|||
)
|
||||
|
||||
type JSONError struct {
|
||||
error
|
||||
err error
|
||||
}
|
||||
|
||||
func (j JSONError) Error() string {
|
||||
return "JSON decoding failed: " + j.err.Error()
|
||||
}
|
||||
|
||||
func (j JSONError) Unwrap() error {
|
||||
return j.err
|
||||
}
|
||||
|
||||
type RequestError struct {
|
||||
error
|
||||
err error
|
||||
}
|
||||
|
||||
func (r RequestError) Error() string {
|
||||
return "Request failed: " + r.err.Error()
|
||||
}
|
||||
|
||||
func (r RequestError) Unwrap() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
type HTTPError struct {
|
||||
|
|
|
@ -9,18 +9,14 @@ import (
|
|||
)
|
||||
|
||||
// DefaultClient implements Client and wraps around the stdlib Client.
|
||||
type DefaultClient struct {
|
||||
Client http.Client
|
||||
}
|
||||
type DefaultClient http.Client
|
||||
|
||||
var _ Client = (*DefaultClient)(nil)
|
||||
|
||||
// WrapClient wraps around the standard library's http.Client and returns an
|
||||
// implementation that's compatible with the Client driver interface.
|
||||
func WrapClient(client http.Client) Client {
|
||||
return &DefaultClient{
|
||||
Client: client,
|
||||
}
|
||||
return DefaultClient(client)
|
||||
}
|
||||
|
||||
// NewClient creates a new client around the standard library's http.Client. The
|
||||
|
@ -31,7 +27,7 @@ func NewClient() Client {
|
|||
})
|
||||
}
|
||||
|
||||
func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
|
||||
func (d DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
|
||||
r, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -39,11 +35,11 @@ func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Req
|
|||
return (*DefaultRequest)(r), nil
|
||||
}
|
||||
|
||||
func (d *DefaultClient) Do(req Request) (Response, error) {
|
||||
func (d DefaultClient) Do(req Request) (Response, error) {
|
||||
// Implementations can safely assert this.
|
||||
request := req.(*DefaultRequest)
|
||||
|
||||
r, err := d.Client.Do((*http.Request)(request))
|
||||
r, err := (*http.Client)(&d).Do((*http.Request)(request))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
)
|
||||
|
||||
type RequestOption func(httpdriver.Request) error
|
||||
type ResponseFunc func(httpdriver.Request, httpdriver.Response) error
|
||||
|
||||
func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption {
|
||||
if len(opts) == 0 {
|
||||
|
|
Loading…
Reference in a new issue