From 2afe683b7d2f69ee53ae024956225c42828ae5e1 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Sun, 19 Apr 2020 14:53:53 -0700 Subject: [PATCH] Utils: Added HTTP drivers --- api/api.go | 46 +++---- api/guild.go | 80 ++++-------- api/message_send.go | 12 +- api/webhook.go | 42 +++---- gateway/gateway.go | 14 +-- session/session.go | 2 +- utils/httputil/client.go | 180 ++++++++++++++++----------- utils/httputil/http.go | 36 ------ utils/httputil/httpdriver/default.go | 103 +++++++++++++++ utils/httputil/httpdriver/driver.go | 55 ++++++++ utils/httputil/options.go | 71 ++++++----- 11 files changed, 380 insertions(+), 261 deletions(-) delete mode 100644 utils/httputil/http.go create mode 100644 utils/httputil/httpdriver/default.go create mode 100644 utils/httputil/httpdriver/driver.go diff --git a/api/api.go b/api/api.go index 5452846..f7245b3 100644 --- a/api/api.go +++ b/api/api.go @@ -7,6 +7,7 @@ import ( "github.com/diamondburned/arikawa/api/rate" "github.com/diamondburned/arikawa/utils/httputil" + "github.com/diamondburned/arikawa/utils/httputil/httpdriver" ) var ( @@ -22,39 +23,40 @@ var ( var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)" type Client struct { - httputil.Client + *httputil.Client Limiter *rate.Limiter - Token string + Token string + UserAgent string } func NewClient(token string) *Client { + return NewCustomClient(token, httputil.NewClient()) +} + +func NewCustomClient(token string, httpClient *httputil.Client) *Client { cli := &Client{ - Client: httputil.DefaultClient, - Limiter: rate.NewLimiter(APIPath), - Token: token, + Client: httpClient, + Limiter: rate.NewLimiter(APIPath), + Token: token, + UserAgent: UserAgent, } - tw := httputil.NewTransportWrapper() - tw.Pre = func(r *http.Request) error { - if cli.Token != "" { - r.Header.Set("Authorization", cli.Token) - } + cli.DefaultOptions = []httputil.RequestOption{ + func(r httpdriver.Request) error { + r.AddHeader(http.Header{ + "Authorization": {cli.Token}, + "User-Agent": {cli.UserAgent}, + "X-RateLimit-Precision": {"millisecond"}, + }) - r.Header.Set("User-Agent", UserAgent) - r.Header.Set("X-RateLimit-Precision", "millisecond") - - // Rate limit stuff - return cli.Limiter.Acquire(r.Context(), r.URL.Path) + // Rate limit stuff + return cli.Limiter.Acquire(r.GetContext(), r.GetPath()) + }, } - tw.Post = func(r *http.Request, resp *http.Response) error { - if resp == nil { - return cli.Limiter.Release(r.URL.Path, nil) - } - return cli.Limiter.Release(r.URL.Path, resp.Header) + cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error { + return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp)) } - cli.Client.Transport = tw - return cli } diff --git a/api/guild.go b/api/guild.go index bc91788..fa4546e 100644 --- a/api/guild.go +++ b/api/guild.go @@ -32,8 +32,7 @@ type CreateGuildData struct { func (c *Client) CreateGuild(data CreateGuildData) (*discord.Guild, error) { var g *discord.Guild - return g, c.RequestJSON(&g, "POST", Endpoint+"guilds", - httputil.WithJSONBody(c, data)) + return g, c.RequestJSON(&g, "POST", Endpoint+"guilds", httputil.WithJSONBody(c, data)) } func (c *Client) Guild(guildID discord.Snowflake) (*discord.Guild, error) { @@ -76,23 +75,17 @@ func (c *Client) Guilds(max uint) ([]discord.Guild, error) { } // GuildsBefore fetches guilds. Check GuildsRange. -func (c *Client) GuildsBefore( - before discord.Snowflake, limit uint) ([]discord.Guild, error) { - +func (c *Client) GuildsBefore(before discord.Snowflake, limit uint) ([]discord.Guild, error) { return c.GuildsRange(before, 0, limit) } // GuildsAfter fetches guilds. Check GuildsRange. -func (c *Client) GuildsAfter( - after discord.Snowflake, limit uint) ([]discord.Guild, error) { - +func (c *Client) GuildsAfter(after discord.Snowflake, limit uint) ([]discord.Guild, error) { return c.GuildsRange(0, after, limit) } // GuildsRange fetches guilds. The limit is 1-100. -func (c *Client) GuildsRange( - before, after discord.Snowflake, limit uint) ([]discord.Guild, error) { - +func (c *Client) GuildsRange(before, after discord.Snowflake, limit uint) ([]discord.Guild, error) { if limit == 0 { limit = 100 } @@ -163,21 +156,15 @@ func (c *Client) DeleteGuild(guildID discord.Snowflake) error { // GuildVoiceRegions is the same as /voice, but returns VIP ones as well if // available. -func (c *Client) VoiceRegionsGuild( - guildID discord.Snowflake) ([]discord.VoiceRegion, error) { - +func (c *Client) VoiceRegionsGuild(guildID discord.Snowflake) ([]discord.VoiceRegion, error) { var vrs []discord.VoiceRegion - return vrs, c.RequestJSON(&vrs, "GET", - EndpointGuilds+guildID.String()+"/regions") + return vrs, c.RequestJSON(&vrs, "GET", EndpointGuilds+guildID.String()+"/regions") } // Integrations requires MANAGE_GUILD. -func (c *Client) Integrations( - guildID discord.Snowflake) ([]discord.Integration, error) { - +func (c *Client) Integrations(guildID discord.Snowflake) ([]discord.Integration, error) { var ints []discord.Integration - return ints, c.RequestJSON(&ints, "GET", - EndpointGuilds+guildID.String()+"/integrations") + return ints, c.RequestJSON(&ints, "GET", EndpointGuilds+guildID.String()+"/integrations") } // AttachIntegration requires MANAGE_GUILD. @@ -214,46 +201,36 @@ func (c *Client) ModifyIntegration( return c.FastRequest( "PATCH", - EndpointGuilds+guildID.String()+ - "/integrations/"+integrationID.String(), + EndpointGuilds+guildID.String()+"/integrations/"+integrationID.String(), httputil.WithSchema(c, param), ) } -func (c *Client) SyncIntegration( - guildID, integrationID discord.Snowflake) error { - +func (c *Client) SyncIntegration(guildID, integrationID discord.Snowflake) error { return c.FastRequest("POST", EndpointGuilds+guildID.String()+ "/integrations/"+integrationID.String()+"/sync") } -func (c *Client) GuildEmbed( - guildID discord.Snowflake) (*discord.GuildEmbed, error) { - +func (c *Client) GuildEmbed(guildID discord.Snowflake) (*discord.GuildEmbed, error) { var ge *discord.GuildEmbed - return ge, c.RequestJSON(&ge, "GET", - EndpointGuilds+guildID.String()+"/embed") + return ge, c.RequestJSON(&ge, "GET", EndpointGuilds+guildID.String()+"/embed") } -// ModifyGuildEmbed should be used with care: if you still want the embed -// enabled, you need to set the Enabled boolean, even if it's already enabled. -// If you don't, JSON will default it to false. -func (c *Client) ModifyGuildEmbed( - guildID discord.Snowflake, - data discord.GuildEmbed) (*discord.GuildEmbed, error) { - - return &data, c.RequestJSON(&data, "PATCH", - EndpointGuilds+guildID.String()+"/embed") +// ModifyGuildEmbed modifies the guild embed and updates the passed in +// GuildEmbed data. +// +// This method should be used with care: if you still want the embed enabled, +// you need to set the Enabled boolean, even if it's already enabled. If you +// don't, JSON will default it to false. +func (c *Client) ModifyGuildEmbed(guildID discord.Snowflake, data *discord.GuildEmbed) error { + return c.RequestJSON(&data, "PATCH", EndpointGuilds+guildID.String()+"/embed") } // GuildVanityURL returns *Invite, but only Code and Uses are filled. Requires // MANAGE_GUILD. -func (c *Client) GuildVanityURL( - guildID discord.Snowflake) (*discord.Invite, error) { - +func (c *Client) GuildVanityURL(guildID discord.Snowflake) (*discord.Invite, error) { var inv *discord.Invite - return inv, c.RequestJSON(&inv, "GET", - EndpointGuilds+guildID.String()+"/vanity-url") + return inv, c.RequestJSON(&inv, "GET", EndpointGuilds+guildID.String()+"/vanity-url") } type GuildImageType string @@ -266,20 +243,15 @@ const ( GuildBanner4 GuildImageType = "banner4" ) -func (c *Client) GuildImageURL( - guildID discord.Snowflake, img GuildImageType) string { - - return EndpointGuilds + guildID.String() + - "/widget.png?style=" + string(img) +func (c *Client) GuildImageURL(guildID discord.Snowflake, img GuildImageType) string { + return EndpointGuilds + guildID.String() + "/widget.png?style=" + string(img) } -func (c *Client) GuildImage( - guildID discord.Snowflake, img GuildImageType) (io.ReadCloser, error) { - +func (c *Client) GuildImage(guildID discord.Snowflake, img GuildImageType) (io.ReadCloser, error) { r, err := c.Request("GET", c.GuildImageURL(guildID, img)) if err != nil { return nil, err } - return r.Body, nil + return r.GetBody(), nil } diff --git a/api/message_send.go b/api/message_send.go index a532dbe..8739ecb 100644 --- a/api/message_send.go +++ b/api/message_send.go @@ -27,8 +27,7 @@ func (c *Client) SendMessageComplex( if len(data.Files) == 0 { // No files, so no need for streaming. - return msg, c.RequestJSON(&msg, "POST", URL, - httputil.WithJSONBody(c, data)) + return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(c, data)) } writer := func(mw *multipart.Writer) error { @@ -40,9 +39,10 @@ func (c *Client) SendMessageComplex( return nil, err } - defer resp.Body.Close() + var body = resp.GetBody() + defer body.Close() - return msg, c.DecodeStream(resp.Body, &msg) + return msg, c.DecodeStream(body, &msg) } const AttachmentSpoilerPrefix = "SPOILER_" @@ -83,9 +83,7 @@ type ExecuteWebhookData struct { AvatarURL discord.URL `json:"avatar_url,omitempty"` } -func (data *ExecuteWebhookData) WriteMultipart( - c json.Driver, body *multipart.Writer) error { - +func (data *ExecuteWebhookData) WriteMultipart(c json.Driver, body *multipart.Writer) error { return writeMultipart(c, body, data, data.Files) } diff --git a/api/webhook.go b/api/webhook.go index dd62157..f6b344d 100644 --- a/api/webhook.go +++ b/api/webhook.go @@ -35,28 +35,21 @@ func (c *Client) CreateWebhook( } // Webhooks requires MANAGE_WEBHOOKS. -func (c *Client) Webhooks( - guildID discord.Snowflake) ([]discord.Webhook, error) { - +func (c *Client) Webhooks(guildID discord.Snowflake) ([]discord.Webhook, error) { var ws []discord.Webhook - return ws, c.RequestJSON(&ws, "GET", - EndpointGuilds+guildID.String()+"/webhooks") + return ws, c.RequestJSON(&ws, "GET", EndpointGuilds+guildID.String()+"/webhooks") } -func (c *Client) Webhook( - webhookID discord.Snowflake) (*discord.Webhook, error) { - +func (c *Client) Webhook(webhookID discord.Snowflake) (*discord.Webhook, error) { var w *discord.Webhook - return w, c.RequestJSON(&w, "GET", - EndpointWebhooks+webhookID.String()) + return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String()) } func (c *Client) WebhookWithToken( webhookID discord.Snowflake, token string) (*discord.Webhook, error) { var w *discord.Webhook - return w, c.RequestJSON(&w, "GET", - EndpointWebhooks+webhookID.String()+"/"+token) + return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String()+"/"+token) } type ModifyWebhookData struct { @@ -70,8 +63,7 @@ func (c *Client) ModifyWebhook( data ModifyWebhookData) (*discord.Webhook, error) { var w *discord.Webhook - return w, c.RequestJSON(&w, "PATCH", - EndpointWebhooks+webhookID.String()) + return w, c.RequestJSON(&w, "PATCH", EndpointWebhooks+webhookID.String()) } func (c *Client) ModifyWebhookWithToken( @@ -79,26 +71,24 @@ func (c *Client) ModifyWebhookWithToken( data ModifyWebhookData, token string) (*discord.Webhook, error) { var w *discord.Webhook - return w, c.RequestJSON(&w, "PATCH", - EndpointWebhooks+webhookID.String()+"/"+token) + return w, c.RequestJSON(&w, "PATCH", EndpointWebhooks+webhookID.String()+"/"+token) } func (c *Client) DeleteWebhook(webhookID discord.Snowflake) error { return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String()) } -func (c *Client) DeleteWebhookWithToken( - webhookID discord.Snowflake, token string) error { - - return c.FastRequest("DELETE", - EndpointWebhooks+webhookID.String()+"/"+token) +func (c *Client) DeleteWebhookWithToken(webhookID discord.Snowflake, token string) error { + return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String()+"/"+token) } // ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will // wait for the message to be delivered and will return the message body. This // also means the returned message will only be there if wait is true. func (c *Client) ExecuteWebhook( - webhookID discord.Snowflake, token string, wait bool, + webhookID discord.Snowflake, + token string, + wait bool, data ExecuteWebhookData) (*discord.Message, error) { for i, embed := range data.Embeds { @@ -112,8 +102,7 @@ func (c *Client) ExecuteWebhook( param.Set("wait", "true") } - var URL = EndpointWebhooks + webhookID.String() + "/" + token + - "?" + param.Encode() + var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode() var msg *discord.Message if len(data.Files) == 0 { @@ -131,12 +120,13 @@ func (c *Client) ExecuteWebhook( return nil, err } - defer resp.Body.Close() + var body = resp.GetBody() + defer body.Close() if !wait { // Since we didn't tell Discord to wait, we have nothing to parse. return nil, nil } - return msg, c.DecodeStream(resp.Body, &msg) + return msg, c.DecodeStream(body, &msg) } diff --git a/gateway/gateway.go b/gateway/gateway.go index 3d7ad25..b12cbbb 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -70,22 +70,22 @@ type SessionStartLimit struct { ResetAfter discord.Milliseconds `json:"reset_after"` } -// GatewayURL asks Discord for a Websocket URL to the Gateway. -func GatewayURL() (string, error) { +// URL asks Discord for a Websocket URL to the Gateway. +func URL() (string, error) { var g GatewayBotData - return g.URL, httputil.DefaultClient.RequestJSON( + return g.URL, httputil.NewClient().RequestJSON( &g, "GET", EndpointGateway, ) } -// GatewayBot fetches the Gateway URL along with extra metadata. The token +// BotURL fetches the Gateway URL along with extra metadata. The token // passed in will NOT be prefixed with Bot. -func GatewayBot(token string) (*GatewayBotData, error) { +func BotURL(token string) (*GatewayBotData, error) { var g *GatewayBotData - return g, httputil.DefaultClient.RequestJSON( + return g, httputil.NewClient().RequestJSON( &g, "GET", EndpointGatewayBot, httputil.WithHeaders(http.Header{ @@ -144,7 +144,7 @@ func NewGateway(token string) (*Gateway, error) { // NewGatewayWithDriver connects to the Gateway and authenticates automatically. func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { - URL, err := GatewayURL() + URL, err := URL() if err != nil { return nil, errors.Wrap(err, "Failed to get gateway endpoint") } diff --git a/session/session.go b/session/session.go index e4e251d..65fd284 100644 --- a/session/session.go +++ b/session/session.go @@ -46,7 +46,7 @@ func New(token string) (*Session, error) { s.Handler = handler.New() s.Client = api.NewClient(token) - // Open a gateway + // Create a gateway g, err := gateway.NewGateway(token) if err != nil { return nil, errors.Wrap(err, "Failed to connect to Gateway") diff --git a/utils/httputil/client.go b/utils/httputil/client.go index 73653bd..002b248 100644 --- a/utils/httputil/client.go +++ b/utils/httputil/client.go @@ -3,44 +3,69 @@ package httputil import ( + "bytes" "context" "io" - "io/ioutil" "mime/multipart" - "net/http" - "time" + "github.com/diamondburned/arikawa/utils/httputil/httpdriver" "github.com/diamondburned/arikawa/utils/json" + "github.com/pkg/errors" ) // Retries is the default attempts to retry if the API returns an error before -// giving up. +// giving up. If the value is smaller than 1, then requests will retry forever. var Retries uint = 5 type Client struct { - http.Client + httpdriver.Client json.Driver SchemaEncoder + // DefaultOptions, if not nil, will be copied and prefixed on each Request. + DefaultOptions []RequestOption + + // OnResponse is called after every Do() call. Response might be nil if Do() + // errors out. The error returned will override Do's if it's not nil. + OnResponse func(httpdriver.Request, httpdriver.Response) error + + // Default to the global Retries variable (5). Retries uint } -var DefaultClient = NewClient() +// ResponseNoop is used for (*Client).OnResponse. +func ResponseNoop(httpdriver.Request, httpdriver.Response) error { + return nil +} -func NewClient() Client { - return Client{ - Client: http.Client{ - Timeout: 10 * time.Second, - }, +func NewClient() *Client { + return &Client{ + Client: httpdriver.NewClient(), Driver: json.Default{}, SchemaEncoder: &DefaultSchema{}, Retries: Retries, + OnResponse: ResponseNoop, } } +func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error { + for _, opt := range c.DefaultOptions { + if err := opt(r); err != nil { + return err + } + } + for _, opt := range extra { + if err := opt(r); err != nil { + return err + } + } + + return nil +} + func (c *Client) MeanwhileMultipart( - multipartWriter func(*multipart.Writer) error, - method, url string, opts ...RequestOption) (*http.Response, error) { + writer func(*multipart.Writer) error, + method, url string, opts ...RequestOption) (httpdriver.Response, error) { // We want to cancel the request if our bodyWriter fails ctx, cancel := context.WithCancel(context.Background()) @@ -52,7 +77,7 @@ func (c *Client) MeanwhileMultipart( var bgErr error go func() { - if err := multipartWriter(body); err != nil { + if err := writer(body); err != nil { bgErr = err cancel() } @@ -61,119 +86,122 @@ func (c *Client) MeanwhileMultipart( w.Close() }() - resp, err := c.RequestCtx(ctx, method, url, - append([]RequestOption{ - WithBody(r), - WithContentType(body.FormDataContentType()), - }, opts...)...) + // Prepend the multipart writer and the correct Content-Type header options. + opts = PrependOptions( + opts, + WithBody(r), + WithContentType(body.FormDataContentType()), + ) + resp, err := c.RequestCtx(ctx, method, url, opts...) if err != nil && bgErr != nil { - if resp.Body != nil { - resp.Body.Close() - } - return nil, bgErr } - return resp, err } -func (c *Client) FastRequest( - method, url string, opts ...RequestOption) error { - +func (c *Client) FastRequest(method, url string, opts ...RequestOption) error { r, err := c.Request(method, url, opts...) if err != nil { return err } - return r.Body.Close() + return r.GetBody().Close() } -func (c *Client) RequestCtx(ctx context.Context, - method, url string, opts ...RequestOption) (*http.Response, error) { +func (c *Client) RequestCtxJSON( + ctx context.Context, + to interface{}, method, url string, opts ...RequestOption) error { - req, err := http.NewRequestWithContext(ctx, method, url, nil) + opts = PrependOptions(opts, JSONRequest) + + r, err := c.RequestCtx(ctx, method, url, opts...) + if err != nil { + return err + } + + var body, status = r.GetBody(), r.GetStatus() + defer body.Close() + + // No content, working as intended (tm) + if status == httpdriver.NoContent { + return nil + } + + if err := c.DecodeStream(body, to); err != nil { + return JSONError{err} + } + + return nil +} + +func (c *Client) RequestCtx( + ctx context.Context, + method, url string, opts ...RequestOption) (httpdriver.Response, error) { + + req, err := c.Client.NewRequest(ctx, method, url) if err != nil { return nil, RequestError{err} } - for _, opt := range opts { - if err := opt(req); err != nil { - return nil, err - } + if err := c.applyOptions(req, opts); err != nil { + return nil, errors.Wrap(err, "Failed to apply options") } - var r *http.Response + var r httpdriver.Response + var status int - for i := uint(0); i < c.Retries; i++ { + for i := uint(0); c.Retries < 1 || i < c.Retries; i++ { r, err = c.Client.Do(req) if err != nil { continue } - if r.StatusCode < 200 || r.StatusCode > 299 { + if status = r.GetStatus(); status < 200 || status > 299 { continue } break } + // Call OnResponse() even if the request failed. + if err := c.OnResponse(req, r); err != nil { + return nil, err + } + // If all retries failed: if err != nil { return nil, RequestError{err} } // Response received, but with a failure status code: - if r.StatusCode < 200 || r.StatusCode > 299 { + if status < 200 || status > 299 { + // Try and parse the body. + var body = r.GetBody() + defer body.Close() + + // This rarely happens, so we can (probably) make an exception for it. + buf := bytes.Buffer{} + buf.ReadFrom(body) + httpErr := &HTTPError{ - Status: r.StatusCode, + Status: status, + Body: buf.Bytes(), } - b, err := ioutil.ReadAll(r.Body) - if err != nil { - return nil, httpErr - } + // Optionally unmarshal the error. + c.Unmarshal(httpErr.Body, &httpErr) - httpErr.Body = b - - c.Unmarshal(b, &httpErr) return nil, httpErr } return r, nil } -func (c *Client) RequestCtxJSON(ctx context.Context, - to interface{}, method, url string, opts ...RequestOption) error { - - r, err := c.RequestCtx(ctx, method, url, - append([]RequestOption{JSONRequest}, opts...)...) - if err != nil { - return err - } - - defer r.Body.Close() - - // No content, working as intended (tm) - if r.StatusCode == http.StatusNoContent { - return nil - } - - if err := c.DecodeStream(r.Body, to); err != nil { - return JSONError{err} - } - - return nil -} - -func (c *Client) Request( - method, url string, opts ...RequestOption) (*http.Response, error) { - +func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) { return c.RequestCtx(context.Background(), method, url, opts...) } -func (c *Client) RequestJSON( - to interface{}, method, url string, opts ...RequestOption) error { - +func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error { return c.RequestCtxJSON(context.Background(), to, method, url, opts...) } diff --git a/utils/httputil/http.go b/utils/httputil/http.go deleted file mode 100644 index 9ca71be..0000000 --- a/utils/httputil/http.go +++ /dev/null @@ -1,36 +0,0 @@ -package httputil - -import ( - "net/http" -) - -type TransportWrapper struct { - Default http.RoundTripper - Pre func(*http.Request) error - Post func(*http.Request, *http.Response) error -} - -var _ http.RoundTripper = (*TransportWrapper)(nil) - -func NewTransportWrapper() *TransportWrapper { - return &TransportWrapper{ - Default: http.DefaultTransport, - Pre: func(*http.Request) error { return nil }, - Post: func(*http.Request, *http.Response) error { return nil }, - } -} - -func (c *TransportWrapper) RoundTrip(req *http.Request) (r *http.Response, err error) { - if err := c.Pre(req); err != nil { - return nil, err - } - - r, err = c.Default.RoundTrip(req) - - // Call Post regardless of error: - if postErr := c.Post(req, r); postErr != nil { - return r, postErr - } - - return r, err -} diff --git a/utils/httputil/httpdriver/default.go b/utils/httputil/httpdriver/default.go new file mode 100644 index 0000000..83a14bb --- /dev/null +++ b/utils/httputil/httpdriver/default.go @@ -0,0 +1,103 @@ +package httpdriver + +import ( + "context" + "io" + "net/http" + "net/url" + "time" +) + +// DefaultClient implements Client and wraps around the stdlib Client. +type DefaultClient struct { + Client http.Client +} + +var _ Client = (*DefaultClient)(nil) + +// WrapClient wraps around the standard library's http.Client and returns an +// implementation that's compatible with the Client driver interface. +func WrapClient(client http.Client) Client { + return &DefaultClient{ + Client: client, + } +} + +// NewClient creates a new client around the standard library's http.Client. The +// client will have a timeout of 10 seconds. +func NewClient() Client { + return WrapClient(http.Client{ + Timeout: 10 * time.Second, + }) +} + +func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) { + r, err := http.NewRequestWithContext(ctx, method, url, nil) + if err != nil { + return nil, err + } + return (*DefaultRequest)(r), nil +} + +func (d *DefaultClient) Do(req Request) (Response, error) { + // Implementations can safely assert this. + request := req.(*DefaultRequest) + + r, err := d.Client.Do((*http.Request)(request)) + if err != nil { + return nil, err + } + + return (*DefaultResponse)(r), nil +} + +// DefaultRequest wraps around the stdlib Request and satisfies the Request +// interface. +type DefaultRequest http.Request + +var _ Request = (*DefaultRequest)(nil) + +func (r *DefaultRequest) GetPath() string { + return r.URL.Path +} + +func (r *DefaultRequest) GetContext() context.Context { + return (*http.Request)(r).Context() +} + +func (r *DefaultRequest) AddQuery(values url.Values) { + var qs = r.URL.Query() + for k, v := range values { + qs[k] = append(qs[k], v...) + } + + r.URL.RawQuery = qs.Encode() +} + +func (r *DefaultRequest) AddHeader(header http.Header) { + for key, values := range header { + r.Header[key] = append(r.Header[key], values...) + } +} + +func (r *DefaultRequest) WithBody(body io.ReadCloser) { + r.Body = body +} + +// DefaultResponse wraps around the stdlib Response and satisfies the Response +// interface. +type DefaultResponse http.Response + +var _ Response = (*DefaultResponse)(nil) + +func (r *DefaultResponse) GetStatus() int { + return r.StatusCode +} + +func (r *DefaultResponse) GetHeader() http.Header { + return r.Header +} + +func (r *DefaultResponse) GetBody() io.ReadCloser { + return r.Body +} diff --git a/utils/httputil/httpdriver/driver.go b/utils/httputil/httpdriver/driver.go new file mode 100644 index 0000000..037f4bf --- /dev/null +++ b/utils/httputil/httpdriver/driver.go @@ -0,0 +1,55 @@ +// Package httpdriver provides interfaces and implementations of a simple HTTP +// client. +package httpdriver + +import ( + "context" + "io" + "net/http" + "net/url" +) + +// NoContent is the status code for HTTP 204, or http.StatusNoContent. +const NoContent = 204 + +// Client is a generic interface used as an adapter allowing for custom HTTP +// client implementations, such as fasthttp. +type Client interface { + NewRequest(ctx context.Context, method, url string) (Request, error) + Do(req Request) (Response, error) +} + +// Request is a generic interface for a normal HTTP request. It should be +// constructed using (Requester).NewRequest(). +type Request interface { + // GetPath should return the URL path, for example "/endpoint/thing". + GetPath() string + // GetContext should return the same context that's passed into NewRequest. + // For implementations that don't support this, it can remove a + // context.Background(). + GetContext() context.Context + // AddHeader appends headers. + AddHeader(http.Header) + // AddQuery appends URL query values. + AddQuery(url.Values) + // WithBody should automatically close the ReadCloser on finish. This is + // similar to the stdlib's Request behavior. + WithBody(io.ReadCloser) +} + +// Response is returned from (Requester).DoContext(). +type Response interface { + GetStatus() int + GetHeader() http.Header + // Body's ReadCloser will always be closed when done, unless DoContext() + // returns an error. + GetBody() io.ReadCloser +} + +// OptHeader returns the response header, or nil if from is nil. +func OptHeader(from Response) http.Header { + if from == nil { + return nil + } + return from.GetHeader() +} diff --git a/utils/httputil/options.go b/utils/httputil/options.go index f1910e6..43c7d26 100644 --- a/utils/httputil/options.go +++ b/utils/httputil/options.go @@ -4,86 +4,93 @@ import ( "io" "net/http" + "github.com/diamondburned/arikawa/utils/httputil/httpdriver" "github.com/diamondburned/arikawa/utils/json" ) -type RequestOption func(*http.Request) error +type RequestOption func(httpdriver.Request) error -func JSONRequest(r *http.Request) error { - r.Header.Set("Content-Type", "application/json") +func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption { + if len(opts) == 0 { + return prepend + } + return append(prepend, opts...) +} + +func JSONRequest(r httpdriver.Request) error { + r.AddHeader(http.Header{ + "Content-Type": {"application/json"}, + }) return nil } -func MultipartRequest(r *http.Request) error { - r.Header.Set("Content-Type", "multipart/form-data") +func MultipartRequest(r httpdriver.Request) error { + r.AddHeader(http.Header{ + "Content-Type": {"multipart/form-data"}, + }) return nil } func WithHeaders(headers http.Header) RequestOption { - return func(r *http.Request) error { - for key, values := range headers { - r.Header[key] = append(r.Header[key], values...) - } + return func(r httpdriver.Request) error { + r.AddHeader(headers) return nil } } func WithContentType(ctype string) RequestOption { - return func(r *http.Request) error { - r.Header.Set("Content-Type", ctype) + return func(r httpdriver.Request) error { + r.AddHeader(http.Header{ + "Content-Type": {ctype}, + }) return nil } } func WithSchema(schema SchemaEncoder, v interface{}) RequestOption { - return func(r *http.Request) error { + return func(r httpdriver.Request) error { params, err := schema.Encode(v) if err != nil { return err } - var qs = r.URL.Query() - for k, v := range params { - qs[k] = append(qs[k], v...) - } - - r.URL.RawQuery = qs.Encode() + r.AddQuery(params) return nil } } func WithBody(body io.ReadCloser) RequestOption { - return func(r *http.Request) error { - // tee := io.TeeReader(body, os.Stderr) - // r.Body = ioutil.NopCloser(tee) - r.Body = body - r.ContentLength = -1 + return func(r httpdriver.Request) error { + r.WithBody(body) return nil } } +// WithJSONBody inserts a JSON body into the request. This ignores JSON errors. func WithJSONBody(json json.Driver, v interface{}) RequestOption { if v == nil { - return func(*http.Request) error { + return func(httpdriver.Request) error { return nil } } - var err error var rp, wp = io.Pipe() go func() { - err = json.EncodeStream(wp, v) + json.EncodeStream(wp, v) wp.Close() }() - return func(r *http.Request) error { - if err != nil { - return err - } + return func(r httpdriver.Request) error { + // TODO: maybe do something to this? + // if err != nil { + // return err + // } - r.Header.Set("Content-Type", "application/json") - r.Body = rp + r.AddHeader(http.Header{ + "Content-Type": {"application/json"}, + }) + r.WithBody(rp) return nil } }